使用RNN生成文本实战:莎士比亚风格诗句 (tensorflow2.0官方教程翻译)

本教程演示了如何使用基于字符的 RNN 生成文本。我们将使用 Andrej Karpathy 在 The Unreasonable Effectiveness of Recurrent Neural Networks 一文中提供的莎士比亚作品数据集。我们根据此数据(“Shakespear”)中的给定字符序列训练一个模型,让它预测序列的下一个字符(“e”)。通过重复调用该模型,可以生成更长的文本序列。

注意:启用 GPU 加速可提高执行速度。在 Colab 中依次选择“运行时”>“更改运行时类型”>“硬件加速器”>“GPU”。如果在本地运行,请确保 TensorFlow 的版本为 1.11.0 或更高版本。

本教程中包含使用 tf.kerasEager Execution 实现的可运行代码。以下是本教程中的模型训练了30个周期时的示例输出,并以字符串“Q”开头:

QUEENE:
I had thought thou hadst a Roman; for the oracle,
Thus by All bids the man against the word,
Which are so weak of care, by old care done;
Your children were in your holy love,
And the precipitation through the bleeding throne.

BISHOP OF ELY: Marry, and will, my lord, to weep in such a one were prettiest; Yet now I was adopted heir Of the world's lamentable day, To watch the next way with his father with his face?

ESCALUS: The cause why then we are all resolved more sons.

VOLUMNIA: O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead, And love and pale as any will to that word.

QUEEN ELIZABETH: But how long have I heard the soul for this world, And show his hands of life be proved to stand.

PETRUCHIO: I say he look'd on, if I must be content To stay him from the fatal of our country's bliss. His lordship pluck'd from this sentence then for prey, And then let us twain, being the moon, were she such a case as fills m

虽然有些句子合乎语法规则,但大多数句子都没有意义。该模型尚未学习单词的含义,但请考虑以下几点:

1. 设置Setup

1.1. 导入 TensorFlow 和其他库

1.2. 下载莎士比亚数据集

通过更改以下行可使用您自己的数据运行此代码。

1.3. 读取数据

首先,我们来看一下文本内容。

 

 

2. 处理文本

2.1. 向量化文本

在训练之前,我们需要将字符串映射到数字表示值。创建两个对照表:一个用于将字符映射到数字,另一个用于将数字映射到字符。

现在,每个字符都有一个对应的整数表示值。请注意,我们按从 0 到 len(unique) 的索引映射字符。

 

2.2. 预测任务

根据给定的字符或字符序列预测下一个字符最有可能是什么?这是我们要训练模型去执行的任务。模型的输入将是字符序列,而我们要训练模型去预测输出,即每一个时间步的下一个字符。

由于 RNN 会依赖之前看到的元素来维持内部状态,那么根据目前为止已计算过的所有字符,下一个字符是什么?

2.3. 创建训练样本和目标

将文本划分为训练样本和训练目标。每个训练样本都包含从文本中选取的 seq_length 个字符。

相应的目标也包含相同长度的文本,但是将所选的字符序列向右顺移一个字符。

将文本拆分成文本块,每个块的长度为 seq_length+1 个字符。例如,假设 seq_length 为 4,我们的文本为“Hello”,则可以将“Hell”创建为训练样本,将“ello”创建为目标。

为此,首先使用tf.data.Dataset.from_tensor_slices函数将文本向量转换为字符索引流。

批处理方法可以让我们轻松地将这些单个字符转换为所需大小的序列。

对于每个序列,复制并移动它以创建输入文本和目标文本,方法是使用 map 方法将简单函数应用于每个批处理:

打印第一个样本输入和目标值:

这些向量的每个索引均作为一个时间步来处理。对于时间步 0 的输入,我们收到了映射到字符 “F” 的索引,并尝试预测 “i” 的索引作为下一个字符。在下一个时间步,执行相同的操作,但除了当前字符外,RNN 还要考虑上一步的信息。

2.4. 使用 tf.data 创建批次文本并重排这些批次

我们使用 tf.data 将文本拆分为可管理的序列。但在将这些数据馈送到模型中之前,我们需要对数据进行重排,并将其打包成批。

3. 实现模型

使用tf.keras.Sequential来定义模型。对于这个简单的例子,我们可以使用三个层来定义模型:

对于每个字符,模型查找嵌入,以嵌入作为输入一次运行GRU,并应用密集层生成预测下一个字符的对数可能性的logits:

A drawing of the data passing through the model

4. 试试这个模型

现在运行模型以查看它的行为符合预期,首先检查输出的形状:

在上面的示例中,输入的序列长度为 100 ,但模型可以在任何长度的输入上运行:

为了从模型中获得实际预测,我们需要从输出分布中进行采样,以获得实际的字符索引。此分布由字符词汇表上的logits定义。

注意:从这个分布中进行sample(采样)非常重要,因为获取分布的argmax可以轻松地将模型卡在循环中。

尝试批处理中的第一个样本:

这使我们在每个时间步都预测下一个字符索引:

解码这些以查看此未经训练的模型预测的文本:

5. 训练模型

此时,问题可以被视为标准分类问题。给定先前的RNN状态,以及此时间步的输入,预测下一个字符的类。

5.1. 添加优化器和损失函数

标准的tf.keras.losses.sparse_softmax_crossentropy损失函数在这种情况下有效,因为它应用于预测的最后一个维度。

因为我们的模型返回logits,所以我们需要设置from_logits标志。

使用 tf.keras.Model.compile 方法配置培训过程。我们将使用带有默认参数和损失函数的 tf.keras.optimizers.Adam

5.2. 配置检查点

使用tf.keras.callbacks.ModelCheckpoint确保在训练期间保存检查点:

5.3. 开始训练

为了使训练时间合理,使用10个时期来训练模型。在Colab中,将运行时设置为GPU以便更快地进行训练。

6. 生成文本

6.1. 加载最新的检查点

要使此预测步骤简单,请使用批处理大小1。

由于RNN状态从时间步长传递到时间步的方式,模型一旦构建就只接受固定大小的批次数据。

要使用不同的 batch_size 运行模型,我们需要重建模型并从检查点恢复权重。

6.2. 预测循环

下面的代码块可生成文本:

To generate text the model's output is fed back to the input

查看生成的文本后,您会发现模型知道何时应使用大写字母,以及如何构成段落和模仿莎士比亚风格的词汇。由于执行的训练周期较少,因此该模型尚未学会生成连贯的句子。

如果要改进结果,最简单的方法是增加模型训练的时长(请尝试 EPOCHS=30)。

您还可以尝试使用不同的起始字符,或尝试添加另一个 RNN 层以提高模型的准确率,又或者调整温度参数以生成具有一定随机性的预测值。

7. 高级:自定义训练

上述训练程序很简单,但不会给你太多控制。

所以现在您已经了解了如何手动运行模型,让我们解压缩训练循环,并自己实现。例如,如果要实施课程学习以帮助稳定模型的开环输出,这就是一个起点。

我们将使用 tf.GradientTape 来跟踪梯度。您可以通过阅读eager execution guide来了解有关此方法的更多信息。

该程序的工作原理如下:

最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-text-text_generation.html 英文版本:https://tensorflow.google.cn/beta/tutorials/text/text_generation 翻译建议PR:https://github.com/mashangxue/tensorflow2-zh/edit/master/r2/tutorials/text/text_generation.md