torch检查什么项目(【记录】torch.nn.CrossEntropyLoss报错及解决)
导读:报错 在多分类语义分割问题中使用torch.nn.CrossEntropyLoss的时候,遇到的报错有:...
报错
在多分类语义分割问题中使用torch.nn.CrossEntropyLoss的时候 ,遇到的报错有:
1. Assertion `t >= 0 && t < n_classes` failed. 2. RuntimeError: Expected floating point type for target with class probabilities, got Long通过官方文档了解到 ,torch.nn.CrossEntropyLoss分为两种情况:
直接使用class进行分类 ,此时的label为0 ,1 ,2…的整数 。对于这类情况 ,torch.nn.CrossEntropyLoss中添加了LogSoftmax以及 NLLLoss ,因此不用在网络的最后添加 softmax和argmax 将输出结果转换为整型 。 使用每一类的概率 。这种标签通常情况下效果比直接使用class进行分类要好一些 ,但在少样本 && 在每一类上使用标签过于严格 的时候 ,才推荐使用概率作为标签 。解决
假设传入torch.nn.CrossEntropyLoss的参数为torch.nn.CrossEntropyLoss(pred, label) ,其中pred为模型预测的输出 ,label为标签 。
这两个报错都是因为pred输入的维度错误导致的 根据官网文档 ,如果直接使用class进行分类,pred的维度应该是[batchsize, class, dim 1, dim 2, ... dim K] ,label的维度应该是[batchsize, dim 1, dim 2, ... dim K] 。注意在网络输出的channel中加入class number的维度 。不然softmax无法计算 ,及model的output channel = class number 。
另,如果想直接使用class进行分类 ,需要讲label的type转换成long格式:labels = labels.to(device, dtype=torch.long)创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!