Let's Talk: Convert A PyTorch Model to Tensorflow Using ONNX
轉換Pytorch模型到Tensorflow使用ONNX
有必要說在前面,避免後來者陷坑: ONNX本來是Facebook聯合AWS對抗Tensorflow的,是以注定ONNX-TF這件事是奸情,這是ONNX和TF偷情的行為,兩個平台都不會為他們背書;Pytorch和Tensorflow各自在獨立演變,動态圖和靜态圖優化兩者不會停戰。如果你在嘗試轉模型這件事情,覺得你有必要考慮: 1. 讓服務部署平台支援Pytorch;(優先考慮) 2. 這件事是一錘子買賣功能實作以後不會再繼續;
3. 轉訓練平台到TF;cinastanbean/pytorch-onnx-tensorflow-pbgithub.com
1. Pre-installation
2. 轉換過程
2.1 Step 1.2.3.
2.2 Verification
3. Related Info
3.1 ONNX
3.2 Microsoft/MMdnn
Reference
1. Pre-installation
Version Infopytorch 0.4.0 py27_cuda0.0_cudnn0.0_1 pytorch
torchvision 0.2.1 py27_1 pytorch
tensorflow 1.8.0 <pip>
onnx 1.2.2 <pip>
onnx-tf 1.1.2 <pip>
注意:
- ONNX1.1.2版本太低會引發BatchNormalization錯誤,目前pip已經支援1.3.0版本;也可以考慮源碼安裝
。pip install -U git+https://github.com/onnx/[email protected]
- 本實驗驗證ONNX1.2.2版本可正常運作
- onnx-tf采用源碼安裝;要求 Tensorflow>=1.5.0.;
2. 轉換過程
2.1 Step 1.2.3.
pipeline: pytorch model --> onnx modle --> tensorflow graph pb.# step 1, load pytorch model and export onnx during running.
modelname = 'resnet18'
weightfile = 'models/model_best_checkpoint_resnet18.pth.tar'
modelhandle = DIY_Model(modelname, weightfile, class_numbers)
model = modelhandle.model
#model.eval() # useless
dummy_input = Variable(torch.randn(1, 3, 224, 224)) # nchw
onnx_filename = os.path.split(weightfile)[-1] + ".onnx"
torch.onnx.export(model, dummy_input,
onnx_filename,
verbose=True)
# step 2, create onnx_model using tensorflow as backend. check if right and export graph.
onnx_model = onnx.load(onnx_filename)
tf_rep = prepare(onnx_model, strict=False)
# install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False)
# Reference https://github.com/onnx/onnx-tensorflow/issues/167
#tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
image = Image.open('pants.jpg')
# debug, here using the same input to check onnx and tf.
output_pytorch, img_np = modelhandle.process(image)
print('output_pytorch = {}'.format(output_pytorch))
output_onnx_tf = tf_rep.run(img_np)
print('output_onnx_tf = {}'.format(output_onnx_tf))
# onnx --> tf.graph.pb
tf_pb_path = onnx_filename + '_graph.pb'
tf_rep.export_graph(tf_pb_path)
# step 3, check if tf.pb is right.
with tf.Graph().as_default():
graph_def = tf.GraphDef()
with open(tf_pb_path, "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
with tf.Session() as sess:
#init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
#sess.run(init)
# print all ops, check input/output tensor name.
# uncomment it if you donnot know io tensor names.
'''
print('-------------ops---------------------')
op = sess.graph.get_operations()
for m in op:
print(m.values())
print('-------------ops done.---------------------')
'''
input_x = sess.graph.get_tensor_by_name("0:0") # input
outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5
outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10
output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np})
#output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:np.random.randn(1, 3, 224, 224)})
print('output_tf_pb = {}'.format(output_tf_pb))
2.2 Verification
確定輸出結果一緻output_pytorch = [array([ 2.5359073 , -1.4261041 , -5.2394 , -0.62402934, 4.7426634 ], dtype=float32), array([ 7.6249304, 5.1203837, 1.8118637, 1.5143847, -4.9409146, 1.1695148, -6.2375665, -1.6033885, -1.4286405, -2.964429 ], dtype=float32)]
output_onnx_tf = Outputs(_0=array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269, 4.7426634]], dtype=float32), _1=array([[ 7.6249285, 5.12038 , 1.811865 , 1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32))
output_tf_pb = [array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269, 4.7426634]], dtype=float32), array([[ 7.6249285, 5.12038 , 1.811865 , 1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32)]
獨立TF驗證程式 def get_img_np_nchw(filename):
try:
image = Image.open(filename).convert('RGB').resize((224, 224))
miu = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
#miu = np.array([0.5, 0.5, 0.5])
#std = np.array([0.22, 0.22, 0.22])
# img_np.shape = (224, 224, 3)
img_np = np.array(image, dtype=float) / 255.
r = (img_np[:,:,0] - miu[0]) / std[0]
g = (img_np[:,:,1] - miu[1]) / std[1]
b = (img_np[:,:,2] - miu[2]) / std[2]
img_np_t = np.array([r,g,b])
img_np_nchw = np.expand_dims(img_np_t, axis=0)
return img_np_nchw
except:
print("RuntimeError: get_img_np_nchw({}).".format(filename))
# NoneType
if __name__ == '__main__':
tf_pb_path = 'model_best_checkpoint_resnet18.pth.tar.onnx_graph.pb'
filename = 'pants.jpg'
img_np_nchw = get_img_np_nchw(filename)
# step 3, check if tf.pb is right.
with tf.Graph().as_default():
graph_def = tf.GraphDef()
with open(tf_pb_path, "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
with tf.Session() as sess:
init = tf.global_variables_initializer()
#init = tf.initialize_all_variables()
sess.run(init)
# print all ops, check input/output tensor name.
# uncomment it if you donnot know io tensor names.
'''
print('-------------ops---------------------')
op = sess.graph.get_operations()
for m in op:
print(m.values())
print('-------------ops done.---------------------')
'''
input_x = sess.graph.get_tensor_by_name("0:0") # input
outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5
outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10
output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np_nchw})
print('output_tf_pb = {}'.format(output_tf_pb))
3. Related Info
3.1 ONNX
Open Neural Network Exchangehttps://github.com/onnxhttps://onnx.ai/
The ONNX exporter is a
trace-basedexporter, which means that it operates by executing your model once, and exporting the operators which were actually run during this run. Limitations
https://github.com/onnx/tensorflow-onnxhttps://github.com/onnx/onnx-tensorflow
3.2 Microsoft/MMdnn
目前網絡沒有調通https://github.com/Microsoft/MMdnn/blob/master/mmdnn/conversion/pytorch/README.md
Reference
- Open Neural Network Exchange https://github.com/onnx
- Exporting model from PyTorch to ONNX
- Importing ONNX models to Tensorflow(ONNX)
- Tensorflow + tornado服務
- graph_def = tf.GraphDef() graph_def.ParseFromString(f.read())
- A Tool Developer's Guide to TensorFlow Model Files
- TensorFlow學習筆記:Retrain Inception_v3