天天看點

torch:)——PyTorch: transforms用法詳解常見的transform操作

PyTorch:transforms用法詳解

  • 常見的transform操作
    • 1.resize: `transforms.Resize`
    • 2.标準化: `transforms.Normalize`
    • 3.轉為Tensor: `transforms.ToTensor`

transforms用于圖形變換,在使用時我們還可以使用

transforms.Compose

将一系列的transforms操作連結起來。

  • torchvision.transforms.Compose([ ts,ts,ts... ])

    ts為transforms操作

i.e.

transforms.Compose([
     transforms.CenterCrop(10),
     transforms.ToTensor(), ])
           
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
           

常見的transform操作

1.resize:

transforms.Resize

将輸入PIL圖像的大小調整為給定大小。

  • size(sequence 或int)
  • 所需的輸出大小。如果size是類似(h,w)的序列,則輸出大小将與此比對。如果size是int,則圖像的較小邊緣将與此數字比對。即,如果高度>寬度,則圖像将重新縮放為(size*高度/寬度,size)
  • interpolation(int,optional) - 所需的插值。預設是 PIL.Image.BILINEAR

2.标準化:

transforms.Normalize

用平均值和标準偏差歸一化張量圖像。給定mean:(M1,…,Mn)和std:(S1,…,Sn)對于n通道,此變換将标準化輸入的每個通道,torch.*Tensor即 input[channel] = (input[channel] - mean[channel]) / std[channel]

  • mean(sequence) - 每個通道的均值序列。
  • std(sequence) - 每個通道的标準偏差序列。

i.e.

transform = transforms.Compose(
    						[transforms.ToTensor(),
    					 	transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
           

把輸入資料的值從0-1變換到(-1,1).具體地說,對每個通道而言,Normalize執行以下操作:

image = (image - mean) / std

其中mean和std分别通過(0.5,0.5,0.5)和(0.5,0.5,0.5)進行指定。原來的0-1最小值0則變成(0-0.5)/0.5=-1,而最大值1則變成(1-0.5)/0.5=1.

3.轉為Tensor:

transforms.ToTensor

torchvision.transforms.ToTensor
           

将PIL Image或者 ndarray 轉換為tensor,并且歸一化至[0-1] ;

注意事項: 歸一化至[0-1]是直接除以255,若自己的ndarray資料尺度有變化,則需要自行修改。

繼續閱讀