深度学习-pytorch-经典CNN模型-LeNet
LeNet模型
原理框架图
代码实现
首先定义我们的LeNet模型
1 | import torch |
检验模型
1 | X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32) |
Reshape output shape: torch.Size([1, 1, 28, 28])
Conv2d output shape: torch.Size([1, 6, 28, 28])
Sigmoid output shape: torch.Size([1, 6, 28, 28])
AvgPool2d output shape: torch.Size([1, 6, 14, 14])
Conv2d output shape: torch.Size([1, 16, 10, 10])
Sigmoid output shape: torch.Size([1, 16, 10, 10])
AvgPool2d output shape: torch.Size([1, 16, 5, 5])
Flatten output shape: torch.Size([1, 400])
Linear output shape: torch.Size([1, 120])
Sigmoid output shape: torch.Size([1, 120])
Linear output shape: torch.Size([1, 84])
Sigmoid output shape: torch.Size([1, 84])
Linear output shape: torch.Size([1, 10])
All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.
Comment