Pytorch tutorials 学习(一)

Pytorch教程文档学习笔记

主要学习顺序:模型的构建顺序

Tensor and Numpy

pytorch的一大优点是与numpy完美融合,我们能使用torch的包生成张量,整个过程类似于生成numpy的ndarray。

1
2
3
4
5
6
7
8
9
import torch as np ###皮一下
empty=np.empty(3,5)
print(empty)
ones=np.ones(3,5)
print(ones)
rand=np.rand(3,5)
print(rand)
zeros=np.zeros(3,5)
print(zeros)

输出的结果的形式与我们在调用numpy时的一样,唯一不同的是此时输出的变量为张量tensor形式。


pytorch的另一个优点是能实现ndarry/list 与tensor之间的快速转换。

1
2
3
4
5
6
7
8
9
10
11
import torch
import numpy as np
list=torch.tensor([3,4,4])
print(list)
y=np.array([1,2,3])
ndarray=torch.tensor(y)
print(ndarray)
print(list.numpy())
print(ndarray.numpy())
print(type(list.numpy()))
print(type(ndarray.numpy()))

torch中提供了tensor()函数对ndarray/list进行转换,而且对标量也能转换。并且现在tensor转换和Variable融合在了一起。
另外torch中还有一个from_numpy()函数仅对ndarray进行转换成tensor。

而从tensor变回ndarray的方法比较单一,只有torch.numpy(),并且返回的结果只能是ndarray。

数据处理加载与处理

这一部分主要介绍数据的载入与预处理里方法。笼统地讲分为两部分:1.制作数据集(Dataset)2.数据装载(DataLoader).这就好比我们手上有制作“子弹”的原料:图片和标签。我们先要将其制作成子弹,然后将他装上枪(DataLoader)。

Dataset和DataLoader都是torch.utils.data下的类

Dataset

制作Dataset,我们往往要处理图片和标签,这需要用到两个库(1.Image(像skimage,cv2都可以) 2.pandas)。
Image 是用于读取图像,而pandas是用于解析csv。因为我们的目的是制作“子弹”,我们要将图片与标签一一对应起来,这就是我们要用Dataset的目的。
torch.utils.data.Dataset是数据集抽象类,我们需要继承Dataset来自定义数据集

1
2
3
4
5
6
7
8
from torch.utils.data import Dataset,DataLoader
class Mydataset(Dataset):
def __init__(self,image_path,label_path,transform):
pass
def __len__(self):
pass
def __getitem__(self,index):
pass

上面是我们自定义数据集的基本形式,里面往往包含着Image和pandas的调用,具体内容不书写了。
当我们写好了这个自定义数据集类之后,我们可以生成一个实例。通过它我们可以调用图片与标签配对成功后的子弹。

在自定义数据集的时候我们往往会进行数据增强的操作。
这一步操作在初始化时被传入,在getitem时会进行。
如何定义数据增强呢??

transform

安装pytorch的同时我们也安装了一个叫torchvision的包,该包提供了一些常见的数据集以及一些转换。
我们这里要用到的是torchvision 下的transforms (加‘s’)

1
2
3
4
5
6
7
8
from torchvision import transforms
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

上面用到的是最一般的些数据增强方法,要注意像transforms.RandomSizedCrop(224)等的这些方法是对图像进行处理的,而transforms.Normalize()是对tensor进行处理的。因此要注意顺序,一定要把ToTensor()和Normalize()放最后。

DataLoader

生成DataLoader就比较简单,不需要自己再继承类。

1
2
3
4
train_dataloader=torch.utils.data.DataLoader(dataset,batch_size,shuffle,num_workers,drop_last)
##使用的时候
for i,(input,target) in enumerate(train_dataloader) :
........

具体详细内容可以看源码:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py

模型搭建

torch下的nn.Module提供了神经网络模型的自定义模块

1
2
3
4
5
6
import torch
class Mymodel(torch.nn.Module):
def __init__(self):
pass
def forward(self,x):
pass

上面这个神经网络类就两部分:1.初始化时定义的网络子模块2.数据流的处理流程(前向传播)
这也是pytorch的一大特点,灵活便捷。
至于该网络的反向传播过程,我们定义好了前向传播,反向传播仅需要输出数据的backward()函数就能实现。

很显然,只要与网络模块有关的,包括反向传播的,我们都可以通过torch.nn.Module来自定义。
比如我们要自定义一个损失函数也可以通过nn.Module。

训练过程的其他部件

Optim

我们定义了数据集,定义了网络甚至可以定义loss.但最重要的深度学习的学习步骤我们还没有处理。
这里需要用到一些最优化方法,如AdaGrad、RMSProp、ADAM。

这些优化器一般存在于torch.nn.Optim里

1
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

优化器的初始化参数包括 1.网络的参数 2.learning_rate 3.weight_decay等
该优化器的作用就是更新参数,因此它必须与网络参数相互关联。

我们一般是在什么时候需要更新参数?
一个batch的数据的loss计算完了以后,需要我们进行网络参数的更新。而这网络参数的更新又与梯度有关,所以我们在更新参数的同时需要知道梯度。

怎么才能知道梯度?我们上一节介绍时提到了backward()反向传播,这是用来计算我们各个参数的梯度的。网络中的每个参数都是Variable型的,它同时也存储着此时他们各自的梯度。loss.backward()就是帮这些参数计算梯度并且存进参数的grad区域。在loss.backward()后,我们进行optimizer.step()就是将参数里的grad区域的梯度取出用于参数的更新。

pytorch有一个不人道的地方,就是在每次loss.backward()前或者后要进行optimizer.zero_grad(),用于清空网络参数grad区域的梯度值。

loss

前面介绍了自定义loss可以通过nn.Module来定义。这里我们可以用torch.nn里的loss类计算每次的loss.

1
2
3
4
5
criterion = torch.nn.crossentropy()
.......
for i,(input,target) in enumerate(train_dataloader):
.....
loss = criterion(y_pred, y)

注意这里y不是onehot的类型,这一点与keras不一样。

总结

乱七八糟地写了这一堆,最难的还是记住各个类在哪个包里。

简单总结一下,torch有点像numpy很万能,一般torch下的函数就能像numpy一样直接处理数据。而如果要处理数据集,因为cv和nlp都会用到,所以存于torch.utils.data下。而与网络相关的类和函数,比如loss,optim,网络模块等都存在于torch.nn下。而常用于数据增强的操作则在torchvision.tranforms里。至于数据集为什么也在torchvision下就很莫名其妙了

-------------本文结束感谢您的阅读-------------

本文标题:Pytorch tutorials 学习(一)

文章作者:Yif Du

发布时间:2018年10月29日 - 21:10

最后更新:2018年10月30日 - 21:10

原始链接:http://yifdu.github.io/2018/10/29/Pytorch-tutorials-学习(一)/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。