You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
130 lines
5.1 KiB
130 lines
5.1 KiB
from ultralytics import YOLO
|
|
import numpy as np
|
|
import cv2
|
|
import os
|
|
from paddleocr import PaddleOCR
|
|
import re
|
|
|
|
# 初始化OCR
|
|
ocr = PaddleOCR(
|
|
use_gpu=False,
|
|
use_angle_cls=True,
|
|
det_model_dir='./code/ocr/ch_PP-OCRv4_det_server_infer',
|
|
rec_model_dir='./code/ocr/ch_PP-OCRv4_rec_server_infer')
|
|
|
|
# 初始化YOLO模型
|
|
model = YOLO('./code/best.pt')
|
|
|
|
# 处理图片的目录
|
|
input_dir = 'test'
|
|
output_dir = 'results'
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# 遍历文件夹中的所有图片
|
|
for filename in os.listdir(input_dir):
|
|
if filename.endswith('.jpg') or filename.endswith('.png'):
|
|
img_path = os.path.join(input_dir, filename)
|
|
img = cv2.imread(img_path)
|
|
results = model.predict(img_path, device='cpu')
|
|
|
|
results_summary = {
|
|
'target_0': [],
|
|
'target_1': [],
|
|
'target_2': [],
|
|
'target_3': []
|
|
}
|
|
|
|
for r in results:
|
|
boxes = r.boxes
|
|
clses = np.array(boxes.cls).astype(int)
|
|
points = np.array(boxes.xyxy).astype(int)
|
|
target_0 = []
|
|
target_1 = []
|
|
target_2 = []
|
|
target_3 = []
|
|
for cls, point in zip(clses, points):
|
|
if cls == 0:
|
|
target_0.append(point)
|
|
elif cls == 1:
|
|
target_1.append(point)
|
|
elif cls == 2:
|
|
target_2.append(point)
|
|
elif cls == 3:
|
|
target_3.append(point)
|
|
|
|
if (len(target_0) == 2 and len(target_1) == 1 and (len(target_2) == 1 or len(target_2) == 2) and len(target_3) == 2):
|
|
# 处理类别0
|
|
target_0 = sorted(target_0, key=lambda x: x[0])
|
|
left_point = target_0[0]
|
|
right_point = target_0[1]
|
|
for target, name in zip([left_point, right_point], ['地址', '姓名']):
|
|
target_img = img[target[1]:target[3], target[0]:target[2]]
|
|
cv2.imwrite(os.path.join(output_dir, f'{name}_{filename}'), target_img)
|
|
result = ocr.ocr(target_img)
|
|
out = ''
|
|
if not result or not any(result):
|
|
out = '未识别到文字'
|
|
else:
|
|
for lines in result:
|
|
for line in lines:
|
|
out += line[1][0]
|
|
results_summary['target_0'].append(f"{name.capitalize()}: {out}")
|
|
|
|
# 处理类别1
|
|
for target in target_1:
|
|
target_img = img[target[1]-5:target[3]+5, target[0]-5:target[2]+5]
|
|
cv2.imwrite(os.path.join(output_dir, f'当前有功_{filename}'), target_img)
|
|
result = ocr.ocr(target_img, det=False)
|
|
out = ''
|
|
for lines in result:
|
|
for line in lines:
|
|
out = line[0]
|
|
out = re.sub(r'\.', '', out)
|
|
out = out[:-2] + '.' + out[-2:]
|
|
results_summary['target_1'].append(f"当前有功: {out}")
|
|
|
|
# 处理类别2
|
|
if len(target_2) == 2:
|
|
target_2_sorted = sorted(target_2, key=lambda x: x[1])
|
|
top_target = target_2_sorted[0]
|
|
elif len(target_2) == 1:
|
|
top_target = target_2[0]
|
|
target_img = img[top_target[1]:top_target[3], top_target[0]:top_target[2]]
|
|
cv2.imwrite(os.path.join(output_dir, f'电表资产号_{filename}'), target_img)
|
|
result = ocr.ocr(target_img)
|
|
longest_line = ""
|
|
max_length = 0
|
|
for lines in result:
|
|
for line in lines:
|
|
text = line[1][0]
|
|
if len(text) > max_length:
|
|
longest_line = text
|
|
max_length = len(text)
|
|
results_summary['target_2'].append(f"电表资产号: {longest_line}")
|
|
|
|
# 处理类别3
|
|
target_3 = sorted(target_3, key=lambda x: x[0])
|
|
left_point = target_3[0]
|
|
right_point = target_3[1]
|
|
for target, name in zip([left_point, right_point], ['封印1', '封印2']):
|
|
target_img = img[target[1]:target[3], target[0]:target[2]]
|
|
height, width = target_img.shape[:2]
|
|
if width <= height:
|
|
target_img = cv2.transpose(target_img)
|
|
target_img = cv2.flip(target_img, flipCode=1)
|
|
cv2.imwrite(os.path.join(output_dir, f'{name}_{filename}'), target_img)
|
|
result = ocr.ocr(target_img)
|
|
out = ''
|
|
for lines in result:
|
|
for line in lines:
|
|
out += line[1][0]
|
|
results_summary['target_3'].append(f"{name.capitalize()}: {out}")
|
|
|
|
# 打印结果
|
|
with open(os.path.join(output_dir, f'results_{filename}.txt'), 'w', encoding='utf-8') as f:
|
|
for category, result_list in results_summary.items():
|
|
for result in result_list:
|
|
f.write(result + '\n')
|
|
|
|
else:
|
|
print(f"图像 {filename} 不清晰或要素不全,请重新拍摄或人工记录")
|
|
|