admin
2025-06-10 568c763084b926a6f2d632b7ac65b9ec8280752f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
 
print(torch.__version__)
print(torch.cuda.is_available())
 
 
# 创建网络
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Sequential(
            # [BATCH_SIZE, 1, 28, 28]
            # (通道大小,特征图数量,卷积核大小,步长,边缘填充像素)
            nn.Conv2d(1, 32, 3, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            # [BATCH_SIZE, 32, 14, 14]
            nn.Conv2d(32, 64, 3, 1, 2),
            # [BATCH_SIZE, 64, 14, 14]
            nn.ReLU(),
            nn.MaxPool2d(2)
            # [BATCH_SIZE, 64, 7, 7]
        )
        self.fc = nn.Linear(64 * 7 * 7, 10)
 
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        y = self.fc(x)
        return y
 
 
# 评估函数
def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1]
    rights = pred.eq(labels.data.view_as(pred)).sum()
    return rights, len(labels)
 
 
def __train():
    net = CNN()
    # 损失函数
    criterion = nn.CrossEntropyLoss()
    #优化器
    optimizer = torch.optim.Adam(net.parameters(),lr=0.001)
    num_epochs= 10
    # 开始循环训练
    for epoch in range(num_epochs):
        pass