欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

用tensorflow构建动态RNN

程序员文章站 2024-03-25 10:41:12
...

直接看代码

def create_cell():
    cell = rnn.LSTMCell(num_units)
    return rnn.DropoutWrapper(cell, input_keep_prob=0.5)

rnn_cell = rnn.MultiRNNCell([create_cell() for _ in range(2)])
output, states = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32)

相关API:

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

参数

cell:一种rnn 的cell,本实例中传入了一个多层的rnncell,每层cell的基本单元是LSTMCell,并且使用了dropout

inputs:输入数据

如果 time_major == False (default)
input的形状必须为 [batch_size, max_time, embed_size]

如果 time_major == True
input输入的形状必须为 [max_time, batch_size, embed_size]

其中batch_size是批大小,max_time是每个序列的大小,而embed_size是序列里面每个分量的大小


返回的是一个元组 (outputs, state)

outputs:RNN的最后一层的输出,是一个tensor
如果为time_major== False,则shape [batch_size,max_time,cell.output_size]。如果为time_major== True,则shape: [max_time,batch_size,cell.output_size]。cell.output_size就是num_units

state: RNN最后时间步的state,如果cell.state_size是一个整数(一般是单层的RNNCell),则state的shape:[batch_size,cell.state_size]。如果它是一个元组(一般这里是 多层的RNNCell),那么它将是一个具有相应形状的元组。注意:如果若RNNCell是 LSTMCells,则state将为每层cell的LSTMStateTuple的元组Tuple(LSTMStateTuple,LSTMStateTuple,LSTMStateTuple)
 

相关标签: rnn