pytorch模板
torch.nn.Module
我们有时候不需要很复杂的网络或者想自己搭建一个网络,这时候我们就可以继承torch.nn.Module类,快速构建一个前向传播的网络结构,当然torch.nn.Module类还可以构建损失函数:
import torch
class net_name(torch.nn.Module):
def __init__(self,other_arguments):
super(net_name,self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size)
# other network layer
# torch.nn.Linear()
# torch.nn.ReLU()
# torch.nn.MaxPool2d()
# torch.nn.Dropout()
# torch.nn.BatchNorm2d()
def forward(self, x):
x = self.conv1(x)
return x
torch.nn.Module
也很容易构建损失函数,如多分类交叉熵,二分类交叉熵,均方误差等,以下列举几种:
import torch
loss_function = torch.nn.CrossEntropyLoss()
loss_function = torch.nn.BCELoss()
loss_function = torch.nn.MSELoss()
loss_function = torch.nn.L1Loss()
# 具体用时
loss = loss_function(output, target)
loss.backward()
torch.optim
torch.optim
里面有很多的可以实现模型参数自动优化的类,我们可以很方便地调用,例如有随机梯度下降SGD(stochastic gradient descent),适应性矩估计Adam(adaptive moment estimation),适应性梯度算法(AdaGrad),均方根传播(RMSProp):
import torch
optimizer = torch.optim.SGD(models.parameters,lr)
optimizer = torch.optim.Adam(models.parameters,lr)
optimizer = torch.optim.Adagrad(models.parameters,lr)
optimizer = torch.optim.RMSprop(models.parameters,lr)
# 训练时
optimizer.zero_grad()
torch.autograd
torch.autograd
包主要的功能是完成神经网络后向传播中的链式求导,在前向传播的时候构建了一张计算图,在后向传播的时候完成对参数的更新。
import torch
from torch.autograd import Variable
use_gpu = torch.cuda.is_available()
if use_gpu:
data,y = Variable(data.cuda()),Variable(y.cuda())
else:
data,y = Variable(data),Variable(y)
torchvision
在pytorch里面有两个核心的包,分别为torch
和torchvision
,torchvision
包的主要功能是实现数据的处理,导入,里面也有预训练的常见模型,例如下面的datasets
,transforms
:
from torchvision import datasets,models,transforms
data = datasets.MNIST(root=save_path,train=True,transform=self_defined_transform)
model = models.vgg16(pretrained=False)
transform = transforms.Normalize(mean=mean,std=std)
torch.save & torch.load
训练模型完事之后肯定就是要保存模型了呀,之后调用就可以之间测试了呀,在pytorch里面模型的保存和加载都有两种方法:
import torch
#保存整个模型,包括结构信息和参数信息
torch.save(model_name,"model_saved_path/model_name.pth")
#只保存参数信息
torch.save(model_name.state_dict(),"model_saved_path/model_name.pth")
#对应第一种模型保存方法,加载完整整个模型
load_model = torch.load("model_saved_path/model_name.pth")
#对应第二种模型保存方法,只加载模型参数,结构需要在上面先定义或导入
model.load_state_dict(torch.load("model_saved_path/model_name.pth"))
另外一种可以参考:
pytorch代码模板