You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
8.3 KiB
217 lines
8.3 KiB
from flask import Flask,request,jsonify,json
|
|
from flask_cors import CORS
|
|
from ultralytics import YOLO
|
|
import base64
|
|
import os
|
|
import numpy as np
|
|
import cv2
|
|
from paddleocr import PaddleOCR
|
|
import re
|
|
import logging
|
|
# import time
|
|
|
|
app = Flask(__name__)
|
|
|
|
# sockitIo解决跨域问题
|
|
app.config["SECRET_KEY"] = "secret!"
|
|
CORS(app) # 允许所有来源的请求
|
|
# CORS(app)
|
|
# 加载模型
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
logging.basicConfig(filename=current_dir+'app.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
model = YOLO(os.path.join(current_dir, "models/best.pt"))
|
|
print("模型加载成功")
|
|
# ocrSimple = [PaddleOCR(
|
|
# use_gpu=False,
|
|
# use_angle_cls=True,
|
|
# det_model_dir=os.path.join(current_dir, "ocr/simple/ch_PP-OCRv4_det_infer"),
|
|
# rec_model_dir=os.path.join(current_dir, "ocr/simple/ch_PP-OCRv4_rec_infer"),
|
|
# use_tensorrt=True,#静态图模式可能会提高性能
|
|
# # det_model_dir=os.path.join(current_dir, "ocr/complex/ch_PP-OCRv4_det_server_infer"),
|
|
# # rec_model_dir=os.path.join(current_dir, "ocr/complex/ch_PP-OCRv4_rec_server_infer"),
|
|
# ) for _ in range(4)]
|
|
# ocrComplex =[PaddleOCR(
|
|
# use_gpu=False,
|
|
# use_angle_cls=True,
|
|
# # det_model_dir=os.path.join(current_dir, "ocr/simple/ch_PP-OCRv4_det_infer"),
|
|
# # rec_model_dir=os.path.join(current_dir, "ocr/simple/ch_PP-OCRv4_rec_infer"),
|
|
# det_model_dir=os.path.join(current_dir, "ocr/complex/ch_PP-OCRv4_det_server_infer"),
|
|
# rec_model_dir=os.path.join(current_dir, "ocr/complex/ch_PP-OCRv4_rec_server_infer"),
|
|
# use_tensorrt=True,#静态图模式可能会提高性能
|
|
# ) for _ in range(4)]
|
|
ocrSimple = PaddleOCR(
|
|
use_gpu=False,
|
|
use_angle_cls=True,
|
|
det_model_dir=os.path.join(current_dir, "ocr/simple/ch_PP-OCRv4_det_infer"),
|
|
rec_model_dir=os.path.join(current_dir, "ocr/simple/ch_PP-OCRv4_rec_infer"),
|
|
use_tensorrt=True,#静态图模式可能会提高性能
|
|
# det_model_dir=os.path.join(current_dir, "ocr/complex/ch_PP-OCRv4_det_server_infer"),
|
|
# rec_model_dir=os.path.join(current_dir, "ocr/complex/ch_PP-OCRv4_rec_server_infer"),
|
|
)
|
|
ocrComplex =PaddleOCR(
|
|
use_gpu=False,
|
|
use_angle_cls=True,
|
|
# det_model_dir=os.path.join(current_dir, "ocr/simple/ch_PP-OCRv4_det_infer"),
|
|
# rec_model_dir=os.path.join(current_dir, "ocr/simple/ch_PP-OCRv4_rec_infer"),
|
|
det_model_dir=os.path.join(current_dir, "ocr/complex/ch_PP-OCRv4_det_server_infer"),
|
|
rec_model_dir=os.path.join(current_dir, "ocr/complex/ch_PP-OCRv4_rec_server_infer"),
|
|
use_tensorrt=True,#静态图模式可能会提高性能
|
|
)
|
|
# 开始识别
|
|
@app.route("/startOcr", methods=["post"])
|
|
def startOcr():
|
|
# startTime=time.time()
|
|
# print("开始识别",startTime)
|
|
global current_dir
|
|
global model
|
|
global ocr
|
|
# 这里省略了实际的base64数据
|
|
base64_str = request.json.get("base64Str")
|
|
ocrType = request.json.get("ocrType")
|
|
pngName = request.json.get("pngName")
|
|
threshold = 127
|
|
# 解码base64字符串
|
|
_, img_data = base64_str.split(",")
|
|
img_data = base64.b64decode(img_data)
|
|
# 将字节流转换为PIL图像对象
|
|
# 将字节数据转换为NumPy数组
|
|
np_data = np.frombuffer(img_data, dtype=np.uint8)
|
|
|
|
# 使用cv2.imdecode将数组解码为图像
|
|
image = cv2.imdecode(np_data, cv2.IMREAD_COLOR)
|
|
imagePath="ocrCurrent"
|
|
cv2.imwrite(os.path.join(current_dir, f'{imagePath}.jpg'), image,[int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
|
# img = cv2.imread(os.path.join(current_dir, f'{imagePath}.jpg'))
|
|
img = image
|
|
results = model.predict(image, device='cpu')
|
|
ocr=None
|
|
if ocrType=="complex":
|
|
ocr=ocrComplex
|
|
else:
|
|
ocr=ocrSimple
|
|
|
|
for r in results:
|
|
boxes = r.boxes
|
|
clses = np.array(boxes.cls).astype(int)
|
|
points = np.array(boxes.xyxy).astype(int)
|
|
target_0 = []
|
|
target_1 = []
|
|
target_2 = []
|
|
target_3 = []
|
|
for cls, point in zip(clses, points):
|
|
if cls == 0:
|
|
target_0.append(point)
|
|
elif cls == 1:
|
|
target_1.append(point)
|
|
elif cls == 2:
|
|
target_2.append(point)
|
|
elif cls == 3:
|
|
target_3.append(point)
|
|
|
|
# 初始化结果字典
|
|
results_summary = {
|
|
'target_0': [],
|
|
'target_1': [],
|
|
'target_2': [],
|
|
'target_3': []
|
|
}
|
|
returnObj={"resultsObj":{},"message":'',pngName:pngName}
|
|
resultAll={}
|
|
# 检查类别数量
|
|
if (len(target_0) == 2 and len(target_1) == 1 and (len(target_2) == 1 or len(target_2) == 2) and len(target_3) == 2):
|
|
# 处理类别0
|
|
target_0 = sorted(target_0, key=lambda x: x[0])
|
|
left_point = target_0[0]
|
|
right_point = target_0[1]
|
|
for target, name in zip([left_point, right_point], ['address', 'name']):
|
|
target_img = img[target[1]:target[3], target[0]:target[2]]
|
|
# cv2.imwrite(f'{name}.jpg', target_img)
|
|
# 灰值化处理但是会识别不到
|
|
gray = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY)
|
|
_, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
# result = ocr[0].ocr(binary)
|
|
result = ocr.ocr(binary)
|
|
out = ''
|
|
if not result or not any(result):
|
|
out = '未识别到文字'
|
|
else:
|
|
for lines in result:
|
|
for line in lines:
|
|
out += line[1][0]
|
|
results_summary['target_0'].append(f"{name.capitalize()}: {out}")
|
|
# 处理类别1
|
|
for target in target_1:
|
|
target_img = img[target[1]-5:target[3]+5, target[0]-5:target[2]+5]
|
|
# cv2.imwrite(f'当前有功.jpg', target_img)
|
|
gray = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY)
|
|
_, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
# result = ocrSimple[1].ocr(binary, det=False)
|
|
result = ocr.ocr(binary, det=False)
|
|
for lines in result:
|
|
for line in lines:
|
|
out = line[0]
|
|
out = re.sub(r'\.', '', out)
|
|
out = out[:-2] + '.' + out[-2:]
|
|
results_summary['target_1'].append(f"lastPower: {out}")
|
|
|
|
# 处理类别2
|
|
if len(target_2) == 2:
|
|
target_2_sorted = sorted(target_2, key=lambda x: x[1])
|
|
top_target = target_2_sorted[0]
|
|
target_img = img[top_target[1]:top_target[3], top_target[0]:top_target[2]]
|
|
elif len(target_2) == 1:
|
|
top_target = target_2[0]
|
|
target_img = img[top_target[1]:top_target[3], top_target[0]:top_target[2]]
|
|
# cv2.imwrite(f'电表资产号.jpg', target_img)
|
|
gray = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY)
|
|
_, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
# result = ocrSimple[2].ocr(binary)
|
|
result = ocr.ocr(binary)
|
|
longest_line = ""
|
|
max_length = 0
|
|
for lines in result:
|
|
for line in lines:
|
|
text = line[1][0]
|
|
if len(text) > max_length:
|
|
longest_line = text
|
|
max_length = len(text)
|
|
results_summary['target_2'].append(f"currentMeterId: {longest_line}")
|
|
|
|
# 处理类别3
|
|
target_3 = sorted(target_3, key=lambda x: x[0])
|
|
left_point = target_3[0]
|
|
right_point = target_3[1]
|
|
for target, name in zip([left_point, right_point], ['qrcode1', 'qrcode2']):
|
|
target_img = img[target[1]:target[3], target[0]:target[2]]
|
|
height, width = target_img.shape[:2]
|
|
if width <= height:
|
|
target_img = cv2.transpose(target_img)
|
|
target_img = cv2.flip(target_img, flipCode=1)
|
|
# cv2.imwrite(f'{name}.jpg', target_img)
|
|
gray = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY)
|
|
_, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
# result = ocrSimple[3].ocr(binary)
|
|
result = ocr.ocr(binary)
|
|
out = ''
|
|
for lines in result:
|
|
for line in lines:
|
|
out += line[1][0]
|
|
results_summary['target_3'].append(f"{name.capitalize()}: {out}")
|
|
for category, result_list in results_summary.items():
|
|
for result in result_list:
|
|
resultList=result.split(":")
|
|
resultAll[resultList[0]]=resultList[1]
|
|
cleaned_data = {k.strip(): v.strip() for k, v in resultAll.items()}
|
|
returnObj["resultsObj"]=cleaned_data
|
|
returnObj["message"]="识别成功"
|
|
returnObj["hasError"]=False
|
|
else:
|
|
returnObj["resultsObj"]={}
|
|
returnObj["message"]="图像不清晰或要素不全请重新拍摄或人工记录"
|
|
returnObj["hasError"]=True
|
|
# endTime=time.time()
|
|
# print("运行时间:",endTime-startTime)
|
|
return jsonify(returnObj), 200, {'Content-Type': 'application/json'}
|
|
|
|
if __name__ == "__main__":
|
|
app.run(debug=False,host="0.0.0.0", port=7003)
|
|
|