最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-images-segmentation.html 英文版本:https://tensorflow.google.cn/beta/tutorials/images/segmentation 翻译建议PR:https://github.com/mashangxue/tensorflow2-zh/edit/master/r2/tutorials/images/segmentation.md
本教程重点介绍使用修改后的U-Net进行图像分割的任务。
前面的章节我们学习了图像分类,网络算法的任务是为输入图像输出对应的标签或类。但是,假设您想知道对象在图像中的位置,该对象的形状,哪个像素属于哪个对象等。在这种情况下,您将要分割图像,即图像的每个像素都是给了一个标签。
因此,图像分割的任务是训练神经网络以输出图像的逐像素掩模。这有助于以更低的水平(即像素级别)理解图像。图像分割在医学成像,自动驾驶汽车和卫星成像等方面具有许多应用。
将用于本教程的数据集是由Parkhi等人创建的Oxford-IIIT Pet Dataset。数据集由图像、其对应的标签和像素方式的掩码组成。掩模基本上是每个像素的标签。每个像素分为三类:
下载依赖项目 https://github.com/tensorflow/examples, 把文件夹tensorflow_examples放到项目下,下面会导入pix2pix
安装tensorflow:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tensorflow-gpu==2.0.0-beta1
安装tensorflow_datasets:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tensorflow_datasets
x1import tensorflow as tf
2
3from __future__ import absolute_import, division, print_function, unicode_literals
4
5from tensorflow_examples.models.pix2pix import pix2pix
6
7import tensorflow_datasets as tfds
8tfds.disable_progress_bar()
9
10from IPython.display import clear_output
11import matplotlib.pyplot as plt
数据集已包含在TensorFlow数据集中,只需下载即可。分段掩码包含在3.0.0版中,这就是使用此特定版本的原因。
xxxxxxxxxx
11dataset, info = tfds.load('oxford_iiit_pet:3.0.0', with_info=True)
以下代码执行翻转图像的简单扩充。另外,图像归一化为[0,1]。 最后,如上所述,分割掩模中的像素标记为{1,2,3}。为了方便起见,让我们从分割掩码中减去1,得到标签:{0,1,2}。
xxxxxxxxxx
251def normalize(input_image, input_mask):
2 input_image = tf.cast(input_image, tf.float32)/128.0 - 1
3 input_mask -= 1
4 return input_image, input_mask
5
6function .
7def load_image_train(datapoint):
8 input_image = tf.image.resize(datapoint['image'], (128, 128))
9 input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
10
11 if tf.random.uniform(()) > 0.5:
12 input_image = tf.image.flip_left_right(input_image)
13 input_mask = tf.image.flip_left_right(input_mask)
14
15 input_image, input_mask = normalize(input_image, input_mask)
16
17 return input_image, input_mask
18
19def load_image_test(datapoint):
20 input_image = tf.image.resize(datapoint['image'], (128, 128))
21 input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
22
23 input_image, input_mask = normalize(input_image, input_mask)
24
25 return input_image, input_mask
数据集已包含测试和训练所需的分割,因此让我们继续使用相同的分割。
xxxxxxxxxx
111TRAIN_LENGTH = info.splits['train'].num_examples
2BATCH_SIZE = 64
3BUFFER_SIZE = 1000
4STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
5
6train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
7test = dataset['test'].map(load_image_test)
8
9train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
10train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
11test_dataset = test.batch(BATCH_SIZE)
让我们看一下图像示例,它是数据集的相应掩模。
xxxxxxxxxx
151def display(display_list):
2 plt.figure(figsize=(15, 15))
3
4 title = ['Input Image', 'True Mask', 'Predicted Mask']
5
6 for i in range(len(display_list)):
7 plt.subplot(1, len(display_list), i+1)
8 plt.title(title[i])
9 plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
10 plt.axis('off')
11 plt.show()
12
13for image, mask in train.take(1):
14 sample_image, sample_mask = image, mask
15display([sample_image, sample_mask])
这里使用的模型是一个改进的U-Net。U-Net由编码器(下采样器)和解码器(上采样器)组成。为了学习鲁棒特征并减少可训练参数的数量,可以使用预训练模型作为编码器。因此,该任务的编码器将是预训练的MobileNetV2模型,其中间输出将被使用,并且解码器是已经在Pix2pix tutorial教程示例中实现的上采样块。
输出三个通道的原因是因为每个像素有三种可能的标签。可以将其视为多分类,其中每个像素被分为三类。
xxxxxxxxxx
11OUTPUT_CHANNELS = 3
如上所述,编码器将是一个预训练的MobileNetV2模型,它已经准备好并可以在tf.keras.applications中使用。编码器由模型中间层的特定输出组成。 请注意,在训练过程中不会训练编码器。
xxxxxxxxxx
161base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)
2
3# Use the activations of these layers
4layer_names = [
5 'block_1_expand_relu', # 64x64
6 'block_3_expand_relu', # 32x32
7 'block_6_expand_relu', # 16x16
8 'block_13_expand_relu', # 8x8
9 'block_16_project', # 4x4
10]
11layers = [base_model.get_layer(name).output for name in layer_names]
12
13# 创建特征提取模型
14down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
15
16down_stack.trainable = False
解码器/上采样器只是在TensorFlow示例中实现的一系列上采样块。
xxxxxxxxxx
321up_stack = [
2 pix2pix.upsample(512, 3), # 4x4 -> 8x8
3 pix2pix.upsample(256, 3), # 8x8 -> 16x16
4 pix2pix.upsample(128, 3), # 16x16 -> 32x32
5 pix2pix.upsample(64, 3), # 32x32 -> 64x64
6]
7
8
9def unet_model(output_channels):
10
11 # 这是模型的最后一层
12 last = tf.keras.layers.Conv2DTranspose(
13 output_channels, 3, strides=2,
14 padding='same', activation='softmax') #64x64 -> 128x128
15
16 inputs = tf.keras.layers.Input(shape=[128, 128, 3])
17 x = inputs
18
19 # 通过该模型进行下采样
20 skips = down_stack(x)
21 x = skips[-1]
22 skips = reversed(skips[:-1])
23
24 # Upsampling and establishing the skip connections
25 for up, skip in zip(up_stack, skips):
26 x = up(x)
27 concat = tf.keras.layers.Concatenate()
28 x = concat([x, skip])
29
30 x = last(x)
31
32 return tf.keras.Model(inputs=inputs, outputs=x)
现在,剩下要做的就是编译和训练模型。这里使用的损失是loss.sparse_categorical_crossentropy
。使用此丢失函数的原因是因为网络正在尝试为每个像素分配标签,就像多类预测一样。在真正的分割掩码中,每个像素都有{0,1,2}。这里的网络输出三个通道。基本上,每个频道都试图学习预测一个类,而 loss.sparse_categorical_crossentropy
是这种情况的推荐损失。使用网络输出,分配给像素的标签是具有最高值的通道。这就是create_mask函数正在做的事情。
xxxxxxxxxx
31model = unet_model(OUTPUT_CHANNELS)
2model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
3 metrics=['accuracy'])
让我们试试模型,看看它在训练前预测了什么。
xxxxxxxxxx
161def create_mask(pred_mask):
2 pred_mask = tf.argmax(pred_mask, axis=-1)
3 pred_mask = pred_mask[..., tf.newaxis]
4 return pred_mask[0]
5
6def show_predictions(dataset=None, num=1):
7 if dataset:
8 for image, mask in dataset.take(num):
9 pred_mask = model.predict(image)
10 display([image[0], mask[0], create_mask(pred_mask)])
11 else:
12 display([sample_image, sample_mask,
13 create_mask(model.predict(sample_image[tf.newaxis, ...]))])
14
15
16show_predictions()
让我们观察模型在训练时如何改进。要完成此任务,下面定义了回调函数。
xxxxxxxxxx
161class DisplayCallback(tf.keras.callbacks.Callback):
2 def on_epoch_end(self, epoch, logs=None):
3 clear_output(wait=True)
4 show_predictions()
5 print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
6
7
8EPOCHS = 20
9VAL_SUBSPLITS = 5
10VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS
11
12model_history = model.fit(train_dataset, epochs=EPOCHS,
13 steps_per_epoch=STEPS_PER_EPOCH,
14 validation_steps=VALIDATION_STEPS,
15 validation_data=test_dataset,
16 callbacks=[DisplayCallback()])
我们查看损失变化情况
xxxxxxxxxx
141loss = model_history.history['loss']
2val_loss = model_history.history['val_loss']
3
4epochs = range(EPOCHS)
5
6plt.figure()
7plt.plot(epochs, loss, 'r', label='Training loss')
8plt.plot(epochs, val_loss, 'bo', label='Validation loss')
9plt.title('Training and Validation Loss')
10plt.xlabel('Epoch')
11plt.ylabel('Loss Value')
12plt.ylim([0, 1])
13plt.legend()
14plt.show()
让我们做一些预测。为了节省时间,周期的数量很小,但您可以将其设置得更高以获得更准确的结果。
xxxxxxxxxx
11show_predictions(test_dataset, 1)
预测效果:
现在您已经了解了图像分割是什么,以及它是如何工作的,您可以尝试使用不同的中间层输出,甚至是不同的预训练模型。您也可以通过尝试在Kaggle上托管的Carvana图像掩蔽比赛来挑战自己。
您可能还希望查看[Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection),以获取您可以重新训练自己数据的其他模型。