你的时序模型为何越训越慢?GRU两门控机制如何突破LSTM算力瓶颈
1 简介
GRU(Gated Recurrent Unit)也称门控循环单元结构,传统RNN的变体,同LSTM一样有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸现象。
但其结构和计算比LSTM更简单,核心结构可分两部分:
- 更新门
- 重置门
2 内部结构
2.1 示意图和计算公式
2.2 更新门和重置门
结构图:
2.3 内部结构分析
类似LSTM门控,先计算更新门和重置门的门值,分别是z(t)、r(t),用X(t)与h(t-1)拼接进行线性变换,再经sigmoid激活。
之后重置门门值作用在h(t-1),代表控制上一时间步传来的信息有多少可被利用。接着用这重置后的h(t-1)进行基本的RNN计算,即与x(t)拼接进行线性变化,经过tanh激活,得到新的h(t)。最后更新门的门值会作用在新的h(t),而1-门值会作用在h(t-1)上, 随后将两者的结果相加, 得到最终的隐含状态输出h(t), 这个过程意味着更新门有能力保留之前的结果, 当门值趋于1时, 输出就是新的h(t), 而当门值趋于0时, 输出就是上一时间步的h(t-1)。
Bi-GRU与Bi-LSTM逻辑相同。
3 工程实践
Pytorch中GRU工具在torch.nn包,通过torch.nn.GRU调用。
核心类
nn.GRU类初始化主要参数解释
- input_size: 输入张量x中特征维度的大小.
- hidden_size: 隐层张量h中特征维度的大小.
- num_layers: 隐含层的数量.
- bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用.
nn.GRU类实例化对象主要参数
- input: 输入张量x.
- h0: 初始化的隐层张量h.
nn.GRU使用示例
代码语言:python代码运行次数:0运行复制>>> import torch
>>> import torch.nn as nn
>>> rnn = nn.GRU(5, 6, 2)
>>> input = torch.randn(1, 3, 5)
>>> h0 = torch.randn(2, 3, 6)
>>> output, hn = rnn(input, h0)
>>> output
tensor([[[-0.2097, -2.2225, 0.6204, -0.1745, -0.1749, -0.0460],
[-0.3820, 0.0465, -0.4798, 0.6837, -0.7894, 0.5173],
[-0.0184, -0.2758, 1.2482, 0.5514, -0.9165, -0.6667]]],
grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.6578, -0.4226, -0.2129, -0.3785, 0.5070, 0.4338],
[-0.5072, 0.5948, 0.8083, 0.4618, 0.1629, -0.1591],
[ 0.2430, -0.4981, 0.3846, -0.4252, 0.7191, 0.5420]],
[[-0.2097, -2.2225, 0.6204, -0.1745, -0.1749, -0.0460],
[-0.3820, 0.0465, -0.4798, 0.6837, -0.7894, 0.5173],
[-0.0184, -0.2758, 1.2482, 0.5514, -0.9165, -0.6667]]],
grad_fn=<StackBackward>)
4 GRU评价
优势
GRU和LSTM作用相同,捕捉长序列语义关联时,能有效抑制梯度消失或爆炸,效果都优于传统RNN且计算复杂度相比LSTM要小。
缺点
GRU仍然不能完全解决梯度消失问题,同时其作用RNN的变体,有着RNN结构本身的一大弊端, 即不可并行计算, 这在数据量和模型体量逐步增大的未来,是RNN发展的关键瓶颈。
本文已收录在Github,关注我,紧跟本系列专栏文章,咱们下篇再续!