Tensorflow学习笔记----循环神经网络RNN原理及实现
程序员文章站
2024-03-14 12:24:58
...
循环神经网络(Recurrent Neural Network,RNN) 是一类具有短期记忆能力的神经网络,在循环神经网络中,神经元不仅可以接受其他神经元的信息,还可以接受自身的信息,形成一个环路结构。在很多现实任务中,网络的输出不仅和当前的输入有关,也和过去一段时间的输出相关。
从网络结构上,循环神经网络会记忆之前的信息,并利用之前的信息影响后面结点的输出。即:循环神经网络的隐藏层之间的结点是有连接的,隐藏层的输入不仅包括输入层的输出,还包括上一时刻隐藏层的输出。
常用于文本填充、时间序列、语音识别等序列数据。循环单元如图所示:
- 我们从一个分析来理解学习RNN:对于网站上的评价,我们对其进行语句分析并判断分类这个语句时好评还是差评;在这个问题中我们遇到两个问题:1.长句子单词太多参数量过大,2.可能没有相关的语义;对于第一个问题,我们可以使用卷积神经网络中学过的权重共享的思想,对于第二个问题,我们可以在原来的模型上额外添加一个memory,对于后续单词的听取时,我们同时应用前一个单词的信息,同时也得到一个新的信息传入下一个单词。
所以对于一个循环神经单元来说,它的表达式为:
ht = fW(ht-1 , xt) = tanh(Whhht-1 + Wxhxt) 这里的tanh可以为relu
yt = Whyht
对于SimpleRNNCell:out,h1 = call(x,h0),其中x:[b , seq len , word vec];h0/h1:[b , h dim];out:[b , h dim];
单个循环单元代码:
x = tf.random.normal([4,80,100])
ht0 = x[:,0,:]
cell = tf.keras.layers.SimpleRNNCell(64)
out,ht1 = cell(ht0,[tf.zeros([4,64])])
print(out.shape,ht1[0].shape) #out:(4, 64) (4, 64)
多层循环单元:
x = tf.random.normal([4,80,100])
xt0 = x[:,0,:]
cell = tf.keras.layers.SimpleRNNCell(64)
cell2 = tf.keras.layers.SimpleRNNCell(64)
state0 = [tf.zeros([4,64])] #相当于memory0
state1 = [tf.zeros([4,64])]
out0,state0 = cell(xt0,state0) #memory0训练中也在更新自己;out0为第一层输出
out2,state2 = cell2(out,state0)
print(out2.shape,state2[0].shape) #out:(4, 64) (4, 64)
对于一个句子,我们可以使用循环语句,使得对于每一个单词都执行一次out0,state0 = cell(xt0,state0); out2,state2 = cell2(out,state0),从而得出最终的out及state。
接下来我们用RNN做一个情感分类实战:情感分类就是输入一个语句,再经过一系列训练,对于这个语句进行判断情感偏向(积极/消极),这是一个二分类问题。
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
batchsz = 128
# the most frequest words
total_words = 10000 #常用单词规定一万个
max_review_len = 80 #最大长度
embedding_len = 100
#x_train:输入语句;y_train:情感偏向
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=total_words)
# x_train:[b, 80];x_test: [b, 80]
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len) #将x_train和x_test都设置为最大长度
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len) #大于maxlen将会减掉,小于maxlen将会填充
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(1000).batch(batchsz, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz, drop_remainder=True)
print('x_train shape:', x_train.shape, tf.reduce_max(y_train), tf.reduce_min(y_train))
#x_train中一共有25k的句子,每个句子有80个单词,y_train为0时差评,y_train为1时为好评
print('x_test shape:', x_test.shape) #x_test中一共有25k的句子,每个句子有80个单词,与x_train一样
class MyRNN(keras.Model):
def __init__(self, units):
super(MyRNN, self).__init__()
self.state0 = [tf.zeros([batchsz, units])] # [b, 64]
self.state1 = [tf.zeros([batchsz, units])]
#将句子转化为embedding的表示,即[b, 80] => [b, 80, 100]
self.embedding = layers.Embedding(total_words, embedding_len,
input_length=max_review_len)
# [b, 80, 100] -> h_dim: 64
# RNN: cell1 ,cell2, cell3
self.rnn_cell0 = layers.SimpleRNNCell(units, dropout=0.5) #dropout防止过拟合
self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.5)
# fc全连接层输出, [b, 80, 100] => [b, 64] => [b, 1]
self.outlayer = layers.Dense(1)
def call(self, inputs, training=None): #前向运算过程
# [b, 80]
x = inputs
# embedding: [b, 80] => [b, 80, 100]
x = self.embedding(x)
# rnn cell compute
# [b, 80, 100] => [b, 64]
state0 = self.state0
state1 = self.state1
for word in tf.unstack(x, axis=1): # word: [b, 100],依次遍历一个句子中的每个单词
# h1 = x*wxh+h0*whh
# out0: [b, 64]
out0, state0 = self.rnn_cell0(word, state0, training) #更新state0,而且新的会覆盖掉原来的
# out1: [b, 64]
out1, state1 = self.rnn_cell1(out0, state1, training) #更新state1
# out: [b, 64] => [b, 1]
x = self.outlayer(out1)
# p(y is pos|x)
prob = tf.sigmoid(x) #将x压缩到(0,1)
return prob
def main():
units = 64
epochs = 4
model = MyRNN(units)
model.compile(optimizer = keras.optimizers.Adam(0.001),
loss = tf.losses.BinaryCrossentropy(),
metrics=['accuracy'])
model.fit(db_train, epochs=epochs, validation_data=db_test)
model.evaluate(db_test)
if __name__ == '__main__':
main()