admin
2025-06-10 568c763084b926a6f2d632b7ac65b9ec8280752f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import time
 
from PIL import Image
import torch
from torchvision import transforms
import cnn
import matplotlib.pyplot as plt
 
 
# 分隔图片
def __split_img_for_cnn(img):
    # 切割图片
    pass
 
 
if __name__ == '__main__':
    path = 'C:/Users/Administrator/Desktop/ocr/codes/'
    imgs = []
    labels = []
    for name in sorted(os.listdir(path)):
        if name.find(".png") < 0:
            continue
        img = Image.open(path + name).convert('L')
        img = transforms.ToTensor()(img)
        imgs.append(img)
        labels.append(int(name[0]))
    imgs = torch.stack(imgs, 0)
    # %% 加载模型
    model = cnn.CNN()
    model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
    model.eval()
    for i in range(0, 10):
        start_time = time.time()
        with torch.no_grad():
            output = model(imgs)
        # %% 打印结果
        pred = output.argmax(1)
        true = torch.LongTensor(labels)
        print(pred)
        print(true)
        print("识别时间:", (time.time() - start_time) * 1000)
 
        # %% 结果显示
        # plt.figure(figsize=(10, 4))
        # for i in range(len(imgs)):
        #     plt.subplot(2, 5, i + 1)
        #     plt.title(f'pred {pred[i]} | true {true[i]}')
        #     plt.axis('off')
        #     plt.imshow(imgs[i].squeeze(0), cmap='gray')
        # # plt.savefig('test.png')
        # plt.show()