使用Keras和TensorFlow Hub对电影评论进行文本分类 (tensorflow2.0官方教程翻译)

此教程本会将文本形式的影评分为“正面”或“负面”影评。这是一个二元分类(又称为两类分类)的示例,也是一种重要且广泛适用的机器学习问题。

本教程演示了使用TensorFlow Hub和Keras进行迁移学习的基本应用。

数据集使用 IMDB 数据集,其中包含来自互联网电影数据库 https://www.imdb.com/ 的50000 条影评文本。我们将这些影评拆分为训练集(25000 条影评)和测试集(25000 条影评)。训练集和测试集之间达成了平衡,意味着它们包含相同数量的正面和负面影评。

此教程使用tf.keras,一种用于在 TensorFlow 中构建和训练模型的高阶 API,以及TensorFlow Hub,一个用于迁移学习的库和平台。

有关使用 tf.keras 的更高级文本分类教程,请参阅 MLCC 文本分类指南

导入库:

1. 下载 IMDB 数据集

TensorFlow数据集上提供了IMDB数据集。以下代码将IMDB数据集下载到您的机器:

2. 探索数据

我们花点时间来了解一下数据的格式,每个样本表示电影评论和相应标签的句子,该句子不以任何方式进行预处理。每个标签都是整数值 0 或 1,其中 0 表示负面影评,1 表示正面影评。

我们先打印10个样本。

我们还打印前10个标签。

3. 构建模

神经网络通过堆叠层创建而成,这需要做出三个架构方面的主要决策:

在此示例中,输入数据由句子组成。要预测的标签是0或1。

表示文本的一种方法是将句子转换为嵌入向量。我们可以使用预先训练的文本嵌入作为第一层,这将具有两个优点:

对于此示例,我们将使用来自TensorFlow Hub 的预训练文本嵌入模型,名为google/tf2-preview/gnews-swivel-20dim/1.

要达到本教程的目的,还有其他三种预训练模型可供测试:

让我们首先创建一个使用TensorFlow Hub模型嵌入句子的Keras层,并在几个输入示例上进行尝试。请注意,无论输入文本的长度如何,嵌入的输出形状为:(num_examples, embedding_dimension)

现在让我们构建完整的模型:

这些图层按顺序堆叠以构建分类器:

  1. 第一层是TensorFlow Hub层。该层使用预先训练的保存模型将句子映射到其嵌入向量。我们正在使用的预训练文本嵌入模型(google/tf2-preview/gnews-swivel-20dim/1)将句子拆分为标记,嵌入每个标记然后组合嵌入。生成的维度为:(num_examples, embedding_dimension)
  2. 这个固定长度的输出矢量通过一个带有16个隐藏单元的完全连接(“密集”)层传输。
  3. 最后一层与单个输出节点密集连接。使用sigmoid激活函数,该值是0到1之间的浮点数,表示概率或置信度。

让我们编译模型。

3.1. 损失函数和优化器

模型在训练时需要一个损失函数和一个优化器。由于这是一个二元分类问题且模型会输出一个概率(应用 S 型激活函数的单个单元层),因此我们将使用 binary_crossentropy 损失函数。

该函数并不是唯一的损失函数,例如,您可以选择 mean_squared_error。但一般来说,binary_crossentropy 更适合处理概率问题,它可测量概率分布之间的“差距”,在本例中则为实际分布和预测之间的“差距”。

稍后,在探索回归问题(比如预测房价)时,我们将了解如何使用另一个称为均方误差的损失函数。

现在,配置模型以使用优化器和损失函数:

4. 训练模型

用有 512 个样本的小批次训练模型 40 个周期。这将对 x_train 和 y_train 张量中的所有样本进行 40 次迭代。在训练期间,监控模型在验证集的 10000 个样本上的损失和准确率:

5. 评估模型

我们来看看模型的表现如何。模型会返回两个值:损失(表示误差的数字,越低越好)和准确率。

使用这种相当简单的方法可实现约 87% 的准确率。如果采用更高级的方法,模型的准确率应该会接近 95%。

6. 进一步阅读

要了解处理字符串输入的更一般方法,以及更详细地分析训练过程中的准确性和损失,请查看 https://www.tensorflow.org/tutorials/keras/basic_text_classification

最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-basic_text_classification_with_tfhub.html 英文版本:https://tensorflow.google.cn/beta/tutorials/keras/basic_text_classification_with_tfhub 翻译建议PR:https://github.com/mashangxue/tensorflow2-zh/edit/master/r2/tutorials/keras/basic_text_classification_with_tfhub.md