基于 RNN(GRU, LSTM)+CNN 的红点位置检测(pytorch)
本项目旨在通过深度学习技术精确识别并输出图片中三条红线的像素位置。我们对比了多种模型结构,包括纯RNN、CNN+RNN和RNN+CNN,并在RNN中引入多头注意力机制、CNN中引入SEAttention以提升性能。实验结果显示,RNN+CNN模型在精确度和性能上表现最佳,有效地识别了红线位置,为图像识别领域提供了一种新的解决方案。
1 项目背景
需要在图片精确识别三跟红线所在的位置,并输出这三个像素的位置。
其中,每跟红线占据不止一个像素,并且像素颜色也并不是饱和度和亮度极高的红黑配色,每个红线放大后可能是这样的。
而我们的目标是精确输出每个红点的位置,需要精确到像素。也就是说,对于每根红线,模型需要输出橙色箭头所指的像素而不是蓝色箭头所指的像素的位置。
之前尝试过纯 RNN 的实验,也试过在 RNN 前用 CNN,给数据带上卷积的信息。在图片长度为1080、低噪声环境时,对比实验的结果如下:
实验 | loss | 完全准确的点 |
---|---|---|
GRU | 129.6641 | 1762.0/9000 (20%) |
LSTM | 249.2053 | 1267.0/9000 (14%) |
CNN+GRU | 1419.5781 | 601.0/9000 (7%) |
CNN+LSTM | 1166.4599 | 762.0/9000 (8%) |
对的,这个方法甚至起到反效果了。问了做过类似尝试的同事,他表示效果其实跟直接使用 RNN 区别不大。
2 数据集
还是之前那个代码合成的数据集数据集,每个数据集规模在15000张图片左右,在没有加入噪音的情况下,每个样本预览如图所示:
加入噪音后,每个样本的预览如下图所示:
图中黑色部分包含比较弱的噪声,并非完全为黑色。
数据集包含两个文件,一个是文件夹,里面包含了jpg压缩的图像数据:
另一个是csv文件,里面包含了每个图像的名字以及3根红线所在的像素的位置。
3 思路
之前 CNN+RNN 的思路是把 CNN 作为一个特征提取器,RNN 作为决策模型。这次主要是想看看直接用 CNN 做决策会比 RNN 强多少,因为其实 CNN 在这类任务上的优势应该会大很多。也就是说把RNN当作一个特征提取器处理图片数据,再用CNN找到这三个点的位置。按照这个思路,RNN+CNN 的处理流程如下:
然后再在模型上加一点Attention:
4 实验结果
实验 | train loss | val loss | test loss | test 完全准确样本 | 点1平均偏移量 | 点2平均偏移量 | 点3平均偏移量 |
---|---|---|---|---|---|---|---|
GRU | 17.1150 | 16.2752 | 233.5694 | 536.0/4500 (12%) | 3.3181 | 3.0701 | 3.3957 |
LSTM | 378.7690 | 47.6191 | 367.7041 | 499.0/4500 (11%) | 4.2166 | 3.6437 | 4.0777 |
CNN | 6.6049 | 13.6372 | 231.4501 | 650.0/4500 (14%) | 2.1816 | 3.0884 | 3.9680 |
CNN+RNN | 5.3883 | 6.6833 | 76.0979 | 821.0/4500 (18%) | 1.8977 | 2.5229 | 1.8854 |
RNN+CNN | 2.6558 | 1.7714 | 28.4280 | 1318.0/4500 (29%) | 1.4926 | 1.3679 | 1.5234 |
RNN+CNN+Attention | 6.5938 | 42.4060 | 41.9453 | 1264.0/4500 (28%) | 1.5860 | 1.5557 | 1.8804 |
Multi-Head Attention + RNN | 174.5019 | 18.1041 | 149.0297 | 645.0/4500 (14%) | 2.6598 | 3.2243 | 2.4309 |
GRU那个妥妥过拟合,CNN 做决策效果确实暴打之前的 RNN,只能说卷积还是适合图像类的任务,RNN 这种针对序列信息的可能效果还是有限。画出前6个模型预测中三个点的偏移量,可以看出 RNN+CNN 模型的预测结果的偏差大多集中于0和1这块:
关于多头注意力机制在 RNN 中的效果以及注意力机制在 CNN 中的效果,我也做了实验,事实证明 CNN 中的 Attention 并不合适,起了反效果:
实验 | train loss | val loss | test loss | test 完全准确样本 | 点1平均偏移量 | 点2平均偏移量 | 点3平均偏移量 |
---|---|---|---|---|---|---|---|
RNN+CNN | 2.6558 | 1.7714 | 28.4280 | 1318.0/4500 (29%) | 1.4926 | 1.3679 | 1.5234 |
RNN+CNN+Attention | 6.5938 | 42.4060 | 41.9453 | 1264.0/4500 (28%) | 1.5860 | 1.5557 | 1.8804 |
RNN(Attention)+CNN | 3.3199 | 3.7312 | 22.7644 | 1498.0/4500 (33%) | 1.4721 | 1.2609 | 1.2932 |
RNN+CNN(Attention) | 4.2012 | 4.5143 | 65.8752 | 1039.0/4500 (23%) | 1.5869 | 2.3705 | 1.9389 |
从上图也能看出,RNN(Attention)+CNN 的效果明显优于其他两种方案。
关于位置信息,因为在之前的实验中,对 RNN 嵌入位置信息能够显著提高模型的效果,但是在该问题中,效果不佳。这意味着位置信息其实对 CNN 的决策起到非常大的干扰作用。
实验 | train loss | val loss | test loss | test 完全准确样本 | 点1平均偏移量 | 点2平均偏移量 | 点3平均偏移量 |
---|---|---|---|---|---|---|---|
RNN+CNN+Attention+Position | 11.9669 | 88.9042 | 103.9887 | 739.0/4500 (16%) | 2.4452 | 2.3939 | 2.3833 |
RNN+CNN+Attention+learnable embedding | 19.2102 | 23.4937 | 223.7447 | 473.0/4500 (11%) | 2.9559 | 3.0082 | 3.6864 |
RNN+CNN+Attention+learnable embedding with position | 21.5659 | 25.1544 | 170.9156 | 677.0/4500 (15%) | 2.3320 | 2.6873 | 2.9070 |
上表中 Position 代表采取使用 transformer 中的 sin cos 的位置编码,learnable embedding 意味着直接把 [0,seq_length] 的转化为可学习的embedding,learnable embedding with position 表示在 learnable embedding 中采用 sin cos 的位置编码作为初始化的参数。
从结果来看,无论是 transformer 的位置编码还是 learnable embedding 都没有提升原来模型表现。
5 代码
GRU+CNN+Attention
import torch
import torch.nn as nn
class Config(object):
def __init__(self, device, csv_file, img_dir, width, input_size):
self.device = device
self.model_name = 'GRU_CNN_Attention'
self.input_size = input_size
self.hidden_size = 128
self.num_layers = 2
self.epoch_number = 150
self.batch_size = 32
self.learn_rate = 0.0002
self.csv_file = csv_file
self.img_dir = img_dir
self.width = width
class GRU_CNN(nn.Module):
def __init__(self, config):
super(GRU_CNN, self).__init__()
self.hidden_size = config.hidden_size
self.num_layers = config.num_layers
self.device = config.device
self.sequence_length = config.width
self.channels = config.input_size
self.gru = nn.GRU(input_size=self.channels, hidden_size=self.hidden_size, num_layers=self.num_layers,
batch_first=True, bidirectional=True, dropout=0.6)
self.attention = nn.MultiheadAttention(embed_dim=2 * self.hidden_size, num_heads=4, batch_first=True)
self.fc = nn.Linear(2 * self.hidden_size, 4)
self.conv1 = nn.Conv2d(4 + self.channels, 32, kernel_size=(1, 3), stride=1, padding=(0, 1))
self.se1 = SEAttention(32)
self.relu = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
self.conv2 = nn.Conv2d(32, 64, kernel_size=(1, 3), stride=1, padding=(0, 1))
self.se2 = SEAttention(64)
self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
self.conv3 = nn.Conv2d(64, 128, kernel_size=(1, 3), stride=1, padding=(0, 1))
self.se3 = SEAttention(128)
self.pool3 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
self.fc1 = nn.Linear(128 * (self.sequence_length // 8), 128)
self.fc2 = nn.Linear(128, 3)
def forward(self, x):
rnn_x = x.squeeze(2).permute(0, 2, 1)
# x = x + self.pos_encoding[:, :x.size(1), :].to(x.device)
h0 = torch.zeros(self.num_layers * 2, rnn_x.size(0), self.hidden_size).to(x.device)
gru_output, _ = self.gru(rnn_x, h0) # batch_size, sequence_length, 2 * hidden_size
context_vector, _ = self.attention(gru_output, gru_output, gru_output) # batch_size, sequence_length, 2 * hidden_size
gru_output_fc = self.fc(context_vector) # batch_size, sequence_length, 3
gru_output_fc = gru_output_fc.transpose(1, 2).unsqueeze(2) # batch_size, 3, 1, sequence_length
x = torch.cat((x, gru_output_fc), dim=1)
x = self.pool1(self.se1(self.relu(self.conv1(x))))
x = self.pool2(self.se2(self.relu(self.conv2(x))))
x = self.pool3(self.se3(self.relu(self.conv3(x))))
x = x.view(-1, 128 * (self.sequence_length // 8))
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
class SEAttention(nn.Module):
def __init__(self, channel, reduction=16):
super(SEAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
GRU+CNN
import torch
import torch.nn as nn
class Config(object):
def __init__(self, device, csv_file, img_dir, width, input_size):
self.device = device
self.model_name = 'GRU_CNN'
self.input_size = input_size
self.hidden_size = 128
self.num_layers = 2
self.epoch_number = 100
self.batch_size = 32
self.learn_rate = 0.001
self.csv_file = csv_file
self.img_dir = img_dir
self.width = width
class GRU_CNN(nn.Module):
def __init__(self, config):
super(GRU_CNN, self).__init__()
self.hidden_size = config.hidden_size
self.num_layers = config.num_layers
self.device = config.device
self.sequence_length = config.width
self.channels = config.input_size
self.gru = nn.GRU(input_size=self.channels, hidden_size=self.hidden_size, num_layers=self.num_layers,
batch_first=True, bidirectional=True, dropout=0.6)
self.fc = nn.Linear(2 * self.hidden_size, 3)
self.conv1 = nn.Conv2d(3 + self.channels, 32, kernel_size=(1, 3), stride=1, padding=(0, 1))
self.relu = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
self.conv2 = nn.Conv2d(32, 64, kernel_size=(1, 3), stride=1, padding=(0, 1))
self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
self.conv3 = nn.Conv2d(64, 128, kernel_size=(1, 3), stride=1, padding=(0, 1))
self.pool3 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
self.fc1 = nn.Linear(128 * (self.sequence_length // 8), 128)
self.fc2 = nn.Linear(128, 3)
def forward(self, x):
rnn_x = x.squeeze(2).permute(0, 2, 1)
# x = x + self.pos_encoding[:, :x.size(1), :].to(x.device)
h0 = torch.zeros(self.num_layers * 2, rnn_x.size(0), self.hidden_size).to(x.device)
gru_output, _ = self.gru(rnn_x, h0) # batch_size, sequence_length, 2 * hidden_size
gru_output_fc = self.fc(gru_output) # batch_size, sequence_length, 3
gru_output_fc = gru_output_fc.transpose(1, 2).unsqueeze(2) # batch_size, 3, 1, sequence_length
x = torch.cat((x, gru_output_fc), dim=1)
x = self.pool1(self.relu(self.conv1(x)))
x = self.pool2(self.relu(self.conv2(x)))
x = self.pool3(self.relu(self.conv3(x)))
x = x.view(-1, 128 * (self.sequence_length // 8))
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
learnable embedding 与 transformer 编码的结合:
class GRU_CNN(nn.Module):
def __init__(self, config):
super(GRU_CNN, self).__init__()
self.hidden_size = config.hidden_size
self.num_layers = config.num_layers
self.device = config.device
self.sequence_length = config.width
self.channels = config.input_size
self.gru = nn.GRU(input_size=self.channels, hidden_size=self.hidden_size, num_layers=self.num_layers,
batch_first=True, bidirectional=True, dropout=0.6)
self.attention = nn.MultiheadAttention(embed_dim=2 * self.hidden_size, num_heads=4, batch_first=True)
self.fc = nn.Linear(2 * self.hidden_size, 4)
self.conv1 = nn.Conv2d(4 + self.channels, 32, kernel_size=(1, 3), stride=1, padding=(0, 1))
self.relu = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
self.conv2 = nn.Conv2d(32, 64, kernel_size=(1, 3), stride=1, padding=(0, 1))
self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
self.conv3 = nn.Conv2d(64, 128, kernel_size=(1, 3), stride=1, padding=(0, 1))
self.pool3 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
self.fc1 = nn.Linear(128 * (self.sequence_length // 8), 128)
self.fc2 = nn.Linear(128, 3)
self.positional_embedding = self.generate_positional_encoding(config.width, self.channels).to(self.device)
def generate_positional_encoding(self, seq_length, d_model):
def generate_sin_cos_positional_encoding(seq_len, d_model):
pos = torch.arange(seq_len).unsqueeze(1) # (seq_len, 1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) # (d_model / 2)
pe = torch.zeros(seq_len, d_model)
pe[:, 0::2] = torch.sin(pos * div_term)
pe[:, 1::2] = torch.cos(pos * div_term)
return pe
positional_encoding = generate_sin_cos_positional_encoding(seq_length, d_model)
embedding = nn.Embedding(seq_length, d_model)
embedding.weight = nn.Parameter(positional_encoding, requires_grad=True)
return embedding
def forward(self, x):
rnn_x = x.squeeze(2).permute(0, 2, 1)
positions = torch.arange(rnn_x.size(1), device=x.device).unsqueeze(0).expand(rnn_x.size(0), -1)
rnn_x = rnn_x + self.positional_embedding(positions)
h0 = torch.zeros(self.num_layers * 2, rnn_x.size(0), self.hidden_size).to(x.device)
gru_output, _ = self.gru(rnn_x, h0) # batch_size, sequence_length, 2 * hidden_size
context_vector, _ = self.attention(gru_output, gru_output, gru_output) # batch_size, sequence_length, 2 * hidden_size
gru_output_fc = self.fc(context_vector) # batch_size, sequence_length, 3
gru_output_fc = gru_output_fc.transpose(1, 2).unsqueeze(2) # batch_size, 3, 1, sequence_length
x = torch.cat((x, gru_output_fc), dim=1)
x = self.pool1(self.relu(self.conv1(x)))
x = self.pool2(self.relu(self.conv2(x)))
x = self.pool3(self.relu(self.conv3(x)))
x = x.view(-1, 128 * (self.sequence_length // 8))
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
更多推荐
所有评论(0)