天天看点

onnx推理模型

resnet18 = models.resnet18()
net = resnet18
net.load_state_dict(torch.load(pth_model_path))
net.eval()
output = net(img)

print("pth weights", output.detach().cpu().numpy())
print("HOST GPU prediction", output.argmax(dim=1)[0].item())

##onnx测试
resnet_session = onnxruntime.InferenceSession(onnx_model_path)
#compute ONNX Runtime output prediction
inputs = {resnet_session.get_inputs()[0].name: to_numpy(img)}
outs = resnet_session.run(None, inputs)[0])

print("onnx weights", outs)
print("onnx prediction", outs.argmax(axis=1)[0])
           

继续阅读