pytorch学习笔记-高阶篇(交叉划分)
本篇主要介绍在训练过程中过拟合与欠拟合的判断方法和减轻办法,主要是交叉划分的介绍
一、过拟合
第三张图片很直观地解释了什么是过拟合
二、split
1. split->Train Set,Test Set
把数据集分为Train Set和Test Set
如果在训练数据集上表现得很好,但是在测试数据集上表现地不好,此时就要考虑是不是过拟合了,在对测试数据集作测试的时候,如果loss曲线在本该一直下降的过程中开始上升。一般会设置一个保存最低点的量,标志着那个时候训练出来的是最好的模型
train_db = datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(
train_db,
batch_size=batch_size, shuffle=True)
test_db = datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.DataLoader(test_db,
batch_size=batch_size, shuffle=True)
2. split->Train Set,Test Set,Val Set(更标准的做法)
把数据集分为Train Set,Test Set和Val Set
此时validation代替了原来test的功能,test数据集真正地作“测试” ,不能用作指导训练的过程
val_loader = torch.utils.data.DataLoader(
val_db,
batch_size=batch_size, shuffle=True)
3. 交叉验证
把Train Set和Val Set合并都用作数据的更新,验证数据集则每次从中随机抽取
三、减轻过拟合
- 更多的数据
- 限制模型复杂度降低
- shallow
- regularization
- Dropout
- data argumentation
- early stopping(使用验证数据集提前终止)
下面主要介绍regularization方法
regularization
L2-regularization 只需要加上weight_decay
device = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)
criteon = nn.CrossEntropyLoss().to(device)
L1-regularization 需要人为地写
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 不听话的兔子君!