import torch.nn.functional as F
class SimpleCNN(nn.Module):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) # 输入通道数为3(RGB图像),输出通道数为32
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 * 8 * 8, 128) # 假设输入图像大小为32x32,经过两次2x2池化后大小为8x8
self.fc2 = nn.Linear(128, 10) # CIFAR-10数据集有10个类别
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(-1, 64 * 8 * 8) # 展平