Refs:
https://www.zhihu.com/question/66532235
一种思路是:ONNX + Caffe2,现将pytorch模型转为caffe2模型,然后再操作。看起来很复杂。
参考:https://pytorch.org/tutorials/advanced/super_resolution_with_caffe2.html
另一种思路是:基于Pytorch 1.0 Preview版+(注意当前该版本暂没有windows发行版),现将模型编译成c++可读的形式,然后重写py成c++代码调用模型。以下转Pytorch文档:Loading a PyTorch Model in C++
在C ++中加载PYTORCH模型
本教程需要PyTorch 1.0(预览版)或更高版本。有关安装信息,请访问http://pytorch.org/get-started。
顾名思义,PyTorch的主要接口是Python编程语言。虽然Python是许多需要动态和易于迭代的场景的合适和首选语言,但同样很多情况下,Python的这些属性恰好是不利的。后者经常适用的一个环境是生产 - 低延迟和严格部署要求的土地。对于生产场景,C ++通常是首选语言,即使只是将其绑定到另一种语言,如Java,Rust或Go。以下段落将概述PyTorch提供的路径,从现有的Python模型转换为可以加载和执行的序列化表示 纯粹来自C ++,不依赖于Python。
第1步:将PYTORCH模型转换为TORCH脚本
PyTorch模型从Python到C ++的旅程由Torch Script实现,Torch Script是PyTorch模型的一种表示,可以由Torch Script编译器理解,编译和序列化。如果您从使用vanilla“eager”API编写的现有PyTorch模型开始,则必须先将模型转换为Torch Script。在下面讨论的最常见的情况下,这只需要很少的努力。如果您已有Torch脚本模块,则可以跳到本教程的下一部分。
有两种方法可以将PyTorch模型转换为Torch Script。第一种称为跟踪,一种机制,通过使用示例输入一次评估模型的结构,并通过模型记录这些输入的流量来捕获模型的结构。这适用于限制使用控制流的模型。第二种方法是向模型添加显式注释,以通知Torch脚本编译器它可以直接解析和编译模型代码,受Torch脚本语言强加的约束。
提示:
您可以在官方的Torch脚本参考中找到这两种方法的完整文档,以及有关使用哪些方法的更多指导。
通过跟踪转换为Torch脚本
要通过跟踪将PyTorch模型转换为Torch脚本,必须将模型的实例以及示例输入传递给
torch.jit.trace
函数。这将生成一个
torch.jit.ScriptModule
对象,其模型的
forward
方法中嵌入了模型评估的跟踪:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
跟踪
ScriptModule
现在可以与常规PyTorch模块相同地进行评估:
In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
通过注释转换为Torch脚本
在某些情况下,例如,如果您的模型使用特定形式的控制流,您可能希望直接在Torch脚本中编写模型并相应地注释您的模型。例如,假设您有以下vanilla Pytorch模型:
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
因为
forward
此模块的方法使用依赖于输入的控制流,所以它不适合跟踪。相反,我们可以
ScriptModule
通过子类化并将注释
torch.jit.ScriptModule
添加
@torch.jit.script_method
到模型的
forward
方法来将其转换为a :
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
@torch.jit.script_method
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_script_module = MyModule()
MyModule
现在直接创建一个新对象会生成一个
ScriptModule
可以进行序列化的实例 。
第2步:将脚本模块序列化为文件
一旦
ScriptModule
掌握了PyTorch模型的跟踪或注释,就可以将其序列化为文件。稍后,您将能够使用C ++从该文件加载模块并执行它而不依赖于Python。假设我们要序列化
ResNet18
跟踪示例中前面显示的模型。要执行此序列化,只需 在模块上调用 save并将其传递给文件名:
traced_script_module.save("model.pt")
这将
model.pt
在您的工作目录中生成一个文件。我们现在正式离开了Python的领域,并准备跨越到C ++领域。
第3步:使用C ++加载脚本模块
要在C ++中加载序列化的PyTorch模型,您的应用程序必须依赖于PyTorch C ++ API - 也称为LibTorch。LibTorch发行版包含一组共享库,头文件和CMake构建配置文件。虽然CMake不是依赖LibTorch的要求,但它是推荐的方法,并且将来会得到很好的支持。在本教程中,我们将使用CMake和LibTorch构建一个最小的C ++应用程序,它只需加载并执行序列化的PyTorch模型。
最小的C ++应用程序
我们首先讨论加载模块的代码。以下内容已经做到:
#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
// Deserialize the ScriptModule from a file using torch::jit::load().
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
assert(module != nullptr);
std::cout << "ok\n";
}
该
<torch/script.h>
首标包括由运行示例所必需的库LibTorch所有相关包括。我们的应用程序接受序列化PyTorch的文件路径
ScriptModule
作为其唯一的命令行参数,然后使用该
torch::jit::load()
函数继续反序列化模块,该函数将此文件路径作为输入。作为回报,我们收到一个指向a的共享指针
torch::jit::script::Module
,相当于
torch.jit.ScriptModule
C ++中的a。目前,我们只验证此指针不为null。我们将研究如何在一瞬间执行它。
取决于LibTorch和构建应用程序
假设我们将上面的代码存储到一个名为的文件中
example-app.cpp
。
CMakeLists.txt
构建它的最小化可能看起来很简单:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
我们构建示例应用程序所需的最后一件事是LibTorch发行版。您可以随时从PyTorch网站的下载页面获取最新的稳定版本。如果下载并解压缩最新存档,则应收到具有以下目录结构的文件夹:
libtorch/
bin/
include/
lib/
share/
- 该
文件夹包含您必须链接的共享库,lib/
- 该
文件夹包含程序需要包含的头文件,include/
- 该
文件夹包含必要的CMake配置,以启用share/
上面的简单命令。find_package(Torch)
最后一步是构建应用程序。为此,假设我们的示例目录布局如下:
example-app/
CMakeLists.txt
example-app.cpp
我们现在可以运行以下命令从
example-app/
文件夹中构建应用程序 :
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make
哪里
/path/to/libtorch
应该是解压缩的LibTorch发行版的完整路径。如果一切顺利,它将看起来像这样:
[email protected]:/example-app# mkdir build
[email protected]:/example-app# cd build
[email protected]:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Configuring done
-- Generating done
-- Build files have been written to: /example-app/build
[email protected]:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
如果我们提供
ResNet18
我们之前为生成的
example-app
二进制文件创建的序列化模型的路径,我们应该得到友好的“ok”奖励:
[email protected]:/example-app/build# ./example-app model.pt
ok
第4步:在C ++中执行脚本模块
成功加载了我们
ResNet18
在C ++中的序列化后,我们现在只需执行几行代码!让我们将这些行添加到我们的C ++应用程序的
main()
函数中:
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.
auto output = module->forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
前两行设置了我们模型的输入。我们创建一个向量
torch::jit::IValue
(类型擦除值类型
script::Module
方法接受并返回)并添加单个输入。要创建输入张量,我们使用
torch::ones()
相当于
torch.ones
C ++ API 的输入张量 。然后我们运行
script::Module
's
forward
方法,将它传递给我们创建的输入向量。作为回报,我们得到一个新的
IValue
,我们通过调用转换为张量
toTensor()
。
小费
要了解有关函数
torch::ones
和PyTorch C ++ API的更多信息,请参阅https://pytorch.org/cppdocs上的文档。PyTorch C ++ API提供与Python API近似的特性奇偶校验,允许您像在Python中一样进一步操纵和处理张量。
在最后一行中,我们打印输出的前五个条目。由于我们在本教程前面的Python中为我们的模型提供了相同的输入,因此理想情况下我们应该看到相同的输出。让我们通过重新编译我们的应用程序并使用相同的序列化模型运行它来尝试:
[email protected]:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
[email protected]:/example-app/build# ./example-app model.pt
-0.2698 -0.0381 0.4023 -0.3010 -0.0448
[ Variable[CPUFloatType]{1,5} ]
作为参考,Python之前的输出是:
tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
看起来很不错!
第5步:获取帮助和探索API
本教程希望能让您对PyTorch模型从Python到C ++的路径有一个大致的了解。使用本教程中描述的概念,您应该能够从一个普通的,“渴望”的PyTorch模型,到用
ScriptModule
Python 编译,到磁盘上的序列化文件,以及 - 关闭循环 - 到
script::Module
C ++中的可执行文件。
当然,我们没有涵盖很多概念。例如,您可能会发现自己希望
ScriptModule
使用在C ++或CUDA中实现的自定义运算符进行扩展,并
ScriptModule
在纯C ++生产环境中的加载中执行此自定义运算符 。好消息是:这是可能的,并得到很好的支持!现在,您可以浏览此文件夹以获取示例,我们将很快跟进一个教程。目前,以下链接通常可能有所帮助:
- Torch脚本参考:https://pytorch.org/docs/master/jit.html
- PyTorch C ++ API文档:https://pytorch.org/cppdocs/
- Pytorch Python API文档:https://pytorch.org/docs/
与往常一样,如果您遇到任何问题或有疑问,您可以使用我们的论坛或GitHub问题与我们 联系。