pytorch转rknn(Pytorch复习笔记–导出Onnx模型为动态输入和静态输入)
目录
1--动态输入和静态输入
2--Pytorch API
3--完整代码演示
4--模型可视化
5--测试动态导出的Onnx模型
1--动态输入和静态输入
当使用 Pytorch 将网络导出为 Onnx 模型格式时 ,可以导出为动态输入和静态输入两种方式 。动态输入即模型输入数据的部分维度是动态的 ,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变 ,当用户使用模型时只能输入指定维度的数据进行推理 。
显然 ,动态输入的通用性比静态输入更强 。
2--Pytorch API
在 Pytorch 中 ,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入 ,dynamic_axes 的默认值为 None ,即默认为静态输入 。
以下展示动态导出的用法 ,通过定义 dynamic_axes 参数来设置动态导出输入 。dynamic_axes 中的 0 、2 、3 表示相应的维度设置为动态值;
# 导出为动态输入 input_name = input output_name = output torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: batch_size, 2: input_height, 3: input_width}, output_name: {0: batch_size, 2: output_height, 3: output_width}})3--完整代码演示
在以下代码中 ,定义了一个网络 ,并使用动态导出和静态导出两种方式 ,将网络导出为 Onnx 模型格式 。
import torch import torch.nn as nn class Model_Net(nn.Module): def __init__(self): super(Model_Net, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, data): data = self.layer1(data) return data if __name__ == "__main__": # 设置输入参数 Batch_size = 8 Channel = 3 Height = 256 Width = 256 input_data = torch.rand((Batch_size, Channel, Height, Width)) # 实例化模型 model = Model_Net() # 导出为静态输入 input_name = input output_name = output torch.onnx.export(model, input_data, "Static_InputNet.onnx", verbose=True, input_names=[input_name], output_names=[output_name]) # 导出为动态输入 torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: batch_size, 2: input_height, 3: input_width}, output_name: {0: batch_size, 2: output_height, 3: output_width}})4--模型可视化
通过 netron 库可视化导出的静态模型和动态模型,代码如下:
import netron netron.start("./Dynamics_InputNet.onnx")静态模型可视化:
动态模型可视化:
5--测试动态导出的Onnx模型
import numpy as np import onnx import onnxruntime if __name__ == "__main__": input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32) input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32) # 导入 Onnx 模型 Onnx_file = "./Dynamics_InputNet.onnx" Model = onnx.load(Onnx_file) onnx.checker.check_model(Model) # 验证Onnx模型是否准确 # 使用 onnxruntime 推理 model = onnxruntime.InferenceSession(Onnx_file, providers=[TensorrtExecutionProvider, CUDAExecutionProvider, CPUExecutionProvider]) input_name = model.get_inputs()[0].name output_name = model.get_outputs()[0].name output1 = model.run([output_name], {input_name:input_data1}) output2 = model.run([output_name], {input_name:input_data2}) print(output1.shape: , np.squeeze(np.array(output1), 0).shape) print(output2.shape: , np.squeeze(np.array(output2), 0).shape)由输出结果可知 ,对应动态输入 Onnx 模型 ,其输出维度也是动态的,并且为对应关系 ,则表明导出的 Onnx 模型无误 。
创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!