首页IT科技pytorch转rknn(Pytorch复习笔记–导出Onnx模型为动态输入和静态输入)

pytorch转rknn(Pytorch复习笔记–导出Onnx模型为动态输入和静态输入)

时间2025-06-15 03:57:29分类IT科技浏览6093
导读:目录...

目录

1--动态输入和静态输入

2--Pytorch API

3--完整代码演示

4--模型可视化

5--测试动态导出的Onnx模型

1--动态输入和静态输入

        当使用 Pytorch 将网络导出为 Onnx 模型格式时                ,可以导出为动态输入和静态输入两种方式                。动态输入即模型输入数据的部分维度是动态的                       ,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的        ,不能够改变                ,当用户使用模型时只能输入指定维度的数据进行推理                        。

        显然                       ,动态输入的通用性比静态输入更强       。

2--Pytorch API

        在 Pytorch 中        ,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入        ,dynamic_axes 的默认值为 None                       ,即默认为静态输入        。

        以下展示动态导出的用法                ,通过定义 dynamic_axes 参数来设置动态导出输入                        。dynamic_axes 中的 0                、2                        、3 表示相应的维度设置为动态值;

# 导出为动态输入 input_name = input output_name = output torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: batch_size, 2: input_height, 3: input_width}, output_name: {0: batch_size, 2: output_height, 3: output_width}})

3--完整代码演示

        在以下代码中        ,定义了一个网络                       ,并使用动态导出和静态导出两种方式                ,将网络导出为 Onnx 模型格式               。

import torch import torch.nn as nn class Model_Net(nn.Module): def __init__(self): super(Model_Net, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, data): data = self.layer1(data) return data if __name__ == "__main__": # 设置输入参数 Batch_size = 8 Channel = 3 Height = 256 Width = 256 input_data = torch.rand((Batch_size, Channel, Height, Width)) # 实例化模型 model = Model_Net() # 导出为静态输入 input_name = input output_name = output torch.onnx.export(model, input_data, "Static_InputNet.onnx", verbose=True, input_names=[input_name], output_names=[output_name]) # 导出为动态输入 torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: batch_size, 2: input_height, 3: input_width}, output_name: {0: batch_size, 2: output_height, 3: output_width}})

4--模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netron netron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

         动态模型可视化:

5--测试动态导出的Onnx模型

import numpy as np import onnx import onnxruntime if __name__ == "__main__": input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32) input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32) # 导入 Onnx 模型 Onnx_file = "./Dynamics_InputNet.onnx" Model = onnx.load(Onnx_file) onnx.checker.check_model(Model) # 验证Onnx模型是否准确 # 使用 onnxruntime 推理 model = onnxruntime.InferenceSession(Onnx_file, providers=[TensorrtExecutionProvider, CUDAExecutionProvider, CPUExecutionProvider]) input_name = model.get_inputs()[0].name output_name = model.get_outputs()[0].name output1 = model.run([output_name], {input_name:input_data1}) output2 = model.run([output_name], {input_name:input_data2}) print(output1.shape: , np.squeeze(np.array(output1), 0).shape) print(output2.shape: , np.squeeze(np.array(output2), 0).shape)

         由输出结果可知                       ,对应动态输入 Onnx 模型                       ,其输出维度也是动态的,并且为对应关系                ,则表明导出的 Onnx 模型无误        。

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
tsconfig-paths(tsconfig.json详细配置) 什么是分布式技术,其解决什么问题(《分布式技术原理与算法解析》学习笔记Day18)