欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow ckpt模型转pb格式

程序员文章站 2022-06-26 15:29:33
...

Tensorflow ckpt模型转pb格式

1.ckpt转pb

def ckpt2pb():
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    MODEL_PATH = BASE_DIR + '/log/train/PB_30_1/model.ckpt'     # .ckpt路径
    output_graph = BASE_DIR + '/log/train/PB_30_1/model_pb.pb'  # 保存.pb格式路径

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "pre_1, pre_2"   # !!!必须是自己网络中节点

    saver = tf.train.import_meta_graph(MODEL_PATH+'.meta', clear_devices=True)  # .meta路径
    graph = tf.get_default_graph()  # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图

    with tf.Session() as sess:
        saver.restore(sess, MODEL_PATH)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.io.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

2.可以直接在训练过程中保存为pb格式

from tensorflow.python.framework import graph_util

......
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['pre_1','pre_2'])  # ['pre_1','pre_2']是你要保存的输出节点
......
# 写入序列化的 PB 文件
with tf.gfile.FastGFile(os.path.join(LOG_DIR, "train/%s/model.pb" % tt), mode='wb') as f:
	f.write(constant_graph.SerializeToString())

但是我在测试的时候用直接保存为pb格式和ckpt测试结果有所不同。不清楚具体问题在哪。哎????。

测试代码

def eval():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = True
    graph = tf.get_default_graph()
    with graph.as_default():
        with tf.Session(config=config) as sess:
            with gfile.FastGFile('D:/usprogram/lyt/tegcnn/models/log/train/PB_30_1/' + 'model_pb.pb', 'rb') as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                tf.import_graph_def(graph_def, name='')

                # 需要有一个初始化的过程
                sess.run(tf.global_variables_initializer())

                # 需要先复原变量
                # print(sess.run('X:0'))
                # 1
                is_training = sess.graph.get_tensor_by_name('is_training:0')
                # 输入
                input_x = sess.graph.get_tensor_by_name('X:0')
                keep_prob = sess.graph.get_tensor_by_name('keep_prob:0')

                op1 = sess.graph.get_tensor_by_name('pre_1:0')

                pred_1_val, pred_2_val = sess.run([op1, op2],
                                                  feed_dict={input_x: img1,
                                                             is_training: False,
                                                             keep_prob: 1.0})
                pred_label_1 = np.argmax(pred_1_val)
                print('class: %s' % (num2tai[pred_label_1]))
相关标签: 机器学习