參考文章
- C++部署pytorch模型
- 利用LibTorch部署PyTorch模型
- 官方文檔
問題
pytorch 的神經網絡模型有很多,但 libtorch 就特别少。現在面臨的問題是要在 C++ 環境下應用神經網絡模型,肯定不能直接使用 pytorch 模型。解決辦法有兩個:
- 方法一是用 TorchScript 工具導出模型 poolnet.pt,模型中包含網絡結構和參數權重,是以可以直接在 C++ 裡面生成神經網絡。
- 方法二是用 C++ 複現網絡結構,封裝為為類對象,再從 poolnet.pt 中導入參數權重。
對于神經網絡模型 PoolNet ,将其應用到 C++ 環境下進行視訊處理,下面這是前 10 幀畫面處理時間。明顯看出,方法一前兩次運作時間很長,從第三幀開始,兩種方法的處理時間幾乎相同。但是,方法一相當簡單,導出模型即可,方法二需要複現網絡結構,工程量巨大。下面重點介紹方法一。
TorchScript 工具介紹
必定要看 官方文檔。上面介紹了 trace 和 script 的差別。
PyTorch 導出模型
resnet50
編輯 export.py 檔案,以 pytorch 提供的 resnet50 為例,分别使用 trace 和 script 導出模型。trace 需要提供一個輸入樣例,script 則不需要。但是複雜的模型使用 script 一般會失敗,但 trace 可以。trace 和 script 導出的模型幾乎沒有差別,缺點是前兩次處理時間都格外久。
import torch
from torchvision.models import resnet50
net = resnet50(pretrained=True)
net = net.cuda()
net.eval()
for key, value in net.named_parameters():
print(key)
# trace
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
traced_module = torch.jit.trace(net, x)
traced_module.save("resnet50_trace.pt")
# script
scripted_module = torch.jit.script(net)
scripted_module.save("resnet50_script.pt")
在 python3+pytorch 的虛拟環境下執行
python export.py
net
上面的例子是 pytorch 提供的 resnet50,如果是自己寫的模型,可以按照下面的方式來。其中,net.pth 是訓練後儲存的參數,net.pt 則是期望導出的模型,使用 trace 方法。
import torch
import torchvision
# 初始化神經網絡
net = Net()
net.load_state_dict(torch.load("net.pth"))
net.cuda()
net.eval()
# 導出模型
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
m = torch.jit.trace(net, x)
m.save("net.pt")
其中
也可寫為
C++ 中調用模型
C++ 使用 libtorch 時,一般使用 CMake 進行管理(參考 Pytorch 官網教程)。下面是在 C++ 環境中調用模型的方法。
#include <torch/torch.h>
#include <torch/script.h>
torch::Device device(torch::kCUDA);
// image.rows, image.cols 高在前,寬在後
torch::Tensor img_tensor = torch::from_blob(img.data, {1, image.rows, image.cols, 3}, torch::kByte).to(device);
img_tensor = img_tensor.permute({0, 3, 1, 2});
img_tensor = img_tensor.toType(torch::kFloat);
img_tensor = img_tensor.div(255.0);
torch::jit::script::Module net = torch::jit::load("../models/net.pt");
// 列印模型中的參數
for (const auto& pair : net.named_parameters()) {
std::cout << pair.name << " " << pair.value.requires_grad() << std::endl;
}
net.to(device)
torch::NoGradGuard no_grad;
torch::Tensor output = net.forward({img_tensor}).toTensor();