本篇主要记录一下一些链式法则相关知识,和导数的加减复合的法则类似

一、常见梯度求导法则

图片描述

二、链式法则

其实很容易理解,和导数的复合是一致的
图片描述
下图是针对神经网络的一个具体的例子
图片描述

# 验证链式法则
x = torch.tensor(1.)
w1 = torch.tensor(2., requires_grad=True)
b1 = torch.tensor(1.)
w2 = torch.tensor(2., requires_grad=True)
b2 = torch.tensor(1.)

y1 = x*w1 + b1
y2 = y1*w2 + b2

dy2_dy1 = torch.autograd.grad(y2, [y1], retain_graph=True)[0]
dy1_dw1 = torch.autograd.grad(y1, [w1], retain_graph=True)[0]
dy2_dw1 = torch.autograd.grad(y2, [w1], retain_graph=True)[0]

dy2_dw1*dy1_dw1
# Out[44]: tensor(2.)
dy2_dw1
# Out[45]: tensor(2.)
# 二者结果一致