MinerU核心代码阅读笔记


记录一下MinerU核心代码阅读时的一些笔记。

1. 模型加载:CustomPEKModel

该类中加载了所有的模型。一共5个模型,已经在代码中注释了。

def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
        """
        ======== model init ========
        """
        # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
        current_file_path = os.path.abspath(__file__)
        # 获取当前文件所在的目录(model)
        current_dir = os.path.dirname(current_file_path)
        # 上一级目录(magic_pdf)
        root_dir = os.path.dirname(current_dir)
        # model_config目录
        model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
        # 构建 model_configs.yaml 文件的完整路径
        config_path = os.path.join(model_config_dir, 'model_configs.yaml')
        with open(config_path, "r", encoding='utf-8') as f:
            self.configs = yaml.load(f, Loader=yaml.FullLoader)
        # 初始化解析配置
        self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
        self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
        # table config
        self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
        self.apply_table = self.table_config.get("is_table_recog_enable", False)
        self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
        self.table_model_type = self.table_config.get("model", TABLE_MASTER)
        self.apply_ocr = ocr
        logger.info(
            "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
                self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
            )
        )
        assert self.apply_layout, "DocAnalysis must contain layout model."
        # 初始化解析方案
        self.device = kwargs.get("device", self.configs["config"]["device"])
        logger.info("using device: {}".format(self.device))
        models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
        logger.info("using models_dir: {}".format(models_dir))

        atom_model_manager = AtomModelSingleton()

        # 初始化公式识别
        if self.apply_formula:
            # 初始化公式检测模型
            # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
            # ------------------------------------1、公式检测模型------------------------------------
            self.mfd_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.MFD,
                mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
            )
            # 初始化公式解析模型
            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
            mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
            # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
            # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
            # ------------------------------------2、公式解析模型------------------------------------
            self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.MFR,
                mfr_weight_dir=mfr_weight_dir,
                mfr_cfg_path=mfr_cfg_path,
                device=self.device
            )

        # 初始化layout模型
        # self.layout_model = Layoutlmv3_Predictor(
        #     str(os.path.join(models_dir, self.configs['weights']['layout'])),
        #     str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
        #     device=self.device
        # )
        # ------------------------------------3、版面分析模型------------------------------------
        self.layout_model = atom_model_manager.get_atom_model(
            atom_model_name=AtomicModel.Layout,
            layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
            layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
            device=self.device
        )
        # 初始化ocr
        if self.apply_ocr:

            # self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
            # ------------------------------------4、ocr模型------------------------------------
            self.ocr_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.OCR,
                ocr_show_log=show_log,
                det_db_box_thresh=0.3
            )
        # init table model
        if self.apply_table:
            table_model_dir = self.configs["weights"][self.table_model_type]
            # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
            #                                     max_time=self.table_max_time, _device_=self.device)
            # ------------------------------------5、表格识别模型------------------------------------
            self.table_model = atom_model_manager.get_atom_model(
                atom_model_name=AtomicModel.Table,
                table_model_type=self.table_model_type,
                table_model_path=str(os.path.join(models_dir, table_model_dir)),
                table_max_time=self.table_max_time,
                device=self.device
            )

        logger.info('DocAnalysis init done!')

2. 模型推理:CustomPEKModel.call

基本上所有的模型推理结果都在这里。 这里的ocr模型推理修改了原始了PaddleOCR预测方法,添加了公式检测的bbox。 以便在ocr识别时,能够将公式区域隔离出来,因为有专门的公式识别模型。 而且在ocr识别时,将相应的子图长宽都扩展了50像素,目的是提升识别效果? 最后的返回结果是layout_res,里面包含了所有的推理结果。 这是一个list,在后续的代码中,都叫做model_list。 每个item都是一个dict,里面包含了推理结果。 如下所示:

[
    {
        'category_id': 15,
        'poly': p1 + p2 + p3 + p4,
        'score': round(score, 2),
        'text': text,

    }
]

text 字段也可以是html,latex。分别对应公式表格和公式的识别结果。 poly也可能是bbox。两种表示方式是混用的。

def __call__(self, image):

        latex_filling_list = []
        mf_image_list = []

        # layout检测
        layout_start = time.time()
        layout_res = self.layout_model(image, ignore_catids=[])
        layout_cost = round(time.time() - layout_start, 2)
        logger.info(f"layout detection cost: {layout_cost}")

        if self.apply_formula:
            # 公式检测
            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
            for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                new_item = {
                    'category_id': 13 + int(cla.item()),
                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                    'score': round(float(conf.item()), 2),
                    'latex': '',
                }
                layout_res.append(new_item)
                latex_filling_list.append(new_item)
                bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
                mf_image_list.append(bbox_img)

            # 公式识别
            mfr_start = time.time()
            dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
            dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
            mfr_res = []
            for mf_img in dataloader:
                mf_img = mf_img.to(self.device)
                output = self.mfr_model.generate({'image': mf_img})
                mfr_res.extend(output['pred_str'])
            for res, latex in zip(latex_filling_list, mfr_res):
                res['latex'] = latex_rm_whitespace(latex)
            mfr_cost = round(time.time() - mfr_start, 2)
            logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")

        # Select regions for OCR / formula regions / table regions
        ocr_res_list = []
        table_res_list = []
        single_page_mfdetrec_res = []
        for res in layout_res:
            if int(res['category_id']) in [13, 14]:
                single_page_mfdetrec_res.append({
                    "bbox": [int(res['poly'][0]), int(res['poly'][1]),
                             int(res['poly'][4]), int(res['poly'][5])],
                })
            elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
                ocr_res_list.append(res)
            elif int(res['category_id']) in [5]:
                table_res_list.append(res)

        #  Unified crop img logic
        def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
            crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
            crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
            # Create a white background with an additional width and height of 50
            crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
            crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
            return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')

            # Crop image
            crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
            cropped_img = input_pil_img.crop(crop_box)
            return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
            return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
            return return_image, return_list

        pil_img = Image.fromarray(image)

        # ocr识别
        if self.apply_ocr:
            ocr_start = time.time()
            # Process each area that requires OCR processing
            for res in ocr_res_list:
                new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
                paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
                # Adjust the coordinates of the formula area
                adjusted_mfdetrec_res = []
                for mf_res in single_page_mfdetrec_res:
                    mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
                    # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
                    x0 = mf_xmin - xmin + paste_x
                    y0 = mf_ymin - ymin + paste_y
                    x1 = mf_xmax - xmin + paste_x
                    y1 = mf_ymax - ymin + paste_y
                    # Filter formula blocks outside the graph
                    if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
                        continue
                    else:
                        adjusted_mfdetrec_res.append({
                            "bbox": [x0, y0, x1, y1],
                        })

                # OCR recognition
                new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]

                # Integration results
                if ocr_res:
                    for box_ocr_res in ocr_res:
                        p1, p2, p3, p4 = box_ocr_res[0]
                        text, score = box_ocr_res[1]

                        # Convert the coordinates back to the original coordinate system
                        p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
                        p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
                        p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
                        p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]

                        layout_res.append({
                            'category_id': 15,
                            'poly': p1 + p2 + p3 + p4,
                            'score': round(score, 2),
                            'text': text,
                        })

            ocr_cost = round(time.time() - ocr_start, 2)
            logger.info(f"ocr cost: {ocr_cost}")

        # 表格识别 table recognition
        if self.apply_table:
            table_start = time.time()
            for res in table_res_list:
                new_image, _ = crop_img(res, pil_img)
                single_table_start_time = time.time()
                logger.info("------------------table recognition processing begins-----------------")
                latex_code = None
                html_code = None
                if self.table_model_type == STRUCT_EQTABLE:
                    with torch.no_grad():
                        latex_code = self.table_model.image2latex(new_image)[0]
                else:
                    html_code = self.table_model.img2html(new_image)

                run_time = time.time() - single_table_start_time
                logger.info(f"------------table recognition processing ends within {run_time}s-----")
                if run_time > self.table_max_time:
                    logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
                # 判断是否返回正常

                if latex_code:
                    expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
                        'end{table}')
                    if expected_ending:
                        res["latex"] = latex_code
                    else:
                        logger.warning(f"------------table recognition processing fails----------")
                elif html_code:
                    res["html"] = html_code
                else:
                    logger.warning(f"------------table recognition processing fails----------")
            table_cost = round(time.time() - table_start, 2)
            logger.info(f"table cost: {table_cost}")

        return layout_res

3. 结果处理

没有一个专门的类叫做PostProcess。 后处理的实现分散在三个地方

1. MagicModel类

核心代码如下:

def __init__(self, model_list: list, docs: fitz.Document):
        self.__model_list = model_list
        self.__docs = docs
        """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
        self.__fix_axis()
        """删除置信度特别低的模型数据(<0.05),提高质量"""
        self.__fix_by_remove_low_confidence()
        """删除高iou(>0.9)数据中置信度较低的那个"""
        self.__fix_by_remove_high_iou_and_low_confidence()
        """确定脚注是属于图片还是表格的"""
        self.__fix_footnote()

2. pdf_parse_union_core.py:parse_page_core

这里的核心代码如下:

'''删除重叠spans中置信度较低的那些'''
    spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
    '''删除重叠spans中较小的那些'''
    spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
    '''对image和table截图'''
    spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)

    '''将所有区块的bbox整理到一起'''
    # interline_equation_blocks参数不够准,后面切换到interline_equations上
    interline_equation_blocks = []
    if len(interline_equation_blocks) > 0:
        all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
            img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
            interline_equation_blocks, page_w, page_h)
    else:
        all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
            img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
            interline_equations, page_w, page_h)

    if len(drop_reasons) > 0:
        need_drop = True
        drop_reason.append(DropReason.OVERLAP_BLOCKS_CAN_NOT_SEPARATION)

    '''先处理不需要排版的discarded_blocks'''
    discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4)
    fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)

    '''如果当前页面没有bbox则跳过'''
    if len(all_bboxes) == 0:
        logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}")
        return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
                                               [], [], interline_equations, fix_discarded_blocks,
                                               need_drop, drop_reason)

    """在切分之前,先检查一下bbox是否有左右重叠的情况,如果有,那么就认为这个pdf暂时没有能力处理好,这种左右重叠的情况大概率是由于pdf里的行间公式、表格没有被正确识别出来造成的 """

    while True:  # 循环检查左右重叠的情况,如果存在就删除掉较小的那个bbox,直到不存在左右重叠的情况
        is_useful_block_horz_overlap, all_bboxes = remove_horizontal_overlap_block_which_smaller(all_bboxes)
        if is_useful_block_horz_overlap:
            need_drop = True
            drop_reason.append(DropReason.USEFUL_BLOCK_HOR_OVERLAP)
        else:
            break

    '''根据区块信息计算layout'''
    page_boundry = [0, 0, page_w, page_h]
    layout_bboxes, layout_tree = get_bboxes_layout(all_bboxes, page_boundry, page_id)

    if len(text_blocks) > 0 and len(all_bboxes) > 0 and len(layout_bboxes) == 0:
        logger.warning(
            f"skip this page, page_id: {page_id}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}")
        need_drop = True
        drop_reason.append(DropReason.CAN_NOT_DETECT_PAGE_LAYOUT)

    """以下去掉复杂的布局和超过2列的布局"""
    if any([lay["layout_label"] == LAYOUT_UNPROC for lay in layout_bboxes]):  # 复杂的布局
        logger.warning(
            f"skip this page, page_id: {page_id}, reason: {DropReason.COMPLICATED_LAYOUT}")
        need_drop = True
        drop_reason.append(DropReason.COMPLICATED_LAYOUT)

    layout_column_width = get_columns_cnt_of_layout(layout_tree)
    if layout_column_width > 2:  # 去掉超过2列的布局pdf
        logger.warning(
            f"skip this page, page_id: {page_id}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}")
        need_drop = True
        drop_reason.append(DropReason.TOO_MANY_LAYOUT_COLUMNS)

    '''根据layout顺序,对当前页面所有需要留下的block进行排序'''
    sorted_blocks = sort_blocks_by_layout(all_bboxes, layout_bboxes)

    '''将span填入排好序的blocks中'''
    block_with_spans, spans = fill_spans_in_blocks(sorted_blocks, spans, 0.3)

    '''对block进行fix操作'''
    fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)

    '''获取QA需要外置的list'''
    images, tables, interline_equations = get_qa_need_list_v2(fix_blocks)

    '''构造pdf_info_dict'''
    page_info = ocr_construct_page_component_v2(fix_blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
                                                images, tables, interline_equations, fix_discarded_blocks,
                                                need_drop, drop_reason)
    return page_info

3. 分段para_split_v2.py:para_split

这里的后处理主要是进行分段,核心代码如下:

def para_split(pdf_info_dict, debug_mode, lang="en"):
    global debug_able
    debug_able = debug_mode
    new_layout_of_pages = []  # 数组的数组,每个元素是一个页面的layoutS
    all_page_list_info = []  # 保存每个页面开头和结尾是否是列表
    for page_num, page in pdf_info_dict.items():
        blocks = copy.deepcopy(page['preproc_blocks'])
        layout_bboxes = page['layout_bboxes']
        new_layout_bbox = __common_pre_proc(blocks, layout_bboxes)
        new_layout_of_pages.append(new_layout_bbox)
        splited_blocks, page_list_info = __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang)
        all_page_list_info.append(page_list_info)
        page['para_blocks'] = splited_blocks

    """连接页面与页面之间的可能合并的段落"""
    pdf_infos = list(pdf_info_dict.values())
    for page_num, page in enumerate(pdf_info_dict.values()):
        if page_num == 0:
            continue
        pre_page_paras = pdf_infos[page_num - 1]['para_blocks']
        next_page_paras = pdf_infos[page_num]['para_blocks']
        pre_page_layout_bbox = new_layout_of_pages[page_num - 1]
        next_page_layout_bbox = new_layout_of_pages[page_num]

        is_conn = __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox,
                                            next_page_layout_bbox, page_num, lang)
        if debug_able:
            if is_conn:
                logger.info(f"连接了第{page_num - 1}页和第{page_num}页的段落")

        is_list_conn = __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox,
                                                 next_page_layout_bbox, all_page_list_info[page_num - 1],
                                                 all_page_list_info[page_num], page_num, lang)
        if debug_able:
            if is_list_conn:
                logger.info(f"连接了第{page_num - 1}页和第{page_num}页的列表段落")

    """接下来可能会漏掉一些特别的一些可以合并的内容,对他们进行段落连接
    1. 正文中有时出现一个行顶格,接下来几行缩进的情况。
    2. 居中的一些连续单行,如果高度相同,那么可能是一个段落。
    """
    for page_num, page in enumerate(pdf_info_dict.values()):
        page_paras = page['para_blocks']
        new_layout_bbox = new_layout_of_pages[page_num]
        __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang)
        __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang)

    # layout展平
    for page_num, page in enumerate(pdf_info_dict.values()):
        page_paras = page['para_blocks']
        page_blocks = [block for layout in page_paras for block in layout]
        page["para_blocks"] = page_blocks

上面的三个后处理步骤,每一个都以前一个步骤的结果为基础。

4. 核心调用流程

magic_pdf/model/doc_analyze_by_custom_model.py里面的doc_analyze,是模型预测流程入口。 核心调用代码:

for index, img_dict in enumerate(images):
        img = img_dict["img"]
        page_width = img_dict["width"]
        page_height = img_dict["height"]
        if start_page_id <= index <= end_page_id:
            #此处的custom_model是CustomPEKModel类
            result = custom_model(img)
        else:
            result = []
        page_info = {"page_no": index, "height": page_height, "width": page_width}
        page_dict = {"layout_dets": result, "page_info": page_info}
        model_json.append(page_dict)
    doc_analyze_cost = time.time() - doc_analyze_start
    logger.info(f"doc analyze cost: {doc_analyze_cost}")

    return model_json

后处理的流程入口在magic_pdf/user_api.py中的parse_ocr_pdf方法中,核心代码:

def pdf_parse_union(pdf_bytes,
                    model_list,
                    imageWriter,
                    parse_mode,
                    start_page_id=0,
                    end_page_id=None,
                    debug_mode=False,
                    ):
    pdf_bytes_md5 = compute_md5(pdf_bytes)
    pdf_docs = fitz.open("pdf", pdf_bytes)

    '''初始化空的pdf_info_dict'''
    pdf_info_dict = {}

    '''用model_list和docs对象初始化magic_model'''
    # model_list是CustomPEKModel预测的结果,这里进行第一阶段的后处理
    magic_model = MagicModel(model_list, pdf_docs)

    '''根据输入的起始范围解析pdf'''
    # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1

    if end_page_id > len(pdf_docs) - 1:
        logger.warning("end_page_id is out of range, use pdf_docs length")
        end_page_id = len(pdf_docs) - 1

    '''初始化启动时间'''
    start_time = time.time()

    for page_id, page in enumerate(pdf_docs):
        '''debug时输出每页解析的耗时'''
        if debug_mode:
            time_now = time.time()
            logger.info(
                f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
            )
            start_time = time_now

        '''解析pdf中的每一页'''
        if start_page_id <= page_id <= end_page_id:
            #这里调用了parse_page_core,进行第二阶段的后处理
            page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
        else:
            page_w = page.rect.width
            page_h = page.rect.height
            page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
                                                [], [], [], [],
                                                True, "skip page")
        pdf_info_dict[f"page_{page_id}"] = page_info

    """分段"""
    # 这里调用了para_split,进行第三阶段的后处理
    para_split(pdf_info_dict, debug_mode=debug_mode)

    """dict转list"""
    pdf_info_list = dict_to_list(pdf_info_dict)
    new_pdf_info_dict = {
        "pdf_info": pdf_info_list,
    }

    return new_pdf_info_dict


原创文章,转载请注明出处,否则拒绝转载!
本文链接:抬头看浏览器地址栏

上篇: paddleocr蒸馏模型的center导出
下篇: 农村小伙的求学逆袭之路