pytorch加载模型继续训练(Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理)
1 数据集Dataset
Dataset类是Pytorch中图像数据集操作的核心类,Pytorch中所有数据集加载类都继承自Dataset父类。当我们自定义数据集处理时,必须实现Dataset类中的三个接口:
初始化 def __init__(self)构造函数,定义一些数据集的公有属性,如数据集下载地址、名称等
数据集大小 def __len__(self)返回数据集大小,不同的数据集有不同的衡量数据量的方式
数据集索引 def __getitem__(self, index):支持数据集索引功能,以实现形如dataset[i]得到数据集中的第i + 1个数据的功能。__getitem__是后期迭代数据时执行的具体函数,其返回值决定了循环变量,例如
class data(Dataset) ... def __getitem__(self, idx: int): if self.transforms: img = self.transforms(img) return img, label # 返回的值即为后续迭代的循环变量 for images, labels in dataLoader: ...2 数据加载DataLoader
为什么有了数据集Dataset还需要数据加载器DataLoader呢?原因在于神经网络需要进一步借助DataLoader对数据进行划分,也就是我们常说的batch,此外DataLoader还实现了打乱数据集、多线程等操作。
DataLoader本质是一个可迭代对象,可以使用形如
for inputs, labels in dataloaders进行可迭代对象的访问。
我们一般不需要去实现DataLoader的接口,只需要在构造函数中指定相应的参数即可,比如常见的batch_size,shuffle等参数。
下面这张图非常好地说明了Dataset和DataLoader的关系
接下来总结数据构造的三步法
继承Dataset对象,并实现__len__()、__getitem__()魔法方法,该步骤的主要目的在于将文件形式的数据集处理为模型可用的标准数据格式,并加载到内存中; 用DataLoader对象封装Dataset,使其成为可迭代对象; 遍历DataLoader对象以将数据加载到模型中进行训练。3 常用预处理方法
在数据集Dataset的__getitem__()中利用torchvision.transforms进行数据预处理与变换
常见的数据预处理变换方法总结如下表
序号 变换 含义 1 RandomCrop(size, ...) 对输入图像依据给定size随机裁剪 2 CenterCrop(size, ...) 对输入图像依据给定size从中心裁剪 3 RandomResizedCrop(size, ...) 对输入图像随机长宽比裁剪,再放缩到给定size 4 FiveCrop(size, ...) 对输入图像进行上下左右及中心裁剪,返回五张图像(size)组成的四维张量 5 TenCrop(size, vertical_flip=False) 对输入图像进行上下左右及中心裁剪,再全部翻转(水平或垂直),返回十张图像(size)组成的四维张量 6 RandomHorizontalFlip(p=0.5) 对输入图像按概率p随机进行水平翻转 7 RandomVerticalFlip(p=0.5) 对输入图像按概率p随机进行垂直翻转 8 RandomRotation(degree, ...) 对输入图像在degree内随机旋转某角度 9 Resize(size, ...) 对输入图像重置分辨率 10 Normalize(mean, std) 对输入图像各通道进行标准化 11 ToTensor() 将输入图像或ndarray 转换为tensor并归一化 12 Pad(padding, fill=0, padding_mode=‘constant’) 对输入图像进行填充 13 ColorJitter(brightness=0, contrast=0, saturation=0, hue=0) 对输入图像修改亮度、对比度、饱和度、色度等 14 Grayscale(num_output_channels=1) 对输入图像转灰度 15 LinearTransformation(matrix) 对输入图像进行线性变换 16 RandomAffine(...) 对输入图像进行仿射变换 17 RandomGrayscale(p=0.1) 对输入图像按概率p随机转灰度 18 ToPILImage(mode=None) 对输入图像转PIL格式图像 19 RandomOrder() 随机打乱transforms操作顺序4 模型处理
考虑以下场景:
网络的部分层级结构已经收敛、无需调整;大型复杂网络需要微调(Fine-tune)某些结构或参数;希望基于已训练好的模型进行改善或其他研究工作。
这些场景下重新通过数据集训练整个神经网络并无必要,甚至会使模型不稳定,因此引入预训练(pretrained)。Pytorch允许用户保存已训练好的模型,或加载其他模型,避免往复的无谓重训练,其中模型参数文件以.pth为后缀
# 保存已训练模型 torch.save(model.state_dict(), path) # 加载预训练模型 model.load_state_dict(torch.load(path), device)通过设置模型某些层可学习参数的requires_grad属性为False即可固定这部分参数不被后续学习过程影响。深度学习框架应用优势之一在于预设了对GPU的支持,大大提高模型处理与训练的效率。Pytorch中通过mode.to(device)方法将模型部署到指定设备上(CPU/GPU),范式如下:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(device)工程上也常使用torch.nn.DataParallel(model, devices)来处理多GPU并行运算,其原理是:首先将模型加载到主GPU上,再将模型从主GPU产生若干副本到其余GPU,随后将一个batch中的数据按维度划分为不同的子任务给各GPU进行前向传播,得到的损失会被累积到主GPU上并由主GPU反向传播更新参数,最后将更新参数拷贝到其余GPU以开始下一轮训练。
5 实例:MNIST数据集处理
下面给出了处理MNIST手写数据集的完整代码,可以用于加深对数据处理流程的理解
from abc import abstractmethod import numpy as np from torchvision.datasets import mnist from torch.utils.data import Dataset from PIL import Image class mnistData(Dataset): * @breif: MNIST数据集抽象接口 * @param[in]: dataPath -> 数据集存放路径 * @param[in]: transforms -> 数据集变换 def __init__(self, dataPath: str, transforms=None) -> None: super().__init__() self.dataPath = dataPath self.transforms = transforms self.data, self.label = [], [] def __len__(self) -> int: return len(self.label) def __getitem__(self, idx: int): img = self.data[idx] if self.transforms: img = self.transforms(img) return img, self.label[idx] @abstractmethod def plot(self, index: int) -> None: pass @abstractmethod def load(self) -> list: pass def plotData(self, index: int, info: str=None) -> None: * @breif: 可视化训练数据 * @param[in]: index -> 数据集索引 * @param[in]: info -> 备注信息 * @retval: None print(info, " --index:", index, "--label:", self.label[index]) if info else \ print(" --index:", index, "--label:", self.label[index]) img = Image.fromarray(np.uint8(self.data[index])) img.show() def loadData(self, train: bool) -> list: * @breif: 下载与加载数据集 * @param[in]: train -> 是否为训练集 * @retval: 数据与标签列表 # 如果指定目录下不存在数据集则下载 dataSet = mnist.MNIST(self.dataPath, train=train, download=True) # 初始化数据与标签 data = [ i[0] for i in dataSet ] label = [ i[1] for i in dataSet ] return data, label class mnistTrainData(mnistData): * @breif: MNIST训练集 * @param[in]: dataPath -> 数据集存放路径 * @param[in]: transforms -> 数据集变换 def __init__(self, dataPath: str, transforms=None) -> None: super().__init__(dataPath, transforms=transforms) self.data, self.label = self.load() def plot(self, index: int) -> None: self.plotData(index, "trainSet data") def load(self) -> list: return self.loadData(train=True) class mnistTestData(mnistData): * @breif: MNIST测试集 * @param[in]: dataPath -> 数据集存放路径 * @param[in]: transforms -> 数据集变换 def __init__(self, dataPath: str, transforms=None) -> None: super().__init__(dataPath, transforms=transforms) self.data, self.label = self.load() def plot(self, index: int) -> None: self.plotData(index, "testSet data") def load(self) -> list: return self.loadData(train=False)创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!