#

Pytorch模板


pytorch模板

torch.nn.Module

我们有时候不需要很复杂的网络或者想自己搭建一个网络,这时候我们就可以继承torch.nn.Module类,快速构建一个前向传播的网络结构,当然torch.nn.Module类还可以构建损失函数:

import torch
class net_name(torch.nn.Module):
    def __init__(self,other_arguments):
        super(net_name,self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size)
        # other network layer
        # torch.nn.Linear()
        # torch.nn.ReLU()
        # torch.nn.MaxPool2d()
        # torch.nn.Dropout()
        # torch.nn.BatchNorm2d()
        
    def forward(self, x):
        x = self.conv1(x)
        return x

torch.nn.Module也很容易构建损失函数,如多分类交叉熵,二分类交叉熵,均方误差等,以下列举几种:

import torch
loss_function = torch.nn.CrossEntropyLoss()
loss_function = torch.nn.BCELoss()
loss_function = torch.nn.MSELoss()
loss_function = torch.nn.L1Loss()
# 具体用时
loss = loss_function(output, target)
loss.backward()

torch.optim

torch.optim里面有很多的可以实现模型参数自动优化的类,我们可以很方便地调用,例如有随机梯度下降SGD(stochastic gradient descent),适应性矩估计Adam(adaptive moment estimation),适应性梯度算法(AdaGrad),均方根传播(RMSProp):

import torch
optimizer = torch.optim.SGD(models.parameters,lr)
optimizer = torch.optim.Adam(models.parameters,lr)
optimizer = torch.optim.Adagrad(models.parameters,lr)
optimizer = torch.optim.RMSprop(models.parameters,lr)
# 训练时
optimizer.zero_grad()

torch.autograd

torch.autograd包主要的功能是完成神经网络后向传播中的链式求导,在前向传播的时候构建了一张计算图,在后向传播的时候完成对参数的更新。

import torch
from torch.autograd import Variable
use_gpu = torch.cuda.is_available()
if use_gpu:
    data,y = Variable(data.cuda()),Variable(y.cuda())
else:
    data,y = Variable(data),Variable(y)

torchvision

在pytorch里面有两个核心的包,分别为torchtorchvisiontorchvision包的主要功能是实现数据的处理,导入,里面也有预训练的常见模型,例如下面的datasets,transforms:

from torchvision import datasets,models,transforms
data = datasets.MNIST(root=save_path,train=True,transform=self_defined_transform)
model = models.vgg16(pretrained=False)
transform = transforms.Normalize(mean=mean,std=std)

torch.save & torch.load

训练模型完事之后肯定就是要保存模型了呀,之后调用就可以之间测试了呀,在pytorch里面模型的保存和加载都有两种方法:

import torch
#保存整个模型,包括结构信息和参数信息
torch.save(model_name,"model_saved_path/model_name.pth")
#只保存参数信息
torch.save(model_name.state_dict(),"model_saved_path/model_name.pth")

#对应第一种模型保存方法,加载完整整个模型
load_model = torch.load("model_saved_path/model_name.pth")
#对应第二种模型保存方法,只加载模型参数,结构需要在上面先定义或导入
model.load_state_dict(torch.load("model_saved_path/model_name.pth"))

另外一种可以参考:
pytorch代码模板


文章作者: 王胜鹏
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 王胜鹏 !
评论
 上一篇
Pytorch案例 Pytorch案例
加载数据 PyTorch有两个处理数据的库:torch.utils.data.DataLoader和torch.utils.data.Dataset。数据集存储样本及其对应的标签,DataLoader在数据集周围包装一个可迭代对象。 imp
2021-10-29
下一篇 
组合数 组合数
组合数一列等式 求解: ∑k=1n−1k(n−1k−1)mn−k\sum_{k=1}^{n-1}{k\left( \begin{array}{c} n-1\\ k-1\\ \end{array} \right)}m^{n-k} k=1∑
2021-10-20
  目录