Tensorflow ckpt模型转pb格式
程序员文章站
2022-06-26 15:29:33
...
Tensorflow ckpt模型转pb格式
1.ckpt转pb
def ckpt2pb():
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = BASE_DIR + '/log/train/PB_30_1/model.ckpt' # .ckpt路径
output_graph = BASE_DIR + '/log/train/PB_30_1/model_pb.pb' # 保存.pb格式路径
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "pre_1, pre_2" # !!!必须是自己网络中节点
saver = tf.train.import_meta_graph(MODEL_PATH+'.meta', clear_devices=True) # .meta路径
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
with tf.Session() as sess:
saver.restore(sess, MODEL_PATH) # 恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
with tf.io.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
2.可以直接在训练过程中保存为pb格式
from tensorflow.python.framework import graph_util
......
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['pre_1','pre_2']) # ['pre_1','pre_2']是你要保存的输出节点
......
# 写入序列化的 PB 文件
with tf.gfile.FastGFile(os.path.join(LOG_DIR, "train/%s/model.pb" % tt), mode='wb') as f:
f.write(constant_graph.SerializeToString())
但是我在测试的时候用直接保存为pb格式和ckpt测试结果有所不同。不清楚具体问题在哪。哎????。
测试代码
def eval():
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
config.log_device_placement = True
graph = tf.get_default_graph()
with graph.as_default():
with tf.Session(config=config) as sess:
with gfile.FastGFile('D:/usprogram/lyt/tegcnn/models/log/train/PB_30_1/' + 'model_pb.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# 需要有一个初始化的过程
sess.run(tf.global_variables_initializer())
# 需要先复原变量
# print(sess.run('X:0'))
# 1
is_training = sess.graph.get_tensor_by_name('is_training:0')
# 输入
input_x = sess.graph.get_tensor_by_name('X:0')
keep_prob = sess.graph.get_tensor_by_name('keep_prob:0')
op1 = sess.graph.get_tensor_by_name('pre_1:0')
pred_1_val, pred_2_val = sess.run([op1, op2],
feed_dict={input_x: img1,
is_training: False,
keep_prob: 1.0})
pred_label_1 = np.argmax(pred_1_val)
print('class: %s' % (num2tai[pred_label_1]))
推荐阅读
-
tensorflow三种模型的加载和保存的方法(.ckpt,.pb,SavedModel)
-
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
-
用bert训练模型并转换为pb格式
-
tensorflow三种模型的加载和保存的方法(.ckpt,.pb,SavedModel)
-
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
-
文字检测模型EAST应用详解 ckpt pb的tf加载,opencv加载
-
用bert训练模型并转换为pb格式
-
TensorFlow模型转ONNX格式-Part1
-
Tensorflow模型的格式
-
Tensorflow ckpt模型转pb格式