结构化数据分类实战:心脏病预测(tensorflow2.0官方教程翻译)

本教程演示了如何对结构化数据进行分类(例如CSV格式的表格数据)。 我们将使用Keras定义模型,并使用特征列作为桥梁,将CSV中的列映射到用于训练模型的特性。 本教程包含完整的代码:

1. 数据集

我们将使用克利夫兰诊所心脏病基金会提供的一个小数据集 。CSV中有几百行,每行描述一个患者,每列描述一个属性。我们将使用此信息来预测患者是否患有心脏病,该疾病在该数据集中是二元分类任务。

以下是此数据集的说明。请注意,有数字和分类列。

ColumnDescriptionFeature TypeData Type
AgeAge in yearsNumericalinteger
Sex(1 = male; 0 = female)Categoricalinteger
CPChest pain type (0, 1, 2, 3, 4)Categoricalinteger
TrestbpdResting blood pressure (in mm Hg on admission to the hospital)Numericalinteger
CholSerum cholestoral in mg/dlNumericalinteger
FBS(fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)Categoricalinteger
RestECGResting electrocardiographic results (0, 1, 2)Categoricalinteger
ThalachMaximum heart rate achievedNumericalinteger
ExangExercise induced angina (1 = yes; 0 = no)Categoricalinteger
OldpeakST depression induced by exercise relative to restNumericalinteger
SlopeThe slope of the peak exercise ST segmentNumericalfloat
CANumber of major vessels (0-3) colored by flourosopyNumericalinteger
Thal3 = normal; 6 = fixed defect; 7 = reversable defectCategoricalstring
TargetDiagnosis of heart disease (1 = true; 0 = false)Classificationinteger

2. 导入TensorFlow和其他库

安装sklearn依赖库

3. 使用Pandas创建数据帧

Pandas 是一个Python库,包含许多有用的实用程序,用于加载和处理结构化数据。我们将使用Pandas从URL下载数据集,并将其加载到数据帧中。

4. 将数据拆分为训练、验证和测试

我们下载的数据集是一个CSV文件,并将其分为训练,验证和测试集。

5. 使用tf.data创建输入管道

接下来,我们将使用tf.data包装数据帧,这将使我们能够使用特征列作为桥梁从Pandas数据框中的列映射到用于训练模型的特征。如果我们使用非常大的CSV文件(如此之大以至于它不适合内存),我们将使用tf.data直接从磁盘读取它,本教程不涉及这一点。

6. 理解输入管道

现在我们已经创建了输入管道,让我们调用它来查看它返回的数据的格式,我们使用了一小批量来保持输出的可读性。

我们可以看到数据集返回一个列名称(来自数据帧),该列表映射到数据帧中行的列值。

7. 演示几种类型的特征列

TensorFlow提供了许多类型的特性列。在本节中,我们将创建几种类型的特性列,并演示它们如何从dataframe转换列。

7.1. 数字列

特征列的输出成为模型的输入(使用上面定义的演示函数,我们将能够准确地看到数据帧中每列的转换方式),数字列是最简单的列类型,它用于表示真正有价值的特征,使用此列时,模型将从数据帧中接收未更改的列值。

在心脏病数据集中,数据帧中的大多数列都是数字。

7.2. Bucketized列(桶列)

通常,您不希望将数字直接输入模型,而是根据数值范围将其值分成不同的类别,考虑代表一个人年龄的原始数据,我们可以使用bucketized列将年龄分成几个桶,而不是将年龄表示为数字列。 请注意,下面的one-hot(独热编码)值描述了每行匹配的年龄范围。

7.3. 分类列

在该数据集中,thal表示为字符串(例如“固定”,“正常”或“可逆”),我们无法直接将字符串提供给模型,相反,我们必须首先将它们映射到数值。分类词汇表列提供了一种将字符串表示为独热矢量的方法(就像上面用年龄段看到的那样)。词汇表可以使用categorical_column_with_vocabulary_list作为列表传递,或者使用categorical_column_with_vocabulary_file从文件加载。

在更复杂的数据集中,许多列将是分类的(例如字符串),在处理分类数据时,特征列最有价值。虽然此数据集中只有一个分类列,但我们将使用它来演示在处理其他数据集时可以使用的几种重要类型的特征列。

7.4. 嵌入列

假设我们不是只有几个可能的字符串,而是每个类别有数千(或更多)值。由于多种原因,随着类别数量的增加,使用独热编码训练神经网络变得不可行,我们可以使用嵌入列来克服此限制。 嵌入列不是将数据表示为多维度的独热矢量,而是将数据表示为低维密集向量,其中每个单元格可以包含任意数字,而不仅仅是0或1.嵌入的大小(在下面的例子中是8)是必须调整的参数。

关键点:当分类列具有许多可能的值时,最好使用嵌入列,我们在这里使用一个用于演示目的,因此您有一个完整的示例,您可以在将来修改其他数据集。

7.5. 哈希特征列

表示具有大量值的分类列的另一种方法是使用categorical_column_with_hash_bucket. 此特征列计算输入的哈希值,然后选择一个hash_bucket_size存储桶来编码字符串,使用此列时,您不需要提供词汇表,并且可以选择使hash_buckets的数量远远小于实际类别的数量以节省空间。

关键点:该技术的一个重要缺点是可能存在冲突,其中不同的字符串被映射到同一个桶,实际上,无论如何,这对某些数据集都有效。

7.6. 交叉特征列

将特征组合成单个特征(也称为特征交叉),使模型能够为每个特征组合学习单独的权重。 在这里,我们将创建一个age和thal交叉的新功能, 请注意,crossed_column不会构建所有可能组合的完整表(可能非常大),相反,它由hashed_column支持,因此您可以选择表的大小。

8. 选择要使用的列

我们已经了解了如何使用几种类型的特征列,现在我们将使用它们来训练模型。本教程的目标是向您展示使用特征列所需的完整代码(例如,机制),我们选择了几列来任意训练我们的模型。

关键点:如果您的目标是建立一个准确的模型,请尝试使用您自己的更大数据集,并仔细考虑哪些特征最有意义,以及如何表示它们。

8.1. 创建特征层

现在我们已经定义了我们的特征列,我们将使用DenseFeatures层将它们输入到我们的Keras模型中。

之前,我们使用小批量大小来演示特征列的工作原理,我们创建了一个具有更大批量的新输入管道。

9. 创建、编译和训练模型

训练过程的输出

测试

关键点:通常使用更大更复杂的数据集进行深度学习,您将看到最佳结果。使用像这样的小数据集时,我们建议使用决策树或随机森林作为强基线。

本教程的目标不是为了训练一个准确的模型,而是为了演示使用结构化数据的机制,因此您在将来使用自己的数据集时需要使用代码作为起点。

10. 下一步

了解有关分类结构化数据的更多信息的最佳方法是亲自尝试,我们建议找到另一个可以使用的数据集,并训练模型使用类似于上面的代码对其进行分类,要提高准确性,请仔细考虑模型中包含哪些特征以及如何表示这些特征。

最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-feature_columns.html 英文版本:https://tensorflow.google.cn/beta/tutorials/keras/feature_columns 翻译建议PR:https://github.com/mashangxue/tensorflow2-zh/edit/master/r2/tutorials/keras/feature_columns.md