查询了很多,主要是使用了tf1的方式,此处给出tf2的方式供大家参考:
这里模型会存为4个文件:assets/variables/keras_metadata.pb/saved_model.pb
import tensorflow as tf
from model_1 import Aggre_model
def saved_model():
SEQUENCE = 4000
model = my_model(24, 20)
model.build(input_shape=[(None, SEQUENCE), (None, SEQUENCE)]) # 这里是我的输入维度
# 将学习率跟优化器也初始进来
lr = 0.0003
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
# 加载ckpt
ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, 'Model/checkmodel/', max_to_keep=5)
ckpt.restore(ckpt_manager.latest_checkpoint)
# model.compute_output_shape(input_shape=[(None, SEQUENCE), (None, SEQUENCE)]) # 这一步可有可无
# 模型保存为db文件
model.save('Model/my_db/')
return # 加入return可以避免返回其他错误
if __name__ =='__main__':
saved_model()
模型的调用也比较简单:
model_path = '' # pb文件位置
model = tf.keras.models.load_model(model_path)
后面的输入和输出就按照自己的数据加载方式去写就可以了