1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
| class BridgeLayer_4(nn.Module): def __init__(self, dims, head, reduction_ratios): super().__init__()
self.norm1 = nn.LayerNorm(dims) self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios) self.norm2 = nn.LayerNorm(dims) self.mixffn1 = MixFFN_skip(dims,dims*4) self.mixffn2 = MixFFN_skip(dims*2,dims*8) self.mixffn3 = MixFFN_skip(dims*5,dims*20) self.mixffn4 = MixFFN_skip(dims*8,dims*32) def forward(self, inputs): B = inputs[0].shape[0] C = 64 if (type(inputs) == list): # print("-----1-----") c1, c2, c3, c4 = inputs B, C, _, _= c1.shape c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C) # 3136*64 c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C) # 1568*64 c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C) # 980*64 c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C) # 392*64 # print(c1f.shape, c2f.shape, c3f.shape, c4f.shape) inputs = torch.cat([c1f, c2f, c3f, c4f], -2) else: B,_,C = inputs.shape
tx1 = inputs + self.attn(self.norm1(inputs)) tx = self.norm2(tx1)
tem1 = tx[:,:3136,:].reshape(B, -1, C) tem2 = tx[:,3136:4704,:].reshape(B, -1, C*2) tem3 = tx[:,4704:5684,:].reshape(B, -1, C*5) tem4 = tx[:,5684:6076,:].reshape(B, -1, C*8)
m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C) m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C) m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C) m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C)
t1 = torch.cat([m1f, m2f, m3f, m4f], -2) tx2 = tx1 + t1
return tx2
|