MISSFormer

论文:《MISSFormer: An Effective Medical Image Segmentation Transformer》(TMI2022)

网络结构

1

编码器由Overlap Patch Embedding、Overlap Patch Merging以及Transformer Block with ReMix-FFN组成。

EMix-FFN:在Enhanced Mix块中引入递归跳跃连接,给定输入特征映射xin,应用深度卷积层捕获局部上下文,然后进行递归跳跃连接

2

5

Enhanced Transformer Context Bridge来捕获不同尺度特征的局部和全局相关性

3

4

步骤:

1.将编码器生成的多尺度的特征的h、w,变成一个序列

2.将四个输出拼接到一起

3.对拼接后的结果进行注意力的计算

4.自注意力输出的结果和输出拼接的结果进行跳跃连接,然后进行层归一化

5.把归一化的结果再分成四个输出

6.将这四个输出再进行EMix-FFN

7.将四个输出的N(h*w)拼接到一起,再与归一化的结果进行跳跃连接

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class BridgeLayer_4(nn.Module):
def __init__(self, dims, head, reduction_ratios):
super().__init__()

self.norm1 = nn.LayerNorm(dims)
self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios)
self.norm2 = nn.LayerNorm(dims)
self.mixffn1 = MixFFN_skip(dims,dims*4)
self.mixffn2 = MixFFN_skip(dims*2,dims*8)
self.mixffn3 = MixFFN_skip(dims*5,dims*20)
self.mixffn4 = MixFFN_skip(dims*8,dims*32)


def forward(self, inputs):
B = inputs[0].shape[0]
C = 64
if (type(inputs) == list):
# print("-----1-----")
c1, c2, c3, c4 = inputs
B, C, _, _= c1.shape
c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C) # 3136*64
c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C) # 1568*64
c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C) # 980*64
c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C) # 392*64

# print(c1f.shape, c2f.shape, c3f.shape, c4f.shape)
inputs = torch.cat([c1f, c2f, c3f, c4f], -2)
else:
B,_,C = inputs.shape

tx1 = inputs + self.attn(self.norm1(inputs))
tx = self.norm2(tx1)


tem1 = tx[:,:3136,:].reshape(B, -1, C)
tem2 = tx[:,3136:4704,:].reshape(B, -1, C*2)
tem3 = tx[:,4704:5684,:].reshape(B, -1, C*5)
tem4 = tx[:,5684:6076,:].reshape(B, -1, C*8)

m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C)
m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C)
m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C)
m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C)

t1 = torch.cat([m1f, m2f, m3f, m4f], -2)

tx2 = tx1 + t1


return tx2

上下文桥循环了四次:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class BridegeBlock_4(nn.Module):   #这是MISSFormer模型图中,上下文桥循环四次的代码
def __init__(self, dims, head, reduction_ratios):
super().__init__()
self.bridge_layer1 = BridgeLayer_4(dims, head, reduction_ratios) #定义上下文桥
self.bridge_layer2 = BridgeLayer_4(dims, head, reduction_ratios)
self.bridge_layer3 = BridgeLayer_4(dims, head, reduction_ratios)
self.bridge_layer4 = BridgeLayer_4(dims, head, reduction_ratios)

def forward(self, x: torch.Tensor) -> torch.Tensor:
bridge1 = self.bridge_layer1(x) #输入数据至第一个上下文桥,得到第一个输出
bridge2 = self.bridge_layer2(bridge1)#将上一个输出输入至第二个上下文桥
bridge3 = self.bridge_layer3(bridge2)#将上一个输出输入至第三个上下文桥
bridge4 = self.bridge_layer4(bridge3)#将上一个输出输入至第四个上下文桥

B,_,C = bridge4.shape #提取B和C大小
outs = [] #建立空列表,存储输出结果

sk1 = bridge4[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2) #输出转换为二维形式
sk2 = bridge4[:,3136:4704,:].reshape(B, 28, 28, C*2).permute(0,3,1,2)
sk3 = bridge4[:,4704:5684,:].reshape(B, 14, 14, C*5).permute(0,3,1,2)
sk4 = bridge4[:,5684:6076,:].reshape(B, 7, 7, C*8).permute(0,3,1,2)

outs.append(sk1) #将结果存储到列表中
outs.append(sk2)
outs.append(sk3)
outs.append(sk4)

return outs

6

实验结果

7