一. 简介:

  从功能上来看,它们的作用是相同的,都是用来重塑 Tensor 的 shape的,view 只适合对满足连续性条件 (contiguous) 的 Tensor进行操作,而reshape 同时还可以对不满足连续性条件的 Tensor 进行操作,具有更好的鲁棒性。简而言之,view 能干的 reshape都能干,如果 view 不能干就可以用 reshape 来处理。

二. Pytorch中Tensor的存储方式

  想要深入理解view与reshape的区别,首先要理解一些有关PyTorch张量存储的底层原理,比如tensor的头信息区(Tensor)存储区 (Storage)以及tensor的步长Stride

2.1 Pytorch中张量存储的底层原理

  tensor数据采用头信息区(Tensor)存储区 (Storage)分开存储的形式,如图下图所示。变量名及存储的数据是分为两个数据分别存储的。
  例如,我们定义并初始化一个tensor ,名称为A,形状为size、步长为stride、数据的索引等信息都存储在头信息区,而A存储的真实数据则存储在存储区。另外,如果我们对A进行截取、转置或者修改等操作后赋值给B,则张量B的数据共享A的存储区,存储区的数据数量没变,变化的是B的头信息区对数据的索引方式。这种方式其实就是浅拷贝

在这里插入图片描述

举个例子:

import torch

A = torch.arange(10)    # 初始化张量A
B = A[2:]               # 截取张量A的部分赋值给张量B

print("A:", A)
print("B:", B)
print("ptr og storage of A:", A.storage().data_ptr())   # 打印A的存储地址区
print("ptr og storage of B:", B.storage().data_ptr())   # 打印B的存储地址区

print("===============================================")

print("A:", A)
print("B:", B)
print("ptr og storage of A:", A.storage().data_ptr())   # 打印A的存储地址区
print("ptr og storage of B:", B.storage().data_ptr())   # 打印B的存储地址区

代码输出:

A: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
B: tensor([2, 3, 4, 5, 6, 7, 8, 9])
ptr og storage of A: 2808642026176
ptr og storage of B: 2808642026176
===============================================
A: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
B: tensor([2, 3, 4, 5, 6, 7, 8, 9])
ptr og storage of A: 2808642026176
ptr og storage of B: 2808642026176

2.2 Pytorch张量步长(stride)属性

   pytorch中的Tensor也是有步长属性的,Tensor的步长可以理解为从索引中一个维度跨到下个维度中间的跨度,如下图所示:

在这里插入图片描述

代码示例:

import torch

A = torch.arange(6).reshape(2,3)    # 初始化张量A
B = torch.arange(6).reshape(3,2)    # 截取张量A的部分赋值给张量B

print("A:", A)
print("stride of A:", A.stride())   # 打印A的stride
print("B:", B)
print("stride of B:", B.stride())   # 打印B的stride

代码输出:

A: tensor([[0, 1, 2],
        [3, 4, 5]])
stride of A: (3, 1)
B: tensor([[0, 1],
        [2, 3],
        [4, 5]])
stride of B: (2, 1)

  我们可以看到对于Tensor A的stride为(3,1),其中3表示从第零个维度中的第一个元素[0, 1, 2]到下一个元素[3, 4, 5]所需要的步长。1指的是第一个维度[0,1,2]中一个元素0到下一个元素1所需要的步长为1。

三. 对视图(view)的理解

  视图是数据的一个别称或者引用,通过该别称或者引用亦便可访问、操作原有数据,但是原有数据不会产生拷贝,如果我们对视图进行修改,它会影响到原始数据,物理内存在同一位置,这样避免了重新创建张量的高内存开销。由上面介绍的PyTorch的张量存储方式可以理解为:对张量的大部分操作就是视图操作。
  与之对应的概念就是副本。副本是一个数据的完整的拷贝,如果我们对副本进行修改,它不会影响到原始数据,物理内存不在同一位置。

代码示例:

import torch
import copy

a = torch.arange(5)  		# 初始化张量 a 为 [0, 1, 2, 3, 4]
b = copy.deepcopy(a)       	# b 是 a 的副本
print('a:', a)
print('b:', b)
print('ptr of storage of a:', a.storage().data_ptr())  # 打印a的存储区地址
print('ptr of storage of b:', b.storage().data_ptr())  # 打印b的存储区地址,可以发现两者不是共用存储区
b[0]+=1
print('a:', a)
print('b:', b)

代码输出:

a: tensor([0, 1, 2, 3, 4])
b: tensor([0, 1, 2, 3, 4])
ptr of storage of a: 1780203768064
ptr of storage of b: 1780203764352
a: tensor([0, 1, 2, 3, 4])
b: tensor([1, 1, 2, 3, 4])

上述代码中,b是a的副本,物理内存不在同一个位置,即使修改b也不会影响到a

四. view()与reshape()的比较

4.1 对view()的理解

pytorch官方定义:
https://pytorch.ac.cn/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view

在这里插入图片描述

4.1.1 (1)如何理解满足条件 stride[i] = stride[i+1] * size[i+1]

这个公式是描述多维张量在内存中如何排布的关键部分:

  • stride[i]: 表示在第 i 个维度的步幅,即要移动多少个内存位置,才能跳到当前维度的下一个元素。
  • stride[i+1]: 表示在第 i+1 个维度的步幅,即要移动多少个内存位置,才能跳到下一个维度的下一个元素。
  • size[i+1]:表示第 i+1 维的大小,即该维度包含的元素数量。

例如:

假设我们有一个形状为 (3, 4, 5) 的张量,这个张量有 3 个维度,分别是 d0=3, d1=4, d2=5。

  • 对于最后一维(d2=5),从一个元素到下一个元素的步幅就是 1(即每个元素之间的内存间隔是 1 字节)。
  • 对于倒数第二维(d1=4),要从当前维度的一个元素跳到下一个元素,需要跨越 5 个元素(因为 size[d2] = 5),每个元素占用 1 字节。所以步幅是 5。
  • 对于第一维(d0=3),要从当前维度的一个元素跳到下一个元素,需要跨越 4 * 5 = 20 个元素(因为 size[d1] = 4 和 size[d2] = 5)。因此,步幅是 20。
  • stride[2] = 1
  • stride[1] = stride[2] * size[2] = 1 * 5 = 5
  • stride[0] = stride[1] * size[1] = 5 * 4 = 20
  • 综上所述: 张量X的的步幅为(20, 5, 1)

代码示例1:

import torch
import copy

a = torch.arange(9).reshape(3,3)

print("a:", a)
print("shape of a:", a.size())
print("stride of a:", a.stride())

代码输出:

a: tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
shape of a: torch.Size([3, 3])
stride of a: (3, 1)

将结果带入stride[i] = stride[i+1] * size[i+1]中,满足连续性条件

代码示例2:

接下来,我们看将a转置后的结果:

import torch
import copy

a = torch.arange(9).reshape(3,3)
b = a.permute(1, 0) # 对a进行转置

print("b:", b)
print("shape of b:", b.size())
print("stride if b:", b.stride())

代码输出:

b: tensor([[0, 3, 6],
        [1, 4, 7],
        [2, 5, 8]])
shape of b: torch.Size([3, 3])
stride if b: (1, 3)

发现将a进行转置后,结果不成立,因此就不满足连续性条件,无法使用view()操作

代码示例3:

我们查看一下转置前后的代码存储区有什么区别

import torch
a = torch.arange(9).reshape(3, 3)             # 初始化张量a
print('ptr of storage of a: ', a.storage().data_ptr())  # 查看a的storage区的地址
print('storage of a: \n', a.storage())        # 查看a的storage区的数据存放形式
b = a.permute(1, 0)                           # 转置
print('ptr of storage of b: ', b.storage().data_ptr())  # 查看b的storage区的地址
print('storage of b: \n', b.storage())        # 查看b的storage区的数据存放形式

代码输出:

ptr of storage of a:  1421123582720
storage of a:
  0
 1
 2
 3
 4
 5
 6
 7
 8
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 9]
ptr of storage of b:  1421123582720
storage of b:
  0
 1
 2
 3
 4
 5
 6
 7
 8
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 9]

    因此可以看出,张量a,b仍然共用存储区,并且存储区数据存放的顺序没有发生变化,张量b只是改变了数据索引方式,为什么b不符合连续性条件呢?转置后的tensor只是对storage区数据索引方式的重映射,但原始的存放方式并没有变化.因此,这时再看tensor b的stride,从b第一行的元素1到第二行的元素2,显然在索引方式上已经不是原来+1了,而是变成了新的+3了。所以这时候就不能用view来对b进行shape的改变了,不然就报错。

在这里插入图片描述

    这种情况下,直接用view不行,先用contiguous()方法将原始tensor转换为满足连续条件的tensor,在使用view进行shape变换,原理是contiguous()方法开辟了一个新的存储区给b,并改变了b原始存储区数据的存放顺序

4.1.2 (2)contiguous() 的作用

  • contiguous() 是一个重要的函数,它会返回一个连续内存的张量副本,确保数据按内存顺序排列。如果原张量的数据布局不是连续的,那么通过 view() 创建视图时,可能会导致数据访问不正确(因为内存中的元素顺序不符合预期)。此时,需要先调用 contiguous() 来确保数据是连续的,然后再进行 view() 操作。
  • 例如,如果一个张量经过转置(transpose())操作后,它的内存布局可能会变得不连续。如果你想要对它进行 view() 操作,首先需要调用 contiguous() 以确保其内存是连续的。

代码示例:

import torch
a = torch.arange(9).reshape(3,3)
print("storage of a:\n", a.storage())
print("====================================================================================")

b = a.permute(1, 0).contiguous()
print("size of b:", b.size())
print("stride of b:", b.stride())
print("viewd b:\n",b.view(9))

print("====================================================================================")
print("storage of a:\n", a.storage())
print("storage of b:\n", b.storage())

print("====================================================================================")
print("ptr of a\n", a.storage().data_ptr())
print("ptr of b\n", b.storage().data_ptr())

代码输出:

storage of a:
  0
 1
 2
 3
 4
 5
 6
 7
 8
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 9]
====================================================================================
size of b: torch.Size([3, 3])
stride of b: (3, 1)
viewd b:
 tensor([0, 3, 6, 1, 4, 7, 2, 5, 8])
====================================================================================
storage of a:
  0
 1
 2
 3
 4
 5
 6
 7
 8
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 9]
storage of b:
  0
 3
 6
 1
 4
 7
 2
 5
 8
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 9]
====================================================================================
ptr of a
 2745201369152
ptr of b
 2745201365952

  由上述结果可以看出,张量a与b已经是两个存在于不同存储区的张量了。也印证了contiguous()方法开辟了一个新的存储区给b,并改变了b原始存储区数据的存放顺序。对应文章开头提到的浅拷贝,这种开辟一个新的内存区的方式其实就是深拷贝

4.2 对reshape()的理解

pytorch官方定义:https://pytorch.ac.cn/docs/stable/generated/torch.reshape.html#torch.reshape

作用: 与view类似,将输入tensor转换为新的shape格式。

    但是reshape方法更强大,可以认为a.reshape = a.view() + a.contiguous().view()。

    即:在满足tensor连续性条件时,a.reshape返回的结果与a.view()相同,否则返回的结果与a.contiguous().view()相同。

在这里插入图片描述

五. 总结

    torch的view()与reshape()方法都可以用来重塑tensor的shape,区别就是使用的条件不一样。view()方法只适用于满足连续性条件的tensor,并且该操作不会开辟新的内存空间,只是产生了对原存储空间的一个新别称和引用,返回值是视图。而reshape()方法的返回值既可以是视图,也可以是副本,当满足连续性条件时返回view,否则返回副本[ 此时等价于先调用contiguous()方法在使用view() ]。因此当不确能否使用view时,可以使用reshape。如果只是想简单地重塑一个tensor的shape,那么就是用reshape,但是如果需要考虑内存的开销而且要确保重塑后的tensor与之前的tensor共享存储空间,那就使用view()。

为什么没把view废除?

  • 1、在PyTorch不同版本的更新过程中,view先于reshape方法出现,后来出现了鲁棒性更好的reshape方法,但view方法并没因此废除。其实不止PyTorch,其他一些框架或语言比如OpenCV也有类似的操作。
  • 2、view的存在可以显示地表示对这个tensor的操作只能是视图操作而非拷贝操作。这对于代码的可读性以及后续可能的bug的查找比较友好。

本文参考:

  1. https://www.zhihu.com/search?type=content&q=torch.view%E4%B8%8Etorch.reshape%E5%8C%BA%E5%88%AB

  2. https://blog.csdn.net/Flag_ing/article/details/109129752

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐