1000
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

161 lines
4.9 KiB

import cv2
import torch
import numpy as np
from PIL import Image
from ultralytics import YOLO
from flask import Flask, request, send_file, make_response
from flask_cors import CORS
import time
from io import BytesIO
# 3个输入参数
img_path = "img0002.jpg"
iou = 0.1
conf = 0.25
originalImgPath = ""
result1Path = ""
result2Path = ""
app = Flask(__name__)
# 解决跨域问题
cors = CORS(
app,
resources={
r"/api/*": {
"origins": ["http://localhost:5173"],
"methods": ["GET", "POST"],
}
},
)
# 初始化模型
resResult = ""
def split_image(img_path, size=(800, 800)):
img = cv2.imread(img_path)
height, width = img.shape[:2]
rows = (height + size[1] - 1) // size[1]
cols = (width + size[0] - 1) // size[0]
img_list = []
indexes = []
for r in range(rows):
for c in range(cols):
y1 = r * size[1]
y2 = min((r + 1) * size[1], height)
x1 = c * size[0]
x2 = min((c + 1) * size[0], width)
split = img[y1:y2, x1:x2]
img_list.append(split)
indexes.append((r, c))
return img_list, indexes, (height, width)
def combine_images(pred_imgs, indexes, size=(800, 800), img_shape=(3000, 4000)):
combined_img = np.zeros((img_shape[0], img_shape[1], 3), dtype=np.uint8)
for idx, (r, c) in enumerate(indexes):
y1 = r * size[1]
y2 = min((r + 1) * size[1], img_shape[0])
x1 = c * size[0]
x2 = min((c + 1) * size[0], img_shape[1])
combined_img[y1:y2, x1:x2] = pred_imgs[idx][: y2 - y1, : x2 - x1]
return combined_img
follicle_groups_detector = YOLO("follicle_groups.pt")
follicles_detector = YOLO("follicles.pt")
# 保存检测图片
@app.route("/api/checkPng", methods=["POST"])
def checkPng():
global follicle_groups_detector
global follicles_detector
global originalImgPath
global result1Path
global result2Path
global resResult
pngFile = request.files["file"]
img = Image.open(pngFile.stream)
iou = request.form.get("iou")
iou = float(iou)
conf = request.form.get("conf")
conf = float(conf)
nowtime1 = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time()))
originalImgPath = "images/original" + str(nowtime1) + ".png"
result1Path = "images/result1" + str(nowtime1) + ".png"
result2Path = "images/result2" + str(nowtime1) + ".png"
img.save(originalImgPath)
results = follicle_groups_detector(originalImgPath, iou=iou, conf=conf)
for r in results:
num_follicle_groups = len(r.boxes)
im_array = r.plot()
im = Image.fromarray(im_array[..., ::-1])
im.save(result1Path) # 输出结果图1
img_list, indexes, (height, width) = split_image(result1Path)
print(f"Number of image blocks: {len(img_list)}")
num_small_follicles = 0
num_big_follicles = 0
pred_imgs = []
for img in img_list:
results = follicles_detector(img, iou=iou, conf=conf)
for r in results:
num_small_follicles += torch.sum(r.boxes.cls == 0).item()
num_big_follicles += torch.sum(r.boxes.cls == 1).item()
im_array = r.plot()
pred_imgs.append(im_array)
# 输出的3个结果文本
print("毛囊群数量:", num_follicle_groups)
print("大毛囊数量:", num_big_follicles)
print("小毛囊数量:", num_small_follicles)
combined_img = combine_images(
pred_imgs, indexes, size=(800, 800), img_shape=(height, width)
)
combined_image_pil = Image.fromarray(combined_img[..., ::-1])
combined_image_pil.save(result2Path) # 输出结果图2
resResult = {
"hasError": False,
"num_follicle_groups": num_follicle_groups,
"num_big_follicles": num_big_follicles,
"num_small_follicles": num_small_follicles,
"originalImgPath": originalImgPath,
"result1Path": result1Path,
"result2Path": result2Path,
}
return resResult
# 检测结果返回
@app.route("/api/checkResult", methods=["GET"])
def checkResult():
global resResult
return resResult
# 图片查询
@app.route("/api/getPng/<pngPath>", methods=["GET"])
def getPng(pngPath):
# 打开或处理图像
img = Image.open("images/" + pngPath)
# 对图像进行处理,例如调整大小
img = img.resize((800, 600))
# 创建一个 BytesIO 对象来保存图像数据
img_byte_arr = BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
# 使用 send_file 返回图像数据
# return send_file(img_byte_arr, mimetype='image/jpeg')
# 或者使用 make_response 来创建一个响应对象
response = make_response(img_byte_arr.getvalue())
response.headers.set("Content-Type", "image/jpeg")
response.headers.set("Content-Disposition", "attachment", filename=pngPath)
return response
if __name__ == "__main__":
app.run(host="127.0.0.1", port=8006, debug=True)