以下是一个保存和加载连接的TensorFlow模型的代码示例:
保存模型:
import tensorflow as tf
# 创建一个连接的TensorFlow模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=10)
# 保存模型
model.save('my_model.h5')
加载模型:
import tensorflow as tf
# 加载模型
model = tf.keras.models.load_model('my_model.h5')
# 使用模型进行预测
predictions = model.predict(x_test)
请确保在保存模型之前先训练模型。保存模型时,模型的权重、架构和训练配置都会被保存在名为'my_model.h5'的文件中。加载模型时,只需提供文件名即可。然后,您可以使用加载的模型进行预测或其他操作。
上一篇:保存和加载库存