网站春节放假/网上企业推广
在辛辛苦苦训练好模型之后,我们想将它保存起来,或我们想使用已经训练完成的模型。那么该如何是实现呢?
本文参考:https://pytorch.org/tutorials/beginner/saving_loading_models.html
本文将以一个CNN模型演示如何保存或加载以训练好的模型。
首先给训练过程:
import torch
import torch.nn
import torch.optim
import torch.utils.data
import torchvision
import numpy
import matplotlibfrom torch.autograd import Variable
from torchvision import datasets
from torchvision import transformsnum_epoch = 5
batch_size = 100
learning_rate = 0.001
# -------------------------------------------------------------------------
# root 用于指定数据集在下载之后的存放路径
# transform 用于指定导入数据集需要对数据进行那种变化操作
# train是指定在数据集下载完成后需要载入那部分数据,
# 如果设置为True 则说明载入的是该数据集的训练集部分
# 如果设置为FALSE 则说明载入的是该数据集的测试集部分
data_train = datasets.MNIST(root="./data/",transform=transforms.ToTensor(),train=True,download=True)
data_test = datasets.MNIST(root="./data/",transform=transforms.ToTensor(),train=False)# ______________________________________________________________________________
# 下面对数据进行装载,我们可以将数据的载入理解为对图片的处理,
# 在处理完成后,我们就需要将这些图片打包好送给我们的模型进行训练 了 而装载就是这个打包的过程
# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,batch_size=batch_size,shuffle=True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test,batch_size=batch_size,shuffle=True)class CNN(torch.nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1, 16, kernel_size=5, padding=2),# 用于搭建卷积神经网络的卷积层,主要的输入参数有输入通道数、输出通道数、# 卷积核大小、卷积核移动步长和Paddingde值。其中,输入通道数的数据类型是# 整型,用于确定输入数据的层数;输出通道数的数据类型也是整型,用于确定# 输出数据的层数;卷积核大小的数据类型是整型,用于确定卷积核的大小;# 卷积核移动步长的数据类型是整型,用于确定卷积核每次滑动的步长;# Paddingde 的数据类型是整型,值为0时表示不进行边界像素的填充,# 如果值大于0,那么增加数字所对应的边界像素层数。torch.nn.BatchNorm2d(16),torch.nn.ReLU(),torch.nn.MaxPool2d(2)# 用于实现卷积神经网络中的最大池化层,主要的输入参数是池化窗口大小、# 池化窗口移动步长和Padding的值。)self.conv2 = torch.nn.Sequential(torch.nn.Conv2d(16, 32, kernel_size=5, padding=2),torch.nn.BatchNorm2d(32),torch.nn.ReLU(),torch.nn.MaxPool2d(2))self.fc = torch.nn.Linear(7 * 7 * 32, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)x = self.fc(x)return xcnn = CNN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cnn.to(device)
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)for epoch in range(num_epoch):for i, data in enumerate(data_loader_train):images, labels = data[0].to(device), data[1].to(device)outputs = cnn(images)loss = loss_func(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'% (epoch + 1, num_epoch, i + 1, len(data_train), loss.item()))
- 方法一(推荐)
训练完成之后:使用如下的方式保存模型:
torch.save(cnn.state_dict(), 'cnn.pkl')
可以看到在同一级目录下出现了一个名为cnn.pkl的文件这个就是模型本尊了。
这种方法仅仅保存了模型的重要参数而非整个模型
通过运行如下代码可以显示cnn.state_dict()中包含了那些内容(这里由于篇幅的原因,只显示了大小,没有显示具体数值):
print("Model's state_dict:")
for param_tensor in cnn.state_dict():print(param_tensor, "\t", cnn.state_dict()[param_tensor].size())
输出结果为:
Model's state_dict:
conv1.0.weight torch.Size([16, 1, 5, 5])
conv1.0.bias torch.Size([16])
conv1.1.weight torch.Size([16])
conv1.1.bias torch.Size([16])
conv1.1.running_mean torch.Size([16])
conv1.1.running_var torch.Size([16])
conv1.1.num_batches_tracked torch.Size([])
conv2.0.weight torch.Size([32, 16, 5, 5])
conv2.0.bias torch.Size([32])
conv2.1.weight torch.Size([32])
conv2.1.bias torch.Size([32])
conv2.1.running_mean torch.Size([32])
conv2.1.running_var torch.Size([32])
conv2.1.num_batches_tracked torch.Size([])
fc.weight torch.Size([10, 1568])
fc.bias torch.Size([10])
可以发现这些都是学习的参数信息。
加载这种方式保存得模型时,使用如下的方式:
cnn_new=CNN()
cnn_new.load_state_dict(torch.load('cnn.pkl'))
cnn_new.eval()
必须调用model.eval()将dropout和批处理规范化层设置为评估模式。
- 方法二
还有一种方法可以保存整个模型
torch.save(model, 'cnn.pkl')
cnn_new=CNN()
cnn_new= torch.load('cnn.pkl')
cnn_new.eval()
- 方法三:
保存多个通用检查点
保存
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss},'cnn.pkl')
加载
cnn_new=CNN()
optimizer_new=torch.optim.Adam(cnn.parameters(), lr=learning_rate)checkpoint = torch.load('cnn.pkl')
cnn_new.load_state_dict(checkpoint['model_state_dict'])
optimizer_new.load_state_dict(checkpoint['optimizer_state_dict'])
epoch_new = checkpoint['epoch']
loss_new = checkpoint['loss']
cnn_new.eval()
官网还给出了很多方式,本文会在实际操作之后再更新文章。