天天看点

tensorflow 2.x下checkpoint转为pb模型

查询了很多,主要是使用了tf1的方式,此处给出tf2的方式供大家参考:

这里模型会存为4个文件:assets/variables/keras_metadata.pb/saved_model.pb

import tensorflow as tf
from model_1 import Aggre_model

def saved_model():

    SEQUENCE = 4000
    model = my_model(24, 20)
    model.build(input_shape=[(None, SEQUENCE), (None, SEQUENCE)]) # 这里是我的输入维度
    # 将学习率跟优化器也初始进来
    lr = 0.0003
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    # 加载ckpt
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt, 'Model/checkmodel/', max_to_keep=5)
    ckpt.restore(ckpt_manager.latest_checkpoint)
    # model.compute_output_shape(input_shape=[(None, SEQUENCE), (None, SEQUENCE)]) # 这一步可有可无
	# 模型保存为db文件
    model.save('Model/my_db/') 
    return # 加入return可以避免返回其他错误

if __name__ =='__main__':
    saved_model()

           

模型的调用也比较简单:

model_path = '' # pb文件位置
model = tf.keras.models.load_model(model_path)
           

后面的输入和输出就按照自己的数据加载方式去写就可以了

继续阅读