pytorch学习笔记-高阶篇(卷积神经网络实战2)
本篇主要是卷积神经网络的实战,网络结构是resnet,数据集用的是CIFAR-10
一、resnet
resnet前面已经介绍过,这里不再赘述,需要和前面的Lenet5作区分,最显著的区别是一个短接 short cut的过程
二、resnet类
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
def __init__(self, ch_in, ch_out, stride=1):
"""
Args:
ch_in:
ch_out:
"""
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
# [b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
"""
Args:
x: [b, ch, h, w]
Returns:
"""
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# short cut
# element-wise add: [b, ch_in, h, w] with [b, ch_out, h, w]
out = self.extra(x) + out
out = F.relu(out)
return out
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
# 预处理层, 先把输入的channel转换成64
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
# [b, 64, h, w] => [b, 128, h, w]
self.blk1 = ResBlk(64, 128, stride=2)
# [b, 128, h, w] => [b, 256, h, w]
self.blk2 = ResBlk(128, 256, stride=2)
# [b, 256, h, w] => [b, 512, h, w]
self.blk3 = ResBlk(256, 512, stride=2)
# [b, 512, h, w] => [b, 1024, h, w]
self.blk4 = ResBlk(512, 512, stride=2)
self.outlayer = nn.Linear(512, 10)
def forward(self, x):
"""
Args:
x:
Returns:
"""
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print('after conv:', x.shape)
# after conv: torch.Size([2, 512, 2, 2])
# [b, 512, h, w] => [b, 512, 1, 1]
x = F.adaptive_avg_pool2d(x, [1, 1])
# print('after pool:', x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
def main():
# 这里步进stride会把输入的后两维除以步进
# stride=2 输出是[b, 128, 16, 16]
blk = ResBlk(64, 128, stride=2)
tmp = torch.randn(2, 64, 32, 32)
out = blk(tmp)
print('block:', out.shape)
x = torch.randn(2, 3, 32, 32)
model = ResNet18()
out = model(x)
print('resnet:', out.shape)
if __name__ == '__main__':
main()
三、main
主函数部分和之前的Lenet5是没有区别的,只是模型的加载就好
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from resnet import ResNet18
def main():
batchsz = 4
# 训练数据集加载
# 一次加载一张
cifar_train = datasets.CIFAR10("./dataset", train=True, transform=transforms.Compose([
transforms.Resize([32, 32]),
transforms.ToTensor(),
# 数据增强的一些操作
# transforms.RandomRotation(5)
# 下面这个数据是统计得来的,RGB三通道上的均值标准层
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
]), download=True)
# 一次加载多张
cifar_train = DataLoader(dataset=cifar_train, batch_size=batchsz, shuffle=True)
# 测试数据集加载
# 一次加载一张
cifar_test = datasets.CIFAR10("./dataset", train=False, transform=transforms.Compose([
transforms.Resize([32, 32]),
transforms.ToTensor()
# 下面这个数据是统计得来的,RGB三通道上的均值标准层
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
]), download=True)
# 一次加载多张
cifar_test = DataLoader(dataset=cifar_test, batch_size=batchsz, shuffle=True)
# 迭代器
x, label = iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
device = torch.device('cuda')
#model =Lenet5().to(device)
model = ResNet18().to(device)
criteon = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
for epoch in range(1000):
model.train()
for batchidx, (x, label) in enumerate(cifar_train):
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
logits = model(x)
# logits: [b, 10]
# label: [b]
# loss: tensor scalar
loss = criteon(logits, label)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
#
print(epoch, loss.item())
model.eval()
with torch.no_grad():
# test
total_correct = 0
total_num = 0
for x, label in cifar_test:
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
# [b, 10]
logits = model(x)
# [b]
pred = logits.argmax(dim=1)
# [b] vs [b] => scalar tensor
total_correct = total_correct + torch.eq(pred, label).float().sum().item()
total_num = total_num + x.size(0)
acc = total_correct / total_num
print(epoch, acc)
if __name__ == '__main__':
main()
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 不听话的兔子君!