pytorch学习笔记-基础篇(broadcast自动扩展)
本篇主要记录一下一些基础的pytorch知识,broadcast相关内容,主要特点是不复制数据地expand,但是是自动的
api简介
特点
1. expand(自动)
2. without copying data
实施方法
1. 从最小的维度开始匹配,如果前面没有维度,则在前面加 1 个维度
2. 把新建的1维数的 size 1 扩展成和目标相同的维数
以上图第二排为例,为了符合第一个tensor [4,3] 的size,第二size是[3]的tensor 经过broadcast,会首先把自己加一维变成[1,3], 然后再做维数的扩展,变成[4,3]的tensor,好和第一个可以相加
broadcast存在的意义
- 实际需求
比如[class,stu,scores] 这个tensor [4,32,8],现在要给所有学生加5分,我们就希望[4,32,8]和一个维度是1的tensor[5]可以直接相加,此时,使用broadcast相当于两次 unsqueeze 和一次 expand_as
- 内存消耗
[4,32,8] => 1024个数据,如果用expand会增加1000倍(相较于[5])的内存消耗,broadcast则不会。
broadcast使用场景
1. 当前维数是 1,想要扩展成和目标相同维数
2. 如果没有维度,则插入1 维,再扩张
代码实操
# A [4, 32, 14, 14]
# B [1, 32, 1, 1] => [4, 32, 14, 14]
a = torch.rand([4, 32, 14, 14])
b = torch.rand([1])
b = torch.broadcast_to(b, a.size())
'''
a = torch.rand([4, 32, 14, 14])
b = torch.rand([1])
b = torch.broadcast_to(b, a.size())
b.shape
Out[20]: torch.Size([4, 32, 14, 14])
'''
# C [14, 14] => [1, 1, 14, 14] => [4, 32, 14, 14]
# D [2, 32, 14, 14] => [4, 32, 14, 14] 不符合使用条件(以下)
# 1. 0 dim 有维数,不能插入且扩张
# 2. 0 dim 维数不是1
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 不听话的兔子君!