view()
在不复制内存的情况下重塑张量,类似于 numpy 的 reshape()
。
给定一个有 16 个元素的张量 a
:
import torch
a = torch.range(1, 16)
要重塑此张量以使其成为 4 x 4
张量,请使用:
a = a.view(4, 4)
现在 a
将成为 4 x 4
张量。 请注意,重塑后的元素总数需要保持不变。将张量 a
重塑为 3 x 5
张量是不合适的。
参数-1是什么意思?
如果在任何情况下您不知道需要多少行但确定列数,则可以使用 -1 指定。 (请注意,您可以将其扩展到具有更多维度的张量。只有一个轴值可以是-1)。这是告诉库的一种方式:“给我一个包含这么多列的张量,然后计算实现这一点所需的适当行数”。
这可以在 this model definition code 中看到。在 forward 函数中的第 x = self.pool(F.relu(self.conv2(x)))
行之后,您将拥有一个 16 深度的特征图。您必须将其展平以将其提供给全连接层。因此,您告诉 PyTorch 重塑您获得的张量以具有特定的列数,并告诉它自己决定行数。
让我们做一些例子,从简单到更难。
view 方法返回一个与 self 张量具有相同数据的张量(这意味着返回的张量具有相同数量的元素),但具有不同的形状。例如: a = torch.arange(1, 17) # a 的形状是 (16,) a.view(4, 4) # 下面输出 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 [torch .FloatTensor of size 4x4] a.view(2, 2, 4) # 下面的输出 (0 ,.,.) = 1 2 3 4 5 6 7 8 (1 ,.,.) = 9 10 11 12 13 14 15 16 [torch.FloatTensor of size 2x2x4] 假设-1不是参数之一,当你将它们相乘时,结果必须等于张量中的元素个数。如果您这样做:a.view(3, 3),它将引发 RuntimeError,因为形状 (3 x 3) 对于 16 个元素的输入无效。换句话说:3 x 3 不等于 16,而是 9。您可以使用 -1 作为传递给函数的参数之一,但只能使用一次。所发生的只是该方法将为您计算如何填充该维度。例如 a.view(2, -1, 4) 等价于 a.view(2, 2, 4)。 [16 / (2 x 4) = 2] 请注意,返回的张量共享相同的数据。如果您在“视图”中进行更改,您将更改原始张量的数据: b = a.view(4, 4) b[0, 2] = 2 a[2] == 3.0 False 现在,对于更复杂的用例。文档说每个新的视图维度必须要么是原始维度的子空间,要么只跨越 d, d + 1, ..., d + k 满足以下类似连续的条件,即对于所有 i = 0, 。 .., k - 1, stride[i] = stride[i + 1] x size[i + 1]。否则,需要在查看张量之前调用 contiguous()。例如: a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2) a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2 , 4) # 下面的注释行将引发 RuntimeError,因为一个维度 # 跨越两个连续的子空间 # a_t.view(-1, 4) # 改为:a_t.contiguous().view(-1, 4) #要查看为什么第一个不起作用而第二个起作用,# 比较 a.stride() 和 a_t.stride() a.stride() # (24, 6, 2, 1) a_t.stride() # (24 , 2, 1, 6) 请注意,对于 a_t,stride[0] != stride[1] x size[1] since 24 != 2 x 3
view()
通过将张量的元素“拉伸”或“挤压”成您指定的形状来重塑张量:
https://i.stack.imgur.com/ORqaP.png
view() 是如何工作的?
首先让我们看看引擎盖下的张量是什么:
张量及其底层存储,例如右侧张量(形状 (3,2))可以从左侧张量计算,其中 t2 = t1.view(3,2)
在这里,您可以看到 PyTorch 通过添加 shape
和 stride
属性将底层连续内存块转换为类似矩阵的对象来生成张量:
shape 表示每个维度有多长
stride 说明在到达每个维度中的下一个元素之前,您需要在内存中执行多少步
view(dim1,dim2,...) 返回相同基础信息的视图,但重新整形为形状为 dim1 x dim2 x ... 的张量(通过修改 shape 和 stride 属性)。
请注意,这隐含地假设新维度和旧维度具有相同的乘积(即新旧张量具有相同的体积)。
火炬 -1
-1
是 PyTorch 的别名,用于“在其他维度都已指定的情况下推断此维度”(即原始产品与新产品的商)。这是取自 numpy.reshape()
的约定。
因此,我们示例中的 t1.view(3,2)
将等价于 t1.view(3,-1)
或 t1.view(-1,2)
。
火炬.Tensor.view()
简单地说,受 numpy.ndarray.reshape()
或 numpy.reshape()
启发的 torch.Tensor.view()
会创建张量的新视图,只要新形状与原始张量的形状兼容。
让我们通过一个具体的例子来详细理解这一点。
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
使用形状为 (18,)
的张量 t
,可以仅为以下形状创建新的视图:
(1, 18)
或等效 (1, -1)
或 (-1, 18)
(2, 9)
或等效 >(2, -1)
或 (-1, 9)
(3, 6)
或等效的 (3, -1)
或 {9 }
(6, 3)
或等效的 (6, -1)
或 (-1, 3)
(9, 2)
或等效 (9, -1)
或 (-1, 2)
(18, 1)
或等效 (18, -1)
强> 或 (-1, 1)
正如我们已经从上面的形状元组中观察到的那样,形状元组的元素(例如 2*9
、3*6
等)的乘积必须始终等于原始张量(在我们的示例中为 18
)。
要观察的另一件事是,我们在每个形状元组的一个位置使用了 -1
。通过使用 -1
,我们懒于自己进行计算,而是将任务委托给 PyTorch,以便在创建新 视图 时计算形状的该值。需要注意的重要一点是,我们可以仅在形状元组中使用单个 -1
。其余值应由我们明确提供。否则 PyTorch 会通过抛出 RuntimeError
来抱怨:
RuntimeError:只能推断一维
因此,对于上述所有形状,PyTorch 将始终返回原始张量 t
的 新视图。这基本上意味着它只是为请求的每个新视图更改张量的步幅信息。
下面是一些示例,说明了张量的步幅如何随着每个新视图的变化而变化。
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
现在,我们将看到新视图的进步:
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
这就是 view()
函数的神奇之处。只要新 view 的形状与原始形状兼容,它只会更改每个新 view 的(原始)张量的步幅。
从 strides 元组中可能观察到的另一件有趣的事情是,第 0 位置的元素的值等于形状元组的第 1 位置的元素的值。
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
这是因为:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
步幅 (6, 1)
表示要沿着第 0th 维度从一个元素移动到下一个元素,我们必须 跳转 或采取 6 步。 (即从 0
到 6
,需要 6 步。)但是要从第 1st 维度中的一个元素到下一个元素,我们只需要一步(例如从 2
到 3
)。
因此,步长信息是如何从内存访问元素以执行计算的核心。
火炬.reshape()
只要新形状与原始张量的形状兼容,此函数将返回一个 view 并且与使用 torch.Tensor.view()
完全相同。否则,它将返回一个副本。
但是,torch.reshape()
的注释警告说:
连续输入和具有兼容步幅的输入可以在不复制的情况下进行重塑,但不应依赖于复制与查看行为。
我发现 x.view(-1, 16 * 5 * 5)
等价于 x.flatten(1)
,其中参数 1 表示展平过程从第一维开始(不是展平“样本”维)如您所见,后一种用法在语义上更清晰,并且更容易使用,所以我更喜欢 flatten()
。
让我们尝试通过以下示例来理解视图:
a=torch.range(1,16)
print(a)
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
15., 16.])
print(a.view(-1,2))
tensor([[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.],
[ 9., 10.],
[11., 12.],
[13., 14.],
[15., 16.]])
print(a.view(2,-1,4)) #3d tensor
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.]],
[[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]])
print(a.view(2,-1,2))
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.],
[13., 14.],
[15., 16.]]])
print(a.view(4,-1,2))
tensor([[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.]]])
-1 作为参数值是计算 x 值的简单方法,前提是我们知道 y、z 的值,或者在 3d 和 2d 的情况下反过来又是计算 x 值的简单方法,前提是我们知道 y 的值,反之亦然..
参数-1是什么意思?
您可以将 -1
解读为参数的动态数量或“任何内容”。因此,view()
中只能有一个参数 -1
。
如果您询问 x.view(-1,1)
,这将根据 x
中的元素数量输出张量形状 [anything, 1]
。例如:
import torch
x = torch.tensor([1, 2, 3, 4])
print(x,x.shape)
print("...")
print(x.view(-1,1), x.view(-1,1).shape)
print(x.view(1,-1), x.view(1,-1).shape)
将输出:
tensor([1, 2, 3, 4]) torch.Size([4])
...
tensor([[1],
[2],
[3],
[4]]) torch.Size([4, 1])
tensor([[1, 2, 3, 4]]) torch.Size([1, 4])
weights.reshape(a, b)
将返回一个新张量,其数据与大小为 (a, b) 的权重相同,因为它将数据复制到内存的另一部分。
weights.resize_(a, b)
返回具有不同形状的相同张量。但是,如果新形状导致的元素少于原始张量,则将从张量中删除一些元素(但不会从内存中删除)。如果新形状导致的元素多于原始张量,则新元素将在内存中未初始化。
weights.view(a, b)
将返回一个新张量,其数据与大小为 (a, b) 的权重相同
我真的很喜欢@Jadiel de Armas 的例子。
我想对 .view(...) 元素的排序方式添加一点见解
对于形状为 (a,b,c) 的张量,其元素的顺序由编号系统确定:其中第一位数字为 a,第二位数字为 b,第三位数字为 c。
.view(...) 返回的新张量中元素的映射保留了原始张量的这个顺序。
reshape
?!