# 二值化图像
|
import os
|
import random
|
|
import cv2
|
import matplotlib.pyplot as plt
|
import numpy
|
|
SHOW_PLT = False
|
|
|
def gray_img(img):
|
if img.ndim == 2:
|
return img
|
result = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
return result
|
|
|
def is_col_empty(img, col, start_row, end_row, thresh_hold=0):
|
for r in range(start_row, end_row + 1):
|
if img[r][col] > thresh_hold:
|
return False
|
return True
|
|
|
def is_col_full(img, col, start_row, end_row, thresh_hold_start=64, thresh_hold_end=255, thresh_empty_count=1):
|
rows, cols = img.shape
|
empty_count = 0
|
for r in range(start_row, end_row + 1):
|
if img[r][col] < thresh_hold_start or img[r][col] > thresh_hold_end:
|
empty_count += 1
|
if empty_count > thresh_empty_count:
|
return False
|
return True
|
|
|
# 行是否是满的
|
def is_row_full(img, row, start_col, end_col, thresh_hold_start=64, thresh_hold_end=255):
|
rows, cols = img.shape
|
for c in range(start_col, end_col + 1):
|
if img[row][c] < thresh_hold_start or img[row][c] > thresh_hold_end:
|
return False
|
return True
|
|
|
def format_num_img(img):
|
# 先调整为方形
|
rows, cols = img.shape
|
img_new = numpy.zeros((max(rows, cols), max(rows, cols)), numpy.uint8)
|
img_new.fill(0)
|
start_col = (max(rows, cols) - cols) // 2
|
if rows > cols:
|
for r in range(rows):
|
for c in range(cols):
|
img_new[r][c + start_col] = img[r][c]
|
else:
|
pass
|
# 设置为28X28像素
|
return cv2.resize(img_new, (28, 28), interpolation=cv2.INTER_AREA)
|
|
|
# 分隔数字
|
def split_nums_img(img):
|
# 获取代码的最大色值
|
rows, cols = img.shape
|
max_color = 0
|
for r in range(0, rows):
|
for c in range(0, cols):
|
if img[r][c] > max_color:
|
max_color = img[r][c]
|
|
start_index = -1
|
end_index = -1
|
codes_pos = []
|
for col in range(0, cols):
|
if not is_col_empty(img, col, 0, rows - 1, max_color * 2 // 3):
|
if start_index < 0:
|
start_index = col
|
end_index = col
|
else:
|
if end_index >= 0 and start_index >= 0:
|
codes_pos.append((start_index, end_index + 1))
|
end_index = -1
|
start_index = -1
|
img_detail = []
|
for i in range(len(codes_pos)):
|
temp_img = img[0:rows - 1, codes_pos[i][0]:codes_pos[i][1]]
|
temp_img = format_num_img(temp_img)
|
img_detail.append(temp_img)
|
return img_detail
|
|
|
# 分离同花顺的代码
|
def clip_ths_code_area(img):
|
img = gray_img(img)
|
rows, cols = img.shape
|
|
# 行分隔
|
full_row = -1
|
start_row = -1
|
end_row = -1
|
for r in range(0, rows, 1):
|
if is_row_full(img, r, 0, cols // 2, 38, 38):
|
# print("找到分隔:", col)
|
full_row = r
|
if start_row >= 0 and r - start_row > 10:
|
end_row = r - 1
|
break
|
|
else:
|
if full_row > -1:
|
full_row = -1
|
if start_row < 0:
|
start_row = r
|
if end_row < 0:
|
end_row = rows - 1
|
|
if start_row < 0:
|
raise Exception("没找到上分割线")
|
|
start_col = cols - 1
|
for c in range(cols - 1, -1, -1):
|
if is_col_full(img, c, start_row, end_row, 38, 38, 2):
|
# print("找到分隔:", col)
|
start_col = c + 1
|
break
|
# 往前找数字分隔
|
content_start = -1
|
empty_start = -1
|
empty_end = -1
|
start_index = -1
|
end_index = -1
|
codes_pos = []
|
# 获取代码的最大色值
|
max_color = 0
|
for r in range(start_row, end_row + 1):
|
for c in range(cols - 1, start_col + (cols - start_col) // 3, -1):
|
if img[r][c] > max_color:
|
max_color = img[r][c]
|
# 从后往前找
|
for col in range(cols - 1, start_col, -1):
|
if not is_col_empty(img, col, start_row, end_row, max_color * 2 // 3):
|
if start_index < 0:
|
start_index = col
|
end_index = col
|
else:
|
if end_index >= 0 and start_index >= 0:
|
codes_pos.append((end_index - 1, start_index + 1))
|
end_index = -1
|
start_index = -1
|
codes_pos = codes_pos[:6]
|
codes_pos.sort(key=lambda x: x[0])
|
|
if SHOW_PLT:
|
plt.figure(figsize=(10, 4))
|
img_detail = []
|
for i in range(len(codes_pos)):
|
if SHOW_PLT:
|
plt.subplot(2, 5, i + 1)
|
plt.title(f'pred {i}')
|
plt.axis('off')
|
temp_img = img[start_row:end_row, codes_pos[i][0]:codes_pos[i][1]]
|
temp_img = format_num_img(temp_img)
|
img_detail.append(temp_img)
|
if SHOW_PLT:
|
plt.imshow(temp_img, cmap='gray')
|
if SHOW_PLT:
|
plt.show()
|
|
clip_img = img[start_row:end_row, codes_pos[0][0]:codes_pos[-1][1]]
|
# cv2.imwrite("test1.png", clip_img)
|
return clip_img, img_detail
|
|
# print(clip_img.shape)
|
# ret1, p1 = cv2.threshold(src=clip_img, thresh=100, maxval=255, type=cv2.THRESH_BINARY)
|
# cv2.imwrite("D:/workspace/GP/trade_desk/test3.png", p1)
|
|
|
def __test4():
|
files = os.listdir("datas/test4/")
|
for file in files:
|
code = file[:-4]
|
img = cv2.imread(f"datas/test4/{file}", cv2.IMREAD_GRAYSCALE)
|
img_details = split_nums_img(img)
|
for d in range(0, len(img_details)):
|
cv2.imwrite(f"C:/Users/Administrator/Desktop/ocr/codes/{code}_{random.randint(0, 100000)}.png",
|
img_details[d])
|
plt.figure(figsize=(10, 4))
|
for i in range(0, len(img_details)):
|
plt.subplot(2, 5, i + 1)
|
plt.title(f'pred {i}')
|
plt.axis('off')
|
plt.imshow(img_details[i], cmap='gray')
|
plt.show()
|
|
|
def __test3():
|
files = os.listdir("datas/test3/")
|
for file in files:
|
code = file[:-4]
|
img = cv2.imread(f"datas/test3/{file}", cv2.IMREAD_GRAYSCALE)
|
rows, cols = img.shape
|
for r in range(rows):
|
for c in range(cols):
|
img[r][c] = 255 - img[r][c]
|
|
img_details = split_nums_img(img)
|
for d in range(0, len(img_details)):
|
cv2.imwrite(f"C:/Users/Administrator/Desktop/ocr/codes/{code}_{random.randint(0, 100000)}.png",
|
img_details[d])
|
plt.figure(figsize=(10, 4))
|
for i in range(0, len(img_details)):
|
plt.subplot(2, 5, i + 1)
|
plt.title(f'pred {i}')
|
plt.axis('off')
|
plt.imshow(img_details[i], cmap='gray')
|
plt.show()
|
|
|
if __name__ == '__main__':
|
gray_img = cv2.imread('D:/test2.png', cv2.IMREAD_GRAYSCALE)
|
clip_ths_code_area(gray_img)
|
|
if __name__ == "__main__1":
|
#
|
files = os.listdir("datas/test/")
|
for file in files:
|
code = file[:6]
|
img = cv2.imread(f"datas/test/{file}", cv2.IMREAD_GRAYSCALE)
|
img, img_details = clip_ths_code_area(img)
|
for d in range(0, len(img_details)):
|
cv2.imwrite(f"C:/Users/Administrator/Desktop/ocr/codes/{code}_{random.randint(0, 100000)}.png",
|
img_details[d])
|
if SHOW_PLT:
|
plt.figure(figsize=(1, 1))
|
plt.subplot(1, 1, 1)
|
plt.title(f"test")
|
plt.axis('off')
|
plt.imshow(img)
|
plt.show()
|
pass
|
# img = gray_img(cv2.imread("C:\\Users\\Administrator\\Desktop\\ocr\\code_test.png"))
|
# h = img.shape[0]
|
# w = img.shape[1]
|
# img = cv2.resize(img, (int(w * 2), int(h * 2)), interpolation=cv2.INTER_AREA)
|
# cv2.imwrite("C:\\Users\\Administrator\\Desktop\\ocr\\code_test_gray.png", img)
|