import os import time from PIL import Image import torch from torchvision import transforms import cnn import matplotlib.pyplot as plt # 分隔图片 def __split_img_for_cnn(img): # 切割图片 pass if __name__ == '__main__': path = 'C:/Users/Administrator/Desktop/ocr/codes/' imgs = [] labels = [] for name in sorted(os.listdir(path)): if name.find(".png") < 0: continue img = Image.open(path + name).convert('L') img = transforms.ToTensor()(img) imgs.append(img) labels.append(int(name[0])) imgs = torch.stack(imgs, 0) # %% 加载模型 model = cnn.CNN() model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu'))) model.eval() for i in range(0, 10): start_time = time.time() with torch.no_grad(): output = model(imgs) # %% 打印结果 pred = output.argmax(1) true = torch.LongTensor(labels) print(pred) print(true) print("识别时间:", (time.time() - start_time) * 1000) # %% 结果显示 # plt.figure(figsize=(10, 4)) # for i in range(len(imgs)): # plt.subplot(2, 5, i + 1) # plt.title(f'pred {pred[i]} | true {true[i]}') # plt.axis('off') # plt.imshow(imgs[i].squeeze(0), cmap='gray') # # plt.savefig('test.png') # plt.show()