tensorflow2保存和加载模型 (tensorflow2.0官方教程翻译)

模型进度可以在训练期间和训练后保存。这意味着模型可以在它停止的地方继续,并避免长时间的训练。保存还意味着您可以共享您的模型,其他人可以重新创建您的工作。当发布研究模型和技术时,大多数机器学习实践者共享:

共享此数据有助于其他人了解模型的工作原理,并使用新数据自行尝试。

注意:小心不受信任的代码(TensorFlow模型是代码)。有关详细信息,请参阅安全使用TensorFlow

选项

保存TensorFlow模型有多种方法,具体取决于你使用的API。本章节使用tf.keras(一个高级API,用于TensorFlow中构建和训练模型),有关其他方法,请参阅TensorFlow保存和还原指南保存在eager中

1. 设置

1.1. 安装和导入

需要安装和导入TensorFlow和依赖项

1.2. 获取样本数据集

我们将使用MNIST数据集来训练我们的模型以演示保存权重,要加速这些演示运行,请只使用前1000个样本数据:

1.3. 定义模型

让我们构建一个简单的模型,我们将用它来演示保存和加载权重。

2. 在训练期间保存检查点

主要用例是在训练期间和训练结束时自动保存检查点,通过这种方式,您可以使用训练有素的模型,而无需重新训练,或者在您离开的地方继续训练,以防止训练过程中断。

tf.keras.callbacks.ModelCheckpoint是执行此任务的回调,回调需要几个参数来配置检查点。

2.1. 检查点回调使用情况

训练模型并将其传递给 ModelCheckpoint回调

这将创建一个TensorFlow检查点文件集合,这些文件在每个周期结束时更新。 文件夹checkpoint_dir下的内容如下:(Linux系统使用 ls命令查看)

创建一个新的未经训练的模型,仅从权重恢复模型时,必须具有与原始模型具有相同体系结构的模型,由于它是相同的模型架构,我们可以共享权重,尽管它是模型的不同示例。

现在重建一个新的,未经训练的模型,并在测试集中评估它。未经训练的模型将在随机水平(约10%的准确率):

然后从检查点加载权重,并重新评估:

2.2. 检查点选项

回调提供了几个选项,可以为生成的检查点提供唯一的名称,并调整检查点频率。

训练一个新模型,每5个周期保存一次唯一命名的检查点:

现在,查看生成的检查点并选择最新的检查点:

注意:默认的tensorflow格式仅保存最近的5个检查点。

要测试,请重置模型并加载最新的检查点:

3. 这些文件是什么?

上述代码将权重存储到检查点格式的文件集合中,这些文件仅包含二进制格式的训练权重. 检查点包含:

如果您只在一台机器上训练模型,那么您将有一个带有后缀的分片:.data-00000-of-00001

4. 手动保存权重

上面你看到了如何将权重加载到模型中。手动保存权重同样简单,使用Model.save_weights方法。

5. 保存整个模型

模型和优化器可以保存到包含其状态(权重和变量)和模型配置的文件中,这允许您导出模型,以便可以在不访问原始python代码的情况下使用它。由于恢复了优化器状态,您甚至可以从中断的位置恢复训练。

保存完整的模型非常有用,您可以在TensorFlow.js(HDF5, Saved Model) 中加载它们,然后在Web浏览器中训练和运行它们,或者使用TensorFlow Lite(HDF5, Saved Model)将它们转换为在移动设备上运行。

5.1. 作为HDF5文件

Keras使用HDF5标准提供基本保存格式,出于我们的目的,可以将保存的模型视为单个二进制blob。

现在从该文件重新创建模型:

检查模型的准确率:

此方法可保存模型的所有东西:

Keras通过检查架构来保存模型,目前它无法保存TensorFlow优化器(来自tf.train)。使用这些时,您需要在加载后重新编译模型,否则您将失去优化程序的状态。

5.2. 作为 saved_model

注意:这种保存tf.keras模型的方法是实验性的,在将来的版本中可能会有所改变。

创建一个新的模型:

创建saved_model,并将其放在带时间戳的目录中:

从保存的模型重新加载新的keras模型:

运行加载的模型进行预测:

6. 下一步是什么

这是使用tf.keras保存和加载的快速指南。

最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本:https://tensorflow.google.cn/beta/tutorials/keras/save_and_restore_models 翻译建议PR:https://github.com/mashangxue/tensorflow2-zh/edit/master/r2/tutorials/keras/save_and_restore_models.md