本篇主要记录一下一些基础的pytorch知识,tensor之间的拼接与拆分,主要涉及4个经典的API

一、API介绍

1. Cat(不新加维度)

# 1. cat
# 场景:两份关于成绩单的数据,现在需要合并两份成绩单
# [class 1~4, stu, scores]
# [class 5~9, stu, scores]
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
# 第一个参数的所有tensor的List,第二个参数是拼接的维度,如果不能拼接会报错
torch.cat([a, b], dim=0).shape
'''
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
torch.cat([a, b], dim=0).shape
Out[5]: torch.Size([9, 32, 8])
'''

2. Stack(新加维度)

# 2. Stack
# stack 会创建一个新的维度
a1 = torch.rand(4, 3, 16, 32)
a2 = torch.rand(4, 3, 16, 32)
# 对于Stack来说,它的参数要求几个tensor的shape是一样的
# 对于两个班的成绩的合并,stack会生成一个意义是班号的维度,这种场景则用stack
torch.cat([a1, a2], dim=2).shape
torch.stack([a1, a2], dim=2).shape
'''
torch.cat([a1, a2], dim=2).shape
Out[9]: torch.Size([4, 3, 32, 32])
torch.stack([a1, a2], dim=2).shape
Out[10]: torch.Size([4, 3, 2, 16, 32])
'''

3. Split(按长度拆分)

# 3. Split(按长度拆分)
a = torch.rand(32, 8)
b = torch.rand(32, 8)
a.shape
b.shape
c = torch.stack([a, b], dim=0)
c.shape
aa, bb = c.split([1, 1], dim=0)
aa.shape, bb.shape
c = torch.cat([c, c], dim=0)
c.shape
# 第一种, 可以指定每段拆分后的长度
aa, bb = c.split([3, 1], dim=0)
aa.shape, bb.shape
# 第二种, 可以指定拆分的单位,比如这里是2,就是每个拆分所得都是2
aa, bb = c.split(2, dim=0)
aa.shape, bb.shape
# 如果不想要其余的,可以这么做,算是一个小技巧
result = c.split(1, dim=0)
result[0].shape, result[1].shape
'''
a = torch.rand(32, 8)
b = torch.rand(32, 8)
a.shape
Out[12]: torch.Size([32, 8])
b.shape
Out[13]: torch.Size([32, 8])
c = torch.stack([a, b], dim=0)
c.shape
Out[15]: torch.Size([2, 32, 8])
aa, bb = c.split([1,1], dim=0)
aa.shape, bb.shape
Out[17]: (torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))
c = torch.stack([a, c], dim=0)
c = torch.cat([c, c], dim=0)
c.shape
Out[21]: torch.Size([4, 32, 8])
aa, bb = c.split([3,1], dim=0)
aa.shape, bb.shape
Out[23]: (torch.Size([3, 32, 8]), torch.Size([1, 32, 8]))
result = c.split(1, dim=0)
result[0].shape, result[1].shape
Out[27]: (torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))
'''

4. Chunk(按数量拆分)

# 4. Chunk(按数量拆分)
# 接上面的c
c.shape
# 即把C等分为两个tensor
aa, bb = c.chunk(2, dim=0)
aa.shape, bb.shape
'''
c.shape
Out[28]: torch.Size([4, 32, 8])
aa, bb = c.chunk(2, dim=0)
aa.shape, bb.shape
Out[30]: (torch.Size([2, 32, 8]), torch.Size([2, 32, 8]))
'''