import cv2 import torch import numpy as np from PIL import Image from ultralytics import YOLO from flask import Flask, request, send_file, make_response from flask_cors import CORS import time from io import BytesIO # 3个输入参数 img_path = "img0002.jpg" iou = 0.1 conf = 0.25 originalImgPath = "" result1Path = "" result2Path = "" app = Flask(__name__) # 解决跨域问题 cors = CORS( app, resources={ r"/api/*": { "origins": ["http://localhost:5173"], "methods": ["GET", "POST"], } }, ) # 初始化模型 resResult = "" def split_image(img_path, size=(800, 800)): img = cv2.imread(img_path) height, width = img.shape[:2] rows = (height + size[1] - 1) // size[1] cols = (width + size[0] - 1) // size[0] img_list = [] indexes = [] for r in range(rows): for c in range(cols): y1 = r * size[1] y2 = min((r + 1) * size[1], height) x1 = c * size[0] x2 = min((c + 1) * size[0], width) split = img[y1:y2, x1:x2] img_list.append(split) indexes.append((r, c)) return img_list, indexes, (height, width) def combine_images(pred_imgs, indexes, size=(800, 800), img_shape=(3000, 4000)): combined_img = np.zeros((img_shape[0], img_shape[1], 3), dtype=np.uint8) for idx, (r, c) in enumerate(indexes): y1 = r * size[1] y2 = min((r + 1) * size[1], img_shape[0]) x1 = c * size[0] x2 = min((c + 1) * size[0], img_shape[1]) combined_img[y1:y2, x1:x2] = pred_imgs[idx][: y2 - y1, : x2 - x1] return combined_img follicle_groups_detector = YOLO("follicle_groups.pt") follicles_detector = YOLO("follicles.pt") # 保存检测图片 @app.route("/api/checkPng", methods=["POST"]) def checkPng(): global follicle_groups_detector global follicles_detector global originalImgPath global result1Path global result2Path global resResult pngFile = request.files["file"] img = Image.open(pngFile.stream) iou = request.form.get("iou") iou = float(iou) conf = request.form.get("conf") conf = float(conf) nowtime1 = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time())) originalImgPath = "images/original" + str(nowtime1) + ".png" result1Path = "images/result1" + str(nowtime1) + ".png" result2Path = "images/result2" + str(nowtime1) + ".png" img.save(originalImgPath) results = follicle_groups_detector(originalImgPath, iou=iou, conf=conf) for r in results: num_follicle_groups = len(r.boxes) im_array = r.plot() im = Image.fromarray(im_array[..., ::-1]) im.save(result1Path) # 输出结果图1 img_list, indexes, (height, width) = split_image(result1Path) print(f"Number of image blocks: {len(img_list)}") num_small_follicles = 0 num_big_follicles = 0 pred_imgs = [] for img in img_list: results = follicles_detector(img, iou=iou, conf=conf) for r in results: num_small_follicles += torch.sum(r.boxes.cls == 0).item() num_big_follicles += torch.sum(r.boxes.cls == 1).item() im_array = r.plot() pred_imgs.append(im_array) # 输出的3个结果文本 print("毛囊群数量:", num_follicle_groups) print("大毛囊数量:", num_big_follicles) print("小毛囊数量:", num_small_follicles) combined_img = combine_images( pred_imgs, indexes, size=(800, 800), img_shape=(height, width) ) combined_image_pil = Image.fromarray(combined_img[..., ::-1]) combined_image_pil.save(result2Path) # 输出结果图2 resResult = { "hasError": False, "num_follicle_groups": num_follicle_groups, "num_big_follicles": num_big_follicles, "num_small_follicles": num_small_follicles, "originalImgPath": originalImgPath, "result1Path": result1Path, "result2Path": result2Path, } return resResult # 检测结果返回 @app.route("/api/checkResult", methods=["GET"]) def checkResult(): global resResult return resResult # 图片查询 @app.route("/api/getPng/", methods=["GET"]) def getPng(pngPath): # 打开或处理图像 img = Image.open("images/" + pngPath) # 对图像进行处理,例如调整大小 img = img.resize((800, 600)) # 创建一个 BytesIO 对象来保存图像数据 img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) # 使用 send_file 返回图像数据 # return send_file(img_byte_arr, mimetype='image/jpeg') # 或者使用 make_response 来创建一个响应对象 response = make_response(img_byte_arr.getvalue()) response.headers.set("Content-Type", "image/jpeg") response.headers.set("Content-Disposition", "attachment", filename=pngPath) return response if __name__ == "__main__": app.run(host="127.0.0.1", port=8006, debug=True)