本篇主要介绍在训练过程中过拟合与欠拟合的判断方法和减轻办法,主要是交叉划分的介绍

一、过拟合

  第三张图片很直观地解释了什么是过拟合
图片描述

二、split

1. split->Train Set,Test Set

  把数据集分为Train Set和Test Set
图片描述
  如果在训练数据集上表现得很好,但是在测试数据集上表现地不好,此时就要考虑是不是过拟合了,在对测试数据集作测试的时候,如果loss曲线在本该一直下降的过程中开始上升。一般会设置一个保存最低点的量,标志着那个时候训练出来的是最好的模型

train_db = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
train_loader = torch.utils.data.DataLoader(
    train_db,
    batch_size=batch_size, shuffle=True)

test_db = datasets.MNIST('../data', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.DataLoader(test_db,
    batch_size=batch_size, shuffle=True)

2. split->Train Set,Test Set,Val Set(更标准的做法)

  把数据集分为Train Set,Test Set和Val Set
图片描述
  此时validation代替了原来test的功能,test数据集真正地作“测试” ,不能用作指导训练的过程

val_loader = torch.utils.data.DataLoader(
    val_db,
    batch_size=batch_size, shuffle=True)

3. 交叉验证

  把Train Set和Val Set合并都用作数据的更新,验证数据集则每次从中随机抽取
图片描述

三、减轻过拟合

  1. 更多的数据
  2. 限制模型复杂度降低
    1. shallow
    2. regularization
  3. Dropout
  4. data argumentation
  5. early stopping(使用验证数据集提前终止)
      下面主要介绍regularization方法

regularization

图片描述
  L2-regularization 只需要加上weight_decay

device = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)
criteon = nn.CrossEntropyLoss().to(device)  

  L1-regularization 需要人为地写
图片描述