最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

你的时序模型为何越训越慢?GRU两门控机制如何突破LSTM算力瓶颈

网站源码admin1浏览0评论

你的时序模型为何越训越慢?GRU两门控机制如何突破LSTM算力瓶颈

1 简介

GRU(Gated Recurrent Unit)也称门控循环单元结构,传统RNN的变体,同LSTM一样有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸现象。

但其结构和计算比LSTM更简单,核心结构可分两部分:

  • 更新门
  • 重置门

2 内部结构

2.1 示意图和计算公式

2.2 更新门和重置门

结构图:

2.3 内部结构分析

类似LSTM门控,先计算更新门和重置门的门值,分别是z(t)、r(t),用X(t)与h(t-1)拼接进行线性变换,再经sigmoid激活。

之后重置门门值作用在h(t-1),代表控制上一时间步传来的信息有多少可被利用。接着用这重置后的h(t-1)进行基本的RNN计算,即与x(t)拼接进行线性变化,经过tanh激活,得到新的h(t)。最后更新门的门值会作用在新的h(t),而1-门值会作用在h(t-1)上, 随后将两者的结果相加, 得到最终的隐含状态输出h(t), 这个过程意味着更新门有能力保留之前的结果, 当门值趋于1时, 输出就是新的h(t), 而当门值趋于0时, 输出就是上一时间步的h(t-1)。

Bi-GRU与Bi-LSTM逻辑相同。

3 工程实践

Pytorch中GRU工具在torch.nn包,通过torch.nn.GRU调用。

核心类

nn.GRU类初始化主要参数解释
  • input_size: 输入张量x中特征维度的大小.
  • hidden_size: 隐层张量h中特征维度的大小.
  • num_layers: 隐含层的数量.
  • bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用.
nn.GRU类实例化对象主要参数
  • input: 输入张量x.
  • h0: 初始化的隐层张量h.

nn.GRU使用示例

代码语言:python代码运行次数:0运行复制
>>> import torch
>>> import torch.nn as nn
>>> rnn = nn.GRU(5, 6, 2)
>>> input = torch.randn(1, 3, 5)
>>> h0 = torch.randn(2, 3, 6)
>>> output, hn = rnn(input, h0)
>>> output
tensor([[[-0.2097, -2.2225,  0.6204, -0.1745, -0.1749, -0.0460],
         [-0.3820,  0.0465, -0.4798,  0.6837, -0.7894,  0.5173],
         [-0.0184, -0.2758,  1.2482,  0.5514, -0.9165, -0.6667]]],
       grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.6578, -0.4226, -0.2129, -0.3785,  0.5070,  0.4338],
         [-0.5072,  0.5948,  0.8083,  0.4618,  0.1629, -0.1591],
         [ 0.2430, -0.4981,  0.3846, -0.4252,  0.7191,  0.5420]],

        [[-0.2097, -2.2225,  0.6204, -0.1745, -0.1749, -0.0460],
         [-0.3820,  0.0465, -0.4798,  0.6837, -0.7894,  0.5173],
         [-0.0184, -0.2758,  1.2482,  0.5514, -0.9165, -0.6667]]],
       grad_fn=<StackBackward>)

4 GRU评价

优势

GRU和LSTM作用相同,捕捉长序列语义关联时,能有效抑制梯度消失或爆炸,效果都优于传统RNN且计算复杂度相比LSTM要小。

缺点

GRU仍然不能完全解决梯度消失问题,同时其作用RNN的变体,有着RNN结构本身的一大弊端, 即不可并行计算, 这在数据量和模型体量逐步增大的未来,是RNN发展的关键瓶颈。

本文已收录在Github,关注我,紧跟本系列专栏文章,咱们下篇再续!

发布评论

评论列表(0)

  1. 暂无评论