一、前言
我們有時候需要可視化特征圖,尤其是發paper,或者對比算法等情況。而且通過可視化特征圖,也可以讓我們對這個整個cnn模型更加熟悉,廢話不多說了。
二、效果圖
下面我會給出代碼,效果圖分為單channel繪圖和1:1通道特征圖融合圖。
我生成了很多特征圖,我就簡單的放兩張吧,意思意思。
![](https://img.laitimes.com/img/__Qf2AjLwojIjJCLyojI0JCLiAzNfRHLGZkRGZkRfJ3bs92YsYTMfVmepNHL0MmeOBTSE50MNpHW4Z0MMBjVtJWd0ckW65UbM5WOHJWa5kHT20ESjBjUIF2X0hXZ0xCMx81dvRWYoNHLrdEZwZ1Rh5WNXp1bwNjW1ZUba9VZwlHdssmch1mclRXY39CXldWYtlWPzNXZj9mcw1ycz9WL49zZuBnL4AzMxATNzgTM2AzNwAjMwIzLc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.png)
單通道特征圖
疊加後的特征圖
三、代碼
我再次描述清楚我的需求以及我現有的東西,我有網絡的結構和網絡的預訓練權重,我想通過輸入圖檔,得到圖檔在網絡特定層的特征圖。
從main()開始看代碼,我會說得詳細一點,盡量讓大家看懂, 這樣你修改起來會友善很多。
圖檔儲存和讀取的路徑相關的問題,我就不說了,這裡大家應該懂。
1.首先我們看導入的包,DepthCompletionFrontNet 這是我的網絡結構,首先你要搭建起的你的網絡(這個得有)。
2.看main()函數,定位到get_feature()函數
3.get_feature做了下面得幾個事兒,第一,讀取圖檔,也就是要輸入網絡得圖檔(我得網絡是雙分支,是以是讀取兩個圖,這裡你讀取一個圖就行,就 img_rgb 就行,把 img_pc 相關内容注釋);第二,定義網絡,執行個體化,載入預訓練權重模型;第三,定義我們要提取出得特定層,這裡必須和你網絡結構定義得層一模一樣,一模一樣,一模一樣。
4.已經定義的網絡結構需要進行修改,假設你網絡定義的代碼如下:
# 僅僅舉例子,我懶得補全了,直接csdn手打的
class Net(nn.Module):
super(Net,self).__init__()
self.conv1 = nn.conv1
self.conv2 = nn.conv2
self.conv3 = nn.conv3
forward(self,x):
x = conv1(x)
x = conv2(x)
x = conv3(x)
return x
網絡的定義不需要修改,我們需要修改下網絡的 forward,加入字典 all_dict去存儲每層的tensor,forward修改如下:
forward(self,x):
all_dict = {}
x = conv1(x)
all_dict['conv1'] = x
x = conv2(x)
all_dict['conv2'] = x
x = conv3(x)
all_dict['conv3'] = x
return x,all_dict
這樣子就修改完成了
總結一下:首先讀入模型和圖檔,圖檔在前向傳播的過程中,我們通過字典儲存每層的tensor,需要提取哪層,就從哪層去擷取tensor,進而可視化。
大家有問題可以留言,我看到一定會回複。如可以運作,麻煩點贊下,謝謝!希望幫到大家。
完整代碼如下(網絡結構我的很複雜,就不放了, 網絡結構修改就像上面我說的一樣,你可以直接讀取img_rgb,在模型的前向傳播輸入img_rgb,我的網絡是雙分支,是以我輸入兩個圖組合的字典):
import torch
import torchvision.transforms as transforms
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
from completion_segmentation_model import DepthCompletionFrontNet
# from completion_segmentation_model_v3_eca_attention import DepthCompletionFrontNet
import math
#https://blog.csdn.net/missyougoon/article/details/85645195
# https://blog.csdn.net/grayondream/article/details/99090247
# 定義是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定義資料預處理方式(将輸入的類似numpy中arrary形式的資料轉化為pytorch中的張量(tensor))
transform = transforms.ToTensor()
def get_picture(picture_dir, transform):
'''
該算法實作了讀取圖檔,并将其類型轉化為Tensor
'''
img = skimage.io.imread(picture_dir)
img256 = skimage.transform.resize(img, (128, 256))
img256 = np.asarray(img256)
img256 = img256.astype(np.float32)
return transform(img256)
def get_picture_rgb(picture_dir):
'''
該函數實作了顯示圖檔的RGB三通道顔色
'''
img = skimage.io.imread(picture_dir)
img256 = skimage.transform.resize(img, (256, 256))
skimage.io.imsave('4.jpg', img256)
# 取單一通道值顯示
# for i in range(3):
# img = img256[:,:,i]
# ax = plt.subplot(1, 3, i + 1)
# ax.set_title('Feature {}'.format(i))
# ax.axis('off')
# plt.imshow(img)
# r = img256.copy()
# r[:,:,0:2]=0
# ax = plt.subplot(1, 4, 1)
# ax.set_title('B Channel')
# # ax.axis('off')
# plt.imshow(r)
# g = img256.copy()
# g[:,:,0]=0
# g[:,:,2]=0
# ax = plt.subplot(1, 4, 2)
# ax.set_title('G Channel')
# # ax.axis('off')
# plt.imshow(g)
# b = img256.copy()
# b[:,:,1:3]=0
# ax = plt.subplot(1, 4, 3)
# ax.set_title('R Channel')
# # ax.axis('off')
# plt.imshow(b)
# img = img256.copy()
# ax = plt.subplot(1, 4, 4)
# ax.set_title('image')
# # ax.axis('off')
# plt.imshow(img)
img = img256.copy()
ax = plt.subplot()
ax.set_title('image')
# ax.axis('off')
plt.imshow(img)
plt.show()
def visualize_feature_map_sum(item,name):
'''
将每張子圖進行相加
:param feature_batch:
:return:
'''
feature_map = item.squeeze(0)
c = item.shape[1]
print(feature_map.shape)
feature_map_combination=[]
for i in range(0,c):
feature_map_split = feature_map.data.cpu().numpy()[i, :, :]
feature_map_combination.append(feature_map_split)
feature_map_sum = sum(one for one in feature_map_combination)
# feature_map = np.squeeze(feature_batch,axis=0)
plt.figure()
plt.title("combine figure")
plt.imshow(feature_map_sum)
plt.savefig('E:/Dataset/qhms/feature_map/feature_map_sum_'+name+'.png') # 儲存圖像到本地
plt.show()
def get_feature():
# 輸入資料
root_path = 'E:/Dataset/qhms/data/small_data/'
pic_dir = 'test_umm_000067.png'
pc_path = root_path+'knn_pc_crop_0.6/'+pic_dir
rgb_path = root_path+'train_image_2_lane_crop_0.6/'+pic_dir
img_rgb = get_picture(rgb_path, transform)
# 插入次元
img_rgb = img_rgb.unsqueeze(0)
img_rgb = img_rgb.to(device)
img_pc = get_picture(pc_path, transform)
# 插入次元
img_pc = img_pc.unsqueeze(0)
img_pc = img_pc.to(device)
# 加載模型
checkpoint = torch.load('E:/Dataset/qhms/all_result/v3/crop_0.6_old/hah/checkpoint-195.pth.tar')
args = checkpoint['args']
print(args)
model = DepthCompletionFrontNet(args)
print(model.keys())
model.load_state_dict(checkpoint['model'])
model.to(device)
exact_list = ["conv1","conv2","conv3","conv4","convt4","convt3","convt2_","convt1_","lane"]
# myexactor = FeatureExtractor(model, exact_list)
img1 = {
'pc': img_pc, 'rgb': img_rgb
}
# print(img1['pc'])
# x = myexactor(img1)
result,all_dict = model(img1)
outputs = []
# 挑選exact_list的層
for item in exact_list:
x = all_dict[item]
outputs.append(x)
# 特征輸出可視化
x = outputs
k=0
print(x[0].shape[1])
for item in x:
c = item.shape[1]
plt.figure()
name = exact_list[k]
plt.suptitle(name)
for i in range(c):
wid = math.ceil(math.sqrt(c))
ax = plt.subplot(wid, wid, i + 1)
ax.set_title('Feature {}'.format(i))
ax.axis('off')
figure_map = item.data.cpu().numpy()[0, i, :, :]
plt.imshow(figure_map, cmap='jet')
plt.savefig('E:/Dataset/qhms/feature_map/feature_map_' + name + '.png') # 儲存圖像到本地
visualize_feature_map_sum(item,name)
k = k + 1
plt.show()
# 訓練
if __name__ == "__main__":
# get_picture_rgb(pic_dir)
get_feature()
參考:
https://blog.csdn.net/missyougoon/article/details/85645195
https://blog.csdn.net/grayondream/article/details/99090247