天天看點

0702-計算機視覺工具包torchvision

0702-計算機視覺工具包torchvision

目錄

  • 一、torchvision 概述
  • 二、通過 torchvision 加載模型
  • 三、通過 torchvision 加載并處理資料集
  • 四、通過 torchvision 拼接并儲存圖檔

pytorch完整教程目錄:https://www.cnblogs.com/nickchen121/p/14662511.html

一、torchvision 概述

計算機視覺是深度學習中最重要的一類應用,為了友善研究者使用,torch 專門開發了一個視覺工具包 torchvision,這個包獨立于 torch,需要使用

pip install torchvision

進行安裝。

之前的我們已經使用過它的部分功能,在這裡我們在做一個系統的介紹,它主要包含以下三個功能:

  • models:提供深度學習中各種經典網絡的網絡結構以及訓練好的模型,包括 Alex-Net、VGG 系列、ResNet 系列、Inception 系列等
  • datasets:提供常用的資料集加載,設計上都是內建 torch.utils.data.Dataset,主要包括 MNIST、CIFAR10/100、ImageNet、COCO 等
  • transforms:提供常用的資料預處理操作,主要包括對 Tensor 以及 PIL Image 對象的操作

二、通過 torchvision 加載模型

from torchvision import models
from torch import nn

# 加載預訓練好的模型,如果不存在會下載下傳
# 預訓練好的模型儲存在 ~/.torch/modes/ 下面
resnet34 = models.resnet34(pretrained=True, num_classes=1000)

# 修改最後的全連接配接層為 10 分類問題(預設是 ImageNet 上的 1000 分類)
resnet34.fc = nn.Linear(512, 10)
           

三、通過 torchvision 加載并處理資料集

from torchvision import datasets
from torchvision import transforms as T
# 指定資料集路徑為 data,如果資料集不存在則進行下載下傳
# 通過 train=False 擷取測試集

normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),  # 把圖檔轉成 Tensor,歸一化至 [0,1]
    T.Lambda(lambda x: x.repeat(3, 1, 1)),  # 把圖檔轉為 3 通道的
    normalize,
])

dataset = datasets.MNIST('data/',
                         download=True,
                         train=False,
                         transform=transform)
           

Transforms 中涵蓋了大部分對 Tensor 和 PIL Image 的常用處理,這個轉換通常分為兩步:

  1. 第一步:建構轉換操作,例如

    transf = transforms.Normalize(mean=x, std=y)

  2. 第二步:執行轉換操作,例如

    otuput = transf(inp)

import torch as t

# 建構随機噪聲,圖檔如下圖所示
to_pil = T.ToPILImage()
to_pil(t.rand(3, 64, 64))
           
0702-計算機視覺工具包torchvision

四、通過 torchvision 拼接并儲存圖檔

torchvision 還提供了兩個常用的函數:

  1. make_grid,它能把多張圖檔拼接在一個網格中
  2. save_img,它能把 Tensor 儲存成圖檔
len(dataset)
           
10000
           
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
dataiter
img = make_grid(next(dataiter)[0], 4)  # 拼接成 4*4 網格圖檔,并且會轉成 3 通道,如下圖所示
to_img = T.ToPILImage()
to_img(img)
           
0702-計算機視覺工具包torchvision
save_image(img, 'a.png')
from PIL import Image
Image.open('a.png')
           
0702-計算機視覺工具包torchvision