【OCR】EAST算法讲解及实现

一、概念介绍

OCR(Optical Character Recognition):光学字符识别;

基本流程:

输入图片 —— 预处理 —— 文字定位(采用EAST算法)—— 文字分割 —— 图像识别

涉及的技术点:检测、分割、分类三个大任务,感觉像图像中很多任务都是在分类、检测、分割上实现的;

二、算法原理

论文地址:https://arxiv.org/pdf/1704.03155.pdf

关键:提出了基于两阶段的文本检测方法:全卷积神经网络和非极大值抑制,消除中间过程冗余,减少检测时间;

作用:可以检测单词级别或文本行级别的文本,检测框可以为任意形状的四边形,可带有旋转;

总结:相比于传统检测算法,输出更多的信息,能够检测除了标准四边形外的其他形状;

1、模型结构图

在这里插入图片描述

可以看出,这里类似于一个编解码的过程,Feature extractor对特征图进行降采样,Feature merging对特征图进行上采样(特征融合),有点类似于图像分割的中的trick;

2、模型输入及输出

输出:score(置信度)、text boxes(x,y,w,h)、text rotation(θ),这里的文本框x,y,w,h也可以用d1,d2,d3,d4表示,意思是点离框的上下左右的距离;

首先要理解GT是怎么生成的:

在这里插入图片描述

3、Loss的计算(分为三个部分)

Score部分的计算:

在这里插入图片描述

可以看出,是用到一个平衡交叉熵的计算;

RBOX部分的计算:

在这里插入图片描述

主要用到IOU计算Loss;

角度部分的计算:

在这里插入图片描述

三、训练部分实现

1、train.py

作用:训练代码,将Loss和Model传入进来,实现模型的训练;

实现步骤如下:

一、引入所需要的包

# 引入一些包
import torch
#数据增强处理包
from torch.utils import data
from torch import nn
#引入学习率变化策略包
from torch.optim import lr_scheduler
#引入自己写的数据加载包,把数据整理成EAST_pytorch需要的格式
from dataset import custom_dataset
#引入自己写的EAST模型
from model import EAST
#引入自己写的Loss
from loss import Loss
import os
import time
import numpy as np

Model和Loss两个包将在后面讲解;

二、加载数据

def train(train_img_path, train_gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, interval):
	# 得到文件夹下得文件列表,然后计算数目
	file_num = len(os.listdir(train_img_path))
	#利用我们自己custom_dataset把数据从文件夹里读取,并整理成pytorch.utils.data.DataLoader需要的格式trainset
	trainset = custom_dataset(train_img_path, train_gt_path)
	#把数据按照batch_size,shuffle,num_workers,drop_last等原则进行整理
	train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                   shuffle=True, num_workers=num_workers, drop_last=True)

三、加载Loss和Model

#加载自己实现的loss函数	
criterion = Loss()
#获取计算设备,使用gpu或cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载自己实现的模型
model = EAST()
#是否并行计算(多GPU)
data_parallel = False
# 如果有gpu设备,就把模型转为并行模型,方便用并行得方式把数据输入到模型中
if torch.cuda.device_count() > 1:
	model = nn.DataParallel(model)
	data_parallel = True
#把模型绑定到gpu或cpu上
model.to(device)

四、加载优化器

#优化器,把模型得参数给到优化器,让优化器按照lr去更新这些参数
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#设定优化器工作的时候,lr的更新策略
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter//2], gamma=0.1)

五、设置epoch次数,进行训练

for epoch in range(epoch_iter):	
	model.train()
	# 每个epoch开始前,更新一个学习率
	scheduler.step()
	epoch_loss = 0
	epoch_time = time.time()
	for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
		start_time = time.time()
		img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), gt_geo.to(device), ignored_map.to(device)
		# 将图片输入east模型中,得到score map 和 geo_map
		pred_score, pred_geo = model(img)
		# 计算score的loss,geomerty的loss
		loss = criterion(gt_score, pred_score, gt_geo, pred_geo, ignored_map)
		# epoch内累加loss	
		epoch_loss += loss.item()
		# 原来梯度清零
		optimizer.zero_grad()
		# 重新计算梯度
		loss.backward()
		optimizer.step()
		#打印当前epoch,当前iter的mini-batch,time cost,batch_loss等信息
		print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item()))
	#打印整个epoch的loss,时间耗费
	print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(epoch_loss/int(file_num/batch_size), time.time()-epoch_time))
	print(time.asctime(time.localtime(time.time())))
	print('='*50)
	# 保存模型参数
	if (epoch + 1) % interval == 0:
        # 并行计算时保存模型有所不同
		state_dict = model.module.state_dict() if data_parallel else model.state_dict()
		torch.save(state_dict, os.path.join(pths_path, 'model_epoch_{}.pth'.format(epoch+1)))

拓展:

  • 上述代码中model.train()有什么作用,对比eval()模式有什么不同?

    函数train和函数eval的作用是将Module及其SubModule分别设置为training mode和evaluation mode。这两个函数只对特定的Module有影响,例如Class Dropout、Class BatchNorm;

六、主函数实现(传入参数并运行)

if __name__ == '__main__':
	#图片路径
	train_img_path = os.path.abspath('../ICDAR_2015/train_img')
	#图片标注路径
	train_gt_path = os.path.abspath('../ICDAR_2015/train_gt')
	#模型保存路径,含预先训练得模型和本次训练得模型
	pths_path = './pths'
	batch_size = 24 
	# 用batch_size等于1进行验证
	# batch_size = 1 	
	lr= 1e-3
	#数据处理得线程数目(可以设定大一些)
	num_workers = 1	
	# 训练轮次
	epoch_iter= 600
	# 保存模型的间隔
	save_interval = 50
    # 传入参数进行训练
	train(train_img_path, train_gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, save_interval)

2、model.py

在讲解模型代码之前,先放一篇文章,这个博主写得很好,把EAST网络得由来都说明了一遍;

文章:https://blog.csdn.net/u011046017/article/details/93392862

源码地址:https://github.com/SakuraRiven/EAST

接下来再看一下模型结构图:

在这里插入图片描述

首先extractor结构,主要作用是特征提取,这里可以选用VGG、ResNet等骨干网络,本次选用较为简单的VGG;

接着是merging结构,主要作用是特征合并,通过特征融合的方式进行上采样;

左后是output结构,将得到的特征图转换成对应ground truth的特征格式;

实现步骤:

一、extroctor结构的实现

class extractor(nn.Module):
	def __init__(self, pretrained):
		super(extractor, self).__init__()
		vgg16_bn = VGG(make_layers(cfg, batch_norm=True))
		# 加载预先训练得模型(需提前下载)
		if pretrained:
			vgg16_bn.load_state_dict(torch.load('./pths/vgg16_bn-6c64b313.pth'))
		self.features = vgg16_bn.features
	
	def forward(self, x):
		out = []
		for m in self.features:
			x = m(x)
			# 遇到MaxPool2d保存feature maps
			if isinstance(m, nn.MaxPool2d):
				out.append(x)
		#只取后四个feature map
		return out[1:]

这里VGG的实现就不展示代码了,也可以替换成其他骨干网络,如ResNet等;

二、merging结构的实现

这里的x输入为特征提取层中的四个特征图输出;

def forward(self, x):
	# 对照论文上融合过程进行实现
	y = F.interpolate(x[3], scale_factor=2, mode='bilinear', align_corners=True)
	y = torch.cat((y, x[2]), 1)
	y = self.relu1(self.bn1(self.conv1(y)))		
	y = self.relu2(self.bn2(self.conv2(y)))

    y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
    y = torch.cat((y, x[1]), 1)
    y = self.relu3(self.bn3(self.conv3(y)))		
    y = self.relu4(self.bn4(self.conv4(y)))

    y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
    y = torch.cat((y, x[0]), 1)
    y = self.relu5(self.bn5(self.conv5(y)))		
    y = self.relu6(self.bn6(self.conv6(y)))

    y = self.relu7(self.bn7(self.conv7(y)))
    return y

三、output结构实现

class output(nn.Module):
	def __init__(self, scope=512):
		super(output, self).__init__()
		self.conv1 = nn.Conv2d(32, 1, 1)
		self.sigmoid1 = nn.Sigmoid()
		self.conv2 = nn.Conv2d(32, 4, 1)
		self.sigmoid2 = nn.Sigmoid()
		self.conv3 = nn.Conv2d(32, 1, 1)
		self.sigmoid3 = nn.Sigmoid()
		self.scope = 512
		# 参数初始化
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
			        # 何凯明初始化方法
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				# bias初始化为0
				if m.bias is not None:
					nn.init.constant_(m.bias, 0)

	def forward(self, x):
		score = self.sigmoid1(self.conv1(x))
		loc   = self.sigmoid2(self.conv2(x)) * self.scope
		#角度数值转化为弧度
		angle = (self.sigmoid3(self.conv3(x)) - 0.5) * math.pi
		geo   = torch.cat((loc, angle), 1) 
		return score, geo

到这里就得到了最终需要的score map和geo map,可以计算损失更新参数了;

3、loss.py实现

一、Score map的损失函数定义

首先score map的损失函数用平衡交叉熵损失,作为一个最常见的二分类损失函数,加入超参数β减少类别不平衡的问题;

当然,也可以选用dice_loss损失函数,二者的效果会在推理部分进行展示;

def get_class_balanced_cross_entropy(gt_score,pred_score):
	# 防止inf,nan
	pred_score[pred_score==0]=1e-8
	pred_score[pred_score==1]=1-1e-8
	# 按照公式求
	beta = 1 - torch.sum(gt_score) / gt_score.numel()	# 计算正例/所有样例
	score_loss = -beta * torch.sum(gt_score*torch.log(pred_score))
	- (1-beta) * torch.sum((1-gt_score)*torch.log(1-pred_score))
	return score_loss

def get_dice_loss(gt_score, pred_score):
	# 按照dice_loss的数学公式写
    inter = torch.sum(gt_score * pred_score)
    union = torch.sum(gt_score) + torch.sum(pred_score) + 1e-5
    return 1. - (2 * inter / union)

这里的Loss相对于标准的交叉熵Loss进行了改进,还有一种改进为Focal Loss,也是提升了一定的效果,说明Loss在实际任务中起到最重要的作用,定义好的Loss有助于提升模型的拟合能力;

numel()函数的作用:计算所有像素点的数量;

二、geo map的损失函数定义

实际上边框四个点的坐标,可以通过IOU Loss进行计算,一般回归检测框的任务都少不了IOU的计算;

angle loss则通过特定公式进行计算;

def get_geo_loss(gt_geo, pred_geo):
	# 得到ground truth的值
	d1_gt, d2_gt, d3_gt, d4_gt, angle_gt = torch.split(gt_geo, 1, 1)
    # 得到模型预测值
	d1_pred, d2_pred, d3_pred, d4_pred, angle_pred = torch.split(pred_geo, 1, 1)
    # 得到gt的检测框面积
	area_gt = (d1_gt + d2_gt) * (d3_gt + d4_gt)
    # 得到预测框的面积
	area_pred = (d1_pred + d2_pred) * (d3_pred + d4_pred)
    # 通过计算d3和d4最小和相加,得到重合框的宽
	w_union = torch.min(d3_gt, d3_pred) + torch.min(d4_gt, d4_pred)
    # 通过计算d1和d2最小和相加,得到重合框的高
	h_union = torch.min(d1_gt, d1_pred) + torch.min(d2_gt, d2_pred)
    # 得到两个框交集的面积
	area_intersect = w_union * h_union
    # 计算两个框并集的面积
	area_union = area_gt + area_pred - area_intersect
    # 计算IOU Loss的值(-log(交集/并集))
	iou_loss_map = -torch.log((area_intersect + 1.0)/(area_union + 1.0))
    # angle的loss直接带入公式即可
	angle_loss_map = 1 - torch.cos(angle_pred - angle_gt)
	return iou_loss_map, angle_loss_map

四、推理部分实现

1、主代码实现

推理部分的逻辑如下代码所示:

if __name__ == '__main__':
	img_path    = './image/img_3.jpg'
	model_path  = './pths/best.pth'
	res_img     = './res.jpg'
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	#加载模型
	model = EAST().to(device)
	#加载模型参数
	model.load_state_dict(torch.load(model_path))
	#把模型设置为测试状态,避免bn,dropout发生计算
	model.eval()
	#读入图片
	img = Image.open(img_path)
	#进行检测得到框 	
	boxes = detect(img, model, device)
	#绘制boxes到图片上
	plot_img = plot_boxes(img, boxes)	
	#保存图片到./res.bmp
	plot_img.save(res_img)

基本上所有的模型推理都是这个流程,至于其中的前后处理都包含在了函数中;

2、detect部分实现

def detect(img, model, device):
	'''
	Input:
		img   : PIL Image
		model : detection model
		device: gpu if gpu is available
	Output:
		detected polys
	'''
	# 图片缩放(记录缩放比例)
	img, ratio_h, ratio_w = resize_img(img)
	# 计算出score_map,geo_map
	with torch.no_grad():
		score, geo = model(load_pil(img).to(device))
	#get_boxes函数里面用到 locality aware NMS(后处理)
	boxes = get_boxes(score.squeeze(0).cpu().numpy(), geo.squeeze(0).cpu().numpy())
	return adjust_ratio(boxes, ratio_w, ratio_h)

上面有前处理和后处理的操作,比如resize_img和load_pil这两个函数做的就是前处理;

get_boxes这个函数做的就是后处理,这里主要讲解一下后处理的操作实现;

  • get_boxes函数实现

    def get_boxes(score, geo, score_thresh=0.9, nms_thresh=0.2):
    	'''
    	Input:
    		score       : score map from model <numpy.ndarray, (1,row,col)>
    		geo         : geo map from model <numpy.ndarray, (5,row,col)>
    		score_thresh: threshold to segment score map(文本框选取的阈值)
    		nms_thresh  : threshold in nms(nms重合阈值)
    	Output:
    		boxes       : final polys <numpy.ndarray, (n,9)>
    	'''
    	# 整个图片上得点,每个点代表一个候选框
    	score = score[0,:,:]	 # 去掉batch这个维度,输出[row, col]
    	#留下score大于score_thresh的候选框对应的像素点坐标
    	xy_text = np.argwhere(score > score_thresh) # n x (y, x), 得到满足条件的点的坐标索引
        # 如果没有任何点满足条件,退出函数
    	if xy_text.size == 0:
    		return None
    	#按行排序
    	xy_text = xy_text[np.argsort(xy_text[:, 0])]
    	#得到像素点坐标
    	valid_pos = xy_text[:, ::-1].copy() # n x 2, [x, y]
    	#利用像素点坐标拿到候选框的geo值
    	valid_geo = geo[:, xy_text[:, 0], xy_text[:, 1]] # 5 x n
    	#把d1,d2,d3,d4转换为四个顶点的坐标
    	polys_restored, index = restore_polys(valid_pos, valid_geo, score.shape) 
    	if polys_restored.size == 0:
    		return None
    	
    	boxes = np.zeros((polys_restored.shape[0], 9), dtype=np.float32)
    	boxes[:, :8] = polys_restored
    	boxes[:, 8] = score[xy_text[index, 0], xy_text[index, 1]]
    	#boxes是quad的形式,四边形得四个顶点的坐标(x1,y1,x2,y2,x3,y3,x4,y4)
    	boxes = locality_aware_nms(boxes.astype('float32'), nms_thresh)
    	return boxes
    

    拓展:

    • np.argwhere的作用是什么?

      返回满足条件的像素点的坐标位置,也就是起到索引的一个作用;

3、Locality-Aware NMS

前提说明:NMS适用于多目标检测,单目标直接取score最大值即可;非极大值抑制的原理是选取score最大的框,将IOU重合度高的框去除,再取剩下框中score最高的框执行同样操作;

但是候选框的数量比较大,如果一一比较的话计算量过大,这里采用局部merge,再对merge后的框进行NMS;

作用:

1、计算量少;

2、检测稳定(抖动大的时候检测框也稳定)

实现要点:

合并两个预选框时,采用加权平均的方式,权值为score的值,将d的值合并起来;

实现方式:

1、通过调用lanms包中的函数实现;

import lanms
boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thresh)

2、通过numpy实现;

def nms_locality(quads, thres=0.3):
    '''
    locality aware nms of EAST
    :param quads: a N*9 numpy array. first 8 coordinates, then prob
    :return: boxes after nms
    '''
    S = []
    p = None
    for g in quads:
        # intersection作用:计算两个框的IOU
        if p is not None and intersection(g, p) > thres:
            # 如果大于阈值,则进行两个框的合并
            p = weighted_merge(g, p)
        else:
            if p is not None:
                S.append(p)
            p = g
    if p is not None:
        S.append(p)

    if len(S) == 0:
        return np.array([])
    
    # 最后再经过标准的NMS过滤
    return standard_nms(np.array(S), thres)

4、效果展示

首先来看一下训练了600个epochs的效果:
在这里插入图片描述

这里采用的分类损失函数为dice_loss,采用类平衡交叉熵损失会造成很多的误检率,模型不小心被删掉了…

五、总结

主要有以下几点看法:

1、作用通用场景下的OCR识别,EAST模型的效果还是比较可观的;

2、在分类损失函数的选择上,使用类平衡交叉熵有较高的误检率,而使用dice_loss有时召回率过低,同时使用二者来进行加权应该能达到更好的效果;

3、这是一个2017年开源的模型,相信以最新的算法来说已经有一个更好更快的通用模型了,包括说paddleOCR,也是支持在端侧上进行部署使用了;

4、实际上一个项目的实现最困难的部分不是模型,而是数据的前后处理,特别是对一些检测类的任务,往往标注的时候会比较简单(节约成本),那么处理上就相对比较复杂,如果有一处地方处理错误,往往会导致整个模型往错误的方向进行学习,所以学好numpy还是很有必要的;

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐