1000
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
4.9 KiB

9 months ago
import cv2
import torch
import numpy as np
from PIL import Image
from ultralytics import YOLO
from flask import Flask, request, send_file, make_response
from flask_cors import CORS
import time
from io import BytesIO
# 3个输入参数
img_path = "img0002.jpg"
iou = 0.1
conf = 0.25
originalImgPath = ""
result1Path = ""
result2Path = ""
app = Flask(__name__)
# 解决跨域问题
cors = CORS(
app,
resources={
r"/api/*": {
"origins": ["http://localhost:5173"],
"methods": ["GET", "POST"],
}
},
)
# 初始化模型
resResult = ""
def split_image(img_path, size=(800, 800)):
img = cv2.imread(img_path)
height, width = img.shape[:2]
rows = (height + size[1] - 1) // size[1]
cols = (width + size[0] - 1) // size[0]
img_list = []
indexes = []
for r in range(rows):
for c in range(cols):
y1 = r * size[1]
y2 = min((r + 1) * size[1], height)
x1 = c * size[0]
x2 = min((c + 1) * size[0], width)
split = img[y1:y2, x1:x2]
img_list.append(split)
indexes.append((r, c))
return img_list, indexes, (height, width)
def combine_images(pred_imgs, indexes, size=(800, 800), img_shape=(3000, 4000)):
combined_img = np.zeros((img_shape[0], img_shape[1], 3), dtype=np.uint8)
for idx, (r, c) in enumerate(indexes):
y1 = r * size[1]
y2 = min((r + 1) * size[1], img_shape[0])
x1 = c * size[0]
x2 = min((c + 1) * size[0], img_shape[1])
combined_img[y1:y2, x1:x2] = pred_imgs[idx][: y2 - y1, : x2 - x1]
return combined_img
follicle_groups_detector = YOLO("follicle_groups.pt")
follicles_detector = YOLO("follicles.pt")
# 保存检测图片
@app.route("/api/checkPng", methods=["POST"])
def checkPng():
global follicle_groups_detector
global follicles_detector
global originalImgPath
global result1Path
global result2Path
global resResult
pngFile = request.files["file"]
img = Image.open(pngFile.stream)
iou = request.form.get("iou")
iou = float(iou)
conf = request.form.get("conf")
conf = float(conf)
nowtime1 = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time()))
originalImgPath = "images/original" + str(nowtime1) + ".png"
result1Path = "images/result1" + str(nowtime1) + ".png"
result2Path = "images/result2" + str(nowtime1) + ".png"
img.save(originalImgPath)
results = follicle_groups_detector(originalImgPath, iou=iou, conf=conf)
for r in results:
num_follicle_groups = len(r.boxes)
im_array = r.plot()
im = Image.fromarray(im_array[..., ::-1])
im.save(result1Path) # 输出结果图1
img_list, indexes, (height, width) = split_image(result1Path)
print(f"Number of image blocks: {len(img_list)}")
num_small_follicles = 0
num_big_follicles = 0
pred_imgs = []
for img in img_list:
results = follicles_detector(img, iou=iou, conf=conf)
for r in results:
num_small_follicles += torch.sum(r.boxes.cls == 0).item()
num_big_follicles += torch.sum(r.boxes.cls == 1).item()
im_array = r.plot()
pred_imgs.append(im_array)
# 输出的3个结果文本
print("毛囊群数量:", num_follicle_groups)
print("大毛囊数量:", num_big_follicles)
print("小毛囊数量:", num_small_follicles)
combined_img = combine_images(
pred_imgs, indexes, size=(800, 800), img_shape=(height, width)
)
combined_image_pil = Image.fromarray(combined_img[..., ::-1])
combined_image_pil.save(result2Path) # 输出结果图2
resResult = {
"hasError": False,
"num_follicle_groups": num_follicle_groups,
"num_big_follicles": num_big_follicles,
"num_small_follicles": num_small_follicles,
"originalImgPath": originalImgPath,
"result1Path": result1Path,
"result2Path": result2Path,
}
return resResult
# 检测结果返回
@app.route("/api/checkResult", methods=["GET"])
def checkResult():
global resResult
return resResult
# 图片查询
@app.route("/api/getPng/<pngPath>", methods=["GET"])
def getPng(pngPath):
# 打开或处理图像
img = Image.open("images/" + pngPath)
# 对图像进行处理,例如调整大小
img = img.resize((800, 600))
# 创建一个 BytesIO 对象来保存图像数据
img_byte_arr = BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
# 使用 send_file 返回图像数据
# return send_file(img_byte_arr, mimetype='image/jpeg')
# 或者使用 make_response 来创建一个响应对象
response = make_response(img_byte_arr.getvalue())
response.headers.set("Content-Type", "image/jpeg")
response.headers.set("Content-Disposition", "attachment", filename=pngPath)
return response
if __name__ == "__main__":
app.run(host="127.0.0.1", port=8006, debug=True)