天天看點

(pytorch-深度學習系列)使用重複元素的網絡(VGG)使用重複元素的網絡(VGG)

使用重複元素的網絡(VGG)

VGG的名字來源于論文作者所在的實驗室Visual Geometry Group,VGG提出了可以通過重複使用簡單的基礎塊來建構深度模型的思路。

VGG Block(VGG 塊)

VGG塊的組成規律是:連續使用數個相同的填充為1、視窗形狀為 3 × 3 3\times 3 3×3的卷積層後接上一個步幅為2、視窗形狀為 2 × 2 2\times 2 2×2的最大池化層。

卷積層保持輸入的高和寬不變,而池化層則對其減半。

用函數來實作基礎的VGG塊,它可以指定卷積層的數量和輸入輸出通道數。

import time
import torch
from torch import nn, optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def vgg_block(num_convs, in_channels, out_channels):
    blk = []
    for i in range(num_convs):
        if i == 0:
            blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        else:
            blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
        blk.append(nn.ReLU())
    blk.append(nn.MaxPool2d(kernel_size=2, stride=2)) # 這裡會使寬高減半
    return nn.Sequential(*blk)
           

對于給定的感受野(與輸出有關的輸入圖檔的局部大小),采用堆積的小卷積核優于采用大的卷積核,因為可以增加網絡深度來保證學習更複雜的模式,而且代價還比較小(參數更少)。例如,在VGG中,使用了3個3x3卷積核來代替7x7卷積核,使用了2個3x3卷積核來代替5*5卷積核,這樣做的主要目的是在保證具有相同感覺野的條件下,提升了網絡的深度,在一定程度上提升了神經網絡的效果。

VGG 網絡

VGG網絡由卷積層子產品後接全連接配接層子產品構成。卷積層子產品串聯數個VGG Block,其超參數由變量conv_arch定義。該變量指定了每個VGG塊裡卷積層個數和輸入輸出通道數。全連接配接子產品則跟AlexNet中的一樣。

構造一個VGG網絡。

  • 它有5個卷積塊,前2塊使用單卷積層,而後3塊使用雙卷積層。
  • 第一塊的輸入輸出通道分别是1和64
  • 之後每次對輸出通道數翻倍,直到變為512

因為這個網絡使用了8個卷積層和3個全連接配接層,是以經常被稱為VGG-11。

conv_arch = ((1, 1, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))
# 經過5個VGG Block, 寬高會減半5次, 變成 224/32 = 7
fc_features = 512 * 7 * 7 # c * w * h
fc_hidden_units = 4096 # 任意設定
           

實作VGG-11

def vgg(conv_arch, fc_features, fc_hidden_units=4096):
    net = nn.Sequential()
    # 卷積層部分
    for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):
        # 每經過一個vgg_block都會使寬高減半
        net.add_module("vgg_block_" + str(i+1), vgg_block(num_convs, in_channels, out_channels))
    # 全連接配接層部分
    net.add_module("fc", nn.Sequential(d2l.FlattenLayer(),
                                 nn.Linear(fc_features, fc_hidden_units),
                                 nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(fc_hidden_units, fc_hidden_units),
                                 nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(fc_hidden_units, 10)
                                ))
    return net
           

構造一個高和寬均為224的單通道資料樣本來觀察每一層的輸出形狀。

net = vgg(conv_arch, fc_features, fc_hidden_units)
X = torch.rand(1, 1, 224, 224)

# named_children擷取一級子子產品及其名字(named_modules會傳回所有子子產品,包括子子產品的子子產品)
for name, blk in net.named_children(): 
    X = blk(X)
    print(name, 'output shape: ', X.shape)
           

輸出:

vgg_block_1 output shape:  torch.Size([1, 64, 112, 112])
vgg_block_2 output shape:  torch.Size([1, 128, 56, 56])
vgg_block_3 output shape:  torch.Size([1, 256, 28, 28])
vgg_block_4 output shape:  torch.Size([1, 512, 14, 14])
vgg_block_5 output shape:  torch.Size([1, 512, 7, 7])
fc output shape:  torch.Size([1, 10])
           

每次輸入的高和寬減半,直到最終高和寬變成7後傳入全連接配接層。與此同時,輸出通道數每次翻倍,直到變成512。因為每個卷積層的視窗大小一樣,是以每層的模型參數尺寸和計算複雜度與輸入高、輸入寬、輸入通道數和輸出通道數的乘積成正比。VGG這種高和寬減半以及通道翻倍的設計使得多數卷積層都有相同的模型參數尺寸和計算複雜度。

擷取資料

ratio = 8
small_conv_arch = [(1, 1, 64//ratio), (1, 64//ratio, 128//ratio), (2, 128//ratio, 256//ratio), 
                   (2, 256//ratio, 512//ratio), (2, 512//ratio, 512//ratio)]
net = vgg(small_conv_arch, fc_features // ratio, fc_hidden_units // ratio)
print(net)
           

輸出:

Sequential(
  (vgg_block_1): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_2): Sequential(
    (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_3): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_4): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vgg_block_5): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): FlattenLayer()
    (1): Linear(in_features=3136, out_features=512, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.5)
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): ReLU()
    (6): Dropout(p=0.5)
    (7): Linear(in_features=512, out_features=10, bias=True)
  )
)
           

訓練:

def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory."""
    trans = []
    if resize:
        trans.append(torchvision.transforms.Resize(size=resize))
    trans.append(torchvision.transforms.ToTensor())
    
    transform = torchvision.transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
    if sys.platform.startswith('win'):
        num_workers = 0  # 0表示不用額外的程序來加速讀取資料
    else:
        num_workers = 4
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_iter, test_iter
           
batch_size = 64
# 如出現“out of memory”的報錯資訊,可減小batch_size或resize
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)