图像分割 (tensorflow2.0官方教程翻译)

最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-images-segmentation.html 英文版本:https://tensorflow.google.cn/beta/tutorials/images/segmentation 翻译建议PR:https://github.com/mashangxue/tensorflow2-zh/edit/master/r2/tutorials/images/segmentation.md

本教程重点介绍使用修改后的U-Net进行图像分割的任务。

什么是图像分割?

前面的章节我们学习了图像分类,网络算法的任务是为输入图像输出对应的标签或类。但是,假设您想知道对象在图像中的位置,该对象的形状,哪个像素属于哪个对象等。在这种情况下,您将要分割图像,即图像的每个像素都是给了一个标签。

因此,图像分割的任务是训练神经网络以输出图像的逐像素掩模。这有助于以更低的水平(即像素级别)理解图像。图像分割在医学成像,自动驾驶汽车和卫星成像等方面具有许多应用。

将用于本教程的数据集是由Parkhi等人创建的Oxford-IIIT Pet Dataset。数据集由图像、其对应的标签和像素方式的掩码组成。掩模基本上是每个像素的标签。每个像素分为三类:

下载依赖项目 https://github.com/tensorflow/examples 把文件夹tensorflow_examples放到项目下,下面会导入pix2pix

安装tensorflow:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tensorflow-gpu==2.0.0-beta1

安装tensorflow_datasets:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tensorflow_datasets

导入各种依赖包

下载Oxford-IIIT Pets数据集

数据集已包含在TensorFlow数据集中,只需下载即可。分段掩码包含在3.0.0版中,这就是使用此特定版本的原因。

以下代码执行翻转图像的简单扩充。另外,图像归一化为[0,1]。 最后,如上所述,分割掩模中的像素标记为{1,2,3}。为了方便起见,让我们从分割掩码中减去1,得到标签:{0,1,2}。

数据集已包含测试和训练所需的分割,因此让我们继续使用相同的分割。

让我们看一下图像示例,它是数据集的相应掩模。

定义模型

这里使用的模型是一个改进的U-Net。U-Net由编码器(下采样器)和解码器(上采样器)组成。为了学习鲁棒特征并减少可训练参数的数量,可以使用预训练模型作为编码器。因此,该任务的编码器将是预训练的MobileNetV2模型,其中间输出将被使用,并且解码器是已经在Pix2pix tutorial教程示例中实现的上采样块。

输出三个通道的原因是因为每个像素有三种可能的标签。可以将其视为多分类,其中每个像素被分为三类。

如上所述,编码器将是一个预训练的MobileNetV2模型,它已经准备好并可以在tf.keras.applications中使用。编码器由模型中间层的特定输出组成。 请注意,在训练过程中不会训练编码器。

解码器/上采样器只是在TensorFlow示例中实现的一系列上采样块。

训练模型

现在,剩下要做的就是编译和训练模型。这里使用的损失是loss.sparse_categorical_crossentropy。使用此丢失函数的原因是因为网络正在尝试为每个像素分配标签,就像多类预测一样。在真正的分割掩码中,每个像素都有{0,1,2}。这里的网络输出三个通道。基本上,每个频道都试图学习预测一个类,而 loss.sparse_categorical_crossentropy 是这种情况的推荐损失。使用网络输出,分配给像素的标签是具有最高值的通道。这就是create_mask函数正在做的事情。

让我们试试模型,看看它在训练前预测了什么。

让我们观察模型在训练时如何改进。要完成此任务,下面定义了回调函数。

我们查看损失变化情况

作出预测

让我们做一些预测。为了节省时间,周期的数量很小,但您可以将其设置得更高以获得更准确的结果。

预测效果:

下一步

现在您已经了解了图像分割是什么,以及它是如何工作的,您可以尝试使用不同的中间层输出,甚至是不同的预训练模型。您也可以通过尝试在Kaggle上托管的Carvana图像掩蔽比赛来挑战自己。

您可能还希望查看[Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection),以获取您可以重新训练自己数据的其他模型。