Pytorch-tutorials-学习(六)

Pytorch中如何处理RNN变长序列padding

为什么RNN需要变长输入

假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是如下所示:
pic1
思路比较简单,但是当我们进行batch个训练数据进行计算的时候,会遇到多个训练样例长度不同的情况,这样我们就会很自然的进行padding,将短句子padding为最长的句子一样。

比如向下图这样:
pic2
但是这会有一个问题,什么问题?比如上图,句子”Yes”只有一个单词,但是padding了5个pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差,直观表示如下:
pic1

这就引出Pytorch中RNN需要处理变长输入的需要了。在上面这个例子,我们想要得到的表示仅仅是LSTM过完单词”Yes”之后的表示,而不是通过了多个无用的”Pad”得到的表示:如下图:
pic3

pytorch 中RNN如何处理变长padding

主要是用函数torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这两个函数的用法。
这里的pack,理解成压紧比较好。将一个填充过的变长序列压紧。(填充时候,会有冗余,所以压紧一下)
输入的形状可以是(T×B× ).这里T是最长序列长度,B是batch_size, 代表任意维度(可以是0).如果batch_first=True,那么相应的input_size就是(B×T×* )

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后(特别注意需要进行排序).即input[:,0]代表的是最长的序列

  1. packed_padded_sequence
    先填充后压紧
    参数说明:
    input(Variable)——变长序列被填充后的batch
    lengths(list[int])——Variable中每个序列的长度(知道了每个序列的长度,才能知道每个序列处理到多长停止)
    batch_first(bool)
    返回值:
    一个PackedSequence对象,一个PackedSequence表示如下所示:
    pic4

具体代码如下:

1
2
3
embed_input_x_packed=pack_padded_sequence(embed_input_x,sentence_lens,batch_first=True)

encoder_outputs_packed,(h_last,c_last)=self.lstm(embed_input_x_packed)

此时返回的h_last和c_last就是剔出padding字符后的hidden state和cell state,都是Variable类型的。代表的意思如下(各个句子的表示,lstm只会作用到它实际长度的句子,而不是通过无用的padding字符,下图用红色的打勾来表示):
pic5

  1. pad_packed_sequence
    先压紧后填充
    参数说明:
  • sequence(PackedSequence)——将要被填充的batch
  • batch_first(bool)

返回的Varaible的值的size是 T×B× , T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×
batch 中的元素将会以它们的长度逆序排列

PackedSequence输入RNN后输出的仍是PackedSequence

补充

之所以要引入pack_padded_sequence是因为序列是变长的,我们在预处理的时候会把序列加pad使其等长,而在训练的时候我们是不希望处理pad.所以之前对变长序列的处理方法是for循环,一个一个放入model中.而现在有了pack_padded_sequence就不需要for了,直接输入一个pack_padded_sequence后的数据,然后输入这个数据的每一条的长度,得到的输出再通过pad_packed_sequence变回原来的形式.

-------------本文结束感谢您的阅读-------------

本文标题:Pytorch-tutorials-学习(六)

文章作者:Yif Du

发布时间:2019年03月28日 - 21:03

最后更新:2019年04月05日 - 19:04

原始链接:http://yifdu.github.io/2019/03/28/Pytorch-tutorials-学习(六)/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。