欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

NLP之简单笔记:LSTM

程序员文章站 2022-07-01 19:09:20
...

一. LSTM简介

  • LSTM:即Long Short-tem Memory,长短期记忆神经网络,1997年就被提出来了。

  • 传统RNN的缺点:
         1. 每个时刻都会更新掉上一时刻memory的信息;LSTM通过增加三个门,来选择保存更多时刻的信息。
         2.会出现梯度爆炸或者消失的情况;LSTM可以解决梯度消失的情况。(严格上来讲,只能缓解梯度消失,而不能完全解决)

  • LSTM结构如下图所示,参考https://apaszke.github.io/lstm-explained.html
    NLP之简单笔记:LSTM

二. LSTM的计算过程

现在有初始状态的输入xt, ht-1 ,ct-1参数说明:

  • x就是你的输入;
  • h表示hidden layer的神经元个数,就是你在定义LSTM结构时设置的参数;
  • c表示LSTM模型中记忆单元存储的状态。

LSTM的计算过程:

step1: 将xt和ht-1并起来,得到X;
step2:X分别和四个权值矩阵相乘,得到z, zi ,zf , zo

     z = tanh(WX),
     zi = sigmoid(Wi X)
     zf = sigmoid(Wf X)
     zo = sigmoid(Wo X)

step3:更新记忆单元状态,求ct
     ct = z⋅zi + ct-1 ⋅ zf

step4: 更新ht
     ht = zo ⋅ tanh(ct)

step5: 输出yt
      yt = f(ht)
(可以根据自己的需要选择f,例如选用softmax,sigmoid或者tanh函数等。)

三. LSTM的参数计算

从上面可以看到,

  • LSTM有4个输入,1个输出;
  • 输入的维度是X的维度,不是xt的维度,X_dim = x_dim + h _dim;
  • 就是4倍的 输入乘以输出 + 偏置项;
  • 因此,LSTM的参数计算为:p = 4*[ (x_dim + h_dim)*h_dim + h_dim ]

下面通过keras的示例验证一下。

"""
	基于 LSTM 的序列分类
"""
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.layers import Embedding
from keras.layers import LSTM

max_features = 1024

model = Sequential()
model.add(Embedding(max_features, output_dim=256))
model.add(LSTM(128))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])
model.summary()
#model.fit(x_train, y_train, batch_size=16, epochs=10)
#score = model.evaluate(x_test, y_test, batch_size=16)

模型的结构和参数如下图所示:
NLP之简单笔记:LSTM
可以看到LSTM这层的参数个数为197120。

输入x_dim = 256, h_dim = 128, 套用上述公式计算:4*((256+128)*128+128) = 197120

验证正确。

四. LSTM为什么能解决梯度消失的问题

4.1 RNN出现梯度消失或者爆炸的原因

RNN的结构如下图所示:
(图来自https://zhuanlan.zhihu.com/p/28687529)
NLP之简单笔记:LSTM

  • 从上图可以看到,隐藏层的状态St+1是前一个时刻St和Ws的函数。
  • 现在假设损失为L,则反向求导L对Ws的导数时,会出现St+1对St求导的连乘,St+1对St求导等于Ws,即会出现很多个Ws连乘的情况。(假设**函数就是1)
  • 如果t足够大,则当Ws小于1时,Ws * Ws * Ws…*Ws趋近于0,导致梯度消失;当Ws大于1时,Ws * Ws * Ws…*Ws会得到很大的值,导致梯度爆炸。
    NLP之简单笔记:LSTM

总结:RNN中的hidden layer的weight随着t被反复的使用。

4.2 LSTM缓解梯度消失的原因

LSTM中,memory存于cell中,类比RNN,求ct对ct-1的导数,有第二节可知,
     X = [ht-1, xt]
     z = tanh(WX),
     zi = sigmoid(Wi X)
     zf = sigmoid(Wf X)
     zo = sigmoid(Wo X)
     ct = z⋅zi + ct-1 ⋅ zf
因为有:
     ht = zo ⋅ tanh(ct)
所以:
     ht-1 = zo ⋅ tanh(ct-1)
可以得到的信息:

  • ct是z, zi, zf, ct-1的函数
  • z, zi, zf是ht-1的函数
  • ht-1是ct-1的函数

因此,ct对ct-1的求导除了要计算ct = z⋅zi + ct-1 ⋅ zf这一项之外,还受到z, zi, zf这几项的影响,所以ct对ct-1的求导可能是大于1,也可能在[0,1]之间。
假如求得的gradient开始趋于0,我们可以通过设置z, zi, zf的值,让ct对ct-1的导数往1靠拢,从而解决梯度消失的问题。
那么,如何设置z, zi, zf的值呢?这几个参数是网络学习的呀!!!通过学习,决定什么时候让梯度消失,什么时候该保留。这就是LSTM多出这几个门厉害的地方了。