用tensorflow构建动态RNN
直接看代码
def create_cell():
cell = rnn.LSTMCell(num_units)
return rnn.DropoutWrapper(cell, input_keep_prob=0.5)
rnn_cell = rnn.MultiRNNCell([create_cell() for _ in range(2)])
output, states = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32)
相关API:
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
参数
cell:一种rnn 的cell,本实例中传入了一个多层的rnncell,每层cell的基本单元是LSTMCell,并且使用了dropout
inputs:输入数据
如果 time_major == False (default)
input的形状必须为 [batch_size, max_time, embed_size]
如果 time_major == True
input输入的形状必须为 [max_time, batch_size, embed_size]
其中batch_size是批大小,max_time是每个序列的大小,而embed_size是序列里面每个分量的大小
返回的是一个元组 (outputs, state)
outputs:RNN的最后一层的输出,是一个tensor
如果为time_major== False,则shape [batch_size,max_time,cell.output_size]。如果为time_major== True,则shape: [max_time,batch_size,cell.output_size]。cell.output_size就是num_units
state: RNN最后时间步的state,如果cell.state_size是一个整数(一般是单层的RNNCell),则state的shape:[batch_size,cell.state_size]。如果它是一个元组(一般这里是 多层的RNNCell),那么它将是一个具有相应形状的元组。注意:如果若RNNCell是 LSTMCells,则state将为每层cell的LSTMStateTuple的元组Tuple(LSTMStateTuple,LSTMStateTuple,LSTMStateTuple)
上一篇: 视觉SLAM理论与实践学习笔记
下一篇: 【学习记录1】视觉SLAM