本篇主要记录一下一些基础的pytorch知识,接上一篇文章的tensor基础,索引(index)与切片

各种索引切片方法介绍

import torch

a = torch.rand(4, 3, 28, 28)
print(a[0].shape)
print(a[0, 0].shape)
# 这里返回的是一个标量
print(a[0, 0, 2, 4])

# 索引最前/后的图片
# 前两张图片 0,1, 因此这里的输出第一维是 2
print(a[:2].shape)
# 同样,这里的输出第一维是 2,第二维是1
print(a[:2, :1, :, :].shape)
# 注意,这里的输出第一维是 2, 第二晚是从1开始(包括1)到最末,因此维数是2
print(a[:2, 1:, :, :].shape)
# 索引下标
# [a, b, c]
# [0, 1, 2] 正序
# [-3, -2, -1] 倒序
# 所以下面是从-1开始(包括-1),一共 1 维
print(a[:2, -1:, :, :].shape)

# select by step
# step = 2
print(a[:, :, 0:28:2, 0:28:2].shape)
print(a[:, :, ::2, ::2].shape)

# 特定索引 .index_select()
# 在第0个维度,取 0,2
print(a.index_select(0, torch.tensor([0, 2])).shape)
# 在第1个维度,取 1,2
print(a.index_select(1, torch.tensor([1, 2])).shape)
# 在第二个维度,按索引取0~27
print(a.index_select(2, torch.arange(28)).shape)
# 在第二个维度,按索引取0~7
print(a.index_select(2, torch.arange(8)).shape)

# ... : 任意多的维度
print(a[...].shape)
# 会根据实际情况推测
print(a[0, ...].shape)
print(a[:, 1, ...].shape)
print(a[..., :2].shape)

# select by mask
x = torch.randn(3, 4)
# 将x中大于0.5的位置置1(true)
mask = x.ge(0.5)
print(mask)
y = torch.masked_select(x, mask)

# 打平之后取
src = torch.tensor([[4, 3, 5],
                    [6, 7, 8]])
y = torch.take(src, torch.tensor([0, 2, 5]))
print(y)