首页IT科技torching翻译(关于torch.cat()与torch.stack())

torching翻译(关于torch.cat()与torch.stack())

时间2025-09-17 12:07:08分类IT科技浏览6632
导读:关于torch.cat( 与torch.stack( 整理...

关于torch.cat()与torch.stack()整理

代码中一直使用torch.cat()和torch.stack()进行tensor维度拼接                ,花点时间整理下                。方便使用🤷‍♂️:

1.用法

torch.cat(): 用于连接两个相同大小的张量

torch.stack(): 用于连接两个相同大小的张量                        ,并扩展维度

见代码示例更清晰:

import torch a = torch.tensor(torch.arange(10)).reshape(3, 3) b = torch.tensor(torch.arange(10, 100, 10)).reshape(3, 3) print(a) Out[7]: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) print(b) Out[10]: tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]])

对上面两个tensor进行操作

torch.cat() 拼接函数        ,将多个张量拼接成一个张量        ,保持维度不变                        。torch.cat()有两个参数                        ,第一个是要拼接的张量的列表或是元组;第二个参数是拼接的维度        。

使用不同的参数                ,输出的结果不同        ,首先填入一个会返回错误的参数:从返回报错原因可以看到                        ,参数的返回必须是在[-2, 1]之间                。

d3 = torch.cat((a, b), dim=2) # 返回输出如下 Traceback (most recent call last): File "/home/franklinpan/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3251, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "<ipython-input-23-b2602bd6230f>", line 1, in <module> d3 = torch.cat((a, b), dim=2) IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

设置dim=-1                ,得到如下结果,当参数为-1时                        ,与dim=1的返回结果相同

dim=-1                        ,表示在第二维度进行拼接 d_1= torch.cat((a, b), dim=-1) print(d_1) Out[25]: tensor([[ 1, 2, 3, 10, 20, 30], [ 4, 5, 6, 40, 50, 60], [ 7, 8, 9, 70, 80, 90]]) d1 = torch.cat((a, b), dim=1) print(d1) Out[22]: tensor([[ 1, 2, 3, 10, 20, 30], [ 4, 5, 6, 40, 50, 60], [ 7, 8, 9, 70, 80, 90]])

设置dim=-2,与dim=0相同:

表示在第一维度进行拼接 d_2= torch.cat((a, b), dim=-2) print(d_2) Out[27]: tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 20, 30], [40, 50, 60], [70, 80, 90]]) d1 = torch.cat((a, b), dim=0) print(d1) Out[20]: tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 20, 30], [40, 50, 60], [70, 80, 90]])

可以看到                ,采用不同的参数                        ,输出的张量维度仍然与原来张量的维度保持一致                        。 若输入参数的维度不一样        ,会产生什么结果呢?

当输出张量保持一个维度一致时                ,若在相同维度的方向进行连接torch.cat操作                        ,则仍然可以张量的合并操作        ,若在维度不同的方向进行连接操作        ,会报错        。(🤦‍♀️torch.cat操作没有广播机制

**torch.stack()**操作

拼接函数                        ,是拼接以后                ,再扩展一维        。torch.stack()有两个参数        ,第一个是要拼接的张量的列表或是元组;第二个参数是拼接的维度                        。

此处不再重复dim=-3 or -2等操作                        ,当dim=0时 c1 = torch.stack((a, b), dim=0) Out[12]: tensor([[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[10, 20, 30], [40, 50, 60], [70, 80, 90]]])

当dim=1时

c2 = torch.stack((a, b), dim=1) Out[15]: tensor([[[ 1, 2, 3], [10, 20, 30]], [[ 4, 5, 6], [40, 50, 60]], [[ 7, 8, 9], [70, 80, 90]]])

当 dim=2时

c3 = torch.stack((a, b), dim=2) Out[17]: tensor([[[ 1, 10], [ 2, 20], [ 3, 30]], [[ 4, 40], [ 5, 50], [ 6, 60]], [[ 7, 70], [ 8, 80], [ 9, 90]]])

若在torch.stack中使用不同维度的输入                ,得到报错的反馈

从实例可见,torch.stack操作将会增加合并后张量的维度                。

总结:

torch.cat()与torch.stack()操作都是对张量进行拼接操作                        ,不同点如下:

torch.stack()将对张量维度进行扩张

torch.cat()可以对只有一个方向维度相同的张量进行合并                        ,而torch.stack()要求输入张量的维度必须一样        。

stack与cat的区别在于,得到的张量的维度会比输入的张量的大小多1                ,并且多出的那个维度就是拼接的维度                        ,那个维度的大小就是输入张量的个数                        。见下面代码:

A=torch.tensor([[1,2,3],[4,5,6],[7,8,9]],dtype=torch.float) print("A:",A) B=torch.tensor([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]],dtype=torch.float) print("B:",B) print("*********************************") c=torch.cat((A,B),dim=0)#保持维度不变 print(c) print(c.shape) d=torch.stack((A,B),dim=0)#多扩展一维度 print(d) print(d.shape)

运行结果:

扩展:

torch.cat和torch.stack()的拼接为[]数据时:

见拼接列表数据

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
网络副业做什么好赚钱(哪些网络**-5个实操性很强的副业推荐,建议收藏) 电脑网速测试网站在哪(电脑在线测速的方法有哪些?)