pytorch学习笔记-高阶篇(RNN训练难题)
本篇主要是说明RNN在训练过程中会遇到的一些问题,包括梯度弥散和梯度爆炸。以及怎么用改进的LSTM来解决
梯度弥散和梯度爆炸简单理解就是0.99^100和1.01^100这种概念。
一、梯度爆炸
如图所示,在正常的梯度下降方向,到达某个位置的时候,一个微小的动作都会让梯度偏离原来的路线。为此,我们可以设置一个阈值,当大于这个阈值的时候,就硬性地把它扳回原来的方向。
loss = criteon(output, y)
model.zero_grad()
loss.backward()
for p in module,parameters():
print(p.grad.norm())
# 最大值设置为10,会把模限制在10以内
torch.nn.utils.clip_grad_norm_(p, 10)
optimizer.step()
一、梯度弥散=>LSTM
如图所示。对此,提出了改进的LSTM网络,与RNN的short term memery相比,不仅改善了梯度弥散,还延长了语境记忆的长度。所以称为long short term memery,第一个long为动词,意为延长
1. LSTM原理(三门)
组合逻辑
2. LSTM实现
# 1.LSTM
lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)
# 3个句子,每个句子有10个单词,每个单词是100维的
x = torch.randn(10, 3, 100)
out, (h, c) = lstm(x)
print(out.shape, h.shape, c.shape)
# torch.Size([10, 3, 20]) torch.Size([4, 3, 20]) torch.Size([4, 3, 20])
# 2.LSTMcell
# 一层LSTM
cell = nn.LSTMCell(input_size=100, hidden_size=20)
h = torch.zeros(3, 20)
c = torch.zeros(3, 20)
for xt in x:
h, c = cell(xt, [h, c])
print(h.shape, c.shape)
# torch.Size([3, 20]) torch.Size([3, 20])
# 两层LSTM
cell1 = nn.LSTMCell(input_size=100, hidden_size=30)
cell2 = nn.LSTMCell(input_size=30, hidden_size=20)
h1 = torch.zeros(3, 30)
c1 = torch.zeros(3, 30)
h2 = torch.zeros(3, 20)
c2 = torch.zeros(3, 20)
for xt in x:
h1, c1 = cell1(xt, [h1, c1])
h2, c2 = cell2(h1, [h2, c2])
print(h2.shape, c2.shape)
# torch.Size([3, 20]) torch.Size([3, 20])
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 不听话的兔子君!