pytorch加载数据集多次(Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法)
需求
Pytorch加载模型时 ,只导入部分层权重 ,跳过部分指定网络层 。(权重文件存储为dict形式)
方法一
常见方法:加载权重时用if对网络层进行筛选
# model为定义的网络结构: class model(nn.Module): def __init__(self): super(model,self).__init__() …… def forward(self,x): …… return x model = model() # load存在的模型参数(权重文件) ,后缀名可能不同 pretrained_dict = torch.load(model.pkl) model_dict = model.state_dict() # 关键在于下面这句 ,从model_dict中读取key 、value时 ,用if筛选掉不需要的网络层 pretrained_dict = {key: value for key, value in pretrained_dict.items() if (key in model_dict and Prediction not in key)} model_dict.update(pretrained_dict) model.load_state_dict(model_dict)方法二
不完全匹配 ,只加载权重中存在的参数 ,不匹配就跳过
# load_state_dict() 默认strict=True ,需要完全匹配 ,否则报错 # 修改为strict=False后 ,只匹配存在的参数 pretrained_dict = torch.load(weight_path) model.load_state_dict(pretrained_dict, strict=False)方法三
不使用原有权重文件训练 ,对原有权重文件进行拷贝 ,拷贝文件中只包含需要的网络层,后续直接利用拷贝权重文件进行训练 。
# 对原有权重文件进行拷贝 ,拷贝文件中只包含需要的网络层 , # 后续直接利用拷贝文件进行训练 。 import pickle model = model() net = model path_weight = R-50.pkl path_weight2 = R2-50.pkl with open(path_weight,rb) as f: obj=f.read() # 用pickle.loads()加载权重信息 la_obj=pickle.loads(obj,encoding=latin1) # 用if进行筛选 weights= {key: value for key, value in la_obj.items()} #if key in la_obj and backbone.bottom_up.stem.conv1.weight not in key} # 使用print查看权重文件信息 print(weights) # 再深拷贝一份文件保存 state_dict = copy.deepcopy(weights) with open(path_weight2,wb) as f2: pickle.dump(state_dict, f2) # 可以写入txt,便于查看信息 path_weight2 = R2-101.txt inf = str(state_dict) ff = open(path_weight2,w) ff.write(inf)下面是对载入参数的优化有特殊要求:参数固定 、或者参数更新速度不同 。
方法四
如果载入的这些参数中 ,有些参数不要求被更新 ,即固定不变 ,不参与训练 ,需要手动设置这些参数的梯度属性为Fasle ,并且在optimizer传参时筛选掉这些参数:
# 载入预训练模型参数后... for name, value in model.named_parameters(): if name 满足某些条件: value.requires_grad = False # setup optimizer params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.Adam(params, lr=1e-4)方法五
如果载入的这些参数中 ,所有参数都更新 ,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样 ,最好知道这些参数的名称都有什么:
# 载入预训练模型参数后... for name, value in model.named_parameters(): print(name) # 或 print(model.state_dict().keys())假设该模型中有encoder ,viewer和decoder两部分 ,参数名称分别是:
encoder.visual_emb.0.weight, encoder.visual_emb.0.bias, viewer.bd.Wsi, viewer.bd.bias, decoder.core.layer_0.weight_ih, decoder.core.layer_0.weight_hh,假设要求encode 、viewer的学习率为1e-6 , decoder的学习率为1e-4 ,那么在将参数传入优化器时:
ignored_params = list(map(id, model.decoder.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) optimizer = torch.optim.Adam([{params:base_params,lr:1e-6}, {params:model.decoder.parameters()} ], lr=1e-4, momentum=0.9)代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的learning_rate=1e-6 。
在传入optimizer时 ,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同 ,参数部分用了一个list, list的每个元素有params和lr两个键值 。如果没有 lr则应用Adam的lr属性 。Adam的属性除了lr , 其他都是参数所共有的(比如momentum) 。
遇见的问题
torch.load 加载权重文件时报错 Magic Number Error
有时候使用 torch.load 加载比较古老的权重文件时可能报错 Magic Number Error ,这有可能是因为该文件使用 pickle 存储并且编码使用了 latin1 ,此时可以这样加载:
若要进行筛选 ,同理可以在后面加上if进行判断 。
import pickle with open(weights_path, rb) as f: obj = f.read() # 用pickle进行load ,编码方式为latin1 weights = {key: weight_dict for key, weight_dict in pickle.loads(obj,encoding=latin1).items()} # 同理 ,可以用if判断进行筛选 # weights = {key: value for key, value in pickle.loads(obj,encoding=latin1).items() if (key in model_dict and Prediction not in key)} model.load_state_dict(weights)TypeError: a bytes-like object is required, not str
python3和python2在套接字返回值解码上有区别 。
套接字就是 socket ,用于描述 IP 地址和端口 ,应用程序通过套接字向网络发出请求或者应答网络请求 ,可以认为是计算机网络的数据接口 。目前套接字分为两种:基于文件型和基于网络型 。
解决方法:
使用函数 encode() 和decode():
str 通过 encode() 函数编码变为 bytes bytes 通过 decode() 函数编码变为 str。(当我们从网络或磁盘上读取了字节流 ,则读到的数据就是 bytes)补充:
str --> bytes
# 声明一个字符串s: >>> s = abc >>> type(s) <class str> # 四种转换方式: >>> b1 = s.encode() >>> type(b1) <class bytes> >>> b2 = str.encode(s) >>> type(b2) <class bytes> >>> b3 = s.encode(encoding=utf-8) >>> type(b3) <class bytes> >>> b4 = bytes(s,encoding=utf-8) >>> type(b4) <class bytes>bytes --> str
# 声明一个bytes: >>> b = babc >>> type(b) <class bytes> # 三种转换方式: >>> s1 = bytes.decode(b) >>> type(s1) <class str> >>> s2 = b.decode() >>> type(s2) <class str> >>> s3 = str(b,encoding=utf-8) >>> type(s3) <class str>参考博客
Pytorch中只导入部分层权重的方法_汐梦聆海的博客-CSDN博客_pytorch加载部分权重
pytorch微调模型—只加载预训练模型的某些层_农夫山泉2号的博客-CSDN博客
Pytorch加载模型不完全匹配 & 只加载部分参数权重 load_hxxjxw的博客-CSDN博客_pytorch加载模型不匹配跳过
pytorch载入预训练模型后 ,只想训练个别层怎么办?_慕白-的博客-CSDN博客_pytorch只训练最后一层
PyTorch | 保存和加载模型 - 知乎 (zhihu.com)
torch.load加载权重时报错 Magic Number Error - 仰望高端玩家的小清新 - 博客园 (cnblogs.com)
Python报错:TypeError: a bytes-like object is required, not ‘str‘_程序媛三妹的博客-CSDN博客
创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!