欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Keras入门--Mnist手写数字识别(源自Keras官方样例)

程序员文章站 2022-09-20 14:59:11
Keras入门--Mnist手写数字识别--源自Keras官方样例1. Keras官方Mnist例子2. 改造模型2.1 加载keras相关模块2.2 准备训练和测试Mnist数据2.3 构建模型2.4 训练模型并保存模型3. 测试1. Keras官方Mnist例子https://keras.io/examples/vision/mnist_convnet/2. 改造模型2.1 加载keras相关模块import osimport tempfileimport numpy as np# fr...

1. Keras官方Mnist例子

https://keras.io/examples/vision/mnist_convnet/

2. 改造模型

2.1 加载keras相关模块

import os
import tempfile
import numpy as np
# from tensorflow import keras
# from tensorflow.keras import layers
# 以上注释掉的,是基于tensorflow 2.x版本的高级接口,因我的tensorflow版本为1.x的,所以改用如下接口
from tensorflow.contrib.keras.python.keras.datasets.mnist import load_data
from tensorflow.contrib.keras.python.keras.utils import to_categorical
from tensorflow.contrib.keras.python import keras

2.2 准备训练和测试Mnist数据

keras有load_data接口下载(https://s3.amazonaws.com/img-datasets/mnist.npz)数据。但是本地网络接口下载数据太慢,半天没有反应。所以先把mnist.npz下载下来,然后用numpy解析。

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# the data, split between train and test sets
# (x_train, y_train), (x_test, y_test) = load_data() #此接口下载太慢,提前下载用网页,然后用如下方式处理数据。
with np.load('../datasets/Keras/mnist.npz') as data:
    x_train = data['x_train']
    y_train = data['y_train']
    x_test = data['x_test']
    y_test = data['y_test']

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

2.3 构建模型

官方源码使用keras.models.Input(shape=input_shape)方式将输入数据的shape传入模型,但是编译老报错,说:必须是layer类的。于是将输入数据的shape放到二维卷积Conv2D类属性中。

model = keras.models.Sequential(
    [
        # keras.models.Input(shape=input_shape), # 源码使用的
        keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=input_shape), # 进行了改装,添加input_shape参数值
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(num_classes, activation="softmax"),
    ]
)
model.summary()

2.4 训练模型并保存模型

保存模型是自己加的,为了拿实际图片进行测试,以验证识别的准确性。更接近生活中的实际应用。

# train the model
batch_size = 128
epochs = 15

# loss function and Gradient descent, get the optimal parameters value
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

# the same as scikit-learn, predict interface input the data, predict the class
out = model.predict(x_train)

# start iteration to train
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

# Evaluate the trained model
score = model.evaluate(x_test, y_test, verbose=0)

# save model to .h5 file
fname = "E:/work/src/Keras/Minist/model/model-new.h5"
keras.models.save_model(model, fname)

print("Test loss:", score[0])
print("Test accuracy:", score[1])

3. 测试

import os
import cv2
import numpy as np
from tensorflow.contrib.keras.python import keras

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

def img_prepare(file_name):
    """handle the png image"""
    image_raw = cv2.imread(file_name)
    rgb_to_gray = cv2.cvtColor(image_raw, cv2.COLOR_RGB2GRAY)
    data_type_convert = rgb_to_gray.astype(np.float32)
    data_normalized = (data_type_convert) * 1.0/255
    data = np.reshape(data_type_convert, [1,28,28,1])
    return data

if __name__ == '__main__':
    # load model
    model_file = "../src/Keras/Minist/model/model-new.h5"
    new_model = keras.models.load_model(model_file)

    # prepare image
    file_name = "../src/Keras/Minist/00007.png"
    test_x = img_prepare(file_name)
    out2 = new_model.predict(test_x)
    print(out2)


打印是形如:
[0 0 0 0 0 0 0 1 0 0]第八个类别

本文地址:https://blog.csdn.net/duanyuwangyuyan/article/details/109611236

相关标签: 深度学习