使用 Estimator 构建线性模型

1. 概述

这个端到端的演练使用tf.estimator API训练逻辑回归模型。该模型通常用作其他更复杂算法的基准。 Estimator 是可扩展性最强且面向生产的 TensorFlow 模型类型。如需了解详情,请参阅 Estimator 指南

2. 安装和导入

安装sklearn命令: pip install sklearn

3. 加载泰坦尼克号数据集

您将使用泰坦尼克数据集,其以预测乘客的生存(相当病态)为目标,给出性别、年龄、阶级等特征。

4. 探索数据

数据集包含以下特征:

 sexagen_siblings_spousesparchfareclassdeckembark_townalone
0male22.0107.2500ThirdunknownSouthamptonn
1female38.01071.2833FirstCCherbourgn
2female26.0007.9250ThirdunknownSouthamptony
3female35.01053.1000FirstCSouthamptonn
4male28.0008.4583ThirdunknownQueenstowny
 agen_siblings_spousesparchfare
count627.000000627.000000627.000000627.000000
mean29.6313080.5454550.37958534.385399
std12.5118181.1510900.79299954.597730
min0.7500000.0000000.0000000.000000
25%23.0000000.0000000.0000007.895800
50%28.0000000.0000000.00000015.045800
75%35.0000001.0000000.00000031.387500
max80.0000008.0000005.000000512.329200

训练和评估集分别有627和264个样本数据:

大多数乘客都在20和30年代

png

机上的男性乘客大约是女性乘客的两倍。

png

大多数乘客都在“第三”阶级:

png

与男性相比,女性的生存机会要高得多,这显然是该模型的预测特征:

png

5. 模型的特征工程

Estimator使用称为特征列的系统来描述模型应如何解释每个原始输入特征,Estimator需要一个数字输入向量,而特征列描述模型应如何转换每个特征。

选择和制作正确的特征列是学习有效模型的关键,特征列可以是原始特征dict(基本特征列)中的原始输入之一,也可以是使用在一个或多个基本列(派生特征列)上定义的转换创建的任何新列。

线性Estimator同时使用数值和分类特征,特征列适用于所有TensorFlow Estimator,它们的目的是定义用于建模的特征。此外,它们还提供了一些特征工程功能,比如独热编码、归一化和分桶。

5.1. 基本特征列

input_function指定如何将数据转换为以流方式提供输入管道的tf.data.Datasettf.data.Dataset采用多种来源,如数据帧DataFrame,csv格式的文件等。

检查数据集:

您还可以使用tf.keras.layers.DenseFeatures层检查特征列的结果:

DenseFeatures只接受密集张量,要检查分类列,需要先将其转换为指示列:

将所有基本特征添加到模型后,让我们训练模型。使用tf.estimator API训练模型只是一个命令:

5.2. 派生特征列

现在你达到了75%的准确率。单独使用每个基本功能列可能不足以解释数据。例如,性别和标签之间的相关性可能因性别不同而不同。因此,如果您只学习gender="Male"gender="Female"的单一模型权重,您将无法捕捉每个年龄-性别组合(例如,区分gender="Male"age="30"gender="Male"age="40")。

要了解不同特征组合之间的差异,可以将交叉特征列添加到模型中(也可以在交叉列之前对年龄进行分桶):

将组合特征添加到模型之后,让我们再次训练模型:

它现在到达了77.6%的准确度,略好于仅在基本特征方面受过训练,您可以尝试使用更多特征和转换,看看您是否可以做得更好。

现在,您可以使用训练模型从评估集对乘客进行预测。TensorFlow模型经过优化,可以同时对样本的批处理或集合进行预测,之前的eval_input_fn是使用整个评估集定义的。

png

最后,查看结果的接收器操作特性(即ROC),这将使我们更好地了解真阳性率和假阳性率之间的权衡。

(0, 1.05)

png

最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-estimators-linear.html
英文版本:https://tensorflow.google.cn/beta/tutorials/estimators/linear
翻译建议PR:https://github.com/mashangxue/tensorflow2-zh/edit/master/r2/tutorials/estimators/linear.md