首页IT科技pytorch(Pytorch DataLoader中的num_workers (选择最合适的num_workers值))

pytorch(Pytorch DataLoader中的num_workers (选择最合适的num_workers值))

时间2025-09-18 10:39:08分类IT科技浏览11830
导读:一、概念 num_workers是Dataloader的概念,默认值是0。是告诉DataLoader实例要使用多少个子进程进行数据加载(和CPU有关,和GPU无关 ...

一                     、概念

num_workers是Dataloader的概念                    ,默认值是0                    。是告诉DataLoader实例要使用多少个子进程进行数据加载(和CPU有关                               ,和GPU无关)

如果num_worker设为0            ,意味着每一轮迭代时               ,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了)                              ,而是在RAM中找batch                 ,找不到时再加载相应的batch                               。缺点当然是速度慢            。

当num_worker不为0时          ,每轮到dataloader加载数据时                              ,dataloader一次性创建num_worker个worker                      ,并用batch_sampler将指定batch分配给指定worker     ,worker将它负责的batch加载进RAM               。

num_worker设置得大                              ,好处是寻batch速度快                           ,因为下一轮迭代的batch很可能在上一轮/上上一轮…迭代时已经加载好了                              。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)                 。num_workers的经验设置值是自己电脑/服务器的CPU核心数                         ,如果CPU很强                                、RAM也很充足                                ,就可以设置得更大些          。

num_worker小了的情况      ,主进程采集完最后一个worker的batch                              。此时需要回去采集第一个worker产生的第二个batch                      。如果该worker此时没有采集完                    ,主线程会卡在这里等     。(这种情况出现在                               ,num_works数量少或者batchsize

比较小            ,显卡很快就计算完了               ,CPU对GPU供不应求                              。)

即                              ,num_workers的值和模型训练快慢有关                 ,和训练出的模型的performance无关

Detectron2的num_workers默认是4

二          、选择最合适的num_workers值

最合适的num_works值与数据集有关

最好是跑代码之前先用这段script跑一下          ,选择最合适的num_workers值 from time import time import multiprocessing as mp import torch import torchvision from torchvision import transforms transform = transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,)) ]) trainset = torchvision.datasets.MNIST( root=dataset/, train=True, #如果为True                              ,从 training.pt 创建数据                      ,否则从 test.pt 创建数据                           。 download=True, #如果为true     ,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集                              ,则不会再次下载                         。 transform=transform ) print(f"num of CPU: {mp.cpu_count()}") for num_workers in range(2, mp.cpu_count(), 2): train_loader = torch.utils.data.DataLoader(trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True) start = time() for epoch in range(1, 3): for i, data in enumerate(train_loader, 0): pass end = time() print("Finish with:{} second, num_workers={}".format(end - start, num_workers))

可以看到                           ,这个服务器24个CPU, 最合适的num_workers值是14

三               、可能出现的问题

linux系统中可以使用多个子进程加载数据,windows系统里是不可以的                         ,可以发现报错时产生在DataLoader文件中的                                。我们找到自己调用DataLoader的文件中num_workers的设置                                ,设置为0或者采用默认为0的设置      。

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
yolov4算法的优势(YOLO系列 — YOLOV7算法(二):YOLO V7算法detect.py代码解析) 网络安全算什么行业(为什么说网络安全是风口行业?是IT行业最后的红利?)