import os
|
import time
|
|
from PIL import Image
|
import torch
|
from torchvision import transforms
|
import cnn
|
import matplotlib.pyplot as plt
|
|
|
# 分隔图片
|
def __split_img_for_cnn(img):
|
# 切割图片
|
pass
|
|
|
if __name__ == '__main__':
|
path = 'C:/Users/Administrator/Desktop/ocr/codes/'
|
imgs = []
|
labels = []
|
for name in sorted(os.listdir(path)):
|
if name.find(".png") < 0:
|
continue
|
img = Image.open(path + name).convert('L')
|
img = transforms.ToTensor()(img)
|
imgs.append(img)
|
labels.append(int(name[0]))
|
imgs = torch.stack(imgs, 0)
|
# %% 加载模型
|
model = cnn.CNN()
|
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
|
model.eval()
|
for i in range(0, 10):
|
start_time = time.time()
|
with torch.no_grad():
|
output = model(imgs)
|
# %% 打印结果
|
pred = output.argmax(1)
|
true = torch.LongTensor(labels)
|
print(pred)
|
print(true)
|
print("识别时间:", (time.time() - start_time) * 1000)
|
|
# %% 结果显示
|
# plt.figure(figsize=(10, 4))
|
# for i in range(len(imgs)):
|
# plt.subplot(2, 5, i + 1)
|
# plt.title(f'pred {pred[i]} | true {true[i]}')
|
# plt.axis('off')
|
# plt.imshow(imgs[i].squeeze(0), cmap='gray')
|
# # plt.savefig('test.png')
|
# plt.show()
|