本改进已同步到YOLO-Magic框架!

在这里插入图片描述

摘要:摘要。基于Transformer的恶劣天气图像修复方法取得了显著进展。大多数方法通过沿通道维度或在空间上固定范围的块内使用自注意力,以减少计算负担。然而,这种折中方式在捕获长距离空间特征方面存在局限性。受到恶劣天气导致的退化因素主要引发类似遮挡和亮度变化的观察启发,本文提出了一种高效的直方图Transformer(Histoformer)用于修复受恶劣天气影响的图像。其核心机制是直方图自注意力,该机制根据强度将空间特征排序并分割到不同的bin中,然后在bin之间或每个bin内应用自注意力,有选择性地关注动态范围的空间特征,并共同处理长距离范围内类似退化的像素。为了增强直方图自注意力,我们提出了一种动态范围卷积,使传统卷积能够操作类似像素而非相邻像素。我们还观察到,常见的逐像素损失忽略了输出与真实值之间的线性关联与相关性。因此,我们建议利用皮尔逊相关系数作为损失函数,确保恢复的像素与真实值保持相同的顺序。大量实验验证了我们提出方法的有效性和优越性。我们已在Github上发布了相关代码。

论文地址:https://arxiv.org/pdf/2407.10172


文章速览

Q1 论文试图解决什么问题?
该论文试图解决轻量化图像修复任务中的局部和全局信息捕获不足的问题。传统的轻量级图像恢复网络在计算效率和参数数量上有很大限制,同时在恢复图像时要么无法充分捕捉全局依赖,要么忽视了局部信息。该论文提出了一种名为Reciprocal Attention Mixing Transformer (RAMiT) 的新型轻量级网络框架来解决这一问题。

Q2 这是否是一个新的问题?
图像修复领域的轻量化一直是研究的热点,尽管处理局部与全局依赖之间的平衡不是全新问题,但通过Transformer网络来同时处理这两种依赖是一种较新的方法。

Q3 这篇文章要验证一个什么科学假设?
科学假设是,结合维度互补的双向注意力机制(D-RAMiT)与层级互补注意力(H-RAMi),可以同时捕捉全局和局部信息,从而提升轻量级图像恢复任务的表现。

Q4 论文中提到的解决方案之关键是什么?
论文的关键在于提出了一种双向自注意力机制(D-RAMiT),该机制可以并行地计算空间和通道的自注意力,弥补两者各自的不足。此外,层级互补注意力(H-RAMi)可以弥补多尺度特征的像素级信息丢失。

Q5 论文中的实验是如何设计的?
实验设计涵盖了五个图像修复任务,包括超分辨率、低光增强、去雨、彩色降噪和灰度降噪。使用不同的数据集进行训练和评估,并且设置了对比实验以展示RAMiT相对于其他方法的优越性。

Q6 用于定量评估的数据集是什么?代码有没有开源?
实验使用了DIV2K、Rain13K等公开数据集。代码已经在GitHub上开源,链接为https://github.com/rami0205/RAMiT

Q7 论文中的实验及结果有没有很好地支持需要验证的科学假设?
实验结果表明,RAMiT在多个轻量级图像修复任务上实现了最先进的性能,证明了结合空间和通道自注意力可以提高图像恢复的质量。

Q8 这篇论文到底有什么贡献?
论文的贡献在于提出了一个轻量化图像修复框架,RAMiT结合了维度和层级的注意力机制,能够捕捉全局和局部的依赖关系,从而在保持高效的同时提升了修复精度。

Q9 下一步呢?有什么工作可以继续深入?
下一步可以研究如何进一步优化注意力机制在其他低级视觉任务中的表现,如去噪和去雨等,同时也可以扩展到更多图像处理任务。


在这里插入图片描述


2 源代码

import numbers
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

Conv2d = nn.Conv2d


## Layer Norm
def to_2d(x):
    return rearrange(x, "b c h w -> b (h w c)")


def to_3d(x):
    #    return rearrange(x, 'b c h w -> b c (h w)')
    return rearrange(x, "b c h w -> b (h w) c")


def to_4d(x, h, w):
    #    return rearrange(x, 'b c (h w) -> b c h w',h=h,w=w)
    return rearrange(x, "b (h w) c -> b c h w", h=h, w=w)


class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma + 1e-5)  # * self.weight


class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5)  # * self.weight + self.bias


class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type="WithBias"):
        super(LayerNorm, self).__init__()
        if LayerNorm_type == "BiasFree":
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)


class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim * ffn_expansion_factor)

        self.project_in = Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)

        self.dwconv_5 = Conv2d(
            hidden_features // 4,
            hidden_features // 4,
            kernel_size=5,
            stride=1,
            padding=2,
            groups=hidden_features // 4,
            bias=bias,
        )
        self.dwconv_dilated2_1 = Conv2d(
            hidden_features // 4,
            hidden_features // 4,
            kernel_size=3,
            stride=1,
            padding=2,
            groups=hidden_features // 4,
            bias=bias,
            dilation=2,
        )
        self.p_unshuffle = nn.PixelUnshuffle(2)
        self.p_shuffle = nn.PixelShuffle(2)

        self.project_out = Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x = self.p_shuffle(x)
        x1, x2 = x.chunk(2, dim=1)
        x1 = self.dwconv_5(x1)
        x2 = self.dwconv_dilated2_1(x2)
        x = F.mish(x2) * x1
        x = self.p_unshuffle(x)
        x = self.project_out(x)
        return x


class Attention_histogram(nn.Module):
    def __init__(self, dim, num_heads=4, bias=False, ifBox=True):
        super(Attention_histogram, self).__init__()
        self.factor = num_heads
        self.ifBox = ifBox
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = Conv2d(dim, dim * 5, kernel_size=1, bias=bias)
        self.qkv_dwconv = Conv2d(
            dim * 5,
            dim * 5,
            kernel_size=3,
            stride=1,
            padding=1,
            groups=dim * 5,
            bias=bias,
        )
        self.project_out = Conv2d(dim, dim, kernel_size=1, bias=bias)

    def pad(self, x, factor):
        hw = x.shape[-1]
        t_pad = [0, 0] if hw % factor == 0 else [0, (hw // factor + 1) * factor - hw]
        x = F.pad(x, t_pad, "constant", 0)
        return x, t_pad

    def unpad(self, x, t_pad):
        _, _, hw = x.shape
        return x[:, :, t_pad[0] : hw - t_pad[1]]

    def softmax_1(self, x, dim=-1):
        logit = x.exp()
        logit = logit / (logit.sum(dim, keepdim=True) + 1)
        return logit

    def normalize(self, x):
        mu = x.mean(-2, keepdim=True)
        sigma = x.var(-2, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5)  # * self.weight + self.bias

    def reshape_attn(self, q, k, v, ifBox):
        b, c = q.shape[:2]
        q, t_pad = self.pad(q, self.factor)
        k, t_pad = self.pad(k, self.factor)
        v, t_pad = self.pad(v, self.factor)
        hw = q.shape[-1] // self.factor
        shape_ori = "b (head c) (factor hw)" if ifBox else "b (head c) (hw factor)"
        shape_tar = "b head (c factor) hw"
        q = rearrange(
            q,
            "{} -> {}".format(shape_ori, shape_tar),
            factor=self.factor,
            hw=hw,
            head=self.num_heads,
        )
        k = rearrange(
            k,
            "{} -> {}".format(shape_ori, shape_tar),
            factor=self.factor,
            hw=hw,
            head=self.num_heads,
        )
        v = rearrange(
            v,
            "{} -> {}".format(shape_ori, shape_tar),
            factor=self.factor,
            hw=hw,
            head=self.num_heads,
        )
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = self.softmax_1(attn, dim=-1)
        out = attn @ v
        out = rearrange(
            out,
            "{} -> {}".format(shape_tar, shape_ori),
            factor=self.factor,
            hw=hw,
            b=b,
            head=self.num_heads,
        )
        out = self.unpad(out, t_pad)
        return out

    def forward(self, x):
        b, c, h, w = x.shape
        x_sort, idx_h = x[:, : c // 2].sort(-2)
        x_sort, idx_w = x_sort.sort(-1)
        x[:, : c // 2] = x_sort
        qkv = self.qkv_dwconv(self.qkv(x))
        q1, k1, q2, k2, v = qkv.chunk(5, dim=1)  # b,c,x,x

        v, idx = v.view(b, c, -1).sort(dim=-1)
        q1 = torch.gather(q1.view(b, c, -1), dim=2, index=idx)
        k1 = torch.gather(k1.view(b, c, -1), dim=2, index=idx)
        q2 = torch.gather(q2.view(b, c, -1), dim=2, index=idx)
        k2 = torch.gather(k2.view(b, c, -1), dim=2, index=idx)

        out1 = self.reshape_attn(q1, k1, v, True)
        out2 = self.reshape_attn(q2, k2, v, False)

        out1 = torch.scatter(out1, 2, idx, out1).view(b, c, h, w)
        out2 = torch.scatter(out2, 2, idx, out2).view(b, c, h, w)
        out = out1 * out2
        out = self.project_out(out)
        out_replace = out[:, : c // 2]
        out_replace = torch.scatter(out_replace, -1, idx_w, out_replace)
        out_replace = torch.scatter(out_replace, -2, idx_h, out_replace)
        out[:, : c // 2] = out_replace
        return out


class HistogramTransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=4,
        ffn_expansion_factor=2.5,
        bias=False,
        LayerNorm_type="WithBias",
    ):  ## Other option 'BiasFree'
        super(HistogramTransformerBlock, self).__init__()
        self.attn_g = Attention_histogram(dim, num_heads, bias, True)
        self.norm_g = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

        self.norm_ff1 = LayerNorm(dim, LayerNorm_type)

    def forward(self, x):
        x = x + self.attn_g(self.norm_g(x))
        x_out = x + self.ffn(self.norm_ff1(x))

        return x_out


if __name__ == "__main__":
    input = torch.randn(1, 64, 128, 128)
    transformer_block = HistogramTransformerBlock(64)
    output = transformer_block(input)
    print(input.size())
    print(output.size())

3 添加方式

  1. ultralytics/ultralytics/nn/modules/layers 文件夹下新建 HistogramTransformerBlock.py 文件,将上述源代码粘贴进去,没有layers文件夹的新建一个;
  2. ultralytics/ultralytics/nn/modules/task.py 中导包;
from ultralytics.nn.modules.layers.HTB import HistogramTransformerBlock
  1. ultralytics/ultralytics/nn/modules/task.py 中添加模块名模块;在这里插入图片描述
HistogramTransformerBlock
        elif m in [HistogramTransformerBlock]:
            c1 = ch[f]
            args = [c1, *args[0:]]
  1. 修改 yaml 文件,开始训练。

4 模型 yaml 文件

更多yaml请关注YOLO-Magic框架!

yolov11-backbone-HistogramTransformerBlock
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024]
  s: [0.50, 0.50, 1024]
  m: [0.50, 1.00, 512]
  l: [1.00, 1.00, 512]
  x: [1.00, 1.50, 512]

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 1, HistogramTransformerBlock, []] # ---> You can add your attention module name here
  - [-1, 2, C2PSA, [1024]] # 11

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 14

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 17 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 14], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 20 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 11], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 23 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

yolov11-neck-PPA
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024]
  s: [0.50, 0.50, 1024]
  m: [0.50, 1.00, 512] 
  l: [1.00, 1.00, 512] 
  x: [1.00, 1.50, 512] 

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1, 1, HistogramTransformerBlock, []] # ---> You can add your attention module name here

  - [-2, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 20 (P4/16-medium)
  - [-1, 1, HistogramTransformerBlock, []] # ---> You can add your attention module name here

  - [-2, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 24 (P5/32-large)
  - [-1, 1, HistogramTransformerBlock, []] # ---> You can add your attention module name here

  - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)

yolov11-small-PPA
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024]
  s: [0.50, 0.50, 1024]
  m: [0.50, 1.00, 512]
  l: [1.00, 1.00, 512]
  x: [1.00, 1.50, 512]

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1, 1, HistogramTransformerBlock, []] # ---> You can add your attention module name here

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 20 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 23 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐