首页IT科技torch检查什么项目(【记录】torch.nn.CrossEntropyLoss报错及解决)

torch检查什么项目(【记录】torch.nn.CrossEntropyLoss报错及解决)

时间2025-04-30 20:22:31分类IT科技浏览4833
导读:报错 在多分类语义分割问题中使用torch.nn.CrossEntropyLoss的时候,遇到的报错有:...

报错

在多分类语义分割问题中使用torch.nn.CrossEntropyLoss的时候               ,遇到的报错有:

1. Assertion `t >= 0 && t < n_classes` failed. 2. RuntimeError: Expected floating point type for target with class probabilities, got Long

通过官方文档了解到                   ,torch.nn.CrossEntropyLoss分为两种情况:

直接使用class进行分类      ,此时的label为0           ,1                    ,2…的整数               。对于这类情况         ,torch.nn.CrossEntropyLoss中添加了LogSoftmax以及 NLLLoss       ,因此不用在网络的最后添加 softmax和argmax 将输出结果转换为整型                   。 使用每一类的概率      。这种标签通常情况下效果比直接使用class进行分类要好一些                     ,但在少样本 && 在每一类上使用标签过于严格 的时候            ,才推荐使用概率作为标签           。

解决

假设传入torch.nn.CrossEntropyLoss的参数为torch.nn.CrossEntropyLoss(pred, label)   ,其中pred为模型预测的输出                     ,label为标签                    。

这两个报错都是因为pred输入的维度错误导致的 根据官网文档               ,如果直接使用class进行分类,pred的维度应该是[batchsize, class, dim 1, dim 2, ... dim K]                  ,label的维度应该是[batchsize, dim 1, dim 2, ... dim K]         。注意在网络输出的channel中加入class number

的维度       。不然softmax无法计算                  ,及model的output channel = class number                     。

另   ,如果想直接使用class进行分类               ,需要讲label的type转换成long格式:labels = labels.to(device, dtype=torch.long)

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
方正笔记本如何进入bios(方正Founder笔记本电脑开机进入BIOS的方法(delete)) 3d怎样用鼠标旋转模型(让交互更加生动!有意思的鼠标跟随 3D 旋转动效)