欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

用bert训练模型并转换为pb格式

程序员文章站 2023-11-06 23:45:46
具体代码在github:https://github.com/danan0755/Bert_Classifier/blob/master/Bert_Train.pydef serving_input_fn(): # 保存模型为SaveModel格式 # 采用最原始的feature方式,输入是feature Tensors。 # 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples df = pd.read_...

具体代码在github:
https://github.com/danan0755/Bert_Classifier/blob/master/Bert_Train.py

def serving_input_fn():
    # 保存模型为SaveModel格式
    # 采用最原始的feature方式,输入是feature Tensors。
    # 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples
    df = pd.read_csv(FLAGS.data_dir, delimiter="\t", names=['labels', 'text'], header=None)

    dense_units = len(df.labels.unique())
    label_ids = tf.placeholder(tf.int32, [None, dense_units], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, 128], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, 128], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, 128], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'label_ids': label_ids,
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
    })()
    return input_fn

本文地址:https://blog.csdn.net/qq236237606/article/details/107078973