本篇主要记录一下一些相对高阶的tensor操作,where和gather

API介绍

1.where

详细介绍看官方链接:https://pytorch.org/docs/stable/generated/torch.where.html?highlight=where
where
  与 for i: for j: if cond: op 这种写法的区别是这种写法是完全不并行的,只能用CPU跑,而pytorch提供的where这个API则可以利用GPU来完成,包括cond的生成,既可以用CPU也可以用GPU

# 1.where
cond = torch.rand(2, 2)
'''
Out[10]: 
tensor([[0.0532, 0.3245],
        [0.5223, 0.5285]])
'''
a = torch.full([2, 2], 0)
b = torch.full([2, 2], 1)
torch.where(cond>0.5, a, b)
'''
Out[16]: 
tensor([[1, 1],
        [0, 0]])
'''

2.gather

详细介绍看官方链接:https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather
gather
  与 for i: for j: if cond: op 这种写法的区别是这种写法是完全不并行的,只能用CPU跑,而pytorch提供的where这个API则可以利用GPU来完成,包括cond的生成,既可以用CPU也可以用GPU

# 2.gather 查表的过程
# 例子:[dog, cat, whale]
#      [1, 0, 1, 2]
#  =>  [cat, dog, cat, whale]
# 在某些情况下,神经网络的label和其对应的编号并不一定相同,此时就需要这种索引查表的操作
# 这里的意义是 4张照片,每张照片是0,1...9的概率
prob = torch.randn(4, 10)
# 这里的idx是最有可能的三种数字
idx = prob.topk(dim=1, k=3)
'''
Out[18]: 
torch.return_types.topk(
values=tensor([[1.2061, 0.4638, 0.2821],
        [1.8109, 1.6480, 1.5306],
        [1.2217, 0.4735, 0.3213],
        [1.4982, 1.0012, 0.8686]]),
indices=tensor([[7, 5, 1],
        [1, 6, 5],
        [1, 0, 5],
        [3, 0, 7]]))

'''
# 取出indices
idx = idx[1]
# 这里就是label和实际的idx不同的情况(这里强行加了100)
label = torch.arange(10)+100
# Out[22]: tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
torch.gather(label.expand(4, 10), dim=1, index=idx)
'''
Out[27]: 
tensor([[107, 105, 101],
        [101, 106, 105],
        [101, 100, 105],
        [103, 100, 107]])
'''