首页IT科技pytorch转rknn(Pytorch复习笔记–导出Onnx模型为动态输入和静态输入)

pytorch转rknn(Pytorch复习笔记–导出Onnx模型为动态输入和静态输入)

时间2025-07-29 16:16:23分类IT科技浏览7797
导读:目录...

目录

1--动态输入和静态输入

2--Pytorch API

3--完整代码演示

4--模型可视化

5--测试动态导出的Onnx模型

1--动态输入和静态输入

        当使用 Pytorch 将网络导出为 Onnx 模型格式时                  ,可以导出为动态输入和静态输入两种方式                  。动态输入即模型输入数据的部分维度是动态的                          ,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的         ,不能够改变                  ,当用户使用模型时只能输入指定维度的数据进行推理                           。

        显然                          ,动态输入的通用性比静态输入更强        。

2--Pytorch API

        在 Pytorch 中         ,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入         ,dynamic_axes 的默认值为 None                          ,即默认为静态输入         。

        以下展示动态导出的用法                  ,通过定义 dynamic_axes 参数来设置动态导出输入                           。dynamic_axes 中的 0                  、2                           、3 表示相应的维度设置为动态值;

# 导出为动态输入 input_name = input output_name = output torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: batch_size, 2: input_height, 3: input_width}, output_name: {0: batch_size, 2: output_height, 3: output_width}})

3--完整代码演示

        在以下代码中         ,定义了一个网络                          ,并使用动态导出和静态导出两种方式                  ,将网络导出为 Onnx 模型格式                 。

import torch import torch.nn as nn class Model_Net(nn.Module): def __init__(self): super(Model_Net, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, data): data = self.layer1(data) return data if __name__ == "__main__": # 设置输入参数 Batch_size = 8 Channel = 3 Height = 256 Width = 256 input_data = torch.rand((Batch_size, Channel, Height, Width)) # 实例化模型 model = Model_Net() # 导出为静态输入 input_name = input output_name = output torch.onnx.export(model, input_data, "Static_InputNet.onnx", verbose=True, input_names=[input_name], output_names=[output_name]) # 导出为动态输入 torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: batch_size, 2: input_height, 3: input_width}, output_name: {0: batch_size, 2: output_height, 3: output_width}})

4--模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netron netron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

         动态模型可视化:

5--测试动态导出的Onnx模型

import numpy as np import onnx import onnxruntime if __name__ == "__main__": input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32) input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32) # 导入 Onnx 模型 Onnx_file = "./Dynamics_InputNet.onnx" Model = onnx.load(Onnx_file) onnx.checker.check_model(Model) # 验证Onnx模型是否准确 # 使用 onnxruntime 推理 model = onnxruntime.InferenceSession(Onnx_file, providers=[TensorrtExecutionProvider, CUDAExecutionProvider, CPUExecutionProvider]) input_name = model.get_inputs()[0].name output_name = model.get_outputs()[0].name output1 = model.run([output_name], {input_name:input_data1}) output2 = model.run([output_name], {input_name:input_data2}) print(output1.shape: , np.squeeze(np.array(output1), 0).shape) print(output2.shape: , np.squeeze(np.array(output2), 0).shape)

         由输出结果可知                          ,对应动态输入 Onnx 模型                          ,其输出维度也是动态的,并且为对应关系                  ,则表明导出的 Onnx 模型无误         。

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
java替换文本内容(java 替换list中值的方法分享) 香港云服务器10元一年怎么收费(香港服务器云主机购买需要注意什么)