import os
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as vutils
import torchvision.models as models
BASE_DIR = os.path.dirname(os.path.abspath(__file__))#获取根目录,里面abspath为了获取文件整个给目录
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
log_dir = os.path.join(BASE_DIR, "..", "results")#相对位置存放的地址
# ----------------------------------- kernel visualization -----------------------------------
writer = SummaryWriter(log_dir=log_dir, filename_suffix="_kernel")#写这些数据放置的位置
# m1
# alexnet = models.alexnet(pretrained=True)
# m2
path_state_dict = os.path.join(BASE_DIR, "..", "data", "alexnet-owt-4df8aa71.pth")#..是为了返回上级目录,因为现在目录是src 与data同级
alexnet = models.alexnet()
pretrained_state_dict = torch.load(path_state_dict)#加载数据
alexnet.load_state_dict(pretrained_state_dict)
kernel_num = -1
vis_max = 1#设置最大卷积层数量进行可视化
for sub_module in alexnet.modules():#通过这个遍历每个层,进行判断
if not isinstance(sub_module, nn.Conv2d):#判读是否是卷积层,然后再进行操作,判断同类别不
continue
kernel_num += 1
if kernel_num > vis_max:#判断是否超过设置的最大卷积层
break
kernels = sub_module.weight#拿出权重参数
c_out, c_int, k_h, k_w = tuple(kernels.shape)#out通道数,int通道数,w,h 这是,利用shape获取维度信息,进行索引
# 拆分channel
for o_idx in range(c_out):#输出out有多少给卷积核
kernel_idx = kernels[o_idx, :, :, :].unsqueeze(1) # 获得(3, h, w), 但是make_grid需要 BCHW,这里拓展C维度变为(3, 1, h, w) 为啥C不是3
kernel_grid = vutils.make_grid(kernel_idx, normalize=True, scale_each=True, nrow=c_int)#生成网格图像,里面利用了inpalce防止梯度被修改使用,可以将修改utils将值赋值出来更新
writer.add_image('{}_Convlayer_split_in_channel'.format(kernel_num), kernel_grid, global_step=o_idx)
kernel_all = kernels.view(-1, 3, k_h, k_w) # 3=RGB, h, w,查下view作用,把全部卷积核的图片组织起来
kernel_grid = vutils.make_grid(kernel_all, normalize=True, scale_each=True, nrow=8) # c, h, w
writer.add_image('{}_all'.format(kernel_num), kernel_grid, global_step=620)
print("{}_convlayer shape:{}".format(kernel_num, tuple(kernels.shape)))
# ----------------------------------- feature map visualization -----------------------------------
writer = SummaryWriter(log_dir=log_dir, filename_suffix="_feature map")
# 数据
path_img = os.path.join(BASE_DIR, "..", "data", "tiger cat.jpg") # your path to image
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
norm_transform = transforms.Normalize(normMean, normStd)
img_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
norm_transform
])
img_pil = Image.open(path_img).convert('RGB')
img_tensor = img_transforms(img_pil)
img_tensor.unsqueeze_(0) # chw --> bchw
# 模型
# alexnet = models.alexnet(pretrained=True)
# forward
convlayer1 = alexnet.features[0]#取出第一个卷积层
fmap_1 = convlayer1(img_tensor)#将图片输出给卷积层进行计算
# 预处理
fmap_1.transpose_(0, 1) # bchw=(1, 64, 55, 55) --> (64, 1, 55, 55)#64是表示多少个照片
fmap_1_grid = vutils.make_grid(fmap_1, normalize=True, scale_each=True, nrow=8)#这个函数第一个维度表示多少张图片,所以要transpose
writer.add_image('feature map in conv1', fmap_1_grid, global_step=620)
writer.close()