Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

将Pure-Pytorch的RCCA模块应用到视频任务中loss没有完全收敛 #14

Open
Lanezzz opened this issue Apr 2, 2021 · 1 comment

Comments

@Lanezzz
Copy link

Lanezzz commented Apr 2, 2021

您好,我打算将您写的pytorch版本的RCCA模块应用到视频的不同帧之间,以获得帧与帧之间的注意力进而增强视频帧的特征表示。主要问题是loss没有完全收敛,维持在1-2中间。我想排除一下是不是我网络改的有问题,需要您的帮助!!!

主要任务是视频的显著性检测,取同一视频中任意两帧经过同一ResNet-101,获得 B x 256 x 47 x 47的特征,然后再输入到RCCA模块,先得到 Q_X , K_X , V_X , Q_Y, K_Y, V_Y,即得到两帧映射到Q,K,V空间的特征。然后再用 Q_X 和 K_Y 做相关性矩阵,作用到V_Y,然后是Q_Y 和 K_X 做相关性,作用到 V_X。 代码的实现如下,几乎没怎么改动,希望您能帮我看一眼,感谢!

`class RCCAModule(nn.Module):
def init(self, in_channels, out_channels = 256):
super(RCCAModule, self).init()

    #inter_channels = in_channels // 4


    self.cca = CrissCrossAttention(in_channels)

    self.convbX = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=True),
                               nn.BatchNorm2d(in_channels))

    self.convbY = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=True),
                                nn.BatchNorm2d(in_channels))

    self.bottleneckX = nn.Sequential(
        nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1, dilation=1, bias=True),
        nn.BatchNorm2d(out_channels),
        #nn.Dropout2d(0.1),  # dropout在这也会有用吗??
        )

    self.bottleneckY = nn.Sequential(
        nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1, dilation=1, bias=True),
        nn.BatchNorm2d(out_channels),
        #nn.Dropout2d(0.1),  # dropout在这也会有用吗??
        )



def forward(self, x, y, recurrence=2):
    #outputX = self.convaX(x)
    #outputY = self.convaY(y)
    outputX = x
    outputY = y
    for i in range(recurrence):
        outputX, outputY = self.cca(outputX, outputY)

    outputX = self.convbX(outputX)
    outputY = self.convbY(outputY)

    outputX = self.bottleneckX(torch.cat([x, outputX], 1))
    outputY = self.bottleneckY(torch.cat([y, outputY], 1))

    return outputX, outputY`

`class CrissCrossAttention(nn.Module):
""" Criss-Cross Attention Module"""
def init(self, in_dim):
super(CrissCrossAttention,self).init()
# 下面三个是转成Q,K,V之前的降维,V不变
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//2, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//2, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.softmax = Softmax(dim=3)
self.INF = INF
self.gamma1 = nn.Parameter(torch.zeros(1)) # 虽然初始化为0了,但是它是一个可以学习的参数,当插入在模型中时,最开始可以保证从
self.gamma2 = nn.Parameter(torch.zeros(1))
# self.gamma2 = torch.zeros(1).cuda().requires_grad_()

    # ImageNet上学来的特征,然后再慢慢学习,会得到一个值,这可以使得整个训练过程更加的平滑


def forward(self, x, y):

    m_batchsize, _, height, width = x.size()  # B x 2C x H x W ,m_batchsize = 2, _ = 256, height = 47, width = 47
    proj_query_X = self.query_conv(x) # 降维,我改成了128,即降维一半, B,C,H,W
    proj_query_X_H = proj_query_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) # BW,H,C
    proj_query_X_W = proj_query_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) # BH,W,C
    proj_key_X = self.key_conv(x) # 降维  B,C,H,W
    proj_key_X_H = proj_key_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,8,5, BW,C,H
    proj_key_X_W = proj_key_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,8,6, BH,C,W
    proj_value_X = self.value_conv(x)  # 2,64,5,6 就是没有降维而已
    proj_value_X_H = proj_value_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,64,5 BW,2C,H
    proj_value_X_W = proj_value_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,64,6 BH,2C,W


    proj_query_Y = self.query_conv(y) # 降维 B,C,W,H
    proj_query_Y_H = proj_query_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) # BW,H,C
    proj_query_Y_W = proj_query_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) # BH,W,C
    proj_key_Y = self.key_conv(y) # 降维  B,C,W,H
    proj_key_Y_H = proj_key_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,8,5, BW,C,H
    proj_key_Y_W = proj_key_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,8,6, BH,C,W
    proj_value_Y = self.value_conv(y)  # 2,64,5,6 就是没有降维而已
    proj_value_Y_H = proj_value_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,64,5 BW,2C,H
    proj_value_Y_W = proj_value_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,64,6 BH,2C,W

    A = torch.bmm(proj_query_X_H, proj_key_Y_H)
    B = self.INF(m_batchsize, height, width)
    C = A+B
    # BW,H,H的注意力图中每一列包含了查询帧中的每一个H信息,BH,W,W同理
    energy_X_H = (torch.bmm(proj_query_X_H, proj_key_Y_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3) # B,H,W,H
    energy_X_W = torch.bmm(proj_query_X_W, proj_key_Y_W).view(m_batchsize,height,width,width)  # B,H,W,W
    concateX = self.softmax(torch.cat([energy_X_H, energy_X_W], 3))  # B,H,W,H+W

    att_X_H = concateX[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)  # BW,H,H
    att_X_W = concateX[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)  # BH,W,W

    # 与X一样
    energy_Y_H = (torch.bmm(proj_query_Y_H, proj_key_X_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
    energy_Y_W = torch.bmm(proj_query_Y_W, proj_key_X_W).view(m_batchsize,height,width,width)  
    concateY = self.softmax(torch.cat([energy_Y_H, energy_Y_W], 3)) 

    att_Y_H = concateY[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)  
    att_Y_W = concateY[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)  
    # 因为这边permute()相当于做了个转置,所以应当是每一行,包含了查询帧中的每一个H信息
    out_Y_H = torch.bmm(proj_value_Y_H, att_X_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)  
    out_Y_W = torch.bmm(proj_value_Y_W, att_X_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)  

    out_X_H = torch.bmm(proj_value_X_H, att_Y_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)  
    out_X_W = torch.bmm(proj_value_X_W, att_Y_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)  


    return (self.gamma1 * (out_X_H + out_X_W) + x), (self.gamma2 * (out_Y_H + out_Y_W) + y)

`
另外这部分的初试话,我是卷积权重 kaiming初始化,偏置0,BN层权重初始化为1,偏置0
@Serge-weihao

@Serge-weihao
Copy link
Owner

你加了残差都不收敛,考虑是不是数据(数据增强)出了问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants