天天看點

可視化參考模型,卷積核可視化

import os
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as vutils
import torchvision.models as models
BASE_DIR = os.path.dirname(os.path.abspath(__file__))#擷取根目錄,裡面abspath為了擷取檔案整個給目錄
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == "__main__":

    log_dir = os.path.join(BASE_DIR, "..", "results")#相對位置存放的位址
    # ----------------------------------- kernel visualization -----------------------------------
    writer = SummaryWriter(log_dir=log_dir, filename_suffix="_kernel")#寫這些資料放置的位置

    # m1
    # alexnet = models.alexnet(pretrained=True)

    # m2
    path_state_dict = os.path.join(BASE_DIR, "..", "data", "alexnet-owt-4df8aa71.pth")#..是為了傳回上級目錄,因為現在目錄是src 與data同級
    alexnet = models.alexnet()
    pretrained_state_dict = torch.load(path_state_dict)#加載資料
    alexnet.load_state_dict(pretrained_state_dict)

    kernel_num = -1
    vis_max = 1#設定最大卷積層數量進行可視化
    for sub_module in alexnet.modules():#通過這個周遊每個層,進行判斷
        if not isinstance(sub_module, nn.Conv2d):#判讀是否是卷積層,然後再進行操作,判斷同類别不
            continue
        kernel_num += 1
        if kernel_num > vis_max:#判斷是否超過設定的最大卷積層
            break

        kernels = sub_module.weight#拿出權重參數
        c_out, c_int, k_h, k_w = tuple(kernels.shape)#out通道數,int通道數,w,h 這是,利用shape擷取次元資訊,進行索引

        # 拆分channel

        for o_idx in range(c_out):#輸出out有多少給卷積核
            kernel_idx = kernels[o_idx, :, :, :].unsqueeze(1)  # 獲得(3, h, w), 但是make_grid需要 BCHW,這裡拓展C次元變為(3, 1, h, w) 為啥C不是3
            kernel_grid = vutils.make_grid(kernel_idx, normalize=True, scale_each=True, nrow=c_int)#生成網格圖像,裡面利用了inpalce防止梯度被修改使用,可以将修改utils将值指派出來更新
            writer.add_image('{}_Convlayer_split_in_channel'.format(kernel_num), kernel_grid, global_step=o_idx)

        kernel_all = kernels.view(-1, 3, k_h, k_w)  # 3=RGB, h, w,查下view作用,把全部卷積核的圖檔組織起來
        kernel_grid = vutils.make_grid(kernel_all, normalize=True, scale_each=True, nrow=8)  # c, h, w
        writer.add_image('{}_all'.format(kernel_num), kernel_grid, global_step=620)

        print("{}_convlayer shape:{}".format(kernel_num, tuple(kernels.shape)))

    # ----------------------------------- feature map visualization -----------------------------------
    writer = SummaryWriter(log_dir=log_dir, filename_suffix="_feature map")

    # 資料
    path_img = os.path.join(BASE_DIR, "..", "data", "tiger cat.jpg")  # your path to image
    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]
    norm_transform = transforms.Normalize(normMean, normStd)
    img_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        norm_transform
    ])

    img_pil = Image.open(path_img).convert('RGB')
    img_tensor = img_transforms(img_pil)
    img_tensor.unsqueeze_(0)  # chw --> bchw

    # 模型
    # alexnet = models.alexnet(pretrained=True)

    # forward
    convlayer1 = alexnet.features[0]#取出第一個卷積層
    fmap_1 = convlayer1(img_tensor)#将圖檔輸出給卷積層進行計算

    # 預處理
    fmap_1.transpose_(0, 1)  # bchw=(1, 64, 55, 55) --> (64, 1, 55, 55)#64是表示多少個照片
    fmap_1_grid = vutils.make_grid(fmap_1, normalize=True, scale_each=True, nrow=8)#這個函數第一個次元表示多少張圖檔,是以要transpose

    writer.add_image('feature map in conv1', fmap_1_grid, global_step=620)
    writer.close()
           

繼續閱讀