我们在搜索阶段(通常在小的/局部空间)使用TEM 模块以便融合更多具有辨识度的特征。
class RFB_modified(nn.Module):
def __init__(self, in_channel, out_channel):
super(RFB_modified, self).__init__()
self.relu = nn.ReLU(True)
self.branch0 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
)
self.branch1 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
)
self.branch2 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
)
self.branch3 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
)
self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
self.conv_res = BasicConv2d(in_channel, out_channel, 1)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
x = self.relu(x_cat + self.conv_res(x))
return x
在聚合多个特征金字塔时仍有两个关键问题:那就是如何保持层内语义一致性和如何桥接层间的上下文内容。这里提出 近邻连接解码器(NCD)来解决这些问题。具体而言,通过近邻连接函数修改了部分解码器 (PDC)模块并得到三个提纯后的特征:$f _ { k } ^ { n c } = F _ { N C } \left( f _ { k } ^ { \prime } ; \mathbf { W } _ { N C } ^ { u } \right)$,其中 $k ∈ \{3, 4, 5\}$ 以及 $u ∈ \{1, 2, 3\}$,整个过程定义如下:
$$ \left\{ \begin{array} { l } f _ { 5 } ^ { n c } = f _ { 5 } ^ { \prime } \\ f _ { 4 } ^ { n c } = f _ { 4 } ^ { \prime } \otimes g \left[ \delta _ { \uparrow } ^ { 2 } \left( f _ { 5 } ^ { \prime } \right) ; \mathbf { W } _ { N C } ^ { 1 } \right] \\ f _ { 3 } ^ { n c } = f _ { 3 } ^ { \prime } \otimes g \left[ \delta _ { \uparrow } ^ { 2 } \left( f _ { 4 } ^ { n c } \right) ; \mathbf { W } _ { N C } ^ { 2 } \right] \otimes g \left[ \delta _ { \uparrow } ^ { 2 } \left( f _ { 4 } ^ { \prime } \right) ; \mathbf { W } _ { N C } ^ { 3 } \right] \end{array} \right. $$
其中 $g[·;\boldsymbol W^u_{NC}]$ 表示一个 3×3 卷积层接一个批归一化操作。为了确保候选特征之间的尺寸是匹配的,在元素级别的相乘 $⊗$ 之前运用上采样操作(例如两倍上采样)$\delta _ { \uparrow } ^ { 2 } ( \cdot )$。接着,将 $f _ { k } ^ { n c }$ , $k \in \{ 3,4,5 \}$ 传入近邻连接解码器(NCD)并生成粗糙的定位图 $C_6$。
class NeighborConnectionDecoder(nn.Module):
def __init__(self, channel):
super(NeighborConnectionDecoder, self).__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
self.conv5 = nn.Conv2d(3*channel, 1, 1)
def forward(self, x1, x2, x3):
x1_1 = x1
x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
x3_1 = self.conv_upsample2(self.upsample(x2_1)) * self.conv_upsample3(self.upsample(x2)) * x3
x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
x2_2 = self.conv_concat2(x2_2)
x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
x3_2 = self.conv_concat3(x3_2)
x = self.conv4(x3_2)
x = self.conv5(x)
return x
全局定位图 $C_6$ 由最高三层特征所生成,它仅仅捕捉了相对粗略的隐蔽物体的位置,而忽略结构和纹理细节。为了解决上述问题,本文提出了一个原则性的策略,通过抹除目标来提取具鉴别性的隐蔽区域。
# Group-Reversal Attention (GRA) Block
class GRA(nn.Module):
def __init__(self, channel, subchannel):
super(GRA, self).__init__()
self.group = channel//subchannel
self.conv = nn.Sequential(
nn.Conv2d(channel + self.group, channel, 3, padding=1), nn.ReLU(True),
)
self.score = nn.Conv2d(channel, 1, 3, padding=1)
def forward(self, x, y):
if self.group == 1:
x_cat = torch.cat((x, y), 1)
elif self.group == 2:
xs = torch.chunk(x, 2, dim=1)
x_cat = torch.cat((xs[0], y, xs[1], y), 1)
elif self.group == 4:
xs = torch.chunk(x, 4, dim=1)
x_cat = torch.cat((xs[0], y, xs[1], y, xs[2], y, xs[3], y), 1)
else:
raise Exception("Invalid Channel")
x = x + self.conv(x_cat)
y = y + self.score(x)
return x, y
class ReverseStage(nn.Module):
def __init__(self, channel):
super(ReverseStage, self).__init__()
self.weak_gra = GRA(channel, channel)
self.medium_gra = GRA(channel, 8)
self.strong_gra = GRA(channel, 1)
def forward(self, x, y):
# reverse guided block
y = -1 * (torch.sigmoid(y)) + 1
# three group-reversal attention blocks
x, y = self.weak_gra(x, y)
x, y = self.medium_gra(x, y)
_, y = self.strong_gra(x, y)
return y