有时候我们通过pytorch导出onnx模型,需要修改一下onnx的图结构,怎么修改呢?
下面两个Python实例提供了修改思路。
Changing the graph is easier than recreating it with make_graph, just use append, remove and insert.参考https://github.com/onnx/onnx/issues/2259
onnx_model = onnx.load(onnxfile)
graph = onnx_model.graph
pads = onnx.helper.make_tensor('avg_pads', onnx.TensorProto.INT64, [8], np.zeros(8, dtype=int))
graph.initializer.append(pads)
node = graph.node[584]
new_node = onnx.helper.make_node(
'Pad',
name='__Pad_584_fixed',
inputs=['675', 'avg_pads'],
outputs=['676'],
mode='constant'
)
graph.node.remove(node)
graph.node.insert(584, new_node)
# Fix Equals (replace with Not)
node = graph.node[322]
new_node = onnx.helper.make_node(
'Not',
name='__Not__Equal_322',
inputs=['412'],
outputs=['414'],
)
graph.node.remove(node)
graph.node.insert(322, new_node)
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, onnxfile)
来源:https://github.com/saurabh-shandilya/onnx-utils
# ------------------------------------------------
# ONNX Model Editor and Graph Extractor
# License under The MIT License
# Written by Saurabh Shandilya
# -----------------------------------------------
import onnx
from onnx import helper, checker
from onnx import TensorProto
import re
import argparse
def createGraphMemberMap(graph_member_list):
member_map=dict();
for n in graph_member_list:
member_map[n.name]=n;
return member_map
def split_io_list(io_list,new_names_all):
#splits input/output list to identify removed, retained and totally new nodes
removed_names=[]
retained_names=[]
for n in io_list:
if n.name not in new_names_all:
removed_names.append(n.name)
if n.name in new_names_all:
retained_names.append(n.name)
new_names=list(set(new_names_all)-set(retained_names))
return [removed_names,retained_names,new_names]
def traceDependentNodes(graph,name,node_input_names,node_map, initializer_map):
# recurisvely traces all dependent nodes for a given output nodes in a graph
for n in graph.node:
for noutput in n.output:
if (noutput == name) and (n.name not in node_input_names):
# give node "name" is node n's output, so add node "n" to node_input_names list
node_input_names.append(n.name)
if n.name in node_map.keys():
for ninput in node_map[n.name].input:
# trace input node's inputs
node_input_names = traceDependentNodes(graph,ninput,node_input_names,node_map, initializer_map)
# don't forget the initializers they can be terminal inputs on a path.
if name in initializer_map.keys():
node_input_names.append(name)
return node_input_names
def onnx_edit(input_model, output_model, new_input_node_names, input_shape_map, new_output_node_names, output_shape_map, verify):
""" edits and modifies an onnx model to extract a subgraph based on input/output node names and shapes.
Arguments:
input_model: path of input onnx model
output_model: path of output onnx model
new_input_node_names: list of input node names including list of original input nodes if they are to be retained.
If the list is empty original input nodes are assumed.
input_shape_map: dictionary/map of input node names to corresponding shapes. Shapes are needed for model checker to pass.
new_output_node_names: list of output node names, including list of original output nodes if they are to be retained
If the list if empty original output nodes are assumed.
output_shape_map: dictionary/map of output node names to corresponding shape. Shapes are needed for model checker to pass.
verify: set to true if input and output models need to be verified.
"""
# LOAD MODEL AND PREP MAPS
model = onnx.load(input_model)
graph = model.graph
if(verify):
print("input model Errors: ", onnx.checker.check_model(model))
node_map = createGraphMemberMap(graph.node)
input_map = createGraphMemberMap(graph.input)
output_map = createGraphMemberMap(graph.output)
initializer_map = createGraphMemberMap(graph.initializer)
if not new_input_node_names:
new_input_node_names = list(input_map)
if not new_output_node_names:
new_output_node_names = list(output_map)
# MODIFY INPUTS
# Break the graph based on the new input node names
[removed_names,retained_names,new_names]=split_io_list(graph.input,new_input_node_names)
for name in removed_names:
if name in input_map.keys():
graph.input.remove(input_map[name])
for name in new_names:
# If a new input name corresponds to an existing node, it implies that original node in the graph needs to be replaced with an input node
# Exactly here the graph is broken
if name in node_map.keys():
graph.node.remove(node_map[name])
if(name in input_shape_map.keys()):
new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, input_shape_map[name])
else:
new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, None)
graph.input.extend([new_nv])
node_map = createGraphMemberMap(graph.node)
input_map = createGraphMemberMap(graph.input)
# MODIFY OUTPUTS
# Break the graph based on the new output node names
[removed_names,retained_names,new_names]=split_io_list(graph.output,new_output_node_names)
for name in removed_names:
if name in output_map.keys():
graph.output.remove(output_map[name])
for name in new_names:
if(name in output_shape_map.keys()):
new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, output_shape_map[name])
else:
new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, None)
graph.output.extend([new_nv])
output_map = createGraphMemberMap(graph.output)
# CLEANUP NODES
# Trace all dependent nodes for the current set of output nodes defined & prepare a list of invalid nodes
valid_node_names=[]
for new_output_node_name in new_output_node_names:
valid_node_names=traceDependentNodes(graph,new_output_node_name,valid_node_names,node_map, initializer_map)
valid_node_names=list(set(valid_node_names))
invalid_node_names = list( (set(node_map.keys()) | set(initializer_map.keys())) - set(valid_node_names))
# Remove all the invalid nodes from the graph
for name in invalid_node_names:
if name in node_map.keys():
graph.node.remove(node_map[name])
if name in initializer_map.keys():
graph.initializer.remove(initializer_map[name])
if name in input_map.keys():
graph.input.remove(input_map[name])
# SAVE MODEL
if(verify):
print("output model Errors: ", onnx.checker.check_model(model))
onnx.save(model, output_model)
def parse_nodename_and_shape(name):
# parses node names and shapes from input argument string
inputs = []
shapes = {}
# input takes in most cases the format name:0, where 0 is the output number, and shapes
# are appended to the same e.g. name:0[1,28,28,3]
name_pattern = r"(?:([\w\d/\-\._:]+)(\[[\-\d,]+\])?),?"
splits = re.split(name_pattern, name)
for i in range(1, len(splits), 3):
inputs.append(splits[i])
if splits[i + 1] is not None:
shapes[splits[i]] = [int(n) for n in splits[i + 1][1:-1].split(",")]
if not shapes:
shapes = None
return inputs, shapes
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input", help="input onnx model")
parser.add_argument("output", help="output onnx model")
parser.add_argument("--inputs", help="comma separated model input names appended with shapes, e.g. --inputs <nodename>[1,2,3],<nodename1>[1,2,3] ")
parser.add_argument("--outputs", help="comma separated model output names appended with shapes, e.g. --outputs <nodename>[1,2,3],<nodename1>[1,2,3] ")
parser.add_argument('--skipverify', dest='skipverify', action='store_true',
help='skip verification of model. Useful if shapes are not known')
args = parser.parse_args()
if args.inputs:
new_input_node_names, input_shape_map = parse_nodename_and_shape(args.inputs)
#print(new_input_node_names)
#print(input_shape_map)
else:
new_input_node_names = []
input_shape_map = {}
if args.outputs:
new_output_node_names, output_shape_map = parse_nodename_and_shape(args.outputs)
#print(new_output_node_names)
#print(output_shape_map)
else:
new_output_node_names = []
output_shape_map = {}
onnx_edit(args.input,args.output,new_input_node_names, input_shape_map, new_output_node_names, output_shape_map, not args.skipverify)