You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
161 lines
4.9 KiB
161 lines
4.9 KiB
import cv2
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
from ultralytics import YOLO
|
|
from flask import Flask, request, send_file, make_response
|
|
from flask_cors import CORS
|
|
import time
|
|
from io import BytesIO
|
|
|
|
# 3个输入参数
|
|
img_path = "img0002.jpg"
|
|
iou = 0.1
|
|
conf = 0.25
|
|
originalImgPath = ""
|
|
result1Path = ""
|
|
result2Path = ""
|
|
|
|
app = Flask(__name__)
|
|
# 解决跨域问题
|
|
cors = CORS(
|
|
app,
|
|
resources={
|
|
r"/api/*": {
|
|
"origins": ["http://localhost:5173"],
|
|
"methods": ["GET", "POST"],
|
|
}
|
|
},
|
|
)
|
|
# 初始化模型
|
|
resResult = ""
|
|
|
|
|
|
def split_image(img_path, size=(800, 800)):
|
|
img = cv2.imread(img_path)
|
|
height, width = img.shape[:2]
|
|
rows = (height + size[1] - 1) // size[1]
|
|
cols = (width + size[0] - 1) // size[0]
|
|
img_list = []
|
|
indexes = []
|
|
for r in range(rows):
|
|
for c in range(cols):
|
|
y1 = r * size[1]
|
|
y2 = min((r + 1) * size[1], height)
|
|
x1 = c * size[0]
|
|
x2 = min((c + 1) * size[0], width)
|
|
split = img[y1:y2, x1:x2]
|
|
img_list.append(split)
|
|
indexes.append((r, c))
|
|
return img_list, indexes, (height, width)
|
|
|
|
|
|
def combine_images(pred_imgs, indexes, size=(800, 800), img_shape=(3000, 4000)):
|
|
combined_img = np.zeros((img_shape[0], img_shape[1], 3), dtype=np.uint8)
|
|
for idx, (r, c) in enumerate(indexes):
|
|
y1 = r * size[1]
|
|
y2 = min((r + 1) * size[1], img_shape[0])
|
|
x1 = c * size[0]
|
|
x2 = min((c + 1) * size[0], img_shape[1])
|
|
combined_img[y1:y2, x1:x2] = pred_imgs[idx][: y2 - y1, : x2 - x1]
|
|
return combined_img
|
|
|
|
|
|
follicle_groups_detector = YOLO("follicle_groups.pt")
|
|
follicles_detector = YOLO("follicles.pt")
|
|
|
|
|
|
# 保存检测图片
|
|
@app.route("/api/checkPng", methods=["POST"])
|
|
def checkPng():
|
|
global follicle_groups_detector
|
|
global follicles_detector
|
|
global originalImgPath
|
|
global result1Path
|
|
global result2Path
|
|
global resResult
|
|
pngFile = request.files["file"]
|
|
img = Image.open(pngFile.stream)
|
|
iou = request.form.get("iou")
|
|
iou = float(iou)
|
|
conf = request.form.get("conf")
|
|
conf = float(conf)
|
|
nowtime1 = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time()))
|
|
originalImgPath = "images/original" + str(nowtime1) + ".png"
|
|
result1Path = "images/result1" + str(nowtime1) + ".png"
|
|
result2Path = "images/result2" + str(nowtime1) + ".png"
|
|
img.save(originalImgPath)
|
|
results = follicle_groups_detector(originalImgPath, iou=iou, conf=conf)
|
|
for r in results:
|
|
num_follicle_groups = len(r.boxes)
|
|
im_array = r.plot()
|
|
im = Image.fromarray(im_array[..., ::-1])
|
|
im.save(result1Path) # 输出结果图1
|
|
|
|
img_list, indexes, (height, width) = split_image(result1Path)
|
|
print(f"Number of image blocks: {len(img_list)}")
|
|
num_small_follicles = 0
|
|
num_big_follicles = 0
|
|
pred_imgs = []
|
|
for img in img_list:
|
|
results = follicles_detector(img, iou=iou, conf=conf)
|
|
for r in results:
|
|
num_small_follicles += torch.sum(r.boxes.cls == 0).item()
|
|
num_big_follicles += torch.sum(r.boxes.cls == 1).item()
|
|
im_array = r.plot()
|
|
pred_imgs.append(im_array)
|
|
|
|
# 输出的3个结果文本
|
|
print("毛囊群数量:", num_follicle_groups)
|
|
print("大毛囊数量:", num_big_follicles)
|
|
print("小毛囊数量:", num_small_follicles)
|
|
|
|
combined_img = combine_images(
|
|
pred_imgs, indexes, size=(800, 800), img_shape=(height, width)
|
|
)
|
|
combined_image_pil = Image.fromarray(combined_img[..., ::-1])
|
|
combined_image_pil.save(result2Path) # 输出结果图2
|
|
resResult = {
|
|
"hasError": False,
|
|
"num_follicle_groups": num_follicle_groups,
|
|
"num_big_follicles": num_big_follicles,
|
|
"num_small_follicles": num_small_follicles,
|
|
"originalImgPath": originalImgPath,
|
|
"result1Path": result1Path,
|
|
"result2Path": result2Path,
|
|
}
|
|
return resResult
|
|
|
|
|
|
# 检测结果返回
|
|
@app.route("/api/checkResult", methods=["GET"])
|
|
def checkResult():
|
|
global resResult
|
|
return resResult
|
|
|
|
|
|
# 图片查询
|
|
@app.route("/api/getPng/<pngPath>", methods=["GET"])
|
|
def getPng(pngPath):
|
|
# 打开或处理图像
|
|
img = Image.open("images/" + pngPath)
|
|
# 对图像进行处理,例如调整大小
|
|
img = img.resize((800, 600))
|
|
# 创建一个 BytesIO 对象来保存图像数据
|
|
img_byte_arr = BytesIO()
|
|
img.save(img_byte_arr, format="PNG")
|
|
img_byte_arr.seek(0)
|
|
|
|
# 使用 send_file 返回图像数据
|
|
# return send_file(img_byte_arr, mimetype='image/jpeg')
|
|
|
|
# 或者使用 make_response 来创建一个响应对象
|
|
response = make_response(img_byte_arr.getvalue())
|
|
response.headers.set("Content-Type", "image/jpeg")
|
|
response.headers.set("Content-Disposition", "attachment", filename=pngPath)
|
|
|
|
return response
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(host="127.0.0.1", port=8006, debug=True)
|
|
|