首页IT科技pytorch模型的保存与加载(PyTorch模型的保存与加载)

pytorch模型的保存与加载(PyTorch模型的保存与加载)

时间2025-06-16 12:51:16分类IT科技浏览4588
导读:载入muti-GPU模型: pretrain_model = torch.load(muti_gpu_model.pth ...

载入muti-GPU模型:

pretrain_model = torch.load(muti_gpu_model.pth) # 网络+权重 # 载入为single-GPU模型 gpu_model = pretrained_model.module # 载入为CPU模型 model = ModelArch() pretained_dict = pretained_model.module.state_dict() model.load_satte_dict(pretained_dict)

载入muti-GPU权重:

model = ModelArch().cuda() model = torch.nn.DataParallel(model, device_ids=[0]) # 将model转为muti-GPU模式 checkpoint = torch.load(model_path, map_location=lambda storage, loc:storage) model.load_state_dict(checkpoint) # 载入为single-GPU模型 gpu_model = model.module # 载入为CPU模型 model = ModelArch() model.load_state_dict(gpu_model.state_dict()) torch.save(cpu_model.state_dict(), cpu_model.pth)

载入CPU权重:

# 载入为CPU模型 model = ModelArch() checkpoint = torch.load(model_path, map_location=lambda storage, loc:storage) # 载入为single-GPU模型 model = ModelArch().cuda() checkpoint = torch.load(model_path, map_location=lambda storage, loc:storage.cuda(0)) model.load_state_dict(checkpoint) # 载入为muti-GPU模型 model = ModelArch().cuda() model = torch.nn.DataParallel(model, device_ids=[0, 1]) checkpoint = torch.load(model_path, map_location=lambda storage, loc:storage.cuda(0)) model.module.load_state_dict(checkpoint)

1. PyTorch中保存的模型文件.pth

模型保存的格式:pytorch中最常见的模型保存使用 .pt 或者是 .pth 作为模型文件扩展名            ,其他方式还有.t7/.pkl格式                   ,t7文件是沿用torch7中读取模型权重的方式       ,而在keras中则是使用.h5文件

.pth 文件基本信息

四个键值:model(OrderedDict),optimizer(Dict),scheduler(Dict),iteration(int)

1)net["model"]   相当于net.state_dict() 返回的字典

键model所对应的值是一个OrderedDict            ,OrderedDict字典存储着所有的每一层的参数名称以及对应的参数值

Eg. module.backbone.body.stem.conv1.weight                  ,参数名称很长       ,是因为搭建网络结构的时候采用了组件式的设计      ,即整个模型里面构造了一个backbone的容器组件                  ,backbone里面又构造了一个body容器组件             ,body里面又构造了一个stem容器      ,stem里面的第一个卷积层的权重

2)net["optimizer"]    相当于optimizer.state_dict() 返回的字典

返回的是一个一般的字典 Dict 对象                  ,这个字典只有两个key:state和param_groups

param_groups对应的值是一个列表; state对应的值是一个字典类型             ,和param_groups有着对应关系,每一个元素的键值就是param_groups中每一个元素的params;

3)net["scheduler"] 返回一个字典

4)net["iteration"]  返回一个具体的数字

2. torch.save()函数:保存模型文件

注意:.pt, .pth, .pkl并不是在格式上有区别                  ,只是后缀不同而已(仅此而已)

pytorch模型保存的两种方式:一种是保存整个模型                   ,另一种是只保存模型的参数

torch.save(model.state_dict(), "my_model.pth") # 只保存模型的参数 torch.save(model, "my_model.pth") # 保存整个模型

保存的模型参数:一个字典类型,通过key-value的形式来存储模型的所有参数

3. torch.load()函数:用来加载torch.save()保存的模型文件

torch.load()先在CPU上加载            ,不会依赖于保存模型的设备            。如果加载失败                   ,可能是因为没有包含某些设备       ,比如在gpu上训练保存的模型            ,而在cpu上加载                  ,可能会报错       ,此时      ,需要使用map_location来将存储动态重新映射到可选设备上

4. torch.nn.Module类model.state_dict()方法

state_dict 是一个简单的python的字典对象                  ,将每一层与它的对应参数建立映射关系

注意:

1)只有参数可以训练的layer才会被保存到模型的state_dict中             ,如卷积层            、线性层等      ,池化层这些本身没有参数的层没有在这个字典中;

2)作用:方便查看某一个层的权值和偏置数据;在模型保存的时候使用                   。

优化器对象Optimizer也有state_dict                  ,包含了Optimizer状态以及超参数(如lr, momentum,weight_decay等)

5. torch.nn.Module类model.parameters()方法:获得模型的参数信息

model.parameters()方法返回的是一个生成器generator             ,每一个元素是从开头到结尾的参数,parameters没有对应的key                  ,是一个由纯参数组成的generator                   ,而state_dict是一个字典,包含了key; parameters是通过named_parameters来实现的            ,也是Module一个与parameters类似的函数       。 # 查看model的参数量:先load model的weight                   ,然后再使用parameters() n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

总结:model.state_dict()                   、model.parameters()       、model.named_parameters()这三个方法都可以查看Module的参数信息       ,用于更新参数            ,或者用于模型的保存

6. checkpoint:保存模型的参数                  ,优化器参数       ,loss      ,epoch等

(相当于一个保存模型的文件夹)

checkpoint的机制:在模型训练的过程中                  ,不断地保存训练结果(包括但不限于EPOCH            、模型权重                  、优化器状态       、调度器状态)             ,即便模型训练中断      ,也可以基于checkpoint接续训练

在反向传播时重新计算深度神经网络的中间值(而通常情况是在前向传播时存储的)                  ,这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)

7. 内存开销

神经网络使用的总内存:

静态内存             ,尽管 PyTorch 模型中内置了一些固定开销,但总的来说几乎完全由模型权重决定 模型的计算图所占用的动态内存                  ,在训练模式下                   ,每次通过神经网络的前向传播都为网络中的每个神经元计算一个激活值,这个值随后被存储在所谓的计算图中            。必须为批次中的每个单个训练样本存储一个值            ,因此数量会迅速的累积起来                  。总成本取决于模型大小和批处理大小                   ,并设置适用于GPU内存的最大批处理大小的限制

PyTorch 通过torch.utils.checkpoint.checkpoint和torch.utils.checkpoint.checkpoint_sequential提供梯度检查点       ,在前向传播时            ,PyTorch 将保存模型中的每个函数的输入元组       。在反向传播过程中                  ,对于每个函数       ,输入元组和函数的组合以实时的方式重新计算      ,插入到每个需要它的函数的梯度公式中                  ,然后丢弃(显存中只保存输入数据和函数)      。网络计算开销大致相当于每个样本通过模型前向传播开销的两倍                  。

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
数据采集软件的使用实验报告(数据采集器软件-数据采集技术有哪些) 知乎界面设计风格(知乎创意总监、Dine 设计团队创始人 @disinfeqt :设计、音乐)