【Pytorch】torch.view与torch.reshape的区别
从功能上来看,它们的作用是相同的,都是用来重塑 Tensor 的 shape的,view 只适合对满足连续性条件 (contiguous) 的 Tensor进行操作,而reshape 同时还可以对不满足连续性条件的 Tensor 进行操作,具有更好的鲁棒性。简而言之,view 能干的 reshape都能干,如果 view 不能干就可以用 reshape 来处理。torch的view()与resha
文章目录
一. 简介:
从功能上来看,它们的作用是相同的,都是用来重塑 Tensor 的 shape的,view 只适合对满足连续性条件 (contiguous) 的 Tensor进行操作,而reshape 同时还可以对不满足连续性条件的 Tensor 进行操作,具有更好的鲁棒性。
简而言之,view 能干的 reshape都能干,如果 view 不能干就可以用 reshape 来处理。
二. Pytorch中Tensor的存储方式
想要深入理解view与reshape的区别,首先要理解一些有关PyTorch张量存储的底层原理,比如tensor的
头信息区(Tensor)
和存储区 (Storage)
以及tensor的步长Stride
。
2.1 Pytorch中张量存储的底层原理
tensor数据采用头信息区(Tensor)
和存储区 (Storage)
分开存储的形式,如图下图所示。变量名及存储的数据是分为两个数据分别存储的。
例如,我们定义并初始化一个tensor ,名称为A,形状为size、步长为stride、数据的索引等信息都存储在头信息区,而A存储的真实数据则存储在存储区。另外,如果我们对A进行截取、转置或者修改等操作后赋值给B,则张量B的数据共享A的存储区,存储区的数据数量没变,变化的是B的头信息区对数据的索引方式
。这种方式其实就是浅拷贝
举个例子:
import torch
A = torch.arange(10) # 初始化张量A
B = A[2:] # 截取张量A的部分赋值给张量B
print("A:", A)
print("B:", B)
print("ptr og storage of A:", A.storage().data_ptr()) # 打印A的存储地址区
print("ptr og storage of B:", B.storage().data_ptr()) # 打印B的存储地址区
print("===============================================")
print("A:", A)
print("B:", B)
print("ptr og storage of A:", A.storage().data_ptr()) # 打印A的存储地址区
print("ptr og storage of B:", B.storage().data_ptr()) # 打印B的存储地址区
代码输出:
A: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
B: tensor([2, 3, 4, 5, 6, 7, 8, 9])
ptr og storage of A: 2808642026176
ptr og storage of B: 2808642026176
===============================================
A: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
B: tensor([2, 3, 4, 5, 6, 7, 8, 9])
ptr og storage of A: 2808642026176
ptr og storage of B: 2808642026176
2.2 Pytorch张量步长(stride)属性
pytorch中的Tensor也是有步长属性的,Tensor的步长可以理解为从索引中一个维度跨到下个维度中间的跨度,如下图所示:
代码示例:
import torch
A = torch.arange(6).reshape(2,3) # 初始化张量A
B = torch.arange(6).reshape(3,2) # 截取张量A的部分赋值给张量B
print("A:", A)
print("stride of A:", A.stride()) # 打印A的stride
print("B:", B)
print("stride of B:", B.stride()) # 打印B的stride
代码输出:
A: tensor([[0, 1, 2],
[3, 4, 5]])
stride of A: (3, 1)
B: tensor([[0, 1],
[2, 3],
[4, 5]])
stride of B: (2, 1)
我们可以看到对于Tensor A的stride为(3,1),其中3表示从第零个维度中的第一个元素[0, 1, 2]到下一个元素[3, 4, 5]所需要的步长。1指的是第一个维度[0,1,2]中一个元素0到下一个元素1所需要的步长为1。
三. 对视图(view)的理解
视图是数据的一个别称或者引用,通过该别称或者引用亦便可访问、操作原有数据,但是原有数据不会产生拷贝,如果我们对视图进行修改,它会影响到原始数据,物理内存在同一位置,这样避免了重新创建张量的高内存开销。由上面介绍的PyTorch的张量存储方式可以理解为:对张量的大部分操作就是视图操作。
与之对应的概念就是副本。副本是一个数据的完整的拷贝
,如果我们对副本进行修改,它不会影响到原始数据,物理内存不在同一位置。
代码示例:
import torch
import copy
a = torch.arange(5) # 初始化张量 a 为 [0, 1, 2, 3, 4]
b = copy.deepcopy(a) # b 是 a 的副本
print('a:', a)
print('b:', b)
print('ptr of storage of a:', a.storage().data_ptr()) # 打印a的存储区地址
print('ptr of storage of b:', b.storage().data_ptr()) # 打印b的存储区地址,可以发现两者不是共用存储区
b[0]+=1
print('a:', a)
print('b:', b)
代码输出:
a: tensor([0, 1, 2, 3, 4])
b: tensor([0, 1, 2, 3, 4])
ptr of storage of a: 1780203768064
ptr of storage of b: 1780203764352
a: tensor([0, 1, 2, 3, 4])
b: tensor([1, 1, 2, 3, 4])
上述代码中,b是a的副本,物理内存不在同一个位置,即使修改b也不会影响到a
四. view()与reshape()的比较
4.1 对view()的理解
pytorch官方定义:
https://pytorch.ac.cn/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
4.1.1 (1)如何理解满足条件 stride[i] = stride[i+1] * size[i+1]
这个公式是描述多维张量在内存中如何排布的关键部分:
stride[i]:
表示在第 i 个维度的步幅,即要移动多少个内存位置,才能跳到当前维度的下一个元素。stride[i+1]:
表示在第 i+1 个维度的步幅,即要移动多少个内存位置,才能跳到下一个维度的下一个元素。size[i+1]:
表示第 i+1 维的大小,即该维度包含的元素数量。
例如:
假设我们有一个形状为 (3, 4, 5) 的张量,这个张量有 3 个维度,分别是 d0=3, d1=4, d2=5。
- 对于最后一维(d2=5),从一个元素到下一个元素的步幅就是 1(即每个元素之间的内存间隔是 1 字节)。
- 对于倒数第二维(d1=4),要从当前维度的一个元素跳到下一个元素,需要跨越 5 个元素(因为 size[d2] = 5),每个元素占用 1 字节。所以步幅是 5。
- 对于第一维(d0=3),要从当前维度的一个元素跳到下一个元素,需要跨越 4 * 5 = 20 个元素(因为 size[d1] = 4 和 size[d2] = 5)。因此,步幅是 20。
- stride[2] = 1
- stride[1] = stride[2] * size[2] = 1 * 5 = 5
- stride[0] = stride[1] * size[1] = 5 * 4 = 20
- 综上所述: 张量X的的步幅为(20, 5, 1)
代码示例1:
import torch
import copy
a = torch.arange(9).reshape(3,3)
print("a:", a)
print("shape of a:", a.size())
print("stride of a:", a.stride())
代码输出:
a: tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
shape of a: torch.Size([3, 3])
stride of a: (3, 1)
将结果带入stride[i] = stride[i+1] * size[i+1]
中,满足连续性条件
代码示例2:
接下来,我们看将a转置后的结果:
import torch
import copy
a = torch.arange(9).reshape(3,3)
b = a.permute(1, 0) # 对a进行转置
print("b:", b)
print("shape of b:", b.size())
print("stride if b:", b.stride())
代码输出:
b: tensor([[0, 3, 6],
[1, 4, 7],
[2, 5, 8]])
shape of b: torch.Size([3, 3])
stride if b: (1, 3)
发现将a进行转置后,结果不成立,因此就不满足连续性条件,无法使用view()操作
代码示例3:
我们查看一下转置前后的代码存储区有什么区别
import torch
a = torch.arange(9).reshape(3, 3) # 初始化张量a
print('ptr of storage of a: ', a.storage().data_ptr()) # 查看a的storage区的地址
print('storage of a: \n', a.storage()) # 查看a的storage区的数据存放形式
b = a.permute(1, 0) # 转置
print('ptr of storage of b: ', b.storage().data_ptr()) # 查看b的storage区的地址
print('storage of b: \n', b.storage()) # 查看b的storage区的数据存放形式
代码输出:
ptr of storage of a: 1421123582720
storage of a:
0
1
2
3
4
5
6
7
8
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 9]
ptr of storage of b: 1421123582720
storage of b:
0
1
2
3
4
5
6
7
8
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 9]
因此可以看出,张量a,b仍然共用存储区,并且存储区数据存放的顺序没有发生变化,张量b只是改变了数据索引方式,为什么b不符合连续性条件呢?转置后的tensor只是对storage区数据索引方式的重映射,但原始的存放方式并没有变化.因此,这时再看tensor b的stride,从b第一行的元素1到第二行的元素2,显然在索引方式上已经不是原来+1了,而是变成了新的+3了。所以这时候就不能用view来对b进行shape的改变了,不然就报错。
这种情况下,直接用view不行,先用contiguous()方法将原始tensor转换为满足连续条件的tensor,在使用view进行shape变换,
原理是contiguous()方法开辟了一个新的存储区给b,并改变了b原始存储区数据的存放顺序
4.1.2 (2)contiguous() 的作用
- contiguous() 是一个重要的函数,它会返回一个连续内存的张量副本,确保数据按内存顺序排列。如果原张量的数据布局不是连续的,那么通过 view() 创建视图时,可能会导致数据访问不正确(因为内存中的元素顺序不符合预期)。此时,需要先调用 contiguous() 来确保数据是连续的,然后再进行 view() 操作。
- 例如,如果一个张量经过转置(transpose())操作后,它的内存布局可能会变得不连续。如果你想要对它进行 view() 操作,首先需要调用 contiguous() 以确保其内存是连续的。
代码示例:
import torch
a = torch.arange(9).reshape(3,3)
print("storage of a:\n", a.storage())
print("====================================================================================")
b = a.permute(1, 0).contiguous()
print("size of b:", b.size())
print("stride of b:", b.stride())
print("viewd b:\n",b.view(9))
print("====================================================================================")
print("storage of a:\n", a.storage())
print("storage of b:\n", b.storage())
print("====================================================================================")
print("ptr of a\n", a.storage().data_ptr())
print("ptr of b\n", b.storage().data_ptr())
代码输出:
storage of a:
0
1
2
3
4
5
6
7
8
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 9]
====================================================================================
size of b: torch.Size([3, 3])
stride of b: (3, 1)
viewd b:
tensor([0, 3, 6, 1, 4, 7, 2, 5, 8])
====================================================================================
storage of a:
0
1
2
3
4
5
6
7
8
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 9]
storage of b:
0
3
6
1
4
7
2
5
8
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 9]
====================================================================================
ptr of a
2745201369152
ptr of b
2745201365952
由上述结果可以看出,张量a与b已经是两个存在于不同存储区的张量了。也印证了contiguous()方法开辟了一个新的存储区给b,并改变了b原始存储区数据的存放顺序。对应文章开头提到的浅拷贝,这种开辟一个新的内存区的方式其实就是深拷贝
4.2 对reshape()的理解
pytorch官方定义:https://pytorch.ac.cn/docs/stable/generated/torch.reshape.html#torch.reshape
作用: 与view类似,将输入tensor转换为新的shape格式。
但是reshape方法更强大,可以认为a.reshape = a.view() + a.contiguous().view()。
即:在满足tensor连续性条件时,a.reshape返回的结果与a.view()相同,否则返回的结果与a.contiguous().view()相同。
五. 总结
torch的view()与reshape()方法都可以用来重塑tensor的shape,区别就是使用的条件不一样。view()方法只适用于满足连续性条件的tensor,并且该操作不会开辟新的内存空间,只是产生了对原存储空间的一个新别称和引用,返回值是视图。而reshape()方法的返回值既可以是视图,也可以是副本,当满足连续性条件时返回view,否则返回副本[ 此时等价于先调用contiguous()方法在使用view() ]。因此当不确能否使用view时,可以使用reshape。如果只是想简单地重塑一个tensor的shape,那么就是用reshape,但是如果需要考虑内存的开销而且要确保重塑后的tensor与之前的tensor共享存储空间,那就使用view()。
为什么没把view废除?
- 1、在PyTorch不同版本的更新过程中,view先于reshape方法出现,后来出现了鲁棒性更好的reshape方法,但view方法并没因此废除。其实不止PyTorch,其他一些框架或语言比如OpenCV也有类似的操作。
- 2、view的存在可以显示地表示对这个tensor的操作只能是视图操作而非拷贝操作。这对于代码的可读性以及后续可能的bug的查找比较友好。
本文参考:
更多推荐
所有评论(0)