首页IT科技pytorch加载数据集多次(Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法)

pytorch加载数据集多次(Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法)

时间2025-06-21 00:10:53分类IT科技浏览5870
导读:需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层。(权重文件存储为dict形式)...

需求

Pytorch加载模型时                ,只导入部分层权重                      ,跳过部分指定网络层              。(权重文件存储为dict形式)

方法一

常见方法:加载权重时用if对网络层进行筛选

# model为定义的网络结构: class model(nn.Module): def __init__(self): super(model,self).__init__() …… def forward(self,x): …… return x model = model() # load存在的模型参数(权重文件)       ,后缀名可能不同    pretrained_dict = torch.load(model.pkl) model_dict = model.state_dict() # 关键在于下面这句            ,从model_dict中读取key              、value时                       ,用if筛选掉不需要的网络层 pretrained_dict = {key: value for key, value in pretrained_dict.items() if (key in model_dict and Prediction not in key)} model_dict.update(pretrained_dict) model.load_state_dict(model_dict)

方法二

不完全匹配          ,只加载权重中存在的参数        ,不匹配就跳过

# load_state_dict() 默认strict=True                        ,需要完全匹配              ,否则报错 # 修改为strict=False后    ,只匹配存在的参数 pretrained_dict = torch.load(weight_path) model.load_state_dict(pretrained_dict, strict=False)

方法三

 不使用原有权重文件训练                        ,对原有权重文件进行拷贝                  ,拷贝文件中只包含需要的网络层,后续直接利用拷贝权重文件进行训练                       。

# 对原有权重文件进行拷贝                    ,拷贝文件中只包含需要的网络层                      , # 后续直接利用拷贝文件进行训练        。 import pickle model = model() net = model path_weight = R-50.pkl path_weight2 = R2-50.pkl with open(path_weight,rb) as f: obj=f.read() # 用pickle.loads()加载权重信息 la_obj=pickle.loads(obj,encoding=latin1) # 用if进行筛选 weights= {key: value for key, value in la_obj.items()} #if key in la_obj and backbone.bottom_up.stem.conv1.weight not in key} # 使用print查看权重文件信息 print(weights) # 再深拷贝一份文件保存 state_dict = copy.deepcopy(weights) with open(path_weight2,wb) as f2: pickle.dump(state_dict, f2) # 可以写入txt    ,便于查看信息 path_weight2 = R2-101.txt inf = str(state_dict) ff = open(path_weight2,w) ff.write(inf)

下面是对载入参数的优化有特殊要求:参数固定                       、或者参数更新速度不同           。

方法四

如果载入的这些参数中                ,有些参数不要求被更新                      ,即固定不变       ,不参与训练            ,需要手动设置这些参数的梯度属性为Fasle                       ,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后... for name, value in model.named_parameters(): if name 满足某些条件: value.requires_grad = False # setup optimizer params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.Adam(params, lr=1e-4)

方法五

如果载入的这些参数中          ,所有参数都更新        ,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样                        ,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后... for name, value in model.named_parameters():     print(name) # 或 print(model.state_dict().keys())

假设该模型中有encoder              ,viewer和decoder两部分    ,参数名称分别是:

encoder.visual_emb.0.weight, encoder.visual_emb.0.bias, viewer.bd.Wsi, viewer.bd.bias, decoder.core.layer_0.weight_ih, decoder.core.layer_0.weight_hh,

假设要求encode        、viewer的学习率为1e-6                        , decoder的学习率为1e-4                  ,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) optimizer = torch.optim.Adam([{params:base_params,lr:1e-6},                               {params:model.decoder.parameters()}                               ],                               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的learning_rate=1e-6                      。

在传入optimizer时                    ,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同                      ,参数部分用了一个list    , list的每个元素有params和lr

两个键值            。如果没有 lr则应用Adam的lr属性        。Adam的属性除了lr                , 其他都是参数所共有的(比如momentum)                      。

 

遇见的问题

torch.load 加载权重文件时报错 Magic Number Error 

有时候使用 torch.load 加载比较古老的权重文件时可能报错 Magic Number Error                      ,这有可能是因为该文件使用 pickle 存储并且编码使用了 latin1       ,此时可以这样加载:

若要进行筛选            ,同理可以在后面加上if进行判断                。

import pickle with open(weights_path, rb) as f: obj = f.read() # 用pickle进行load                       ,编码方式为latin1 weights = {key: weight_dict for key, weight_dict in pickle.loads(obj,encoding=latin1).items()} # 同理          ,可以用if判断进行筛选 # weights = {key: value for key, value in pickle.loads(obj,encoding=latin1).items() if (key in model_dict and Prediction not in key)} model.load_state_dict(weights)

TypeError: a bytes-like object is required, not str

python3和python2在套接字返回值解码上有区别    。

套接字就是 socket        ,用于描述 IP 地址和端口                        ,应用程序通过套接字向网络发出请求或者应答网络请求              ,可以认为是计算机网络的数据接口                      。目前套接字分为两种:基于文件型和基于网络型                    。

解决方法

使用函数 encode() 和decode():

str 通过 encode() 函数编码变为 bytes bytes 通过 decode() 函数编码变为 str。(当我们从网络或磁盘上读取了字节流    ,则读到的数据就是 bytes)

补充:

str --> bytes

# 声明一个字符串s: >>> s = abc >>> type(s) <class str> # 四种转换方式: >>> b1 = s.encode() >>> type(b1) <class bytes> >>> b2 = str.encode(s) >>> type(b2) <class bytes> >>> b3 = s.encode(encoding=utf-8) >>> type(b3) <class bytes> >>> b4 = bytes(s,encoding=utf-8) >>> type(b4) <class bytes>

bytes --> str

# 声明一个bytes: >>> b = babc >>> type(b) <class bytes> # 三种转换方式: >>> s1 = bytes.decode(b) >>> type(s1) <class str> >>> s2 = b.decode() >>> type(s2) <class str> >>> s3 = str(b,encoding=utf-8) >>> type(s3) <class str>

参考博客

Pytorch中只导入部分层权重的方法_汐梦聆海的博客-CSDN博客_pytorch加载部分权重

pytorch微调模型—只加载预训练模型的某些层_农夫山泉2号的博客-CSDN博客

Pytorch加载模型不完全匹配 & 只加载部分参数权重 load_hxxjxw的博客-CSDN博客_pytorch加载模型不匹配跳过

pytorch载入预训练模型后                        ,只想训练个别层怎么办?_慕白-的博客-CSDN博客_pytorch只训练最后一层

PyTorch | 保存和加载模型 - 知乎 (zhihu.com)

torch.load加载权重时报错 Magic Number Error - 仰望高端玩家的小清新 - 博客园 (cnblogs.com)

Python报错:TypeError: a bytes-like object is required, not ‘str‘_程序媛三妹的博客-CSDN博客 

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
nksc文件(NkbMonitor.exe – NkbMonitor是什么进程文件 有什么作用)