VGG网络模型 结构图
代码实现 1 2 3 4 5 6 7 import torchfrom torch import nnfrom d2l import torch as d2lfrom torch.utils import datafrom torchvision import transformsimport torchvision%matplotlib inline
1 2 3 4 5 6 7 8 9 10 11 12 def vgg_block (num_convs, in_channels, out_channels ): layers = [] for _ in range (num_convs): layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3 , padding=1 )) layers.append(nn.ReLU()) in_channels = out_channels layers.append(nn.MaxPool2d(kernel_size=2 , stride=2 )) return nn.Sequential(*layers)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 conv_arch = ((1 , 64 ), (1 , 128 ), (2 , 256 ), (2 , 512 ), (2 , 512 )) def vgg (conv_arch ): conv_blks = [] in_channels = 1 for (num_convs, out_channels) in conv_arch: conv_blks.append(vgg_block(num_convs=num_convs, in_channels=in_channels, out_channels=out_channels)) in_channels = out_channels return nn.Sequential( *conv_blks, nn.Flatten(), nn.Linear(out_channels*7 *7 , 4096 ), nn.ReLU(), nn.Dropout(0.5 ), nn.Linear(4096 , 4096 ), nn.ReLU(), nn.Dropout(0.5 ), nn.Linear(4096 , 10 )) net = vgg(conv_arch=conv_arch)
观察一下每一层的形状
1 2 3 4 5 X = torch.randn(size=(1 , 1 , 224 , 224 )) for blk in net: X = blk(X) print (blk.__class__.__name__, 'output shape:\t' , X.shape)
Sequential output shape: torch.Size([1, 64, 112, 112])
Sequential output shape: torch.Size([1, 128, 56, 56])
Sequential output shape: torch.Size([1, 256, 28, 28])
Sequential output shape: torch.Size([1, 512, 14, 14])
Sequential output shape: torch.Size([1, 512, 7, 7])
Flatten output shape: torch.Size([1, 25088])
Linear output shape: torch.Size([1, 4096])
ReLU output shape: torch.Size([1, 4096])
Dropout output shape: torch.Size([1, 4096])
Linear output shape: torch.Size([1, 4096])
ReLU output shape: torch.Size([1, 4096])
Dropout output shape: torch.Size([1, 4096])
Linear output shape: torch.Size([1, 10])
开始训练 1 2 3 ratio = 4 small_conv_arch = [(pair[0 ], pair[1 ]//ratio)for pair in conv_arch] net = vgg(small_conv_arch)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 def get_dataloader_workers (): return 4 def load_data_fashion_mnist (batch_size, resize=None ): """下载Fashion-MNIST数据集,然后将其加载到内存中""" trans = [transforms.ToTensor()] if resize: trans.insert(0 , transforms.Resize(resize)) trans = transforms.Compose(trans) mnist_train = torchvision.datasets.FashionMNIST( root="./data" , train=True , transform=trans, download=True ) mnist_test = torchvision.datasets.FashionMNIST( root="./data" , train=False , transform=trans, download=True ) return (data.DataLoader(mnist_train, batch_size, shuffle=True , num_workers=get_dataloader_workers()), data.DataLoader(mnist_test, batch_size, shuffle=False , num_workers=get_dataloader_workers()))
1 2 3 4 lr, num_epochs, batch_size = 0.05 , 10 , 128 train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224 ) d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())