欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【tensorflow】数字识别 — rnn 算法

程序员文章站 2024-03-07 21:58:45
...

【tensorflow】数字识别 — rnn 算法
  在数字识别- softmax回归文章中使用softmax回归算法对图片进行分类,准确率在92%左右,那么如何使算法准确率得到提升了?本篇文章将使用 rnn 神经网络算法进行数字识别。

import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data

# 获取 mnist 数据
mnist = input_data.read_data_sets('data/mnist', one_hot=True)

# 图片size 为28*28,以一列为一个 time_step, 所以time_step_size=28,input_size = 28
time_step_size = 28
input_size = 28
hidden_size = 256
layer_size = 2

_X = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
X = tf.reshape(_X, [-1, time_step_size, input_size])
keep_prob = tf.placeholder(tf.float32)
batch_size = tf.placeholder(tf.int32, [])

# 定义 lstm 单元
def lstm_cell():
    lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, reuse=tf.get_variable_scope().reuse)
    return rnn.DropoutWrapper(lstm_cell, output_keep_prob=keep_prob)


mlstm_cell = rnn.MultiRNNCell([lstm_cell() for _ in range(layer_size)], state_is_tuple=True)

init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)

outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)

# 取最后一个 time_step的输出作为softmax层的输入
h_state = outputs[:, -1, :]

W = tf.Variable(tf.truncated_normal([hidden_size, 10], stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=[10]))
prediction = tf.nn.softmax(tf.matmul(h_state, W) + b)

# 计算损失
cross_entropy = -tf.reduce_sum(y * tf.log(prediction))

# Adam 进行模型优化
train_op = tf.train.AdamOptimizer(1e-3).minimize(cross_entropy)

# 正确预测数目
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))

# 计算准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.Session()
sess.run(tf.global_variables_initializer())

for i in range(2001):
    xs, ys = mnist.train.next_batch(100)
    if i % 200 == 0:
        train_accuracy = sess.run(accuracy, feed_dict={_X: xs, y: ys, keep_prob: 1.0, batch_size: 100})
        print("Iter%d, step %d, training accuracy %g" % (mnist.train.epochs_completed, i, train_accuracy))
    sess.run(train_op, feed_dict={_X: xs, y: ys, keep_prob: 0.5, batch_size: 100})
print("test accuracy %g" % sess.run(accuracy, feed_dict={_X: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0,batch_size: mnist.test.labels.shape[0]}))

  正确率为98.31%