欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

python深度学习TensorFlow神经网络模型的保存和读取

程序员文章站 2022-06-25 12:26:15
目录之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以...

之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西。

tensorflow提供了一个非常简单的api,即tf.train.saver类来保存和还原一个神经网络模型。

下面代码给出了保存tensorflow模型的方法:

import tensorflow as tf

# 声明两个变量
v1 = tf.variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.saver(write_version=tf.train.saverdef.v1) # 声明tf.train.saver类用于保存模型
with tf.session() as sess:
    sess.run(init_op)
    print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
    print("v2:", sess.run(v2))
    saver_path = saver.save(sess, "save/model.ckpt")  # 将模型保存到save/model.ckpt文件
    print("model saved in file:", saver_path)

注:saver方法已经发生了更改,现在是v2版本,tf.train.saver(write_version=tf.train.saverdef.v1)括号里加入该参数可继续使用v1,但会报warning,可忽略。若使用saver = tf.train.saver()则默认使用当前的版本(v2),保存后在save这个文件夹中会出现4个文件,比v1版多出model.ckpt.data-00000-of-00001这个文件,这点感谢评论里那位朋友指出。至于这个文件的含义到目前我仍不是很清楚,也没查到具体资料,tensorflow15年底开源到现在很多类啊函数都一直发生着变动,或被更新或被弃用,可能一些代码在当时是没问题的,但过了一大段时间后再跑可能就会报错,在此注明事件时间:2017.4.30

这段代码中,通过saver.save函数将tensorflow模型保存到了save/model.ckpt文件中,这里代码中指定路径为"save/model.ckpt",也就是保存到了当前程序所在文件夹里面的save文件夹中。

tensorflow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中会出现3个文件,因为tensorflow会将计算图的结构和图上参数取值分开保存。

checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.saver类持久化的所有tensorflow模型文件的文件名。当某个保存的tensorflow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为checkpointstate protocol buffer.

model.ckpt.meta文件保存了tensorflow计算图的结构,可以理解为神经网络的网络结构
tensorflow通过元图(metagraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。tensorflow中元图是由metagraphdef protocol buffer定义的。metagraphdef 中的内容构成了tensorflow持久化时的第一个文件。保存metagraphdef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

model.ckpt文件保存了tensorflow程序中每一个变量的取值,这个文件是通过sstable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过savedslice protocol buffer定义的。savedslice类型中保存了变量的名称、当前片段的信息以及变量取值。tensorflow提供了tf.train.newcheckpointreader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.newcheckpointreader类这里不做说明,自查。

python深度学习TensorFlow神经网络模型的保存和读取

下面代码给出了加载tensorflow模型的方法:

可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?

import tensorflow as tf

# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.variable(tf.random_normal([2, 3]), name="v2")
saver = tf.train.saver() # 声明tf.train.saver类用于保存模型
with tf.session() as sess:
    saver.restore(sess, "save/model.ckpt") # 即将固化到硬盘中的session从保存路径再读取出来
    print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比
    print("v2:", sess.run(v2))
    print("model restored")

运行结果:

v1: [[ 0.76705766  1.82217288]]
v2: [[-0.98012197  1.2369734   0.5797025 ]
 [ 2.50458145  0.81897354  0.07858191]]
model restored

这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了tensorflow计算图上所有的运算,并声明了一个tf.train.saver类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。
也就是说使用tensorflow完成了一次模型的保存和读取的操作。

如果不希望重复定义图上的运算,也可以直接加载已经持久化的图:

import tensorflow as tf
# 在下面的代码中,默认加载了tensorflow计算图上定义的全部变量
# 直接加载持久化的图
saver = tf.train.import_meta_graph("save/model.ckpt.meta")
with tf.session() as sess:
    saver.restore(sess, "save/model.ckpt")
    # 通过张量的名称来获取张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))

运行程序,输出:

[[ 0.76705766  1.82217288]]

有时可能只需要保存或者加载部分变量。
比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。

为了保存或者加载部分变量,在声明tf.train.saver类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.saver([v1])命令来构建tf.train.saver类,那么只有变量v1会被加载进来。

以上就是python深度学习tensorflow神经网络模型的保存和读取的详细内容,更多关于tensorflow网络模型保存和读取的资料请关注其它相关文章!