欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

对LSTM中间变量形状shape的理解, 附keras中LSTM的各个变量的shape理解

程序员文章站 2022-07-01 19:10:26
...

假设输入的shape是[bs, length, d], bs是批数量, length是预定义的最大序列长度, d是序列中每个step的维度(对于图像序列,可以理解为每一帧的特征向量维度).

下面说对于bs中1个样本的情况, 也就是shape为[1, length, d]
LSTM(或者RNN)有多个cell, 1个cell对应1个step(1个时刻的状态), 这些cell之间的网络层是共享的, 即对于1个LSTM层, 所有的参数数量等于1个cell的参数数量(下图的LSTM由3个cell)
对LSTM中间变量形状shape的理解, 附keras中LSTM的各个变量的shape理解
下图是1个样本在1个cell中(step=t)的工作原理和中间变量的shape:
对LSTM中间变量形状shape的理解, 附keras中LSTM的各个变量的shape理解
图中的六边形表示神经网络层(通常是全连接层), 表示矩阵乘积操作, 里面是这个网络层的权重W(shape:[d+m, m]), 可以分为Wa和Ua, Wa(shape: [d, m])是与Xt相乘的, Ua(shape: [m, m])是与ht-1相乘的, 红色是权重的shape. 
4个六边形对应3个门(输入门涉及了两个六边形, 由it, ct共同决定).
m是LSTM的unit个数, 就是每个网络层的神经元个数(全连接层的输出向量维度)
具体的各个操作如下:
对LSTM中间变量形状shape的理解, 附keras中LSTM的各个变量的shape理解
更多的关于LSTM的原理可以参考这里这里这里.
(个人原创,转载请注明出处https://blog.csdn.net/ying86615791/article/details/103085269,谢谢!)

每个cell都由1个Xt,1个Ct-1和1个ht-1作为输入,
生成1个ht和1个Ct.

因为有length个step, 就有length个cell, 
那么1个样本(X: [length, d])最终可以产生的各个变量的shape如下
h: [length, m], 隐藏状态
c: [length, m], 细胞状态

如果网络只有1个输出
通常取最后1个cell(即最后1个step, 即最后1个时刻)输出的h_last[1, m]作为最终LSTM的输出, 因为该此时的h已经融合了之前的信息了.

上面是1个样本的情况, 对于bs个样本, 该LSTM的各个变量的shape就是
h: [bs, length, m], 隐藏状态
c: [bs, length, m], 细胞状态
最后LSTM的输出就是
h_last: [bs, 1, m]==>[bs, m]

下面从从keras中1个LSTM使用例子来理解这几个变量的shape

import numpy as np
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Input

bs = 5
length = 60
d = 1
m = 30
x = np.random.rand(bs, length, d)
input_ = Input(shape=(length, d))
lstm, hidden, cell = LSTM(units=m, return_state=True, return_sequences=True)(input_)
model = Model(inputs=input_, outputs=(lstm, hidden, cell))

print('input shape ',x.shape)
output, hidden, cell = model(x)
print('output shape ',output.shape)
wx, ux, b = model.layers[1].trainable_weights
print('wx shape ',wx.shape)
print('ux shape ',ux.shape)
print('b shape ',b.shape)

length=60表示序列长度为60个step,过程中会有60个cell
运行后, 各个变量的shape如下
对LSTM中间变量形状shape的理解, 附keras中LSTM的各个变量的shape理解
这里的wx, ux, b是把4个全连接层的权重拼在一起的结果参考
也就是, 对于其中1层的话,
wx: [1, 30]
ux:  [30, 30]
b: [30]
这里的hidden是最后1个step(也就是最后1个cell)的h, cell也是最后1个step的c
output shape的意思是包含了60个cell的h, 每个h的维度是m=30, output由所有cell输出的h组成
如果只取最后时刻的h作为输出的话, 就是
output[:, -1, :]==>shape: [5, 30]
注意,其实这个时候output[:,-1, :]的值就是就是hidden了, 验证如下:对LSTM中间变量形状shape的理解, 附keras中LSTM的各个变量的shape理解
通过上面可以知道所有cell的h都在output里面, 那么keras中如何拿到所有cell的细胞状态c呢?看这里好像还不能实现?