首页IT科技torch检查什么项目(【记录】torch.nn.CrossEntropyLoss报错及解决)

torch检查什么项目(【记录】torch.nn.CrossEntropyLoss报错及解决)

时间2025-09-18 19:07:36分类IT科技浏览7138
导读:报错 在多分类语义分割问题中使用torch.nn.CrossEntropyLoss的时候,遇到的报错有:...

报错

在多分类语义分割问题中使用torch.nn.CrossEntropyLoss的时候                    ,遇到的报错有:

1. Assertion `t >= 0 && t < n_classes` failed. 2. RuntimeError: Expected floating point type for target with class probabilities, got Long

通过官方文档了解到                          ,torch.nn.CrossEntropyLoss分为两种情况:

直接使用class进行分类          ,此时的label为0               ,1                          ,2…的整数                    。对于这类情况              ,torch.nn.CrossEntropyLoss中添加了LogSoftmax以及 NLLLoss          ,因此不用在网络的最后添加 softmax和argmax 将输出结果转换为整型                          。 使用每一类的概率          。这种标签通常情况下效果比直接使用class进行分类要好一些                           ,但在少样本 && 在每一类上使用标签过于严格 的时候                  ,才推荐使用概率作为标签               。

解决

假设传入torch.nn.CrossEntropyLoss的参数为torch.nn.CrossEntropyLoss(pred, label)     ,其中pred为模型预测的输出                            ,label为标签                          。

这两个报错都是因为pred输入的维度错误导致的 根据官网文档                      ,如果直接使用class进行分类,pred的维度应该是[batchsize, class, dim 1, dim 2, ... dim K]                        ,label的维度应该是[batchsize, dim 1, dim 2, ... dim K]              。注意在网络输出的channel中加入class number

的维度          。不然softmax无法计算                          ,及model的output channel = class number                           。

另     ,如果想直接使用class进行分类                    ,需要讲label的type转换成long格式:labels = labels.to(device, dtype=torch.long)

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
datepicker当前日期(解决el-date-picker日期选择控件少一天的问题) 苹果4充电接口型号(苹果14换接口吗详细介绍)