将 TF1.x 代码迁移到 TensorFlow 2.0(tensorflow2.0官方教程翻译)

在TensorFlow 2.0中,仍然可以运行未经修改的1.x代码(contrib除外):

但是,这并不能让您利用TensorFlow2.0中的许多改进。本指南将帮助您升级代码,使其更简单、更高效、更易于维护。

自动转换脚本

第一步是尝试运行升级脚本.

这将在将您的代码升级到TensorFlow 2.0时执行初始步骤。但是它不能使您的代码适合TensorFlowF 2.0。您的代码仍然可以使用tf.compat.v1 接口来访问占位符,会话,集合和其他1.x样式的功能。

使代码2.0原生化

本指南将介绍将TensorFlow 1.x代码转换为TensorFlow 2.0的几个示例。这些更改将使您的代码利用性能优化和简化的API调用。 在每一种情况下,模式是:

1. 替换tf.Session.run调用

每个tf.Session.run调用都应该被Python函数替换。

您可以使用标准Python工具(如pdb)逐步调试和调试函数

如果您对它的工作感到满意,可以添加一个tf.function装饰器,使其在图形模式下高效运行。有关其工作原理的更多信息,请参阅Autograph Guide

2. 使用Python对象来跟踪变量和损失

使用tf.Variable而不是tf.get_variable 每个variable_scope都可以转换为Python对象。通常这将是以下之一:

如果需要聚合变量列表(如 tf.Graph.get_collection(tf.GraphKeys.VARIABLES) ),请使用LayerModel对象的.variables.trainable_variables属性。

这些LayerModel类实现了几个不需要全局集合的其他属性。他们的.losses属性可以替代使用tf.GraphKeys.LOSSES集合。

有关详细信息,请参阅keras指南

警告:许多tf.compat.v1符号隐式使用全局集合。

3. 升级您的训练循环

使用适用于您的用例的最高级API。首选tf.keras.Model.fit构建自己的训练循环。

如果您编写自己的训练循环,这些高级函数可以管理很多可能容易遗漏的低级细节。例如,它们会自动收集正则化损失,并在调用模型时设置training = True参数。

4. 升级数据输入管道

使用tf.data数据集进行数据输入。这些对象是高效的,富有表现力的,并且与张量流很好地集成。

它们可以直接传递给tf.keras.Model.fit方法。

它们可以直接在标准Python上迭代:

转换模型

设置

低阶变量和操作执行

低级API使用的示例包括:

转换前

以下是使用TensorFlow 1.x在代码中看起来像这些模式的内容:

转换后

在转换后的代码中:

基于tf.layers的模型

tf.layers模块用于包含依赖于tf.variable_scope来定义和重用变量的层函数。

转换前

转换后

大多数参数保持不变,但注意区别:

同时也要注意:

混合变量和tf.layers

现存的代码通常将较低级别的TF 1.x变量和操作与较高级的 tf.layers 混合。

转换前

转换后

要转换此代码,请遵循将图层映射到图层的模式,如上例所示。

一般模式是:

tf.variable_scope实际上是它自己的一层。所以把它重写为tf.keras.layers.Layer 有关信息请参阅 指南

需要注意以下几点:

关于Slim&contrib.layers的说明

大量较旧的TensorFlow 1.x代码使用 Slim 库,与TensorFlow 1.x一起打包为tf.contrib.layers。作为contrib模块,TensorFlow 2.0中不再提供此功能,即使在tf.compat.v1中也是如此。使用Slim转换为TF 2.0比转换使用tf.layers的存储库更复杂。事实上,首先将Slim代码转换为tf.layers然后转换为Keras可能是有意义的。

一些tf.contrib图层可能没有被移动到核心TensorFlow,而是被移动到了 TF附加组件包.

训练

有很多方法可以将数据提供给tf.keras模型。他们将接受Python生成器和Numpy数组作为输入。

将数据提供给模型的推荐方法是使用tf.data包,其中包含一组用于处理数据的高性能类。

如果您仍在使用tf.queue,则仅支持这些作为数据结构,而不是数据管道。

使用Datasets

TensorFlow数据集包 (tfds) 包含用于将预定义数据集加载为 tf.data.Dataset 对象的使用程序。

对于此示例,使用 tfds 加载MNIST数据集:

然后为训练准备数据:

 

要使示例保持简短,请修剪数据集以仅返回5个批次:

使用Keras训练循环

如果你不需要对训练过程进行低级别的控制,建议使用Keras内置的fit、evaluate和predict方法,这些方法提供了一个统一的接口来训练模型,而不管实现是什么(sequential、functional或子类化的)。

这些方法的有点包括:

以下是使用数据集训练模型的示例:

编写你自己的训练循环

如果Keras模型的训练步骤适合您,但您需要在该步骤之外进行更多的控制,请考虑在您自己的数据迭代循环中使用 tf.keras.model.train_on_batch 方法。

记住:许多东西可以作为 tf.keras.Callback 的实现。

此方法具有上一节中提到的方法的许多优点,但允许用户控制外循环。

您还可以使用 tf.keras.model.test_on_batchtf.keras.Model.evaluate 来检查训练期间的性能。

注意:train_on_batchtest_on_batch,默认返回单批的损失和指标。如果你传递reset_metrics = False,它们会返回累积的指标,你必须记住适当地重置指标累加器。还要记住,像 AUC 这样的一些指标需要正确计算 reset_metrics = False

继续训练上面的模型:

自定义训练步骤

如果您需要更多的灵活性和控制,可以通过实现自己的训练循环来实现,有三个步骤:

  1. 迭代Python生成器或tf.data.Dataset以获取样本数据;
  2. 使用tf.GradientTape收集渐变;
  3. 使用tf.keras.optimizer将权重更新应用于模型。

记住:

请注意相对于v1的简化:

上面的模型:

新型指标

在TensorFlow 2.0中,metrics是对象,Metrics对象在eager和tf.functions中运行,一个metrics具有以下方法:

对象本身是可调用的,与 update_state 一样,调用新观察更新状态,并返回metrics的新结果。

你不需要手动初始化metrics的变量,而且因为TensorFlow 2.0具有自动控制依赖项,所以您也不需要担心这些。

下面的代码使用metrics来跟踪自定义训练循环中观察到的平均损失:

保存和加载

Checkpoint兼容性

TensorFlow 2.0使用基于对象的检查点。

如果小心的话,仍然可以加载旧式的基于名称的检查点,代码转换过程可能会导致变量名的更改,但是有一些变通的方法。

最简单的方法是将新模型的名称与检查点的名称对齐:

如果这不适合您的用例,请尝试使用 tf.compat.v1.train.init_from_checkpoint 函数,它需要一个 assignment_map 参数,该参数指定从旧名称到新名称的映射。

注意:与基于对象的检查点(可以延迟加载不同,基于名称的检查点要求在调用函数时构建所有变量。某些模型推迟构建变量,直到您调用 build 或在一批数据上运行模型。

保存的模型兼容性

对于保存的模型没有明显的兼容性问题:

Estimators

使用Estimators进行训练

TensorFlow 2.0支持Estimators,使用Estimators时,可以使用TensorFlow 1.x中的 input_fn()tf.extimatro.TrainSpectf.estimator.EvalSpec

以下是使用 input_fn 和train以及evaluate的示例:

创建input_fn和train/eval规范

使用Keras模型定义

在TensorFlow2.0中如何构建estimators存在一些差异。

我们建议您使用Keras定义模型,然后使用 tf.keras.model_to_estimator 将您的模型转换为estimator。下面的代码展示了如何在创建和训练estimator时使用这个功能。

使用自定义 model_fn

如果您需要维护现有的自定义估算器 model_fn,则可以将 model_fn 转换为使用Keras模型。

但是出于兼容性原因,自定义 model_fn 仍将以1.x样式的图形模式运行,这意味着没有eager execution,也没有自动控制依赖。

在自定义 model_fn 中使用Keras模型类似于在自定义训练循环中使用它:

但相对于自定义循环,存在重要差异:

注意:“更新”是每批后需要应用于模型的更改。例如,tf.keras.layers.BatchNormalization层中均值和方差的移动平均值。

以下代码从自定义model_fn创建一个估算器,说明所有这些问题。

TensorShape

这个类被简化为保存ints,而不是tf.compat.v1.Dimension对象。所以不需要调用.value()来获得int

仍然可以从tf.TensorShape.dims访问单个tf.compat.v1.Dimension对象。

以下演示了TensorFlow 1.x和TensorFlow 2.0之间的区别。

TF 1.x 运行:

TF 2.0 运行::

 

TF 1.x 运行::

TF 2.0 运行::

在TF 1.x(或使用任何其他维度方法)中运行:

TF 2.0运行:

如果等级已知,则 tf.TensorShape 的布尔值为“True”,否则为“False”。

其他行为改变

您可能会遇到TensorFlow 2.0中的一些其他行为变化。

ResourceVariables

TensorFlow 2.0默认创建ResourceVariables,而不是RefVariables

ResourceVariables被锁定用于写入,因此提供更直观的一致性保证。

Control Flow

控制流op实现得到了简化,因此在TensorFlow 2.0中生成了不同的图。

结论

回顾一下本节内容:

  1. 运行更新脚本
  2. 删除contrib符号
  3. 将模型切换为面向对象的样式(Keras)
  4. 尽可能使用tf.kerastf.estimator培训和评估循环。
  5. 否则,请使用自定义循环,但请务必避免会话和集合。

将代码转换为TensorFlow 2.0需要一些工作,但会有以下改变:

最新版本:https://www.mashangxue123.com/tensorflow/tf2-guide-migration_guide.html 英文版本:https://tensorflow.google.cn/beta/guide/migration_guide 翻译建议PR:https://github.com/mashangxue/tensorflow2-zh/edit/master/r2/guide/migration_guide.md