pytorch学习笔记-高阶篇(where和gather)
本篇主要记录一下一些相对高阶的tensor操作,where和gather
API介绍
1.where
详细介绍看官方链接:https://pytorch.org/docs/stable/generated/torch.where.html?highlight=where
与 for i: for j: if cond: op 这种写法的区别是这种写法是完全不并行的,只能用CPU跑,而pytorch提供的where这个API则可以利用GPU来完成,包括cond的生成,既可以用CPU也可以用GPU
# 1.where
cond = torch.rand(2, 2)
'''
Out[10]:
tensor([[0.0532, 0.3245],
[0.5223, 0.5285]])
'''
a = torch.full([2, 2], 0)
b = torch.full([2, 2], 1)
torch.where(cond>0.5, a, b)
'''
Out[16]:
tensor([[1, 1],
[0, 0]])
'''
2.gather
详细介绍看官方链接:https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather
与 for i: for j: if cond: op 这种写法的区别是这种写法是完全不并行的,只能用CPU跑,而pytorch提供的where这个API则可以利用GPU来完成,包括cond的生成,既可以用CPU也可以用GPU
# 2.gather 查表的过程
# 例子:[dog, cat, whale]
# [1, 0, 1, 2]
# => [cat, dog, cat, whale]
# 在某些情况下,神经网络的label和其对应的编号并不一定相同,此时就需要这种索引查表的操作
# 这里的意义是 4张照片,每张照片是0,1...9的概率
prob = torch.randn(4, 10)
# 这里的idx是最有可能的三种数字
idx = prob.topk(dim=1, k=3)
'''
Out[18]:
torch.return_types.topk(
values=tensor([[1.2061, 0.4638, 0.2821],
[1.8109, 1.6480, 1.5306],
[1.2217, 0.4735, 0.3213],
[1.4982, 1.0012, 0.8686]]),
indices=tensor([[7, 5, 1],
[1, 6, 5],
[1, 0, 5],
[3, 0, 7]]))
'''
# 取出indices
idx = idx[1]
# 这里就是label和实际的idx不同的情况(这里强行加了100)
label = torch.arange(10)+100
# Out[22]: tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
torch.gather(label.expand(4, 10), dim=1, index=idx)
'''
Out[27]:
tensor([[107, 105, 101],
[101, 106, 105],
[101, 100, 105],
[103, 100, 107]])
'''
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 不听话的兔子君!