摘要

本文总结了各种注意力,即插即用,方便大家将注意力加到自己的论文中。
在这里插入图片描述

SE

论文标题:Squeeze-and-Excitation Networks

论文链接:https://arxiv.org/abs/1709.01507 或 https://arxiv.org/pdf/1709.01507.pdf(直接下载PDF版本)

SE(Squeeze-and-Excitation)注意力机制是一种基于特征通道的注意力机制,用于加强网络对不同特征通道的关注度,从而提高网络的性能。以下是对SE注意力机制的详细解析:

一、基本原理

SE注意力机制的核心思想是通过学习一个表示通道间关系的注意力权重向量,将通道间的信息进行重新分配,从而增强网络对重要特征通道的关注度。这一机制包括两个关键步骤:Squeeze(压缩)和Excitation(激励)。

  1. Squeeze(压缩)

    • 通过一个全局平均池化层(Global Average Pooling),将每个通道的特征图压缩成一个标量值,这个标量值可以看作是该通道的全局特征描述符。这一步骤实现了对特征图的空间维度的压缩,只保留通道维度的信息。
  2. Excitation(激励)

    • 通过一个全连接层(Fully Connected Layer),将全局特征描述符映射成一个权重向量。这个权重向量表示了每个通道的重要性。
    • 接着,使用一个激活函数(如sigmoid函数)对权重向量进行归一化,确保每个通道的权重值都在0到1之间。这样,就得到了一个注意力权重向量,用于表示每个通道的重要程度。

二、作用过程

将得到的注意力权重向量与原始特征图相乘,得到加权后的特征图。这一步实现了通道特征的重新分配,使得网络能够更多地关注那些重要的特征通道,从而提高网络的性能。

三、应用场景

SE注意力机制在计算机视觉领域有着广泛的应用,如图像分类、目标检测和图像分割等任务。通过引入SE注意力机制,这些任务中的网络模型能够更好地学习到数据中的重要特征,从而提高模型的准确性和鲁棒性。

四、优缺点

优点

  • SE注意力机制计算简单,能够有效地提取全局特征。
  • 通过学习通道间的相关性,能够自适应地调整通道特征的重要性,提高模型的性能。

缺点

  • SE注意力机制只考虑了通道维度上的注意力,没有考虑空间维度上的注意力。因此,在处理需要空间信息的任务时,可能效果有限。
  • 在某些场景下,对于通道数较少的网络结构,SE注意力机制可能无法充分发挥其作用。

综上所述,SE注意力机制是一种简单而有效的注意力机制,它通过学习和调整通道特征的重要性,提高了网络模型的性能。然而,它也存在着一定的局限性,需要根据具体任务和网络结构进行选择和调整。

import torch  
from torch import nn  
  
class SEAttention(nn.Module):  
    """  
    SENet(Squeeze-and-Excitation Networks)中的注意力模块。  
    通过全局平均池化后,使用两个全连接层来学习通道间的相关性,  
    最后通过sigmoid激活函数得到每个通道的权重,用于对输入特征进行重标定。  
  
    Args:  
        channel (int): 输入特征的通道数。  
        reduction (int): 第一个全连接层的压缩比例,用于减少参数和计算量。  
    """  
  
    def __init__(self, channel=512, reduction=16):  
        super(SEAttention, self).__init__()  
        # 使用自适应平均池化将特征图的空间维度压缩为1x1  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        # 定义两个全连接层,中间使用ReLU激活,最后使用Sigmoid得到权重  
        self.fc = nn.Sequential(  
            nn.Linear(channel, channel // reduction, bias=False),  # 压缩通道数  
            nn.ReLU(inplace=True),  
            nn.Linear(channel // reduction, channel, bias=False),  # 恢复通道数  
            nn.Sigmoid()  # 得到每个通道的权重  
        )  
  
    def forward(self, x):  
        """  
        前向传播函数。  
  
        Args:  
            x (torch.Tensor): 输入特征图,形状为(batch_size, channel, height, width)。  
  
        Returns:  
            torch.Tensor: 经过注意力机制重标定后的特征图,形状与输入相同。  
        """  
        b, c, _, _ = x.size()  # 获取批次大小、通道数、高度和宽度  
        # 通过全局平均池化压缩空间维度  
        y = self.avg_pool(x).view(b, c)  # 形状变为(batch_size, channel)  
        # 通过全连接层学习通道间的相关性,并应用sigmoid激活得到权重  
        y = self.fc(y).view(b, c, 1, 1)  # 形状调整为(batch_size, channel, 1, 1)以便与输入特征图相乘  
        # 使用权重对输入特征图进行重标定  
        return x * y.expand_as(x)  
  
if __name__ == '__main__':  
    # 创建一个随机的输入张量,模拟一批数据  
    input_tensor = torch.randn(64, 512, 20, 20)  # 形状为(batch_size, channel, height, width)  
    # 实例化SEAttention模块  
    se_attention = SEAttention(channel=512, reduction=8)  
    # 通过模块处理输入张量  
    output_tensor = se_attention(input_tensor)  
    # 打印输出张量的形状,应与输入相同  
    print(output_tensor.shape)

A2-Nets: Double Attention Networks

A2-Nets: Double Attention Networks 是一种创新的网络架构,其核心思想在于通过双注意力机制来捕获和分配图像或视频中的关键特征。以下是对该网络的详细解析:

一、论文基本信息

  • 论文标题:A2-Nets: Double Attention Networks
  • 论文链接https://arxiv.org/abs/1810.11579
  • 发布时间:2018年(具体日期可能因版本更新而有所不同)

二、核心思想

A2-Nets的核心思想是将整个空间的关键特征收集到一个紧凑的集合中,然后自适应地将其分布到每个位置。这样,后续的卷积层即使没有很大的接收域,也可以感知整个空间的特征。这一思想通过两步双注意机制实现:

  1. Feature Gathering(特征收集)

    • 使用second-order attention pooling(二阶注意力池化)将整个空间的特征集合成一个紧凑的集合。这一步有选择地从整个空间中收集关键特征,隐式地计算池化特征的二阶统计,从而捕获复杂的外观和运动相关性。
  2. Feature Distribution(特征分配)

    • 采用另一种注意力机制,将收集到的关键特征自适应地分配到每个位置。这些特征有助于补充高级任务的每个时空位置,使得每个位置都能接收到其自定义的全局信息。

三、网络架构

A2-Nets中的双注意块(Double Attention Block)是实现上述思想的关键组件。该块包含两个主要部分:

  1. 第一级注意力操作

    • 通过二级注意池化(second-order attention pooling)将整个空间的特征集合成一个紧凑的集合。这一步通常涉及特征图的重新排列和注意力权重的计算,以选择关键特征。
  2. 第二级注意力操作

    • 使用另一种注意力机制将收集到的关键特征分配到每个位置。这一步通常涉及注意力向量的生成和全局特征描述子的分配,以确保每个位置都能接收到适当的全局信息。

四、优势与特点

  1. 高效性

    • 相比传统的CNN模型,A2-Nets能够更有效地捕获长距离特征关系,而不需要显著增加模型的深度和复杂度。
  2. 灵活性

    • 双注意块可以作为一个即插即用的模块嵌入到各种网络架构中,以提升其性能。
  3. 性能提升

    • 在多个计算机视觉任务中,如图像分类、目标检测等,A2-Nets均取得了显著的性能提升。

五、应用场景

A2-Nets在计算机视觉领域具有广泛的应用前景,包括但不限于图像分类、目标检测、图像分割、视频识别等任务。通过引入双注意力机制,该网络能够更有效地处理复杂场景下的图像和视频数据,提高模型的准确性和鲁棒性。

六、总结

A2-Nets: Double Attention Networks通过提出双注意力机制,实现了对图像或视频中关键特征的有效捕获和分配。该网络架构在多个计算机视觉任务中均取得了显著的性能提升,为相关领域的研究提供了新的思路和方法。

import torch  
from torch import nn  
from torch.nn import functional as F  
  
  
class DoubleAttention(nn.Module):  
    """  
    双注意力模块,结合了特征门控和特征分布机制。  
  
    Args:  
        in_channels (int): 输入特征的通道数。  
        c_m (int): 卷积层convA的输出通道数。  
        c_n (int): 卷积层convB和convV的输出通道数。  
        reconstruct (bool): 是否在注意力处理后使用卷积层进行重构。  
    """  
  
    def __init__(self, in_channels, c_m=128, c_n=128, reconstruct=True):  
        super(DoubleAttention, self).__init__()  
        self.in_channels = in_channels  
        self.reconstruct = reconstruct  
        self.c_m = c_m  
        self.c_n = c_n  
        # 定义三个卷积层  
        self.convA = nn.Conv2d(in_channels, c_m, 1)  
        self.convB = nn.Conv2d(in_channels, c_n, 1)  
        self.convV = nn.Conv2d(in_channels, c_n, 1)  
        # 如果需要重构,则添加一个卷积层  
        if self.reconstruct:  
            self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size=1)  
  
    def forward(self, x):  
        """  
        前向传播函数。  
  
        Args:  
            x (torch.Tensor): 输入特征图,形状为(batch_size, in_channels, height, width)。  
  
        Returns:  
            torch.Tensor: 经过双注意力机制处理后的特征图,形状可能根据reconstruct参数变化。  
        """  
        b, c, h, w = x.shape  
        assert c == self.in_channels, "输入通道数与预期不符"  
  
        # 通过三个不同的卷积层得到不同的特征图  
        A = self.convA(x)  # b, c_m, h, w  
        B = self.convB(x)  # b, c_n, h, w  
        V = self.convV(x)  # b, c_n, h, w  
  
        # 将特征图A展平以便进行矩阵乘法  
        tmpA = A.view(b, self.c_m, -1)  # b, c_m, h*w  
  
        # 计算注意力图  
        attention_maps = F.softmax(B.view(b, self.c_n, -1), dim=-1)  # b, c_n, h*w  
        attention_vectors = F.softmax(V.view(b, self.c_n, -1), dim=1)  # b, c_n, h*w  
  
        # 步骤1: 特征门控  
        global_descriptors = torch.bmm(tmpA, attention_maps.permute(0, 2, 1))  # b, c_m, c_n  
  
        # 步骤2: 特征分布  
        tmpZ = torch.bmm(global_descriptors, attention_vectors)  # b, c_m, h*w  
        tmpZ = tmpZ.view(b, self.c_m, h, w)  # b, c_m, h, w  
  
        # 如果需要重构,则通过卷积层处理tmpZ  
        if self.reconstruct:  
            tmpZ = self.conv_reconstruct(tmpZ)  
  
        return tmpZ  
  
if __name__ == '__main__':  
    # 创建一个随机的输入张量  
    input_tensor = torch.randn(64, 512, 20, 20)  
    # 实例化双注意力模块  
    double_attention = DoubleAttention(512)  
    # 通过模块处理输入张量  
    output_tensor = double_attention(input_tensor)  
    # 打印输出张量的形状  
    print(output_tensor.shape)

BAM

BAM论文链接
BAM: https://arxiv.org/pdf/1807.06514

BAM(Bottleneck Attention Module)注意力机制是一种用于计算机视觉领域的深度学习模型结构,旨在提高神经网络对图像的特征提取和感受野处理能力。以下是关于BAM注意力机制的详细解析:

一、定义与概述

BAM模块是一种引入通道注意力和空间注意力机制的深度学习模块,通过自适应地调整通道和空间特征响应,帮助模型更好地捕获重要信息,提高性能。它已经被广泛应用于计算机视觉任务中,特别是图像分类和物体检测领域。

二、核心组件

BAM模块由两个关键组件组成:通道注意力机制和空间注意力机制。

  1. 通道注意力机制

    • 目的:自适应地调整每个通道的特征响应。
    • 步骤
      1. 全局平均池化:对每个通道的特征图执行全局平均池化操作,将每个通道的特征图池化为一个标量值,该值代表该通道特征的全局重要性。
      2. 全连接层:将每个通道的全局平均池化结果通过两个全连接层传递,学习如何加权每个通道的特征响应。
      3. Sigmoid激活函数:在全连接层之后,通过Sigmoid激活函数将输出限制在0到1之间,以表示每个通道的权重。
      4. 通道特征加权:将通道的特征响应与学习到的通道权重相乘,得到加权后的通道特征响应。
  2. 空间注意力机制

    • 目的:处理不同空间位置的特征。
    • 步骤
      1. 全局最大池化:对每个通道的特征图执行全局最大池化操作,将每个通道的特征图池化为一个标量值,该值代表该通道特征的局部重要性。
      2. 全连接层:与通道注意力机制类似,将每个通道的全局最大池化结果通过两个全连接层传递,学习如何加权每个通道的特征响应。
      3. Sigmoid激活函数:在全连接层之后,通过Sigmoid激活函数将输出限制在0到1之间,以表示每个通道的权重。
      4. 空间特征加权:将通道的特征响应与学习到的空间权重相乘,得到加权后的通道特征响应。

三、工作机制

通道注意力机制和空间注意力机制的输出分别通过相乘的方式融合在一起,得到最终的BAM模块输出。这个输出是经过自适应调整的通道特征响应,对于不同通道和空间位置的特征都有不同程度的强调。

四、优势与应用

  • 优势:BAM模块通过自适应地调整通道和空间特征响应,有助于模型更好地捕获重要信息,提高性能。其轻量化和高效的设计使得它可以轻松集成到各种现有的深度学习架构中。
  • 应用:BAM模块已经在多个基准数据集(如CIFAR-100、ImageNet-1K、VOC 2007和MS COCO)上进行了广泛的实验验证,并在分类和检测任务中取得了显著的性能提升。它已经成为一系列模式识别应用的强大工具,包括分类、检测、分割和控制问题等。

五、总结

BAM注意力机制通过引入通道注意力和空间注意力机制,提高了深度神经网络对图像特征的提取和感受野处理能力。其简单而有效的设计使得它成为计算机视觉领域中的一个重要工具,为各种视觉任务的性能提升提供了新的思路和方法。

import torch  
from torch import nn  
  
def autopad(kernel_size, padding=None, dilation=1):  
    """  
    计算并返回'same'形状输出所需的自动填充大小。  
  
    Args:  
        kernel_size (int or list of int): 卷积核大小。  
        padding (int or list of int, optional): 填充大小。如果为None,则自动计算。  
        dilation (int, optional): 扩张率。默认为1。  
  
    Returns:  
        int or list of int: 所需的填充大小。  
    """  
    if dilation > 1:  
        kernel_size = dilation * (kernel_size - 1) + 1 if isinstance(kernel_size, int) else [dilation * (x - 1) + 1 for x in kernel_size]  
    if padding is None:  
        padding = kernel_size // 2 if isinstance(kernel_size, int) else [x // 2 for x in kernel_size]  
    return padding  
  
  
class Flatten(nn.Module):  
    """  
    将输入张量展平为二维张量。  
    """  
    def forward(self, x):  
        return x.view(x.size(0), -1)  
  
  
class ChannelAttention(nn.Module):  
    """  
    通道注意力机制模块。  
    """  
    def __init__(self, in_channels, reduction=16, num_layers=3):  
        """  
        初始化通道注意力模块。  
  
        Args:  
            in_channels (int): 输入通道数。  
            reduction (int, optional): 通道数减少的比例。默认为16。  
            num_layers (int, optional): 内部全连接层的数量。默认为3。  
        """  
        super(ChannelAttention, self).__init__()  
        self.avgpool = nn.AdaptiveAvgPool2d(1)  
        gate_channels = [in_channels]  
        gate_channels += [in_channels // reduction] * num_layers  
        gate_channels += [in_channels]  
  
        self.ca = nn.Sequential()  
        self.ca.add_module('flatten', Flatten())  
        for i in range(len(gate_channels) - 2):  
            self.ca.add_module(f'fc_{i}', nn.Linear(gate_channels[i], gate_channels[i + 1]))  
            self.ca.add_module(f'bn_{i}', nn.BatchNorm1d(gate_channels[i + 1]))  
            self.ca.add_module(f'relu_{i}', nn.ReLU())  
        self.ca.add_module('last_fc', nn.Linear(gate_channels[-2], gate_channels[-1]))  
  
    def forward(self, x):  
        """  
        前向传播。  
  
        Args:  
            x (torch.Tensor): 输入张量。  
  
        Returns:  
            torch.Tensor: 经过通道注意力加权后的张量。  
        """  
        res = self.avgpool(x)  
        res = self.ca(res)  
        res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)  
        return res
  
# 空间注意力模块  
class SpatialAttention(nn.Module):  
    def __init__(self, in_channels, reduction=16, num_layers=3, dilation=2):  
        """  
        初始化空间注意力模块。  
  
        Args:  
            in_channels (int): 输入通道数。  
            reduction (int, optional): 通道数减少的比例。默认为16。  
            num_layers (int, optional): 内部卷积层的数量。默认为3。  
            dilation (int, optional): 卷积层的扩张率。默认为2。  
        """  
        super(SpatialAttention, self).__init__()  
        self.sa = nn.Sequential()  
        # 第一个卷积层,用于减少通道数  
        self.sa.add_module('conv_reduce', nn.Conv2d(kernel_size=1, in_channels=in_channels, out_channels=in_channels // reduction))  
        self.sa.add_module('bn_reduce', nn.BatchNorm2d(in_channels // reduction))  
        self.sa.add_module('relu_reduce', nn.ReLU())  
        # 添加多个卷积层  
        for i in range(num_layers):  
            self.sa.add_module(f'conv_{i}', nn.Conv2d(kernel_size=3, in_channels=in_channels // reduction,  
                                                     out_channels=in_channels // reduction,  
                                                     padding=autopad(3, None, dilation), dilation=dilation))  
            self.sa.add_module(f'bn_{i}', nn.BatchNorm2d(in_channels // reduction))  
            self.sa.add_module(f'relu_{i}', nn.ReLU())  
        # 最后一个卷积层,输出单通道特征图  
        self.sa.add_module('last_conv', nn.Conv2d(in_channels // reduction, 1, kernel_size=1))  
  
    def forward(self, x):  
        """  
        前向传播。  
  
        Args:  
            x (torch.Tensor): 输入张量。  
  
        Returns:  
            torch.Tensor: 经过空间注意力加权后的张量(单通道),之后将用于扩展。  
        """  
        res = self.sa(x)  
        res = res.expand_as(x)  # 将单通道张量扩展为与输入相同的形状  
        return res  
  
  
# BAM块,结合了通道注意力和空间注意力  
class BAMBlock(nn.Module):  
    def __init__(self, in_channels=512, reduction=16, dilation=2):  
        """  
        初始化BAM块。  
  
        Args:  
            in_channels (int, optional): 输入通道数。默认为512。  
            reduction (int, optional): 通道数减少的比例。默认为16。  
            dilation (int, optional): 空间注意力中卷积层的扩张率。默认为2。  
        """  
        super(BAMBlock, self).__init__()  
        self.ca = ChannelAttention(in_channels=in_channels, reduction=reduction)  
        self.sa = SpatialAttention(in_channels=in_channels, reduction=reduction, dilation=dilation)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        """  
        前向传播。  
  
        Args:  
            x (torch.Tensor): 输入张量。  
  
        Returns:  
            torch.Tensor: 经过BAM块处理后的输出张量。  
        """  
        sa_out = self.sa(x)  # 空间注意力输出  
        ca_out = self.ca(x)  # 通道注意力输出  
        # 将空间注意力和通道注意力相加,并通过sigmoid激活函数得到权重  
        weight = self.sigmoid(sa_out + ca_out)  
        # 将权重应用于输入张量,并进行残差连接  
        out = (1 + weight) * x  
        return out  
  
  
# 测试BAM块  
if __name__ == '__main__':  
    input = torch.randn(64, 512, 7, 7)  
    bam = BAMBlock(in_channels=512, reduction=16, dilation=2)  
    output = bam(input)  
    print(output.shape)  # 应该输出 shape.

BiFormer

BiFormer是一种具有双层路由注意力的视觉Transformer(Vision Transformer,简称ViT)模型,该模型在CVPR 2023(计算机视觉与模式识别会议)上被提出。以下是对BiFormer的详细解析:

一、基本信息

二、核心思想

BiFormer的核心思想是通过双层路由注意力机制来实现动态稀疏注意力,以减轻传统Transformer在计算所有空间位置上的token交互时所带来的巨大计算负担和内存占用。具体来说,双层路由注意力包括在粗略的区域级别过滤掉不相关的键值对,然后在剩余候选区域(即路由区域)的并集中应用细粒度的token-to-token注意力。

三、方法与技术细节

  1. 双层路由注意力

    • 第一层(区域级别):构建区域级关联图,并通过修剪仅保留每个节点的前k个连接,从而在粗粒度上过滤掉最不相关的键值对。
    • 第二层(token级别):在剩余候选区域(路由区域)的联合中,应用细粒度的token-to-token注意力。为了避免稀疏矩阵乘法在现代GPU中的低效率问题,通过收集key/value tokens来执行密集矩阵乘法。
  2. BiFormer模型构建

    • 利用所提出的双层路由注意力作为核心构建块,构建了一种新的通用ViT模型,即BiFormer。
    • BiFormer能够以查询自适应的方式处理一小部分相关标记,而不会分散其他不相关标记的注意力,从而具有良好的性能和高计算效率。

四、实验结果与应用

BiFormer在多个计算机视觉任务中表现出了卓越的性能,包括图像分类、目标检测和语义分割等。具体实验结果表明,BiFormer能够以较少的计算资源和内存占用,达到甚至超过现有模型的性能水平。

此外,BiFormer作为一种即插即用的模块,可以方便地集成到现有的深度学习架构中,以提升模型的性能。例如,在YOLOv5等目标检测模型中引入BiFormer作为主干网络,可以显著提高小目标检测的性能。

五、总结

BiFormer是一种具有双层路由注意力的视觉Transformer模型,通过动态稀疏注意力机制实现了高效的计算分配和内存使用。该模型在多个计算机视觉任务中表现出了卓越的性能,并具有广泛的应用前景。随着深度学习技术的不断发展,BiFormer有望成为未来计算机视觉领域的重要研究方向之一。

"""
Core of BiFormer, Bi-Level Routing Attention.

To be refactored.

author: ZHU Lei
github: https://github.com/rayleizhu
email: ray.leizhu@outlook.com

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from typing import Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, LongTensor


class TopkRouting(nn.Module):
    """
    differentiable topk routing with scaling
    Args:
        qk_dim: int, feature dimension of query and key
        topk: int, the 'topk'
        qk_scale: int or None, temperature (multiply) of softmax activation
        with_param: bool, wether inorporate learnable params in routing unit
        diff_routing: bool, wether make routing differentiable
        soft_routing: bool, wether make output value multiplied by routing weights
    """
    def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
        super().__init__()
        self.topk = topk
        self.qk_dim = qk_dim
        self.scale = qk_scale or qk_dim ** -0.5
        self.diff_routing = diff_routing
        # TODO: norm layer before/after linear?
        self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
        # routing activation
        self.routing_act = nn.Softmax(dim=-1)
    
    def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:
        """
        Args:
            q, k: (n, p^2, c) tensor
        Return:
            r_weight, topk_index: (n, p^2, topk) tensor
        """
        if not self.diff_routing:
            query, key = query.detach(), key.detach()
        query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c) 
        attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)
        topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)
        r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)
        
        return r_weight, topk_index
        

class KVGather(nn.Module):
    def __init__(self, mul_weight='none'):
        super().__init__()
        assert mul_weight in ['none', 'soft', 'hard']
        self.mul_weight = mul_weight

    def forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):
        """
        r_idx: (n, p^2, topk) tensor
        r_weight: (n, p^2, topk) tensor
        kv: (n, p^2, w^2, c_kq+c_v)

        Return:
            (n, p^2, topk, w^2, c_kq+c_v) tensor
        """
        # select kv according to routing index
        n, p2, w2, c_kv = kv.size()
        topk = r_idx.size(-1)
        # print(r_idx.size(), r_weight.size())
        # FIXME: gather consumes much memory (topk times redundancy), write cuda kernel? 
        topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpy
                                dim=2,
                                index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv)
                               )

        if self.mul_weight == 'soft':
            topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)
        elif self.mul_weight == 'hard':
            raise NotImplementedError('differentiable hard routing TBA')
        return topk_kv

class QKVLinear(nn.Module):
    def __init__(self, dim, qk_dim, bias=True):
        super().__init__()
        self.dim = dim
        self.qk_dim = qk_dim
        self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)
    
    def forward(self, x):
        q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1)
        return q, kv
        # q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)
        # return q, k, v

class BiLevelRoutingAttention(nn.Module):
    """
    n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
    kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
    topk: topk for window filtering
    param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
    param_routing: extra linear for routing
    diff_routing: wether to set routing differentiable
    soft_routing: wether to multiply soft routing weights 
    """
    def __init__(self, dim, n_win=7, num_heads=8, qk_dim=None, qk_scale=None,
                 kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
                 topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,
                 auto_pad=True):
        super().__init__()
        # local attention setting
        self.dim = dim
        self.n_win = n_win  # Wh, Ww
        self.num_heads = num_heads
        self.qk_dim = qk_dim or dim
        assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'
        self.scale = qk_scale or self.qk_dim ** -0.5


        ################side_dwconv (i.e. LCE in ShuntedTransformer)###########
        self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
                    lambda x: torch.zeros_like(x)
        
        ################ global routing setting #################
        self.topk = topk
        self.param_routing = param_routing
        self.diff_routing = diff_routing
        self.soft_routing = soft_routing
        # router
        assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
        self.router = TopkRouting(qk_dim=self.qk_dim,
                                  qk_scale=self.scale,
                                  topk=self.topk,
                                  diff_routing=self.diff_routing,
                                  param_routing=self.param_routing)
        if self.soft_routing: # soft routing, always diffrentiable (if no detach)
            mul_weight = 'soft'
        elif self.diff_routing: # hard differentiable routing
            mul_weight = 'hard'
        else:  # hard non-differentiable routing
            mul_weight = 'none'
        self.kv_gather = KVGather(mul_weight=mul_weight)

        # qkv mapping (shared by both global routing and local attention)
        self.param_attention = param_attention
        if self.param_attention == 'qkvo':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Linear(dim, dim)
        elif self.param_attention == 'qkv':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Identity()
        else:
            raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')
        
        self.kv_downsample_mode = kv_downsample_mode
        self.kv_per_win = kv_per_win
        self.kv_downsample_ratio = kv_downsample_ratio
        self.kv_downsample_kenel = kv_downsample_kernel
        if self.kv_downsample_mode == 'ada_avgpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'ada_maxpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'maxpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'avgpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'identity': # no kv downsampling
            self.kv_down = nn.Identity()
        elif self.kv_downsample_mode == 'fracpool':
            # assert self.kv_downsample_ratio is not None
            # assert self.kv_downsample_kenel is not None
            # TODO: fracpool
            # 1. kernel size should be input size dependent
            # 2. there is a random factor, need to avoid independent sampling for k and v 
            raise NotImplementedError('fracpool policy is not implemented yet!')
        elif kv_downsample_mode == 'conv':
            # TODO: need to consider the case where k != v so that need two downsample modules
            raise NotImplementedError('conv policy is not implemented yet!')
        else:
            raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')

        # softmax for local attention
        self.attn_act = nn.Softmax(dim=-1)

        self.auto_pad=auto_pad

    def forward(self, x, ret_attn_mask=False):
        """
        x: NHWC tensor

        Return:
            NHWC tensor
        """
        x = rearrange(x, "n c h w -> n h w c")
        if self.auto_pad:
            N, H_in, W_in, C = x.size()

            pad_l = pad_t = 0
            pad_r = (self.n_win - W_in % self.n_win) % self.n_win
            pad_b = (self.n_win - H_in % self.n_win) % self.n_win
            x = F.pad(x, (0, 0, # dim=-1
                          pad_l, pad_r, # dim=-2
                          pad_t, pad_b)) # dim=-3
            _, H, W, _ = x.size() # padded size
        else:
            N, H, W, C = x.size()
            assert H%self.n_win == 0 and W%self.n_win == 0 #


        # patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size
        x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)

        # q: (n, p^2, w, w, c_qk)
        # kv: (n, p^2, w, w, c_qk+c_v)
        # NOTE: separte kv if there were memory leak issue caused by gather
        q, kv = self.qkv(x) 

        # pixel-wise qkv
        # q_pix: (n, p^2, w^2, c_qk)
        # kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)
        q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')
        kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
        kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)

        q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)

        # NOTE: call contiguous to avoid gradient warning when using ddp
        lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())
        lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)


        r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors

        kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)
        k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)
        # kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)
        # v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)
        
        ######### do attention as normal ####################
        k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?
        v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)
        q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)

        # param-free multihead attention
        attn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)
        attn_weight = self.attn_act(attn_weight)
        out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)
        out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
                        h=H//self.n_win, w=W//self.n_win)

        out = out + lepe
        # output linear
        out = self.wo(out)

        # NOTE: use padding for semantic segmentation
        # crop padded region
        if self.auto_pad and (pad_r > 0 or pad_b > 0):
            out = out[:, :H_in, :W_in, :].contiguous()

        if ret_attn_mask:
            return out, r_weight, r_idx, attn_weight
        else:
            return rearrange(out, "n h w c -> n c h w")

class Attention(nn.Module):
    """
    vanilla attention
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        """
        args:
            x: NCHW tensor
        return:
            NCHW tensor
        """
        _, _, H, W = x.size()
        x = rearrange(x, 'n c h w -> n (h w) c')
        
        #######################################
        B, N, C = x.shape        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        #######################################

        x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)
        return x

class AttentionLePE(nn.Module):
    """
    vanilla attention
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
                    lambda x: torch.zeros_like(x)

    def forward(self, x):
        """
        args:
            x: NCHW tensor
        return:
            NCHW tensor
        """
        _, _, H, W = x.size()
        x = rearrange(x, 'n c h w -> n (h w) c')
        
        #######################################
        B, N, C = x.shape        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        lepe = self.lepe(rearrange(x, 'n (h w) c -> n c h w', h=H, w=W))
        lepe = rearrange(lepe, 'n c h w -> n (h w) c')

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = x + lepe

        x = self.proj(x)
        x = self.proj_drop(x)
        #######################################

        x = rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)
        return x

def _grid2seq(x:Tensor, region_size:Tuple[int], num_heads:int):
    """
    Args:
        x: BCHW tensor
        region size: int
        num_heads: number of attention heads
    Return:
        out: rearranged x, has a shape of (bs, nhead, nregion, reg_size, head_dim)
        region_h, region_w: number of regions per col/row
    """
    B, C, H, W = x.size()
    region_h, region_w =  H//region_size[0],  W//region_size[1]
    x = x.view(B, num_heads, C//num_heads, region_h, region_size[0], region_w, region_size[1])
    x = torch.einsum('bmdhpwq->bmhwpqd', x).flatten(2, 3).flatten(-3, -2) # (bs, nhead, nregion, reg_size, head_dim)
    return x, region_h, region_w


def _seq2grid(x:Tensor, region_h:int, region_w:int, region_size:Tuple[int]):
    """
    Args: 
        x: (bs, nhead, nregion, reg_size^2, head_dim)
    Return:
        x: (bs, C, H, W)
    """
    bs, nhead, nregion, reg_size_square, head_dim = x.size()
    x = x.view(bs, nhead, region_h, region_w, region_size[0], region_size[1], head_dim)
    x = torch.einsum('bmhwpqd->bmdhpwq', x).reshape(bs, nhead*head_dim,
        region_h*region_size[0], region_w*region_size[1])
    return x


def regional_routing_attention_torch(
    query:Tensor, key:Tensor, value:Tensor, scale:float,
    region_graph:LongTensor, region_size:Tuple[int],
    kv_region_size:Optional[Tuple[int]]=None,
    auto_pad=True)->Tensor:
    """
    Args:
        query, key, value: (B, C, H, W) tensor
        scale: the scale/temperature for dot product attention
        region_graph: (B, nhead, h_q*w_q, topk) tensor, topk <= h_k*w_k
        region_size: region/window size for queries, (rh, rw)
        key_region_size: optional, if None, key_region_size=region_size
        auto_pad: required to be true if the input sizes are not divisible by the region_size
    Return:
        output: (B, C, H, W) tensor
        attn: (bs, nhead, q_nregion, reg_size, topk*kv_region_size) attention matrix
    """
    kv_region_size = kv_region_size or region_size
    bs, nhead, q_nregion, topk = region_graph.size()
    
    # Auto pad to deal with any input size 
    q_pad_b, q_pad_r, kv_pad_b, kv_pad_r = 0, 0, 0, 0
    if auto_pad:
        _, _, Hq, Wq = query.size()
        q_pad_b = (region_size[0] - Hq % region_size[0]) % region_size[0]
        q_pad_r = (region_size[1] - Wq % region_size[1]) % region_size[1]
        if (q_pad_b > 0 or q_pad_r > 0):
            query = F.pad(query, (0, q_pad_r, 0, q_pad_b)) # zero padding

        _, _, Hk, Wk = key.size()
        kv_pad_b = (kv_region_size[0] - Hk % kv_region_size[0]) % kv_region_size[0]
        kv_pad_r = (kv_region_size[1] - Wk % kv_region_size[1]) % kv_region_size[1]
        if (kv_pad_r > 0 or kv_pad_b > 0):
            key = F.pad(key, (0, kv_pad_r, 0, kv_pad_b)) # zero padding
            value = F.pad(value, (0, kv_pad_r, 0, kv_pad_b)) # zero padding
    
    # to sequence format, i.e. (bs, nhead, nregion, reg_size, head_dim)
    query, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=nhead)
    key, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=nhead)
    value, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=nhead)

    # gather key and values.
    # TODO: is seperate gathering slower than fused one (our old version) ?
    # torch.gather does not support broadcasting, hence we do it manually
    bs, nhead, kv_nregion, kv_region_size, head_dim = key.size()
    broadcasted_region_graph = region_graph.view(bs, nhead, q_nregion, topk, 1, 1).\
        expand(-1, -1, -1, -1, kv_region_size, head_dim)
    key_g = torch.gather(key.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\
        expand(-1, -1, query.size(2), -1, -1, -1), dim=3,
        index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim)
    value_g = torch.gather(value.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\
        expand(-1, -1, query.size(2), -1, -1, -1), dim=3,
        index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim)
    
    # token-to-token attention
    # (bs, nhead, q_nregion, reg_size, head_dim) @ (bs, nhead, q_nregion, head_dim, topk*kv_region_size)
    # -> (bs, nhead, q_nregion, reg_size, topk*kv_region_size)
    # TODO: mask padding region
    attn = (query * scale) @ key_g.flatten(-3, -2).transpose(-1, -2)
    attn = torch.softmax(attn, dim=-1)
    # (bs, nhead, q_nregion, reg_size, topk*kv_region_size) @ (bs, nhead, q_nregion, topk*kv_region_size, head_dim)
    # -> (bs, nhead, q_nregion, reg_size, head_dim)
    output = attn @ value_g.flatten(-3, -2)

    # to BCHW format
    output = _seq2grid(output, region_h=q_region_h, region_w=q_region_w, region_size=region_size)

    # remove paddings if needed
    if auto_pad and (q_pad_b > 0 or q_pad_r > 0):
        output = output[:, :, :Hq, :Wq]
    return output, attn

CAA

https://export.arxiv.org/pdf/2403.06258

import torch.nn as nn  
  
def autopad(kernel_size, padding=None, dilation=1):  
    """  
    根据kernel_size, padding和dilation自动计算padding。  
    如果dilation大于1,则先调整kernel_size。  
    如果padding未指定,则使用kernel_size的一半作为padding(对于每个维度)。  
    """  
    if dilation > 1:  
        kernel_size = dilation * (kernel_size - 1) + 1 if isinstance(kernel_size, int) else [dilation * (x - 1) + 1 for x in kernel_size]  
    if padding is None:  
        padding = kernel_size // 2 if isinstance(kernel_size, int) else [x // 2 for x in kernel_size]  
    return padding  
  
  
class ConvLayer(nn.Module):  
    """  
    标准的卷积层,包括卷积、批归一化和可选的激活函数。  
    """  
    default_activation = nn.SiLU()  # 默认激活函数  
  
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=None, groups=1, dilation=1, activation=True):  
        """  
        初始化卷积层。  
          
        参数:  
        - in_channels: 输入通道数  
        - out_channels: 输出通道数  
        - kernel_size: 卷积核大小  
        - stride: 卷积步长  
        - padding: 填充大小,如果为None则自动计算  
        - groups: 分组卷积的组数  
        - dilation: 空洞卷积的扩张率  
        - activation: 是否应用激活函数,或者指定一个激活函数  
        """  
        super().__init__()  
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, autopad(kernel_size, padding, dilation), groups=groups, dilation=dilation, bias=False)  
        self.bn = nn.BatchNorm2d(out_channels)  
        self.activation = self.default_activation if activation is True else activation if isinstance(activation, nn.Module) else nn.Identity()  
  
    def forward(self, x):  
        """  
        对输入应用卷积、批归一化和激活函数。  
        """  
        return self.activation(self.bn(self.conv(x)))  
  
    def forward_fuse(self, x):  
        """  
        (注意:此方法名可能有些误导,因为它并没有执行融合操作,只是跳过了批归一化)  
        对输入应用卷积和激活函数,跳过批归一化。  
        """  
        return self.activation(self.conv(x))  
  
  
class CAA(nn.Module):  
    def __init__(self, channels, h_kernel_size=11, v_kernel_size=11):  
        """  
        跨维度注意力聚合模块。  
          
        参数:  
        - channels: 输入和输出的通道数  
        - h_kernel_size: 水平卷积核大小  
        - v_kernel_size: 垂直卷积核大小  
        """  
        super().__init__()  
        self.avg_pool = nn.AvgPool2d(7, stride=1, padding=3)  # 使用padding=3来保持输出尺寸  
        self.conv1 = ConvLayer(channels, channels)  
        self.h_conv = nn.Conv2d(channels, channels, (1, h_kernel_size), stride=1, padding=(0, h_kernel_size // 2), groups=channels)  
        self.v_conv = nn.Conv2d(channels, channels, (v_kernel_size, 1), stride=1, padding=(v_kernel_size // 2, 0), groups=channels)  
        self.conv2 = ConvLayer(channels, channels)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        """  
        计算注意力权重并将其应用于输入特征图。  
        """  
        attn_factor = self.sigmoid(self.conv2(self.v_conv(self.h_conv(self.conv1(self.avg_pool(x))))))  
        return attn_factor * x

CBAM

CBAM(Convolutional Block Attention Module)是一种结合了通道注意力和空间注意力的注意力机制模块,旨在提高卷积神经网络(CNN)的表征能力和泛化能力。以下是对CBAM的详细解析:

一、概述

CBAM由Sanghyun Woo等人在2018年的论文中提出,并在多个计算机视觉任务中取得了显著的效果。CBAM通过动态地调整网络中每个特征通道的权重,以及特征图在空间维度上的权重,使网络更加关注重要的特征,从而提升模型的性能。

二、核心原理

CBAM包含两个关键部分:通道注意力模块(Channel Attention Module, CAM)和空间注意力模块(Spatial Attention Module, SAM)。这两个模块顺序执行,首先计算通道注意力,然后将其应用于特征图上,接着计算空间注意力,并对特征图进行进一步的加权。

1. 通道注意力模块(CAM)

通道注意力模块通过学习每个通道的重要性来增强特征图的表达能力。其主要步骤如下:

  • 全局平均池化和全局最大池化:对于输入特征图,首先对每个通道执行全局平均池化和全局最大池化操作,计算每个通道上的平均特征值和最大特征值。这两个操作能够捕捉到不同尺度的特征信息。
  • 特征向量处理:将全局平均池化和全局最大池化后的特征向量输入到一个共享的全连接层(或称为多层感知机,MLP)中。这个全连接层用于学习每个通道的注意力权重。
  • Sigmoid激活:对全连接层的输出应用Sigmoid激活函数,生成通道注意力权重。这些权重位于0到1之间,用于表示每个通道的重要性。
  • 注意力加权:将得到的通道注意力权重与原始特征图的每个通道相乘,得到注意力加权后的通道特征图。
2. 空间注意力模块(SAM)

空间注意力模块通过学习每个空间位置的重要性来增强特征图的局部表达能力。其主要步骤如下:

  • 基于通道的全局池化:对输入特征图进行基于通道的全局最大池化和全局平均池化操作,生成两个包含空间信息的特征图。
  • 特征图拼接与卷积:将两个池化后的特征图在通道维度上进行拼接,然后通过一个卷积层进行处理,以生成空间注意力权重。这个卷积层通常使用较小的卷积核(如7x7),以捕捉局部空间信息。
  • Sigmoid激活:对卷积层的输出应用Sigmoid激活函数,生成空间注意力权重。
  • 注意力加权:将得到的空间注意力权重与原始特征图进行乘法操作,对特征图的每个空间位置进行加权。

三、优势与应用

CBAM的优势在于它同时考虑了特征图中的通道和空间信息,将这两种关注方式结合在一起,简化了网络结构,且计算成本相对较低。CBAM已被广泛应用于多种计算机视觉任务中,如图像分类、目标检测、语义分割等,并取得了显著的性能提升。

四、结论

CBAM作为一种高效的注意力机制模块,通过引入通道注意力和空间注意力,显著提升了卷积神经网络的性能。随着深度学习技术的不断发展,CBAM及其改进版本有望在更多领域得到应用和推广。

import torch  
from torch import nn  
  
class ChannelAttention(nn.Module):  
    """  
    通道注意力机制模块,使用Squeeze-and-Excitation (SE) 结构。  
    """  
    def __init__(self, channels, reduction=16):  
        """  
        初始化通道注意力模块。  
          
        参数:  
        - channels: 输入特征图的通道数。  
        - reduction: 压缩通道数的比例。  
        """  
        super().__init__()  
        self.maxpool = nn.AdaptiveMaxPool2d(1)  # 全局最大池化  
        self.avgpool = nn.AdaptiveAvgPool2d(1)  # 全局平均池化  
        self.se_block = nn.Sequential(  # SE结构  
            nn.Conv2d(channels, channels // reduction, 1, bias=False),  
            nn.ReLU(inplace=True),  
            nn.Conv2d(channels // reduction, channels, 1, bias=False)  
        )  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        """  
        前向传播,计算通道注意力权重。  
        """  
        max_pooled = self.maxpool(x)  
        avg_pooled = self.avgpool(x)  
        max_out = self.se_block(max_pooled)  
        avg_out = self.se_block(avg_pooled)  
        output = self.sigmoid(max_out + avg_out)  
        return output  
  
  
class SpatialAttention(nn.Module):  
    """  
    空间注意力机制模块。  
    """  
    def __init__(self, kernel_size=7):  
        """  
        初始化空间注意力模块。  
          
        参数:  
        - kernel_size: 卷积核大小,用于空间注意力权重计算。  
        """  
        super().__init__()  
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        """  
        前向传播,计算空间注意力权重。  
        """  
        max_result, _ = torch.max(x, dim=1, keepdim=True)  
        avg_result = torch.mean(x, dim=1, keepdim=True)  
        result = torch.cat([max_result, avg_result], dim=1)  
        output = self.conv(result)  
        output = self.sigmoid(output)  
        return output  
  
  
class CBAM(nn.Module):  
    """  
    CBAM注意力机制模块,结合了通道注意力和空间注意力。  
    """  
    def __init__(self, channels=512, reduction=16, kernel_size=7):  
        """  
        初始化CBAM模块。  
          
        参数:  
        - channels: 输入特征图的通道数。  
        - reduction: 通道注意力中压缩通道数的比例。  
        - kernel_size: 空间注意力中卷积核的大小。  
        """  
        super().__init__()  
        self.channel_attention = ChannelAttention(channels=channels, reduction=reduction)  
        self.spatial_attention = SpatialAttention(kernel_size=kernel_size)  
  
    def forward(self, x):  
        """  
        前向传播,依次应用通道注意力和空间注意力。  
        """  
        out = x * self.channel_attention(x)  # 应用通道注意力  
        out = out * self.spatial_attention(out)  # 应用空间注意力  
        return out  
  
  
if __name__ == '__main__':  
    # 示例用法  
    input_tensor = torch.randn(64, 512, 20, 20)  # 假设输入特征图的形状为[batch_size, channels, height, width]  
    cbam_module = CBAM(channels=512, reduction=16, kernel_size=7)  
    output = cbam_module(input_tensor)  
    print(output.shape)  # 输出应与输入形状相同

CloAttention

https://arxiv.org/pdf/2303.17803.pdf

import torch
import torch.nn as nn

class MemoryEfficientSwish(nn.Module):
    # 节省内存的Swish 不采用自动求导(自己写前向传播和反向传播) 更高效
    class F(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            # save_for_backward会保留x的全部信息(一个完整的外挂Autograd Function的Variable),
            # 并提供避免in-place操作导致的input在backward被修改的情况.
            # in-place操作指不通过中间变量计算的变量间的操作。
            ctx.save_for_backward(x)
            return x * torch.sigmoid(x)

        @staticmethod
        def backward(ctx, grad_output):
            # 此处saved_tensors[0] 作用同上文 save_for_backward
            x = ctx.saved_tensors[0]
            sx = torch.sigmoid(x)
            # 返回该激活函数求导之后的结果 求导过程见上文
            return grad_output * (sx * (1 + x * (1 - sx)))

    def forward(self, x): # 应用前向传播方法
        return self.F.apply(x)

class AttnMap(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.act_block = nn.Sequential(
                            nn.Conv2d(dim, dim, 1, 1, 0),
                            MemoryEfficientSwish(),
                            nn.Conv2d(dim, dim, 1, 1, 0)
                         )
    def forward(self, x):
        return self.act_block(x)

class EfficientAttention(nn.Module):
    def __init__(self, dim, num_heads=8, group_split=[4, 4], kernel_sizes=[5], window_size=4, 
                 attn_drop=0., proj_drop=0., qkv_bias=True):
        super().__init__()
        assert sum(group_split) == num_heads
        assert len(kernel_sizes) + 1 == len(group_split)
        self.dim = dim
        self.num_heads = num_heads
        self.dim_head = dim // num_heads
        self.scalor = self.dim_head ** -0.5
        self.kernel_sizes = kernel_sizes
        self.window_size = window_size
        self.group_split = group_split
        convs = []
        act_blocks = []
        qkvs = []
        #projs = []
        for i in range(len(kernel_sizes)):
            kernel_size = kernel_sizes[i]
            group_head = group_split[i]
            if group_head == 0:
                continue
            convs.append(nn.Conv2d(3*self.dim_head*group_head, 3*self.dim_head*group_head, kernel_size,
                         1, kernel_size//2, groups=3*self.dim_head*group_head))
            act_blocks.append(AttnMap(self.dim_head*group_head))
            qkvs.append(nn.Conv2d(dim, 3*group_head*self.dim_head, 1, 1, 0, bias=qkv_bias))
            #projs.append(nn.Linear(group_head*self.dim_head, group_head*self.dim_head, bias=qkv_bias))
        if group_split[-1] != 0:
            self.global_q = nn.Conv2d(dim, group_split[-1]*self.dim_head, 1, 1, 0, bias=qkv_bias)
            self.global_kv = nn.Conv2d(dim, group_split[-1]*self.dim_head*2, 1, 1, 0, bias=qkv_bias)
            #self.global_proj = nn.Linear(group_split[-1]*self.dim_head, group_split[-1]*self.dim_head, bias=qkv_bias)
            self.avgpool = nn.AvgPool2d(window_size, window_size) if window_size!=1 else nn.Identity()

        self.convs = nn.ModuleList(convs)
        self.act_blocks = nn.ModuleList(act_blocks)
        self.qkvs = nn.ModuleList(qkvs)
        self.proj = nn.Conv2d(dim, dim, 1, 1, 0, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def high_fre_attntion(self, x: torch.Tensor, to_qkv: nn.Module, mixer: nn.Module, attn_block: nn.Module):
        '''
        x: (b c h w)
        '''
        b, c, h, w = x.size()
        qkv = to_qkv(x) #(b (3 m d) h w)
        qkv = mixer(qkv).reshape(b, 3, -1, h, w).transpose(0, 1).contiguous() #(3 b (m d) h w)
        q, k, v = qkv #(b (m d) h w)
        attn = attn_block(q.mul(k)).mul(self.scalor)
        attn = self.attn_drop(torch.tanh(attn))
        res = attn.mul(v) #(b (m d) h w)
        return res
        
    def low_fre_attention(self, x : torch.Tensor, to_q: nn.Module, to_kv: nn.Module, avgpool: nn.Module):
        '''
        x: (b c h w)
        '''
        b, c, h, w = x.size()
        
        q = to_q(x).reshape(b, -1, self.dim_head, h*w).transpose(-1, -2).contiguous() #(b m (h w) d)
        kv = avgpool(x) #(b c h w)
        kv = to_kv(kv).view(b, 2, -1, self.dim_head, (h*w)//(self.window_size**2)).permute(1, 0, 2, 4, 3).contiguous() #(2 b m (H W) d)
        k, v = kv #(b m (H W) d)
        attn = self.scalor * q @ k.transpose(-1, -2) #(b m (h w) (H W))
        attn = self.attn_drop(attn.softmax(dim=-1))
        res = attn @ v #(b m (h w) d)
        res = res.transpose(2, 3).reshape(b, -1, h, w).contiguous()
        return res

    def forward(self, x: torch.Tensor):
        '''
        x: (b c h w)
        '''
        res = []
        for i in range(len(self.kernel_sizes)):
            if self.group_split[i] == 0:
                continue
            res.append(self.high_fre_attntion(x, self.qkvs[i], self.convs[i], self.act_blocks[i]))
        if self.group_split[-1] != 0:
            res.append(self.low_fre_attention(x, self.global_q, self.global_kv, self.avgpool))
        return self.proj_drop(self.proj(torch.cat(res, dim=1)))

CrissCrossAttention

https://arxiv.org/pdf/1811.11721

'''
This code is borrowed from Serge-weihao/CCNet-Pure-Pytorch
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax


def INF(B,H,W):
     return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)


class CrissCrossAttention(nn.Module):
    """ Criss-Cross Attention Module"""
    def __init__(self, in_dim):
        super(CrissCrossAttention,self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1))


    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))

        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
        #print(concate)
        #print(att_H) 
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x



if __name__ == '__main__':
    model = CrissCrossAttention(64)
    x = torch.randn(2, 64, 5, 6)
    out = model(x)
    print(out.shape)

CoordAttention

import torch
import torch.nn as nn
import torch.nn.functional as F


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordinateAttention(nn.Module):
    def __init__(self, inp, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x

        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

if __name__ == '__main__':
    input = torch.randn(64, 512, 20, 20)
    pna = CoordinateAttention(inp=512)
    output = pna(input)
    print(output.shape)

CoTAttention

import numpy as np
import torch
from torch import flatten, nn
from torch.nn import init
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import functional as F


class CoTAttention(nn.Module):

    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        factor = 4
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),
            nn.ReLU(),
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
        )

    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # bs,c,h,w
        v = self.value_embed(x).view(bs, c, -1)  # bs,c,h,w

        y = torch.cat([k1, x], dim=1)  # bs,2c,h,w
        att = self.attention_embed(y)  # bs,c*k*k,h,w
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*w
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)

        return k1 + k2


if __name__ == '__main__':
    input = torch.randn(64, 512, 20, 20)
    cot = CoTAttention(dim=512, kernel_size=3)
    output = cot(input)
    print(output.shape)

CPCA

import torch
import torch.nn as nn
import torch.nn.functional as F

class CPCA_ChannelAttention(nn.Module):

    def __init__(self, input_channels, internal_neurons):
        super(CPCA_ChannelAttention, self).__init__()
        self.fc1 = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True)
        self.fc2 = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True)
        self.input_channels = input_channels

    def forward(self, inputs):
        x1 = F.adaptive_avg_pool2d(inputs, output_size=(1, 1))
        x1 = self.fc1(x1)
        x1 = F.relu(x1, inplace=True)
        x1 = self.fc2(x1)
        x1 = torch.sigmoid(x1)
        x2 = F.adaptive_max_pool2d(inputs, output_size=(1, 1))
        x2 = self.fc1(x2)
        x2 = F.relu(x2, inplace=True)
        x2 = self.fc2(x2)
        x2 = torch.sigmoid(x2)
        x = x1 + x2
        x = x.view(-1, self.input_channels, 1, 1)
        return inputs * x

class CPCA(nn.Module):
    def __init__(self, channels, channelAttention_reduce=4):
        super().__init__()

        self.ca = CPCA_ChannelAttention(input_channels=channels, internal_neurons=channels // channelAttention_reduce)
        self.dconv5_5 = nn.Conv2d(channels,channels,kernel_size=5,padding=2,groups=channels)
        self.dconv1_7 = nn.Conv2d(channels,channels,kernel_size=(1,7),padding=(0,3),groups=channels)
        self.dconv7_1 = nn.Conv2d(channels,channels,kernel_size=(7,1),padding=(3,0),groups=channels)
        self.dconv1_11 = nn.Conv2d(channels,channels,kernel_size=(1,11),padding=(0,5),groups=channels)
        self.dconv11_1 = nn.Conv2d(channels,channels,kernel_size=(11,1),padding=(5,0),groups=channels)
        self.dconv1_21 = nn.Conv2d(channels,channels,kernel_size=(1,21),padding=(0,10),groups=channels)
        self.dconv21_1 = nn.Conv2d(channels,channels,kernel_size=(21,1),padding=(10,0),groups=channels)
        self.conv = nn.Conv2d(channels,channels,kernel_size=(1,1),padding=0)
        self.act = nn.GELU()

    def forward(self, inputs):
        #   Global Perceptron
        inputs = self.conv(inputs)
        inputs = self.act(inputs)
        
        inputs = self.ca(inputs)

        x_init = self.dconv5_5(inputs)
        x_1 = self.dconv1_7(x_init)
        x_1 = self.dconv7_1(x_1)
        x_2 = self.dconv1_11(x_init)
        x_2 = self.dconv11_1(x_2)
        x_3 = self.dconv1_21(x_init)
        x_3 = self.dconv21_1(x_3)
        x = x_1 + x_2 + x_3 + x_init
        spatial_att = self.conv(x)
        out = spatial_att * inputs
        out = self.conv(out)
        return out

DAttention

import torch, einops
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from timm.models.layers import trunc_normal_

class LayerNormProxy(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = einops.rearrange(x, 'b c h w -> b h w c')
        x = self.norm(x)
        return einops.rearrange(x, 'b h w c -> b c h w')

class DAttention(nn.Module):
    # Vision Transformer with Deformable Attention CVPR2022
    # fixed_pe=True need adujust 640x640
    def __init__(
        self, channel, q_size, n_heads=8, n_groups=4,
        attn_drop=0.0, proj_drop=0.0, stride=1, 
        offset_range_factor=4, use_pe=True, dwc_pe=True,
        no_off=False, fixed_pe=False, ksize=3, log_cpb=False, kv_size=None
    ):
        super().__init__()
        n_head_channels = channel // n_heads
        self.dwc_pe = dwc_pe
        self.n_head_channels = n_head_channels
        self.scale = self.n_head_channels ** -0.5
        self.n_heads = n_heads
        self.q_h, self.q_w = q_size
        # self.kv_h, self.kv_w = kv_size
        self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride
        self.nc = n_head_channels * n_heads
        self.n_groups = n_groups
        self.n_group_channels = self.nc // self.n_groups
        self.n_group_heads = self.n_heads // self.n_groups
        self.use_pe = use_pe
        self.fixed_pe = fixed_pe
        self.no_off = no_off
        self.offset_range_factor = offset_range_factor
        self.ksize = ksize
        self.log_cpb = log_cpb
        self.stride = stride
        kk = self.ksize
        pad_size = kk // 2 if kk != stride else 0

        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
            LayerNormProxy(self.n_group_channels),
            nn.GELU(),
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
        )
        if self.no_off:
            for m in self.conv_offset.parameters():
                m.requires_grad_(False)

        self.proj_q = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_k = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_v = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_out = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_drop = nn.Dropout(proj_drop, inplace=True)
        self.attn_drop = nn.Dropout(attn_drop, inplace=True)

        if self.use_pe and not self.no_off:
            if self.dwc_pe:
                self.rpe_table = nn.Conv2d(
                    self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)
            elif self.fixed_pe:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
                )
                trunc_normal_(self.rpe_table, std=0.01)
            elif self.log_cpb:
                # Borrowed from Swin-V2
                self.rpe_table = nn.Sequential(
                    nn.Linear(2, 32, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Linear(32, self.n_group_heads, bias=False)
                )
            else:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
                )
                trunc_normal_(self.rpe_table, std=0.01)
        else:
            self.rpe_table = None

    @torch.no_grad()
    def _get_ref_points(self, H_key, W_key, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2

        return ref
    
    @torch.no_grad()
    def _get_q_grid(self, H, W, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.arange(0, H, dtype=dtype, device=device),
            torch.arange(0, W, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2

        return ref

    def forward(self, x):

        B, C, H, W = x.size()
        dtype, device = x.dtype, x.device

        q = self.proj_q(x)
        q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
        offset = self.conv_offset(q_off).contiguous()  # B * g 2 Hg Wg
        Hk, Wk = offset.size(2), offset.size(3)
        n_sample = Hk * Wk

        if self.offset_range_factor >= 0 and not self.no_off:
            offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, 'b p h w -> b h w p')
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)

        if self.no_off:
            offset = offset.fill_(0.0)

        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            pos = (offset + reference).clamp(-1., +1.)

        if self.no_off:
            x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
            assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
        else:
            pos = pos.type(x.dtype)
            x_sampled = F.grid_sample(
                input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), 
                grid=pos[..., (1, 0)], # y, x -> x, y
                mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg
                

        x_sampled = x_sampled.reshape(B, C, 1, n_sample)

        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)

        attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
        attn = attn.mul(self.scale)

        if self.use_pe and (not self.no_off):

            if self.dwc_pe:
                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels, H * W)
            elif self.fixed_pe:
                rpe_table = self.rpe_table
                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
            elif self.log_cpb:
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(4.0) # d_y, d_x [-8, +8]
                displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
                attn_bias = self.rpe_table(displacement) # B * g, H * W, n_sample, h_g
                attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
            else:
                rpe_table = self.rpe_table
                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)
                attn_bias = F.grid_sample(
                    input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads, g=self.n_groups),
                    grid=displacement[..., (1, 0)],
                    mode='bilinear', align_corners=True) # B * g, h_g, HW, Ns

                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
                attn = attn + attn_bias

        attn = F.softmax(attn, dim=2)
        attn = self.attn_drop(attn)

        out = torch.einsum('b m n, b c n -> b c m', attn, v)

        if self.use_pe and self.dwc_pe:
            out = out + residual_lepe
        out = out.reshape(B, C, H, W)

        y = self.proj_drop(self.proj_out(out))

        return y
Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐