当前位置: 首页 > news >正文

网站春节放假/网上企业推广

网站春节放假,网上企业推广,iis发布php网站,wordpress 香港繁体在辛辛苦苦训练好模型之后,我们想将它保存起来,或我们想使用已经训练完成的模型。那么该如何是实现呢? 本文参考:https://pytorch.org/tutorials/beginner/saving_loading_models.html 本文将以一个CNN模型演示如何保存或加载以训…

在辛辛苦苦训练好模型之后,我们想将它保存起来,或我们想使用已经训练完成的模型。那么该如何是实现呢?
本文参考:https://pytorch.org/tutorials/beginner/saving_loading_models.html
本文将以一个CNN模型演示如何保存或加载以训练好的模型。
首先给训练过程:

import torch
import torch.nn
import torch.optim
import torch.utils.data
import torchvision
import numpy
import matplotlibfrom torch.autograd import Variable
from torchvision import datasets
from torchvision import transformsnum_epoch = 5
batch_size = 100
learning_rate = 0.001
# -------------------------------------------------------------------------
# root 用于指定数据集在下载之后的存放路径
# transform 用于指定导入数据集需要对数据进行那种变化操作
# train是指定在数据集下载完成后需要载入那部分数据,
# 如果设置为True 则说明载入的是该数据集的训练集部分
# 如果设置为FALSE 则说明载入的是该数据集的测试集部分
data_train = datasets.MNIST(root="./data/",transform=transforms.ToTensor(),train=True,download=True)
data_test = datasets.MNIST(root="./data/",transform=transforms.ToTensor(),train=False)# ______________________________________________________________________________
# 下面对数据进行装载,我们可以将数据的载入理解为对图片的处理,
# 在处理完成后,我们就需要将这些图片打包好送给我们的模型进行训练 了  而装载就是这个打包的过程
# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
#  在装载的过程会将数据随机打乱顺序并进打包
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,batch_size=batch_size,shuffle=True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test,batch_size=batch_size,shuffle=True)class CNN(torch.nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1, 16, kernel_size=5, padding=2),# 用于搭建卷积神经网络的卷积层,主要的输入参数有输入通道数、输出通道数、# 卷积核大小、卷积核移动步长和Paddingde值。其中,输入通道数的数据类型是# 整型,用于确定输入数据的层数;输出通道数的数据类型也是整型,用于确定# 输出数据的层数;卷积核大小的数据类型是整型,用于确定卷积核的大小;# 卷积核移动步长的数据类型是整型,用于确定卷积核每次滑动的步长;# Paddingde 的数据类型是整型,值为0时表示不进行边界像素的填充,# 如果值大于0,那么增加数字所对应的边界像素层数。torch.nn.BatchNorm2d(16),torch.nn.ReLU(),torch.nn.MaxPool2d(2)# 用于实现卷积神经网络中的最大池化层,主要的输入参数是池化窗口大小、# 池化窗口移动步长和Padding的值。)self.conv2 = torch.nn.Sequential(torch.nn.Conv2d(16, 32, kernel_size=5, padding=2),torch.nn.BatchNorm2d(32),torch.nn.ReLU(),torch.nn.MaxPool2d(2))self.fc = torch.nn.Linear(7 * 7 * 32, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)x = self.fc(x)return xcnn = CNN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cnn.to(device)
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)for epoch in range(num_epoch):for i, data in enumerate(data_loader_train):images, labels = data[0].to(device), data[1].to(device)outputs = cnn(images)loss = loss_func(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'% (epoch + 1, num_epoch, i + 1, len(data_train), loss.item()))

- 方法一(推荐)

训练完成之后:使用如下的方式保存模型:

torch.save(cnn.state_dict(), 'cnn.pkl')

可以看到在同一级目录下出现了一个名为cnn.pkl的文件这个就是模型本尊了。
这种方法仅仅保存了模型的重要参数而非整个模型
通过运行如下代码可以显示cnn.state_dict()中包含了那些内容(这里由于篇幅的原因,只显示了大小,没有显示具体数值):

print("Model's state_dict:")
for param_tensor in cnn.state_dict():print(param_tensor, "\t", cnn.state_dict()[param_tensor].size())

输出结果为:

Model's state_dict:
conv1.0.weight 	 torch.Size([16, 1, 5, 5])
conv1.0.bias 	 torch.Size([16])
conv1.1.weight 	 torch.Size([16])
conv1.1.bias 	 torch.Size([16])
conv1.1.running_mean 	 torch.Size([16])
conv1.1.running_var 	 torch.Size([16])
conv1.1.num_batches_tracked 	 torch.Size([])
conv2.0.weight 	 torch.Size([32, 16, 5, 5])
conv2.0.bias 	 torch.Size([32])
conv2.1.weight 	 torch.Size([32])
conv2.1.bias 	 torch.Size([32])
conv2.1.running_mean 	 torch.Size([32])
conv2.1.running_var 	 torch.Size([32])
conv2.1.num_batches_tracked 	 torch.Size([])
fc.weight 	 torch.Size([10, 1568])
fc.bias 	 torch.Size([10])

可以发现这些都是学习的参数信息。

加载这种方式保存得模型时,使用如下的方式:

cnn_new=CNN()
cnn_new.load_state_dict(torch.load('cnn.pkl'))
cnn_new.eval()

必须调用model.eval()将dropout和批处理规范化层设置为评估模式。

  • 方法二
    还有一种方法可以保存整个模型
torch.save(model, 'cnn.pkl')
cnn_new=CNN()
cnn_new= torch.load('cnn.pkl')
cnn_new.eval()

- 方法三:

保存多个通用检查点

保存

torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss},'cnn.pkl')

加载

cnn_new=CNN()
optimizer_new=torch.optim.Adam(cnn.parameters(), lr=learning_rate)checkpoint = torch.load('cnn.pkl')
cnn_new.load_state_dict(checkpoint['model_state_dict'])
optimizer_new.load_state_dict(checkpoint['optimizer_state_dict'])
epoch_new = checkpoint['epoch']
loss_new = checkpoint['loss']
cnn_new.eval()

官网还给出了很多方式,本文会在实际操作之后再更新文章。

http://www.jmfq.cn/news/5212567.html

相关文章:

  • 没有静态ip可以做网站服务器/广告语
  • 棋牌网站搭建公司/360优化大师下载官网
  • 中国最新军事新闻 新闻/哈尔滨网络推广优化
  • 沈阳网站推广优化/百度seo白皮书
  • 那些平台可以给网站做外链/chrome浏览器
  • 如何用织梦猫做网站和后台/自媒体
  • 焦作做网站/在线推广企业网站的方法有哪些
  • 做调查问卷的网站/上海做关键词推广企业
  • 做演示的网站/海外广告优化师
  • 做团购网站的公司/山东济南最新消息
  • 网站建设销售员工作内容/平台推广渠道
  • 微网站可以做商城吗/总推荐榜总点击榜总排行榜
  • 如何安装织梦做的网站/草根站长工具
  • 制作网站账号系统/友链价格
  • 网站需要哪些备案/seo全网推广
  • 大型公司网站建设/百度一下你就知道首页
  • 如何做网站/做一个私人网站需要多少钱
  • 一个公司可以做几个网站/百度搜索排名规则
  • 国内wordpress有名的网站/如何创建自己的网址
  • 会计软件定制开发包括/西安seo
  • 天津网站建设渠道/长沙网站关键词排名公司
  • 腾讯云做网站选哪个/湖南竞价优化哪家好
  • 代理做网站的合同/国家税务总局网
  • 公众号怎么制作教程/福州seo关键字推广
  • 露兜博客 wordpress/seo站内优化培训
  • 域名为www.com的网站/新网站百度多久收录
  • 网站制作公司中/最近新闻大事
  • 网站建设公司的企业特色有哪些/软文标题写作技巧
  • 西部数码手机网站/网站seo网络优化
  • 教育类网站 前置审批/百度刷搜索词