PSPNet 是 Hengshuang Zhao 等人与 2016 年提出的语义分割算法。 PSPNet 仍然基于空洞卷积改造的主干网络生成特征,并使用全卷积网络(FCN)产生类别掩码。 与以往工作不同的是,PSPNet 在 FCN 之前加入了一个池化金字塔结构。该结构可以有效地将图像的上下文信息融入到局部特征种,进而提高像素分类的准确性。
上下文信息的重要性
(回顾一下 FCN 的结构)
也就是说,特征图上某一个位置的特征,只包含了原图上对应位置附近一小块区域内的图像信息。
这对于语义分割任务是不利的,因为不同类别的物体可能在局部区域十分相似。为了能更准确地确定途中物体的类别,我们还需要结合周围图像的内容,这就是所谓的上下文(Context)信息。
下图为 PSPNet 论文中的一个例子。图中船的顶棚和汽车在外观上十分相似,从局部看很难区分二者。但如果我们知道这个周围都是水,就可以更准确地将这块区域预测为“船”类别,而不是“汽车”类别。而 PSPNet 最大的创新点就在于将上下文信息融入到局部特征中,以获得更准确的预测。
用池化金字塔获取上下文信息
我们知道池化,尤其是平均池化,可以将不同位置的特征融合在一起。 PSPNet 使用了不同尺度的池化层,以获取不同尺度的的上下文信息,如下图所示。
图中(b)为主干网络(含空洞卷积的 ResNet)计算出的原始特征图。图中(c)为池化金字塔,最上方红色的 1x1 的特征由原始特征图经过全局平均池化得来。它包含了全部空间位置的图像信息,是尺度最大的上下文特征。 下面橙色 2x2、绿色 3x3、蓝色 6x6 的特征图分别由原始特征图池化至对应的大小得来,这些上下文特征图包含了从大到小不同尺度的上下文信息。需要注意的是,特征图的尺寸越大,上下文的尺度就越小。
接下来,不同尺度的上下文特征通过 1x1 的卷积层,以降低特征图的通道数。为了平衡原始特征图和上下文特征图的权重,PSPNet 将上下文特称图的通道数压缩至 1/N,这里 N 为尺度的个数,再上图中为 4 个,即 1x1、2x2、3x3、6x6 四个尺度。
接下来,经过维度压缩的特征图通过双线性插值恢复空间尺寸,并与原始特征图在空间维度拼接在一起。这时,每个位置的特征包含两部分,一部分是主干网络提取的局部图像特征,另一部分为池化金字塔产生的上下文特征。
最后,如图中(d)部分所示,我们再用一个 3x3 的卷积网络预测出每个像素对应物体的类别,这与 FCN 相似。
使用辅助监督进行训练
ResNet 等工作已经表明,网络越深就越难以训练,尤其是低层的部分。 在 PSPNet 这篇工作中,作者还提出使用辅助监督信号,更好地训练主干网络的低层。
具体而言,我们需要在 ResNet 的主干网络上嫁接一个 FCN 的分支网络,使用倒数第二层特征产生类别掩码,与真值比较计算损失函数,并参与回传。经过实验比对,辅助损失函数的权重设置在 0.4 可以达到更好的效果。
辅助监督信号在 GoogLeNet 等工作中也有所使用,但在图像分割算法中非常常见。
Pytorch 实现
mmsegmentation 实现了 PSPNet,池化金字塔模块的实现在psp_head.py文件中,网络整体对应的配置文件为 pspnet_r50-d8.py。一些变种,如使用 ResNet101 作为主干网络的结构也包含在同级文件夹中。我们将代码摘录如下,并自顶向下进行讲解。
整体配置文件
网络整体的配置文件如下。
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='PSPHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
)
可以看到,网络整体采用了 EncoderDecoder
结构。其中 Encoder 部分,即主干网络,为使用了空洞卷积的 50 层 ResNetV1c 结构,可以产生原图变长 1/8 大小,通道数为 2048 的特征图。
Decoder 部分为 PSPHead 结构,以 ResNet 最后一层(由in_index=3
指定)的特征图为输入,分别使用尺度 pool_scales=(1, 2, 3, 6)
的池化层产生池化金字塔,并将上下文特征降低的通道数降低为 1/4 ,即channels=512
,最终通过卷积层输出 num_classes=19
个类别的掩码图。
辅助监督信号由 auxiliary_head
字段定义,该分支为 FCN 结构,以 ResNet 的倒数第二层特征图(由in_index=2
指定)为输入。该层特征图的通道数为 in_channels=1024
,经过 num_convs=1
层的 3x3 的卷积层产生一个 channels=256
的中间特征图,再使用 1x1 的卷积层产生 num_classes=19
类别的掩码图。需要注意的是,辅助监督信号的比重由loss_decode.loss_weight
指定为 0.4 。
另外需要注意的一个细节是,第一行norm_cfg
指定 BN 的类型为SyncBN
,这也是分割网络常用的归一化配置。
PSPHead
PSPHead 的实现如下。
class PSPHead(BaseDecodeHead):
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(PSPHead, self).__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales
self.psp_modules = PPM(
self.pool_scales,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
x = self._transform_inputs(inputs)
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
output = self.cls_seg(output)
return output
从forward
函数不难看出PSPHead
的逻辑。函数的参数 inputs
为 ResNet 产生的 4 阶段的特征图,_transform_input(inputs)
则按照配置文件取出最后一层的特征图。
psp_modules
为池化金字塔模块,该模块可以产生 1、2、3、6 四个池化尺度的上下文特征图,并且将其上采样至原始大小。
接下来,对四个尺度的上下文特征图和原始特征图在通道维度进行拼接,并通过一个 3x3 的卷积层——bottleneck
模块。最后通过 1x1 的卷积层cls_seg
模块产生类别掩码图。
池化金字塔 PPM
池化金字塔模块是 PSPNet 模型的核心,实现如下。
class PPM(nn.ModuleList):
"""Pooling Pyramid Module used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
act_cfg, align_corners):
super(PPM, self).__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)))
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
upsampled_ppm_out = resize(
ppm_out,
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
从理论讲解部分我们知道,池化金字塔包含若干个分支,每个分支包含一个均值池化层和一个 1x1 的卷积层。
在实现上,PPM
继承自 torch.nn.ModuleList
,并在 __init__
函数中加入了所有的分支子模块,即以下代码的部分。
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
不难看出,每一个分支都是由一个均值池化层和一个 1x1 的卷积层串联构成。
在 forward
函数中,主干网络产生的特征图分别输入至每个池化分支,产生对应尺度的上下文特征,再被上采样至原特征图的大小。最终,所有上采样后的上下文特征图被收集在一个列表中返回,供 PSPHead
模块使用。