首页IT科技loss at(loss.item()用法和注意事项详解)

loss at(loss.item()用法和注意事项详解)

时间2025-09-15 09:27:57分类IT科技浏览12031
导读:.item( 方法是,取一个元素张量里面的具体元素值并返回该值,可以将一个零维张量转换成...

.item()方法是                  ,取一个元素张量里面的具体元素值并返回该值                              ,可以将一个零维张量转换成int型或者float型          ,在计算loss             ,accuracy时常用到                    。

作用:

1.item()取出张量具体位置的元素元素值

2.并且返回的是该位置元素值的高精度值

3.保持原元素类型不变;必须指定位置

4.节省内存(不会计入计算图)

import torch loss = torch.randn(2, 2) print(loss) print(loss[1,1]) print(loss[1,1].item())

输出结果

tensor([[-2.0274, -1.5974],

        [-1.4775,  1.9320]])

tensor(1.9320)

1.9319512844085693

其它:

loss = criterion(out, label) loss_sum += loss # <--- 这里

运行着就发现显存炸了                             ,观察发现随着每个batch显存消耗在不断增大…因为输出的loss的数据类型是Variable                            。PyTorch的动态图机制就是通过Variable来构建图          。主要是使用Variable计算的时候               ,会记录下新产生的Variable的运算符号         ,在反向传播求导的时候进行使用               。如果这里直接将loss加起来                            ,系统会认为这里也是计算图的一部分                    ,也就是说网络会一直延伸变大     ,那么消耗的显存也就越来越大                            。

正确的loss一般是这样写 

loss_sum += loss.data[0]

其它注意事项:

使用loss += loss.detach()来获取不需要梯度回传的部分              。

使用loss.item()直接获得对应的python数据类型          。

补充阅读                            ,pytorch 计算图

Pytorch的计算图由节点和边组成                         ,节点表示张量或者Function,边表示张量和Function之间的依赖关系                             。

Pytorch中的计算图是动态图                  。这里的动态主要有两重含义     。

第一层含义是:计算图的正向传播是立即执行的                              。无需等待完整的计算图创建完毕                       ,每条语句都会在计算图中动态添加节点和边                              ,并立即执行正向传播得到计算结果                       。

第二层含义是:计算图在反向传播后立即销毁。下次调用需要重新构建计算图                         。如果在程序中使用了backward方法执行了反向传播     ,或者利用torch.autograd.grad方法计算了梯度                  ,那么创建的计算图会被立即销毁                              ,释放存储空间          ,下次调用需要重新创建                            。

1             ,计算图的正向传播是立即执行的     。

import torch w = torch.tensor([[3.0,1.0]],requires_grad=True) b = torch.tensor([[3.0]],requires_grad=True) X = torch.randn(10,2) Y = torch.randn(10,1) Y_hat = X@w.t() + b # Y_hat定义后其正向传播被立即执行                             ,与其后面的loss创建语句无关 loss = torch.mean(torch.pow(Y_hat-Y,2)) print(loss.data) print(Y_hat.data) tensor(17.8969) tensor([[3.2613], [4.7322], [4.5037], [7.5899], [7.0973], [1.3287], [6.1473], [1.3492], [1.3911], [1.2150]])

2               ,计算图在反向传播后立即销毁                    。

import torch w = torch.tensor([[3.0,1.0]],requires_grad=True) b = torch.tensor([[3.0]],requires_grad=True) X = torch.randn(10,2) Y = torch.randn(10,1) Y_hat = X@w.t() + b # Y_hat定义后其正向传播被立即执行         ,与其后面的loss创建语句无关 loss = torch.mean(torch.pow(Y_hat-Y,2)) #计算图在反向传播后立即销毁                            ,如果需要保留计算图, 需要设置retain_graph = True loss.backward() #loss.backward(retain_graph = True) #loss.backward() #如果再次执行反向传播将报错

参考链接:pytorch学习:loss为什么要加item()_dlvector的博客-CSDN博客_loss.item()

https://blog.csdn.net/cs111211/article/details/126221102

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
怎么做思维导图 漂亮又漂亮(怎么做出好看的思维导图_详细制作方法看这里) 挂机软件使用(挂机软件有哪些-如何用挂机软件刷趣头条APP赚钱)