pytorch学习笔记-高阶篇(多分类实战)
本篇主要介是一个LR多分类问题的实战
一、网络部分
# 3个线性层
# 在pytorch中[CH_out,CH_in]的格式
w1, b1 = torch.randn(200, 784, requires_grad=True),\
torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True),\
torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True),\
torch.zeros(10, requires_grad=True)
torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)
def forward(x):
# 这里的 x 输入是[?,784] 输出是[?, 200],b会自动作一个broadcast
x = x@w1.t() + b1
x = F.relu(x)
x = x@w2.t() + b2
x = F.relu(x)
x = x@w3.t() + b3
x = F.relu(x)
return x
```
### 二、初始化
  在运行过程中,loss始终保持不变,出现了梯度离散的情况。可以尝试加上初始化;很多时候,训练的效果不好,可能不是网络的原因,而仅仅是初始化的位置不对。
``` python
torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)
```
### 三、训练部分
``` python
optimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
# Crossentropy包含softmax,log等操作
criteon = nn.CrossEntropyLoss()
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
# -1表示不清楚维度,第0维不变保持
data = data.view(-1, 28*28)
logits = forward(data)
loss = criteon(logits, target)
optimizer.zero_grad()
loss.backward()
# print(w1.grad.norm(), w2.grad.norm())
optimizer.step()
```
### 四、测试部分
  需要注意的是,对于深度学习来说,训练的次数并不是越多越好,可以看到图中随着训练次数的增加,测试集的准确率开始出现波动,并且直接影响到loss。出现这种“过拟合”现象的原因是,次数过多,会导致过多的对训练数据集sample的关注从而忽略了共性,导致测试其他数据集可能准确率并不理想。

``` python
test_loss = 0
correct = 0
for data, target in test_loader:
data = data.view(-1, 28 * 28)
logits = forward(data)
test_loss += criteon(logits, target).item()
pred = logits.data.max(1)[1]
correct += pred.eq(target.data).sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
```
### 五、可视化(Visdom)
- 首先自然是安装
- 然后是打开web服务器
python -m visdom.server
- 代码中体现
``` python
viz = Visdom()
viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
viz.line([[0.0, 0.0]], [0.], win='test', opts=dict(title='test loss&acc.', legend=['loss', 'acc.']))
生成的图片:
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 不听话的兔子君!