FCN
Last updated
Last updated
Fully Convolutional Networks (简称 FCN) 是来自伯克利的 Jonathan Long、Evan Shelhamer、 Trevor Darrell 三人于 2014 年在论文《Fully Convolutional Networks for Semantic Segmentation》中提出的语义分割模型。 与更早期的工作相比,FCN 并没有复杂的前后处理过程,使用单个神经网络就完成了图像到掩码的计算,也是首个可以端到端训练的语义分割模型。 FCN 最大的贡献在于其提出的“卷积化”操作,使图像分类的网络经过简单的改造与模型微调就可以适应语义分割任务。
分类和分割在任务目标上有着很大的相似性,都是预测物体的类别。 区别在于分类只需要对图片整体给出一个预测结果,分割需要给出每个像素的预测结果。 在当时,卷积神经网络模型在图像分类已经取得了不错的成绩,且这种优势可以迁移到其他任务上,如目标检测等等。 这就让人自然联想到使用分类模型去解决分割问题。
显然,直接应用图像分类模型是不行的。 图像分类模型以整张图像作为输入,输出单个数值,表示图中物体的类别。 为了能让模型输出与图像相同分辨率的掩码,就必须配合一些其他技术。
一种朴素的方法是用训练好的分类模型在图像上进行滑窗扫描。 对于图像中每一个像素 ,我们以该像素为中心,裁取一个 224x224 大小的图像区域下来(如果遇到边界可以进行边界填充),使用分类网络预测该区域内物体的类别,作为掩码图上对应位置的预测值 。 卷积网络扫过整张图像后,也就输出了整幅图像的类别掩码。
这也是在 FCN 以前,一些基于深度学习的分割算法的基本思路。
这种方式最大的缺陷就是计算量。 不难想象,对于一张 224x224 分辨率大小的图像,得到一张推理掩码的计算量是得到一个分类结果的 倍。
另一方面,模型训练的过程也相对复杂。 由于算法中的神经网络仍然是一个分类网络,我们就需要根据分割数据集产生用于训练分类网络的图像-类别标签对(pair),并用这些图片对网络进行微调训练。 具体而言,对于训练图像上的每一个像素 ,我们需要将以此像素为中心的图像区域裁剪下来,并使用对应的掩码 作为该区域的类别标签。
不难注意到,在滑窗算法中,图像区域是大幅重叠的,许多计算也是重复冗余的。 FCN 在本质上并没有脱离上述滑窗扫描的模式,但其提出的“全卷积化”操作使得推理和训练尤为高效。
用于图像分类的卷积神经网络通常由位于底部的一系列卷积层和位于顶部的若干全连接层组成。 习惯上,我们将这一系列卷积层称为特征提取部分,而全连接层组成的单层或多层感知器称为分类器部分。 卷积层根据输入图像产生特征图,分类器以特征图为特征输出物体的类别。
卷积具有位置不变性,可以接受任意大小的图像,并输出分辨率对应缩小的特征图。 但构成分类器的全连接层只能接受固定维度的输入。 当网络接受非标准大小的图像并产生非标准大小的特征图时,网络中的分类器就不能工作。
为了能让全连接层也适应不同大小的特征图,我们可以用卷积实现全连接。 以 VGGNet 为例,第一个全连接层以 7x7x512 大小的特征图作为输入,将其向量化后,乘以一个 (7x7x512)x4096 大小的矩阵,输出对应 4096 维度的向量。
我们可以将这个参数矩阵重构成一个 7x7 大小、512 输入通道、4096 输出通道的卷积层的权重。 将这个卷积层应用到 VGGNet 输出的特征图上,可以得到与全连接层相同的结果。 二者仅仅在数组形状上有所区别,卷积层输出将表示为一个 1x1x4096 的三维数组,而不是 4096 维度的一维数组。
我们可以用同样的方法,将第二个全连接层变成一个 1x1 大小,4096 通道输入,4096 通道输出的卷积层。 将第三个全连接层变成 1x1 大小,4096 通道输入,1000 通道输出的卷积层。 第三个特征图的输出即为图中物体属于各个类别的相对概率,只不过表示不同类别维度转换到了“通道”维。
在 FCN 的论文中,这种操作成为“卷积化”。 经过卷积化的图像分类网络可以适应任意大小的输入,并在一个缩小的尺度上预测出物体的类别。
全卷积网络可以从图像计算出掩码图,但由于主干网络中存在步长大于 1 的卷积和池化层。 对于 VGGNet、ResNet 等主流结构,掩码图的分辨率通常只有原图的 1/32 大小。 为了输出与原图分辨率相同的掩码,需要对掩码图进行放大。
FCN 的原始论文提出了三种放大方法,一种与后续工作中提出的空洞卷积类似,但 FCN 并没有采用。第二种方法则是直接用双线性插值放大掩码图。
第三种,也是 FCN 论文中采用的方法,就是使用反卷积层对掩码图进行放大。事实上,我们可以构造合适的卷积核,使反卷积达到双线性插值相同的效果。 因而双线性插值可以看作是反卷积的一个特例。 FCN 之所以选取反卷积是因为反卷积中的参数是可学习的,而且还可以通过多层叠加构成相对复杂的非线性变换,而双线性插值是固定的。
这种放大方式还存在另一个问题。 由于图像的主干网络输出的特征图通常经过了高达 32 倍的降采样,空间信息已经严重丢失。 因此,从主干网络直接恢复出的掩码图通常非常粗糙,如下图第一幅 FCN-32s 所示。
为了提高掩码图的精细程度,FCN 还提出了使用高低层特征融合的方法,如下图所示。 图中以 VGGNet 为例,网络中五个池化层逐次对特征图进行 1/2 的空间降采样。 通常,低层特征分辨率高,语义信息相对贫乏,但特征的位置准确度高,而高层特征正好相反。
上述直接将顶层特征(对于 VGGNet 而言是 pool5 层输出的特征)升采样 32 倍的网络结构称为 FCN-32s 。 为了融合低层特征的位置信息,我们可以将顶层特征升先采样 2 倍,再与 pool4 层的特征图求和,再将求和结果升采样 16 倍,这个结构称为 FCN-16s。 我们也可以将顶层特征先升采样 4 倍,再与 pool3 层的特征求和,再将求和结果升采样 8 倍,这个结构称为 FCN-8s。
实验结果如上图所示,FCN-8s 产生的掩码图比 FCN-32s 产生的掩码图更为精细。
由于全卷积的结构是通过分类网络更改而来的,我们可以使用 ImageNet 与训练的参数对其进行初始化。再使用模型微调的技术再分割的数据集上进行微调训练。 由于全卷积网络可以直接输出掩码,因此可以进行端到端训练。 全卷积网络通常使用交叉熵损失函数(crossentropy loss)。前传一张图片得到预测掩码后,针对掩码上的每一个像素计算交叉熵损失,求和再回传。 针对整张图计算损失函数也等价于将不同的图片区域切割下来做成 batch 计算 loss,但是由于卷积可以消除重叠区域的冗余计算,使用整张图象计算损失的方式要比朴素方法高效很多。
mmsegmentation 提供了 FCN 的实现。 由于 FCN 提出的时间较早,mmsegmentation 中实现的 FCN 网络已经与原始论文中提出的结构有所不同,并加入了空洞卷积等后续提出的技术。 不过这些变化并没有改变 FCN 的本质,即全卷积网络与分辨率升采样。
下面给出的配置定义了一个典型的 FCN 网络。这份配置出自 fcn_r50-d8.py ,为了方便展示,我们删去了一些与模型本身关系不大的内容。
可以看到,该模型在 EncoderDecoder
的框架下定义,分为 backbone
和 decode_head
两部分。 其中 backbone
为 50 层的 ResNetV1c
结构。 与分类所采用的 ResNet 结构有所不同,在后两个阶段不进行降采样,并使用空洞卷积保持了特征图的分辨率,在配置中以 dilations
和 strides
字段体现。由于整个网络只有 stem 的 1/4 降采样与第二个阶段的 1/2 降采样,特征图的大小为原图的 1/8 。
decode_head
为 FCNHead
类型,以 in_channels=2048
通道的特征图为输入,经过 num_convs=2
层卷积输出类别数为 num_classes=19
的预测图。 auxiliary_head
用于在训练时增加辅助监督,与 FCN 模型本身并没有太大关系,这里从略。
FCNHead
实现在 fcn_head.py 文件中,代码如下。 为了方便阅读,我们删去了不相关的代码,并将基类中的部分函数合并展示。
通过 forward
函数我们可以看出FCNHead
模块的计算流程。
首先,self._transform_inputs
用于融合多尺度特征,这份配置文件并没有涉及对应的内容。因此该函数的作用仅仅是选出 ResNet 最后一个阶段的输出,即 1/8 分辨率,2048 通道的特征图。
接下来,特征图经过 convs
模块,该模块在 __init__
函数中构建,由 num_convs=2
个卷积层构成。 两个卷积都使用 kernel_size=3
大小的卷积核,padding=kernel_size // 2
,步长默认为 1,因此不改变特征图的空间分辨率。
如果设置了 concat_input=True
,经过两个卷积层后的特征图还需要与原始的特征图在通道维度进行拼接。 经过拼接后的特征图在 self.cls_seg
函数中,经过大小为 1x1,通道数等于 num_classes
的 conv_seg
卷积层,输出 1/8 分辨率、num_classes
个通道的类别概率图。
在训练和推理过程中,我们还需要将类别概率图放大到原始图像的大小,作为输出或与真值相比计算损失函数。这个过程在 EncoderDecoder
类型的各种训练和推理函数中实现,但最终都是经由 mmseg.ops.resize 函数调用 torch.nn.functional.interpolate
函数,以双线性插值的方式实现。