pytorch学习笔记-高阶篇(链式法则)
本篇主要记录一下一些链式法则相关知识,和导数的加减复合的法则类似
一、常见梯度求导法则
二、链式法则
其实很容易理解,和导数的复合是一致的
下图是针对神经网络的一个具体的例子
# 验证链式法则
x = torch.tensor(1.)
w1 = torch.tensor(2., requires_grad=True)
b1 = torch.tensor(1.)
w2 = torch.tensor(2., requires_grad=True)
b2 = torch.tensor(1.)
y1 = x*w1 + b1
y2 = y1*w2 + b2
dy2_dy1 = torch.autograd.grad(y2, [y1], retain_graph=True)[0]
dy1_dw1 = torch.autograd.grad(y1, [w1], retain_graph=True)[0]
dy2_dw1 = torch.autograd.grad(y2, [w1], retain_graph=True)[0]
dy2_dw1*dy1_dw1
# Out[44]: tensor(2.)
dy2_dw1
# Out[45]: tensor(2.)
# 二者结果一致
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 不听话的兔子君!