记录一下MinerU核心代码阅读时的一些笔记。
该类中加载了所有的模型。一共5个模型,已经在代码中注释了。
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
"""
======== model init ========
"""
# 获取当前文件(即 pdf_extract_kit.py)的绝对路径
current_file_path = os.path.abspath(__file__)
# 获取当前文件所在的目录(model)
current_dir = os.path.dirname(current_file_path)
# 上一级目录(magic_pdf)
root_dir = os.path.dirname(current_dir)
# model_config目录
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, "r", encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
# table config
self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
self.apply_table = self.table_config.get("is_table_recog_enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_type = self.table_config.get("model", TABLE_MASTER)
self.apply_ocr = ocr
logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
)
)
assert self.apply_layout, "DocAnalysis must contain layout model."
# 初始化解析方案
self.device = kwargs.get("device", self.configs["config"]["device"])
logger.info("using device: {}".format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
logger.info("using models_dir: {}".format(models_dir))
atom_model_manager = AtomModelSingleton()
# 初始化公式识别
if self.apply_formula:
# 初始化公式检测模型
# self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
# ------------------------------------1、公式检测模型------------------------------------
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
)
# 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
# self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
# self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
# ------------------------------------2、公式解析模型------------------------------------
self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path,
device=self.device
)
# 初始化layout模型
# self.layout_model = Layoutlmv3_Predictor(
# str(os.path.join(models_dir, self.configs['weights']['layout'])),
# str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
# device=self.device
# )
# ------------------------------------3、版面分析模型------------------------------------
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
device=self.device
)
# 初始化ocr
if self.apply_ocr:
# self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
# ------------------------------------4、ocr模型------------------------------------
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3
)
# init table model
if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_type]
# self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
# max_time=self.table_max_time, _device_=self.device)
# ------------------------------------5、表格识别模型------------------------------------
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_type=self.table_model_type,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
)
logger.info('DocAnalysis init done!')
基本上所有的模型推理结果都在这里。 这里的ocr模型推理修改了原始了PaddleOCR预测方法,添加了公式检测的bbox。 以便在ocr识别时,能够将公式区域隔离出来,因为有专门的公式识别模型。 而且在ocr识别时,将相应的子图长宽都扩展了50像素,目的是提升识别效果? 最后的返回结果是layout_res,里面包含了所有的推理结果。 这是一个list,在后续的代码中,都叫做model_list。 每个item都是一个dict,里面包含了推理结果。 如下所示:
[
{
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
}
]
text 字段也可以是html,latex。分别对应公式表格和公式的识别结果。 poly也可能是bbox。两种表示方式是混用的。
def __call__(self, image):
latex_filling_list = []
mf_image_list = []
# layout检测
layout_start = time.time()
layout_res = self.layout_model(image, ignore_catids=[])
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection cost: {layout_cost}")
if self.apply_formula:
# 公式检测
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res.append(new_item)
latex_filling_list.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
mf_image_list.append(bbox_img)
# 公式识别
mfr_start = time.time()
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
output = self.mfr_model.generate({'image': mf_img})
mfr_res.extend(output['pred_str'])
for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
# Select regions for OCR / formula regions / table regions
ocr_res_list = []
table_res_list = []
single_page_mfdetrec_res = []
for res in layout_res:
if int(res['category_id']) in [13, 14]:
single_page_mfdetrec_res.append({
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])],
})
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res)
elif int(res['category_id']) in [5]:
table_res_list.append(res)
# Unified crop img logic
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
# Create a white background with an additional width and height of 50
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
return return_image, return_list
pil_img = Image.fromarray(image)
# ocr识别
if self.apply_ocr:
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# Adjust the coordinates of the formula area
adjusted_mfdetrec_res = []
for mf_res in single_page_mfdetrec_res:
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0 = mf_xmin - xmin + paste_x
y0 = mf_ymin - ymin + paste_y
x1 = mf_xmax - xmin + paste_x
y1 = mf_ymax - ymin + paste_y
# Filter formula blocks outside the graph
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
continue
else:
adjusted_mfdetrec_res.append({
"bbox": [x0, y0, x1, y1],
})
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
# Integration results
if ocr_res:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
# Convert the coordinates back to the original coordinate system
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
layout_res.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr cost: {ocr_cost}")
# 表格识别 table recognition
if self.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time()
logger.info("------------------table recognition processing begins-----------------")
latex_code = None
html_code = None
if self.table_model_type == STRUCT_EQTABLE:
with torch.no_grad():
latex_code = self.table_model.image2latex(new_image)[0]
else:
html_code = self.table_model.img2html(new_image)
run_time = time.time() - single_table_start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time:
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
# 判断是否返回正常
if latex_code:
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
'end{table}')
if expected_ending:
res["latex"] = latex_code
else:
logger.warning(f"------------table recognition processing fails----------")
elif html_code:
res["html"] = html_code
else:
logger.warning(f"------------table recognition processing fails----------")
table_cost = round(time.time() - table_start, 2)
logger.info(f"table cost: {table_cost}")
return layout_res
没有一个专门的类叫做PostProcess。 后处理的实现分散在三个地方
核心代码如下:
def __init__(self, model_list: list, docs: fitz.Document):
self.__model_list = model_list
self.__docs = docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self.__fix_axis()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
"""确定脚注是属于图片还是表格的"""
self.__fix_footnote()
这里的核心代码如下:
'''删除重叠spans中置信度较低的那些'''
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
'''删除重叠spans中较小的那些'''
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
'''对image和table截图'''
spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
'''将所有区块的bbox整理到一起'''
# interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = []
if len(interline_equation_blocks) > 0:
all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
interline_equation_blocks, page_w, page_h)
else:
all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
interline_equations, page_w, page_h)
if len(drop_reasons) > 0:
need_drop = True
drop_reason.append(DropReason.OVERLAP_BLOCKS_CAN_NOT_SEPARATION)
'''先处理不需要排版的discarded_blocks'''
discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
'''如果当前页面没有bbox则跳过'''
if len(all_bboxes) == 0:
logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}")
return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
[], [], interline_equations, fix_discarded_blocks,
need_drop, drop_reason)
"""在切分之前,先检查一下bbox是否有左右重叠的情况,如果有,那么就认为这个pdf暂时没有能力处理好,这种左右重叠的情况大概率是由于pdf里的行间公式、表格没有被正确识别出来造成的 """
while True: # 循环检查左右重叠的情况,如果存在就删除掉较小的那个bbox,直到不存在左右重叠的情况
is_useful_block_horz_overlap, all_bboxes = remove_horizontal_overlap_block_which_smaller(all_bboxes)
if is_useful_block_horz_overlap:
need_drop = True
drop_reason.append(DropReason.USEFUL_BLOCK_HOR_OVERLAP)
else:
break
'''根据区块信息计算layout'''
page_boundry = [0, 0, page_w, page_h]
layout_bboxes, layout_tree = get_bboxes_layout(all_bboxes, page_boundry, page_id)
if len(text_blocks) > 0 and len(all_bboxes) > 0 and len(layout_bboxes) == 0:
logger.warning(
f"skip this page, page_id: {page_id}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}")
need_drop = True
drop_reason.append(DropReason.CAN_NOT_DETECT_PAGE_LAYOUT)
"""以下去掉复杂的布局和超过2列的布局"""
if any([lay["layout_label"] == LAYOUT_UNPROC for lay in layout_bboxes]): # 复杂的布局
logger.warning(
f"skip this page, page_id: {page_id}, reason: {DropReason.COMPLICATED_LAYOUT}")
need_drop = True
drop_reason.append(DropReason.COMPLICATED_LAYOUT)
layout_column_width = get_columns_cnt_of_layout(layout_tree)
if layout_column_width > 2: # 去掉超过2列的布局pdf
logger.warning(
f"skip this page, page_id: {page_id}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}")
need_drop = True
drop_reason.append(DropReason.TOO_MANY_LAYOUT_COLUMNS)
'''根据layout顺序,对当前页面所有需要留下的block进行排序'''
sorted_blocks = sort_blocks_by_layout(all_bboxes, layout_bboxes)
'''将span填入排好序的blocks中'''
block_with_spans, spans = fill_spans_in_blocks(sorted_blocks, spans, 0.3)
'''对block进行fix操作'''
fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
'''获取QA需要外置的list'''
images, tables, interline_equations = get_qa_need_list_v2(fix_blocks)
'''构造pdf_info_dict'''
page_info = ocr_construct_page_component_v2(fix_blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
images, tables, interline_equations, fix_discarded_blocks,
need_drop, drop_reason)
return page_info
这里的后处理主要是进行分段,核心代码如下:
def para_split(pdf_info_dict, debug_mode, lang="en"):
global debug_able
debug_able = debug_mode
new_layout_of_pages = [] # 数组的数组,每个元素是一个页面的layoutS
all_page_list_info = [] # 保存每个页面开头和结尾是否是列表
for page_num, page in pdf_info_dict.items():
blocks = copy.deepcopy(page['preproc_blocks'])
layout_bboxes = page['layout_bboxes']
new_layout_bbox = __common_pre_proc(blocks, layout_bboxes)
new_layout_of_pages.append(new_layout_bbox)
splited_blocks, page_list_info = __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang)
all_page_list_info.append(page_list_info)
page['para_blocks'] = splited_blocks
"""连接页面与页面之间的可能合并的段落"""
pdf_infos = list(pdf_info_dict.values())
for page_num, page in enumerate(pdf_info_dict.values()):
if page_num == 0:
continue
pre_page_paras = pdf_infos[page_num - 1]['para_blocks']
next_page_paras = pdf_infos[page_num]['para_blocks']
pre_page_layout_bbox = new_layout_of_pages[page_num - 1]
next_page_layout_bbox = new_layout_of_pages[page_num]
is_conn = __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox,
next_page_layout_bbox, page_num, lang)
if debug_able:
if is_conn:
logger.info(f"连接了第{page_num - 1}页和第{page_num}页的段落")
is_list_conn = __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox,
next_page_layout_bbox, all_page_list_info[page_num - 1],
all_page_list_info[page_num], page_num, lang)
if debug_able:
if is_list_conn:
logger.info(f"连接了第{page_num - 1}页和第{page_num}页的列表段落")
"""接下来可能会漏掉一些特别的一些可以合并的内容,对他们进行段落连接
1. 正文中有时出现一个行顶格,接下来几行缩进的情况。
2. 居中的一些连续单行,如果高度相同,那么可能是一个段落。
"""
for page_num, page in enumerate(pdf_info_dict.values()):
page_paras = page['para_blocks']
new_layout_bbox = new_layout_of_pages[page_num]
__connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang)
__merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang)
# layout展平
for page_num, page in enumerate(pdf_info_dict.values()):
page_paras = page['para_blocks']
page_blocks = [block for layout in page_paras for block in layout]
page["para_blocks"] = page_blocks
上面的三个后处理步骤,每一个都以前一个步骤的结果为基础。
magic_pdf/model/doc_analyze_by_custom_model.py里面的doc_analyze,是模型预测流程入口。 核心调用代码:
for index, img_dict in enumerate(images):
img = img_dict["img"]
page_width = img_dict["width"]
page_height = img_dict["height"]
if start_page_id <= index <= end_page_id:
#此处的custom_model是CustomPEKModel类
result = custom_model(img)
else:
result = []
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict)
doc_analyze_cost = time.time() - doc_analyze_start
logger.info(f"doc analyze cost: {doc_analyze_cost}")
return model_json
后处理的流程入口在magic_pdf/user_api.py中的parse_ocr_pdf方法中,核心代码:
def pdf_parse_union(pdf_bytes,
model_list,
imageWriter,
parse_mode,
start_page_id=0,
end_page_id=None,
debug_mode=False,
):
pdf_bytes_md5 = compute_md5(pdf_bytes)
pdf_docs = fitz.open("pdf", pdf_bytes)
'''初始化空的pdf_info_dict'''
pdf_info_dict = {}
'''用model_list和docs对象初始化magic_model'''
# model_list是CustomPEKModel预测的结果,这里进行第一阶段的后处理
magic_model = MagicModel(model_list, pdf_docs)
'''根据输入的起始范围解析pdf'''
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1
if end_page_id > len(pdf_docs) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length")
end_page_id = len(pdf_docs) - 1
'''初始化启动时间'''
start_time = time.time()
for page_id, page in enumerate(pdf_docs):
'''debug时输出每页解析的耗时'''
if debug_mode:
time_now = time.time()
logger.info(
f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
)
start_time = time_now
'''解析pdf中的每一页'''
if start_page_id <= page_id <= end_page_id:
#这里调用了parse_page_core,进行第二阶段的后处理
page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
else:
page_w = page.rect.width
page_h = page.rect.height
page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
[], [], [], [],
True, "skip page")
pdf_info_dict[f"page_{page_id}"] = page_info
"""分段"""
# 这里调用了para_split,进行第三阶段的后处理
para_split(pdf_info_dict, debug_mode=debug_mode)
"""dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = {
"pdf_info": pdf_info_list,
}
return new_pdf_info_dict