【OCR】EAST算法讲解及实现
【OCR】EAST算法讲解及实现一、概念介绍OCR(Optical Character Recognition):光学字符识别;基本流程:输入图片 —— 预处理 —— 文字定位(采用EAST算法)—— 文字分割 —— 图像识别涉及的技术点:检测、分割、分类三个大任务,感觉像图像中很多任务都是在分类、检测、分割上实现的;二、算法原理论文地址:https://arxiv.org/pdf/1704.03
【OCR】EAST算法讲解及实现
一、概念介绍
OCR(Optical Character Recognition):光学字符识别;
基本流程:
输入图片 —— 预处理 —— 文字定位(采用EAST算法)—— 文字分割 —— 图像识别
涉及的技术点:检测、分割、分类三个大任务,感觉像图像中很多任务都是在分类、检测、分割上实现的;
二、算法原理
论文地址:https://arxiv.org/pdf/1704.03155.pdf
关键:提出了基于两阶段的文本检测方法:全卷积神经网络和非极大值抑制,消除中间过程冗余,减少检测时间;
作用:可以检测单词级别或文本行级别的文本,检测框可以为任意形状的四边形,可带有旋转;
总结:相比于传统检测算法,输出更多的信息,能够检测除了标准四边形外的其他形状;
1、模型结构图
可以看出,这里类似于一个编解码的过程,Feature extractor对特征图进行降采样,Feature merging对特征图进行上采样(特征融合),有点类似于图像分割的中的trick;
2、模型输入及输出
输出:score(置信度)、text boxes(x,y,w,h)、text rotation(θ),这里的文本框x,y,w,h也可以用d1,d2,d3,d4表示,意思是点离框的上下左右的距离;
首先要理解GT是怎么生成的:
3、Loss的计算(分为三个部分)
Score部分的计算:
可以看出,是用到一个平衡交叉熵的计算;
RBOX部分的计算:
主要用到IOU计算Loss;
角度部分的计算:
三、训练部分实现
1、train.py
作用:训练代码,将Loss和Model传入进来,实现模型的训练;
实现步骤如下:
一、引入所需要的包
# 引入一些包
import torch
#数据增强处理包
from torch.utils import data
from torch import nn
#引入学习率变化策略包
from torch.optim import lr_scheduler
#引入自己写的数据加载包,把数据整理成EAST_pytorch需要的格式
from dataset import custom_dataset
#引入自己写的EAST模型
from model import EAST
#引入自己写的Loss
from loss import Loss
import os
import time
import numpy as np
Model和Loss两个包将在后面讲解;
二、加载数据
def train(train_img_path, train_gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, interval):
# 得到文件夹下得文件列表,然后计算数目
file_num = len(os.listdir(train_img_path))
#利用我们自己custom_dataset把数据从文件夹里读取,并整理成pytorch.utils.data.DataLoader需要的格式trainset
trainset = custom_dataset(train_img_path, train_gt_path)
#把数据按照batch_size,shuffle,num_workers,drop_last等原则进行整理
train_loader = data.DataLoader(trainset, batch_size=batch_size, \
shuffle=True, num_workers=num_workers, drop_last=True)
三、加载Loss和Model
#加载自己实现的loss函数
criterion = Loss()
#获取计算设备,使用gpu或cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载自己实现的模型
model = EAST()
#是否并行计算(多GPU)
data_parallel = False
# 如果有gpu设备,就把模型转为并行模型,方便用并行得方式把数据输入到模型中
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
data_parallel = True
#把模型绑定到gpu或cpu上
model.to(device)
四、加载优化器
#优化器,把模型得参数给到优化器,让优化器按照lr去更新这些参数
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#设定优化器工作的时候,lr的更新策略
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter//2], gamma=0.1)
五、设置epoch次数,进行训练
for epoch in range(epoch_iter):
model.train()
# 每个epoch开始前,更新一个学习率
scheduler.step()
epoch_loss = 0
epoch_time = time.time()
for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
start_time = time.time()
img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), gt_geo.to(device), ignored_map.to(device)
# 将图片输入east模型中,得到score map 和 geo_map
pred_score, pred_geo = model(img)
# 计算score的loss,geomerty的loss
loss = criterion(gt_score, pred_score, gt_geo, pred_geo, ignored_map)
# epoch内累加loss
epoch_loss += loss.item()
# 原来梯度清零
optimizer.zero_grad()
# 重新计算梯度
loss.backward()
optimizer.step()
#打印当前epoch,当前iter的mini-batch,time cost,batch_loss等信息
print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item()))
#打印整个epoch的loss,时间耗费
print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(epoch_loss/int(file_num/batch_size), time.time()-epoch_time))
print(time.asctime(time.localtime(time.time())))
print('='*50)
# 保存模型参数
if (epoch + 1) % interval == 0:
# 并行计算时保存模型有所不同
state_dict = model.module.state_dict() if data_parallel else model.state_dict()
torch.save(state_dict, os.path.join(pths_path, 'model_epoch_{}.pth'.format(epoch+1)))
拓展:
-
上述代码中model.train()有什么作用,对比eval()模式有什么不同?
函数train和函数eval的作用是将Module及其SubModule分别设置为training mode和evaluation mode。这两个函数只对特定的Module有影响,例如Class Dropout、Class BatchNorm;
六、主函数实现(传入参数并运行)
if __name__ == '__main__':
#图片路径
train_img_path = os.path.abspath('../ICDAR_2015/train_img')
#图片标注路径
train_gt_path = os.path.abspath('../ICDAR_2015/train_gt')
#模型保存路径,含预先训练得模型和本次训练得模型
pths_path = './pths'
batch_size = 24
# 用batch_size等于1进行验证
# batch_size = 1
lr= 1e-3
#数据处理得线程数目(可以设定大一些)
num_workers = 1
# 训练轮次
epoch_iter= 600
# 保存模型的间隔
save_interval = 50
# 传入参数进行训练
train(train_img_path, train_gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, save_interval)
2、model.py
在讲解模型代码之前,先放一篇文章,这个博主写得很好,把EAST网络得由来都说明了一遍;
文章:https://blog.csdn.net/u011046017/article/details/93392862
源码地址:https://github.com/SakuraRiven/EAST
接下来再看一下模型结构图:
首先extractor结构,主要作用是特征提取,这里可以选用VGG、ResNet等骨干网络,本次选用较为简单的VGG;
接着是merging结构,主要作用是特征合并,通过特征融合的方式进行上采样;
左后是output结构,将得到的特征图转换成对应ground truth的特征格式;
实现步骤:
一、extroctor结构的实现
class extractor(nn.Module):
def __init__(self, pretrained):
super(extractor, self).__init__()
vgg16_bn = VGG(make_layers(cfg, batch_norm=True))
# 加载预先训练得模型(需提前下载)
if pretrained:
vgg16_bn.load_state_dict(torch.load('./pths/vgg16_bn-6c64b313.pth'))
self.features = vgg16_bn.features
def forward(self, x):
out = []
for m in self.features:
x = m(x)
# 遇到MaxPool2d保存feature maps
if isinstance(m, nn.MaxPool2d):
out.append(x)
#只取后四个feature map
return out[1:]
这里VGG的实现就不展示代码了,也可以替换成其他骨干网络,如ResNet等;
二、merging结构的实现
这里的x输入为特征提取层中的四个特征图输出;
def forward(self, x):
# 对照论文上融合过程进行实现
y = F.interpolate(x[3], scale_factor=2, mode='bilinear', align_corners=True)
y = torch.cat((y, x[2]), 1)
y = self.relu1(self.bn1(self.conv1(y)))
y = self.relu2(self.bn2(self.conv2(y)))
y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
y = torch.cat((y, x[1]), 1)
y = self.relu3(self.bn3(self.conv3(y)))
y = self.relu4(self.bn4(self.conv4(y)))
y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
y = torch.cat((y, x[0]), 1)
y = self.relu5(self.bn5(self.conv5(y)))
y = self.relu6(self.bn6(self.conv6(y)))
y = self.relu7(self.bn7(self.conv7(y)))
return y
三、output结构实现
class output(nn.Module):
def __init__(self, scope=512):
super(output, self).__init__()
self.conv1 = nn.Conv2d(32, 1, 1)
self.sigmoid1 = nn.Sigmoid()
self.conv2 = nn.Conv2d(32, 4, 1)
self.sigmoid2 = nn.Sigmoid()
self.conv3 = nn.Conv2d(32, 1, 1)
self.sigmoid3 = nn.Sigmoid()
self.scope = 512
# 参数初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
# 何凯明初始化方法
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# bias初始化为0
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
score = self.sigmoid1(self.conv1(x))
loc = self.sigmoid2(self.conv2(x)) * self.scope
#角度数值转化为弧度
angle = (self.sigmoid3(self.conv3(x)) - 0.5) * math.pi
geo = torch.cat((loc, angle), 1)
return score, geo
到这里就得到了最终需要的score map和geo map,可以计算损失更新参数了;
3、loss.py实现
一、Score map的损失函数定义
首先score map的损失函数用平衡交叉熵损失,作为一个最常见的二分类损失函数,加入超参数β减少类别不平衡的问题;
当然,也可以选用dice_loss损失函数,二者的效果会在推理部分进行展示;
def get_class_balanced_cross_entropy(gt_score,pred_score):
# 防止inf,nan
pred_score[pred_score==0]=1e-8
pred_score[pred_score==1]=1-1e-8
# 按照公式求
beta = 1 - torch.sum(gt_score) / gt_score.numel() # 计算正例/所有样例
score_loss = -beta * torch.sum(gt_score*torch.log(pred_score))
- (1-beta) * torch.sum((1-gt_score)*torch.log(1-pred_score))
return score_loss
def get_dice_loss(gt_score, pred_score):
# 按照dice_loss的数学公式写
inter = torch.sum(gt_score * pred_score)
union = torch.sum(gt_score) + torch.sum(pred_score) + 1e-5
return 1. - (2 * inter / union)
这里的Loss相对于标准的交叉熵Loss进行了改进,还有一种改进为Focal Loss,也是提升了一定的效果,说明Loss在实际任务中起到最重要的作用,定义好的Loss有助于提升模型的拟合能力;
numel()函数的作用:计算所有像素点的数量;
二、geo map的损失函数定义
实际上边框四个点的坐标,可以通过IOU Loss进行计算,一般回归检测框的任务都少不了IOU的计算;
angle loss则通过特定公式进行计算;
def get_geo_loss(gt_geo, pred_geo):
# 得到ground truth的值
d1_gt, d2_gt, d3_gt, d4_gt, angle_gt = torch.split(gt_geo, 1, 1)
# 得到模型预测值
d1_pred, d2_pred, d3_pred, d4_pred, angle_pred = torch.split(pred_geo, 1, 1)
# 得到gt的检测框面积
area_gt = (d1_gt + d2_gt) * (d3_gt + d4_gt)
# 得到预测框的面积
area_pred = (d1_pred + d2_pred) * (d3_pred + d4_pred)
# 通过计算d3和d4最小和相加,得到重合框的宽
w_union = torch.min(d3_gt, d3_pred) + torch.min(d4_gt, d4_pred)
# 通过计算d1和d2最小和相加,得到重合框的高
h_union = torch.min(d1_gt, d1_pred) + torch.min(d2_gt, d2_pred)
# 得到两个框交集的面积
area_intersect = w_union * h_union
# 计算两个框并集的面积
area_union = area_gt + area_pred - area_intersect
# 计算IOU Loss的值(-log(交集/并集))
iou_loss_map = -torch.log((area_intersect + 1.0)/(area_union + 1.0))
# angle的loss直接带入公式即可
angle_loss_map = 1 - torch.cos(angle_pred - angle_gt)
return iou_loss_map, angle_loss_map
四、推理部分实现
1、主代码实现
推理部分的逻辑如下代码所示:
if __name__ == '__main__':
img_path = './image/img_3.jpg'
model_path = './pths/best.pth'
res_img = './res.jpg'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#加载模型
model = EAST().to(device)
#加载模型参数
model.load_state_dict(torch.load(model_path))
#把模型设置为测试状态,避免bn,dropout发生计算
model.eval()
#读入图片
img = Image.open(img_path)
#进行检测得到框
boxes = detect(img, model, device)
#绘制boxes到图片上
plot_img = plot_boxes(img, boxes)
#保存图片到./res.bmp
plot_img.save(res_img)
基本上所有的模型推理都是这个流程,至于其中的前后处理都包含在了函数中;
2、detect部分实现
def detect(img, model, device):
'''
Input:
img : PIL Image
model : detection model
device: gpu if gpu is available
Output:
detected polys
'''
# 图片缩放(记录缩放比例)
img, ratio_h, ratio_w = resize_img(img)
# 计算出score_map,geo_map
with torch.no_grad():
score, geo = model(load_pil(img).to(device))
#get_boxes函数里面用到 locality aware NMS(后处理)
boxes = get_boxes(score.squeeze(0).cpu().numpy(), geo.squeeze(0).cpu().numpy())
return adjust_ratio(boxes, ratio_w, ratio_h)
上面有前处理和后处理的操作,比如resize_img和load_pil这两个函数做的就是前处理;
get_boxes这个函数做的就是后处理,这里主要讲解一下后处理的操作实现;
-
get_boxes函数实现
def get_boxes(score, geo, score_thresh=0.9, nms_thresh=0.2): ''' Input: score : score map from model <numpy.ndarray, (1,row,col)> geo : geo map from model <numpy.ndarray, (5,row,col)> score_thresh: threshold to segment score map(文本框选取的阈值) nms_thresh : threshold in nms(nms重合阈值) Output: boxes : final polys <numpy.ndarray, (n,9)> ''' # 整个图片上得点,每个点代表一个候选框 score = score[0,:,:] # 去掉batch这个维度,输出[row, col] #留下score大于score_thresh的候选框对应的像素点坐标 xy_text = np.argwhere(score > score_thresh) # n x (y, x), 得到满足条件的点的坐标索引 # 如果没有任何点满足条件,退出函数 if xy_text.size == 0: return None #按行排序 xy_text = xy_text[np.argsort(xy_text[:, 0])] #得到像素点坐标 valid_pos = xy_text[:, ::-1].copy() # n x 2, [x, y] #利用像素点坐标拿到候选框的geo值 valid_geo = geo[:, xy_text[:, 0], xy_text[:, 1]] # 5 x n #把d1,d2,d3,d4转换为四个顶点的坐标 polys_restored, index = restore_polys(valid_pos, valid_geo, score.shape) if polys_restored.size == 0: return None boxes = np.zeros((polys_restored.shape[0], 9), dtype=np.float32) boxes[:, :8] = polys_restored boxes[:, 8] = score[xy_text[index, 0], xy_text[index, 1]] #boxes是quad的形式,四边形得四个顶点的坐标(x1,y1,x2,y2,x3,y3,x4,y4) boxes = locality_aware_nms(boxes.astype('float32'), nms_thresh) return boxes
拓展:
-
np.argwhere的作用是什么?
返回满足条件的像素点的坐标位置,也就是起到索引的一个作用;
-
3、Locality-Aware NMS
前提说明:NMS适用于多目标检测,单目标直接取score最大值即可;非极大值抑制的原理是选取score最大的框,将IOU重合度高的框去除,再取剩下框中score最高的框执行同样操作;
但是候选框的数量比较大,如果一一比较的话计算量过大,这里采用局部merge,再对merge后的框进行NMS;
作用:
1、计算量少;
2、检测稳定(抖动大的时候检测框也稳定)
实现要点:
合并两个预选框时,采用加权平均的方式,权值为score的值,将d的值合并起来;
实现方式:
1、通过调用lanms包中的函数实现;
import lanms
boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thresh)
2、通过numpy实现;
def nms_locality(quads, thres=0.3):
'''
locality aware nms of EAST
:param quads: a N*9 numpy array. first 8 coordinates, then prob
:return: boxes after nms
'''
S = []
p = None
for g in quads:
# intersection作用:计算两个框的IOU
if p is not None and intersection(g, p) > thres:
# 如果大于阈值,则进行两个框的合并
p = weighted_merge(g, p)
else:
if p is not None:
S.append(p)
p = g
if p is not None:
S.append(p)
if len(S) == 0:
return np.array([])
# 最后再经过标准的NMS过滤
return standard_nms(np.array(S), thres)
4、效果展示
首先来看一下训练了600个epochs的效果:
这里采用的分类损失函数为dice_loss,采用类平衡交叉熵损失会造成很多的误检率,模型不小心被删掉了…
五、总结
主要有以下几点看法:
1、作用通用场景下的OCR识别,EAST模型的效果还是比较可观的;
2、在分类损失函数的选择上,使用类平衡交叉熵有较高的误检率,而使用dice_loss有时召回率过低,同时使用二者来进行加权应该能达到更好的效果;
3、这是一个2017年开源的模型,相信以最新的算法来说已经有一个更好更快的通用模型了,包括说paddleOCR,也是支持在端侧上进行部署使用了;
4、实际上一个项目的实现最困难的部分不是模型,而是数据的前后处理,特别是对一些检测类的任务,往往标注的时候会比较简单(节约成本),那么处理上就相对比较复杂,如果有一处地方处理错误,往往会导致整个模型往错误的方向进行学习,所以学好numpy还是很有必要的;
更多推荐
所有评论(0)