天天看點

【深度學習】使用NETRON工具可視化PYTORCH模型

netron是微軟小哥lutzroeder的一個廣受好評的開源項目,位址https://github.com/lutzroeder/Netro

支援衆多模型:

【深度學習】使用NETRON工具可視化PYTORCH模型

1. 安裝NETRON

pip install netron
           

2. 測試代碼

由于不支援預設的pytorch模型格式(.pth),是以需要存為onnx

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx
 
import netron
 
 
class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64)
        )
 
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1, bias=False)
        self.output = nn.Sequential(
            nn.Conv2d(64, 1, 3, padding=1, bias=True),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        x = self.conv1(x)
        identity = x
        x = F.relu(self.block1(x) + identity)
        x = self.output(x)
        return x
 
 
d = torch.rand(1, 3, 416, 416)
m = model()
o = m(d)
 
onnx_path = "onnx_model_name.onnx"
torch.onnx.export(m, d, onnx_path)
 
netron.start(onnx_path)
           

3. 結果

執行上面代碼後,會調用本地浏覽器打開,形式和tensorboard差不多

【深度學習】使用NETRON工具可視化PYTORCH模型

繼續閱讀