magic-pdf 0.9.2__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/config/constants.py +53 -0
- magic_pdf/config/drop_reason.py +35 -0
- magic_pdf/config/drop_tag.py +19 -0
- magic_pdf/config/make_content_config.py +11 -0
- magic_pdf/{libs/ModelBlockTypeEnum.py → config/model_block_type.py} +2 -1
- magic_pdf/data/read_api.py +1 -1
- magic_pdf/dict2md/mkcontent.py +226 -185
- magic_pdf/dict2md/ocr_mkcontent.py +12 -12
- magic_pdf/filter/pdf_meta_scan.py +101 -79
- magic_pdf/integrations/rag/utils.py +4 -5
- magic_pdf/libs/config_reader.py +6 -6
- magic_pdf/libs/draw_bbox.py +13 -6
- magic_pdf/libs/pdf_image_tools.py +36 -12
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/doc_analyze_by_custom_model.py +2 -0
- magic_pdf/model/magic_model.py +13 -13
- magic_pdf/model/pdf_extract_kit.py +142 -351
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
- magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
- magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
- magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
- magic_pdf/model/sub_modules/model_init.py +149 -0
- magic_pdf/model/sub_modules/model_utils.py +51 -0
- magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +285 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +176 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
- magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
- magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
- magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
- magic_pdf/model/sub_modules/table/__init__.py +0 -0
- magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +16 -0
- magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
- magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
- magic_pdf/model/sub_modules/table/table_utils.py +11 -0
- magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
- magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +31 -29
- magic_pdf/para/para_split.py +411 -248
- magic_pdf/para/para_split_v2.py +352 -182
- magic_pdf/para/para_split_v3.py +121 -66
- magic_pdf/pdf_parse_by_ocr.py +2 -0
- magic_pdf/pdf_parse_by_txt.py +2 -0
- magic_pdf/pdf_parse_union_core.py +174 -100
- magic_pdf/pdf_parse_union_core_v2.py +253 -50
- magic_pdf/pipe/AbsPipe.py +28 -44
- magic_pdf/pipe/OCRPipe.py +5 -5
- magic_pdf/pipe/TXTPipe.py +5 -6
- magic_pdf/pipe/UNIPipe.py +24 -25
- magic_pdf/post_proc/pdf_post_filter.py +7 -14
- magic_pdf/pre_proc/cut_image.py +9 -11
- magic_pdf/pre_proc/equations_replace.py +203 -212
- magic_pdf/pre_proc/ocr_detect_all_bboxes.py +235 -49
- magic_pdf/pre_proc/ocr_dict_merge.py +5 -5
- magic_pdf/pre_proc/ocr_span_list_modify.py +122 -63
- magic_pdf/pre_proc/pdf_pre_filter.py +37 -33
- magic_pdf/pre_proc/remove_bbox_overlap.py +20 -18
- magic_pdf/pre_proc/remove_colored_strip_bbox.py +36 -14
- magic_pdf/pre_proc/remove_footer_header.py +2 -5
- magic_pdf/pre_proc/remove_rotate_bbox.py +111 -63
- magic_pdf/pre_proc/resolve_bbox_conflict.py +10 -17
- magic_pdf/resources/model_config/model_configs.yaml +2 -1
- magic_pdf/spark/spark_api.py +15 -17
- magic_pdf/tools/cli.py +3 -4
- magic_pdf/tools/cli_dev.py +6 -9
- magic_pdf/tools/common.py +70 -36
- magic_pdf/user_api.py +29 -38
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/METADATA +18 -13
- magic_pdf-0.10.0.dist-info/RECORD +198 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/WHEEL +1 -1
- magic_pdf/libs/Constants.py +0 -53
- magic_pdf/libs/MakeContentConfig.py +0 -11
- magic_pdf/libs/drop_reason.py +0 -27
- magic_pdf/libs/drop_tag.py +0 -19
- magic_pdf/model/pek_sub_modules/post_process.py +0 -36
- magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
- magic_pdf/para/para_pipeline.py +0 -297
- magic_pdf-0.9.2.dist-info/RECORD +0 -178
- /magic_pdf/{libs → config}/ocr_content_type.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
- /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
- /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
- /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,98 @@
|
|
1
|
+
import os
|
2
|
+
import argparse
|
3
|
+
import re
|
4
|
+
|
5
|
+
from PIL import Image
|
6
|
+
import torch
|
7
|
+
from torch.utils.data import Dataset, DataLoader
|
8
|
+
from torchvision import transforms
|
9
|
+
from unimernet.common.config import Config
|
10
|
+
import unimernet.tasks as tasks
|
11
|
+
from unimernet.processors import load_processor
|
12
|
+
|
13
|
+
|
14
|
+
class MathDataset(Dataset):
|
15
|
+
def __init__(self, image_paths, transform=None):
|
16
|
+
self.image_paths = image_paths
|
17
|
+
self.transform = transform
|
18
|
+
|
19
|
+
def __len__(self):
|
20
|
+
return len(self.image_paths)
|
21
|
+
|
22
|
+
def __getitem__(self, idx):
|
23
|
+
# if not pil image, then convert to pil image
|
24
|
+
if isinstance(self.image_paths[idx], str):
|
25
|
+
raw_image = Image.open(self.image_paths[idx])
|
26
|
+
else:
|
27
|
+
raw_image = self.image_paths[idx]
|
28
|
+
if self.transform:
|
29
|
+
image = self.transform(raw_image)
|
30
|
+
return image
|
31
|
+
|
32
|
+
|
33
|
+
def latex_rm_whitespace(s: str):
|
34
|
+
"""Remove unnecessary whitespace from LaTeX code.
|
35
|
+
"""
|
36
|
+
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
|
37
|
+
letter = '[a-zA-Z]'
|
38
|
+
noletter = '[\W_^\d]'
|
39
|
+
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
|
40
|
+
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
41
|
+
news = s
|
42
|
+
while True:
|
43
|
+
s = news
|
44
|
+
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
|
45
|
+
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
|
46
|
+
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
|
47
|
+
if news == s:
|
48
|
+
break
|
49
|
+
return s
|
50
|
+
|
51
|
+
|
52
|
+
class UnimernetModel(object):
|
53
|
+
def __init__(self, weight_dir, cfg_path, _device_='cpu'):
|
54
|
+
|
55
|
+
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
56
|
+
cfg = Config(args)
|
57
|
+
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
|
58
|
+
cfg.config.model.model_config.model_name = weight_dir
|
59
|
+
cfg.config.model.tokenizer_config.path = weight_dir
|
60
|
+
task = tasks.setup_task(cfg)
|
61
|
+
self.model = task.build_model(cfg)
|
62
|
+
self.device = _device_
|
63
|
+
self.model.to(_device_)
|
64
|
+
self.model.eval()
|
65
|
+
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
66
|
+
self.mfr_transform = transforms.Compose([vis_processor, ])
|
67
|
+
|
68
|
+
def predict(self, mfd_res, image):
|
69
|
+
|
70
|
+
formula_list = []
|
71
|
+
mf_image_list = []
|
72
|
+
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
73
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
74
|
+
new_item = {
|
75
|
+
'category_id': 13 + int(cla.item()),
|
76
|
+
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
77
|
+
'score': round(float(conf.item()), 2),
|
78
|
+
'latex': '',
|
79
|
+
}
|
80
|
+
formula_list.append(new_item)
|
81
|
+
pil_img = Image.fromarray(image)
|
82
|
+
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
83
|
+
mf_image_list.append(bbox_img)
|
84
|
+
|
85
|
+
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
86
|
+
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
87
|
+
mfr_res = []
|
88
|
+
for mf_img in dataloader:
|
89
|
+
mf_img = mf_img.to(self.device)
|
90
|
+
with torch.no_grad():
|
91
|
+
output = self.model.generate({'image': mf_img})
|
92
|
+
mfr_res.extend(output['pred_str'])
|
93
|
+
for res, latex in zip(formula_list, mfr_res):
|
94
|
+
res['latex'] = latex_rm_whitespace(latex)
|
95
|
+
return formula_list
|
96
|
+
|
97
|
+
|
98
|
+
|
File without changes
|
@@ -0,0 +1,149 @@
|
|
1
|
+
from loguru import logger
|
2
|
+
|
3
|
+
from magic_pdf.config.constants import MODEL_NAME
|
4
|
+
from magic_pdf.model.model_list import AtomicModel
|
5
|
+
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
|
6
|
+
DocLayoutYOLOModel
|
7
|
+
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
|
8
|
+
Layoutlmv3_Predictor
|
9
|
+
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
|
10
|
+
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
|
11
|
+
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
|
12
|
+
ModifiedPaddleOCR
|
13
|
+
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
|
14
|
+
RapidTableModel
|
15
|
+
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
|
16
|
+
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
|
17
|
+
StructTableModel
|
18
|
+
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
|
19
|
+
TableMasterPaddleModel
|
20
|
+
|
21
|
+
|
22
|
+
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
|
23
|
+
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
|
24
|
+
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
|
25
|
+
elif table_model_type == MODEL_NAME.TABLE_MASTER:
|
26
|
+
config = {
|
27
|
+
'model_dir': model_path,
|
28
|
+
'device': _device_
|
29
|
+
}
|
30
|
+
table_model = TableMasterPaddleModel(config)
|
31
|
+
elif table_model_type == MODEL_NAME.RAPID_TABLE:
|
32
|
+
table_model = RapidTableModel()
|
33
|
+
else:
|
34
|
+
logger.error('table model type not allow')
|
35
|
+
exit(1)
|
36
|
+
|
37
|
+
return table_model
|
38
|
+
|
39
|
+
|
40
|
+
def mfd_model_init(weight, device='cpu'):
|
41
|
+
mfd_model = YOLOv8MFDModel(weight, device)
|
42
|
+
return mfd_model
|
43
|
+
|
44
|
+
|
45
|
+
def mfr_model_init(weight_dir, cfg_path, device='cpu'):
|
46
|
+
mfr_model = UnimernetModel(weight_dir, cfg_path, device)
|
47
|
+
return mfr_model
|
48
|
+
|
49
|
+
|
50
|
+
def layout_model_init(weight, config_file, device):
|
51
|
+
model = Layoutlmv3_Predictor(weight, config_file, device)
|
52
|
+
return model
|
53
|
+
|
54
|
+
|
55
|
+
def doclayout_yolo_model_init(weight, device='cpu'):
|
56
|
+
model = DocLayoutYOLOModel(weight, device)
|
57
|
+
return model
|
58
|
+
|
59
|
+
|
60
|
+
def ocr_model_init(show_log: bool = False,
|
61
|
+
det_db_box_thresh=0.3,
|
62
|
+
lang=None,
|
63
|
+
use_dilation=True,
|
64
|
+
det_db_unclip_ratio=1.8,
|
65
|
+
):
|
66
|
+
if lang is not None and lang != '':
|
67
|
+
model = ModifiedPaddleOCR(
|
68
|
+
show_log=show_log,
|
69
|
+
det_db_box_thresh=det_db_box_thresh,
|
70
|
+
lang=lang,
|
71
|
+
use_dilation=use_dilation,
|
72
|
+
det_db_unclip_ratio=det_db_unclip_ratio,
|
73
|
+
)
|
74
|
+
else:
|
75
|
+
model = ModifiedPaddleOCR(
|
76
|
+
show_log=show_log,
|
77
|
+
det_db_box_thresh=det_db_box_thresh,
|
78
|
+
use_dilation=use_dilation,
|
79
|
+
det_db_unclip_ratio=det_db_unclip_ratio,
|
80
|
+
# use_angle_cls=True,
|
81
|
+
)
|
82
|
+
return model
|
83
|
+
|
84
|
+
|
85
|
+
class AtomModelSingleton:
|
86
|
+
_instance = None
|
87
|
+
_models = {}
|
88
|
+
|
89
|
+
def __new__(cls, *args, **kwargs):
|
90
|
+
if cls._instance is None:
|
91
|
+
cls._instance = super().__new__(cls)
|
92
|
+
return cls._instance
|
93
|
+
|
94
|
+
def get_atom_model(self, atom_model_name: str, **kwargs):
|
95
|
+
lang = kwargs.get('lang', None)
|
96
|
+
layout_model_name = kwargs.get('layout_model_name', None)
|
97
|
+
key = (atom_model_name, layout_model_name, lang)
|
98
|
+
if key not in self._models:
|
99
|
+
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
100
|
+
return self._models[key]
|
101
|
+
|
102
|
+
|
103
|
+
def atom_model_init(model_name: str, **kwargs):
|
104
|
+
atom_model = None
|
105
|
+
if model_name == AtomicModel.Layout:
|
106
|
+
if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
|
107
|
+
atom_model = layout_model_init(
|
108
|
+
kwargs.get('layout_weights'),
|
109
|
+
kwargs.get('layout_config_file'),
|
110
|
+
kwargs.get('device')
|
111
|
+
)
|
112
|
+
elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
|
113
|
+
atom_model = doclayout_yolo_model_init(
|
114
|
+
kwargs.get('doclayout_yolo_weights'),
|
115
|
+
kwargs.get('device')
|
116
|
+
)
|
117
|
+
elif model_name == AtomicModel.MFD:
|
118
|
+
atom_model = mfd_model_init(
|
119
|
+
kwargs.get('mfd_weights'),
|
120
|
+
kwargs.get('device')
|
121
|
+
)
|
122
|
+
elif model_name == AtomicModel.MFR:
|
123
|
+
atom_model = mfr_model_init(
|
124
|
+
kwargs.get('mfr_weight_dir'),
|
125
|
+
kwargs.get('mfr_cfg_path'),
|
126
|
+
kwargs.get('device')
|
127
|
+
)
|
128
|
+
elif model_name == AtomicModel.OCR:
|
129
|
+
atom_model = ocr_model_init(
|
130
|
+
kwargs.get('ocr_show_log'),
|
131
|
+
kwargs.get('det_db_box_thresh'),
|
132
|
+
kwargs.get('lang')
|
133
|
+
)
|
134
|
+
elif model_name == AtomicModel.Table:
|
135
|
+
atom_model = table_model_init(
|
136
|
+
kwargs.get('table_model_name'),
|
137
|
+
kwargs.get('table_model_path'),
|
138
|
+
kwargs.get('table_max_time'),
|
139
|
+
kwargs.get('device')
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
logger.error('model name not allow')
|
143
|
+
exit(1)
|
144
|
+
|
145
|
+
if atom_model is None:
|
146
|
+
logger.error('model init failed')
|
147
|
+
exit(1)
|
148
|
+
else:
|
149
|
+
return atom_model
|
@@ -0,0 +1,51 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from PIL import Image
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from magic_pdf.libs.clean_memory import clean_memory
|
8
|
+
|
9
|
+
|
10
|
+
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
|
11
|
+
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
|
12
|
+
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
|
13
|
+
# Create a white background with an additional width and height of 50
|
14
|
+
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
|
15
|
+
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
|
16
|
+
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
|
17
|
+
|
18
|
+
# Crop image
|
19
|
+
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
|
20
|
+
cropped_img = input_pil_img.crop(crop_box)
|
21
|
+
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
|
22
|
+
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
|
23
|
+
return return_image, return_list
|
24
|
+
|
25
|
+
|
26
|
+
# Select regions for OCR / formula regions / table regions
|
27
|
+
def get_res_list_from_layout_res(layout_res):
|
28
|
+
ocr_res_list = []
|
29
|
+
table_res_list = []
|
30
|
+
single_page_mfdetrec_res = []
|
31
|
+
for res in layout_res:
|
32
|
+
if int(res['category_id']) in [13, 14]:
|
33
|
+
single_page_mfdetrec_res.append({
|
34
|
+
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
35
|
+
int(res['poly'][4]), int(res['poly'][5])],
|
36
|
+
})
|
37
|
+
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
38
|
+
ocr_res_list.append(res)
|
39
|
+
elif int(res['category_id']) in [5]:
|
40
|
+
table_res_list.append(res)
|
41
|
+
return ocr_res_list, table_res_list, single_page_mfdetrec_res
|
42
|
+
|
43
|
+
|
44
|
+
def clean_vram(device, vram_threshold=8):
|
45
|
+
if torch.cuda.is_available() and device != 'cpu':
|
46
|
+
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
|
47
|
+
if total_memory <= vram_threshold:
|
48
|
+
gc_start = time.time()
|
49
|
+
clean_memory()
|
50
|
+
gc_time = round(time.time() - gc_start, 2)
|
51
|
+
logger.info(f"gc time: {gc_time}")
|
File without changes
|
File without changes
|
@@ -0,0 +1,285 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from loguru import logger
|
5
|
+
|
6
|
+
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
|
7
|
+
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
|
8
|
+
|
9
|
+
|
10
|
+
def bbox_to_points(bbox):
|
11
|
+
""" 将bbox格式转换为四个顶点的数组 """
|
12
|
+
x0, y0, x1, y1 = bbox
|
13
|
+
return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
|
14
|
+
|
15
|
+
|
16
|
+
def points_to_bbox(points):
|
17
|
+
""" 将四个顶点的数组转换为bbox格式 """
|
18
|
+
x0, y0 = points[0]
|
19
|
+
x1, _ = points[1]
|
20
|
+
_, y1 = points[2]
|
21
|
+
return [x0, y0, x1, y1]
|
22
|
+
|
23
|
+
|
24
|
+
def merge_intervals(intervals):
|
25
|
+
# Sort the intervals based on the start value
|
26
|
+
intervals.sort(key=lambda x: x[0])
|
27
|
+
|
28
|
+
merged = []
|
29
|
+
for interval in intervals:
|
30
|
+
# If the list of merged intervals is empty or if the current
|
31
|
+
# interval does not overlap with the previous, simply append it.
|
32
|
+
if not merged or merged[-1][1] < interval[0]:
|
33
|
+
merged.append(interval)
|
34
|
+
else:
|
35
|
+
# Otherwise, there is overlap, so we merge the current and previous intervals.
|
36
|
+
merged[-1][1] = max(merged[-1][1], interval[1])
|
37
|
+
|
38
|
+
return merged
|
39
|
+
|
40
|
+
|
41
|
+
def remove_intervals(original, masks):
|
42
|
+
# Merge all mask intervals
|
43
|
+
merged_masks = merge_intervals(masks)
|
44
|
+
|
45
|
+
result = []
|
46
|
+
original_start, original_end = original
|
47
|
+
|
48
|
+
for mask in merged_masks:
|
49
|
+
mask_start, mask_end = mask
|
50
|
+
|
51
|
+
# If the mask starts after the original range, ignore it
|
52
|
+
if mask_start > original_end:
|
53
|
+
continue
|
54
|
+
|
55
|
+
# If the mask ends before the original range starts, ignore it
|
56
|
+
if mask_end < original_start:
|
57
|
+
continue
|
58
|
+
|
59
|
+
# Remove the masked part from the original range
|
60
|
+
if original_start < mask_start:
|
61
|
+
result.append([original_start, mask_start - 1])
|
62
|
+
|
63
|
+
original_start = max(mask_end + 1, original_start)
|
64
|
+
|
65
|
+
# Add the remaining part of the original range, if any
|
66
|
+
if original_start <= original_end:
|
67
|
+
result.append([original_start, original_end])
|
68
|
+
|
69
|
+
return result
|
70
|
+
|
71
|
+
|
72
|
+
def update_det_boxes(dt_boxes, mfd_res):
|
73
|
+
new_dt_boxes = []
|
74
|
+
angle_boxes_list = []
|
75
|
+
for text_box in dt_boxes:
|
76
|
+
|
77
|
+
if calculate_is_angle(text_box):
|
78
|
+
angle_boxes_list.append(text_box)
|
79
|
+
continue
|
80
|
+
|
81
|
+
text_bbox = points_to_bbox(text_box)
|
82
|
+
masks_list = []
|
83
|
+
for mf_box in mfd_res:
|
84
|
+
mf_bbox = mf_box['bbox']
|
85
|
+
if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
|
86
|
+
masks_list.append([mf_bbox[0], mf_bbox[2]])
|
87
|
+
text_x_range = [text_bbox[0], text_bbox[2]]
|
88
|
+
text_remove_mask_range = remove_intervals(text_x_range, masks_list)
|
89
|
+
temp_dt_box = []
|
90
|
+
for text_remove_mask in text_remove_mask_range:
|
91
|
+
temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
|
92
|
+
if len(temp_dt_box) > 0:
|
93
|
+
new_dt_boxes.extend(temp_dt_box)
|
94
|
+
|
95
|
+
new_dt_boxes.extend(angle_boxes_list)
|
96
|
+
|
97
|
+
return new_dt_boxes
|
98
|
+
|
99
|
+
|
100
|
+
def merge_overlapping_spans(spans):
|
101
|
+
"""
|
102
|
+
Merges overlapping spans on the same line.
|
103
|
+
|
104
|
+
:param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
|
105
|
+
:return: A list of merged spans
|
106
|
+
"""
|
107
|
+
# Return an empty list if the input spans list is empty
|
108
|
+
if not spans:
|
109
|
+
return []
|
110
|
+
|
111
|
+
# Sort spans by their starting x-coordinate
|
112
|
+
spans.sort(key=lambda x: x[0])
|
113
|
+
|
114
|
+
# Initialize the list of merged spans
|
115
|
+
merged = []
|
116
|
+
for span in spans:
|
117
|
+
# Unpack span coordinates
|
118
|
+
x1, y1, x2, y2 = span
|
119
|
+
# If the merged list is empty or there's no horizontal overlap, add the span directly
|
120
|
+
if not merged or merged[-1][2] < x1:
|
121
|
+
merged.append(span)
|
122
|
+
else:
|
123
|
+
# If there is horizontal overlap, merge the current span with the previous one
|
124
|
+
last_span = merged.pop()
|
125
|
+
# Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
|
126
|
+
x1 = min(last_span[0], x1)
|
127
|
+
y1 = min(last_span[1], y1)
|
128
|
+
x2 = max(last_span[2], x2)
|
129
|
+
y2 = max(last_span[3], y2)
|
130
|
+
# Add the merged span back to the list
|
131
|
+
merged.append((x1, y1, x2, y2))
|
132
|
+
|
133
|
+
# Return the list of merged spans
|
134
|
+
return merged
|
135
|
+
|
136
|
+
|
137
|
+
def merge_det_boxes(dt_boxes):
|
138
|
+
"""
|
139
|
+
Merge detection boxes.
|
140
|
+
|
141
|
+
This function takes a list of detected bounding boxes, each represented by four corner points.
|
142
|
+
The goal is to merge these bounding boxes into larger text regions.
|
143
|
+
|
144
|
+
Parameters:
|
145
|
+
dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
list: A list containing the merged text regions, where each region is represented by four corner points.
|
149
|
+
"""
|
150
|
+
# Convert the detection boxes into a dictionary format with bounding boxes and type
|
151
|
+
dt_boxes_dict_list = []
|
152
|
+
angle_boxes_list = []
|
153
|
+
for text_box in dt_boxes:
|
154
|
+
text_bbox = points_to_bbox(text_box)
|
155
|
+
|
156
|
+
if calculate_is_angle(text_box):
|
157
|
+
angle_boxes_list.append(text_box)
|
158
|
+
continue
|
159
|
+
|
160
|
+
text_box_dict = {
|
161
|
+
'bbox': text_bbox,
|
162
|
+
'type': 'text',
|
163
|
+
}
|
164
|
+
dt_boxes_dict_list.append(text_box_dict)
|
165
|
+
|
166
|
+
# Merge adjacent text regions into lines
|
167
|
+
lines = merge_spans_to_line(dt_boxes_dict_list)
|
168
|
+
|
169
|
+
# Initialize a new list for storing the merged text regions
|
170
|
+
new_dt_boxes = []
|
171
|
+
for line in lines:
|
172
|
+
line_bbox_list = []
|
173
|
+
for span in line:
|
174
|
+
line_bbox_list.append(span['bbox'])
|
175
|
+
|
176
|
+
# Merge overlapping text regions within the same line
|
177
|
+
merged_spans = merge_overlapping_spans(line_bbox_list)
|
178
|
+
|
179
|
+
# Convert the merged text regions back to point format and add them to the new detection box list
|
180
|
+
for span in merged_spans:
|
181
|
+
new_dt_boxes.append(bbox_to_points(span))
|
182
|
+
|
183
|
+
new_dt_boxes.extend(angle_boxes_list)
|
184
|
+
|
185
|
+
return new_dt_boxes
|
186
|
+
|
187
|
+
|
188
|
+
def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
|
189
|
+
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
190
|
+
# Adjust the coordinates of the formula area
|
191
|
+
adjusted_mfdetrec_res = []
|
192
|
+
for mf_res in single_page_mfdetrec_res:
|
193
|
+
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
|
194
|
+
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
|
195
|
+
x0 = mf_xmin - xmin + paste_x
|
196
|
+
y0 = mf_ymin - ymin + paste_y
|
197
|
+
x1 = mf_xmax - xmin + paste_x
|
198
|
+
y1 = mf_ymax - ymin + paste_y
|
199
|
+
# Filter formula blocks outside the graph
|
200
|
+
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
|
201
|
+
continue
|
202
|
+
else:
|
203
|
+
adjusted_mfdetrec_res.append({
|
204
|
+
"bbox": [x0, y0, x1, y1],
|
205
|
+
})
|
206
|
+
return adjusted_mfdetrec_res
|
207
|
+
|
208
|
+
|
209
|
+
def get_ocr_result_list(ocr_res, useful_list):
|
210
|
+
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
|
211
|
+
ocr_result_list = []
|
212
|
+
for box_ocr_res in ocr_res:
|
213
|
+
|
214
|
+
if len(box_ocr_res) == 2:
|
215
|
+
p1, p2, p3, p4 = box_ocr_res[0]
|
216
|
+
text, score = box_ocr_res[1]
|
217
|
+
else:
|
218
|
+
p1, p2, p3, p4 = box_ocr_res
|
219
|
+
text, score = "", 1
|
220
|
+
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
|
221
|
+
# if average_angle_degrees > 0.5:
|
222
|
+
poly = [p1, p2, p3, p4]
|
223
|
+
if calculate_is_angle(poly):
|
224
|
+
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
|
225
|
+
# 与x轴的夹角超过0.5度,对边界做一下矫正
|
226
|
+
# 计算几何中心
|
227
|
+
x_center = sum(point[0] for point in poly) / 4
|
228
|
+
y_center = sum(point[1] for point in poly) / 4
|
229
|
+
new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
|
230
|
+
new_width = p3[0] - p1[0]
|
231
|
+
p1 = [x_center - new_width / 2, y_center - new_height / 2]
|
232
|
+
p2 = [x_center + new_width / 2, y_center - new_height / 2]
|
233
|
+
p3 = [x_center + new_width / 2, y_center + new_height / 2]
|
234
|
+
p4 = [x_center - new_width / 2, y_center + new_height / 2]
|
235
|
+
|
236
|
+
# Convert the coordinates back to the original coordinate system
|
237
|
+
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
|
238
|
+
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
|
239
|
+
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
|
240
|
+
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
|
241
|
+
|
242
|
+
ocr_result_list.append({
|
243
|
+
'category_id': 15,
|
244
|
+
'poly': p1 + p2 + p3 + p4,
|
245
|
+
'score': float(round(score, 2)),
|
246
|
+
'text': text,
|
247
|
+
})
|
248
|
+
|
249
|
+
return ocr_result_list
|
250
|
+
|
251
|
+
|
252
|
+
def calculate_angle_degrees(poly):
|
253
|
+
# 定义对角线的顶点
|
254
|
+
diagonal1 = (poly[0], poly[2])
|
255
|
+
diagonal2 = (poly[1], poly[3])
|
256
|
+
|
257
|
+
# 计算对角线的斜率
|
258
|
+
def slope(p1, p2):
|
259
|
+
return (p2[1] - p1[1]) / (p2[0] - p1[0]) if p2[0] != p1[0] else float('inf')
|
260
|
+
|
261
|
+
slope1 = slope(diagonal1[0], diagonal1[1])
|
262
|
+
slope2 = slope(diagonal2[0], diagonal2[1])
|
263
|
+
|
264
|
+
# 计算对角线与x轴的夹角(以弧度为单位)
|
265
|
+
angle1_radians = math.atan(slope1)
|
266
|
+
angle2_radians = math.atan(slope2)
|
267
|
+
|
268
|
+
# 将弧度转换为角度
|
269
|
+
angle1_degrees = math.degrees(angle1_radians)
|
270
|
+
angle2_degrees = math.degrees(angle2_radians)
|
271
|
+
|
272
|
+
# 取两条对角线与x轴夹角的平均值
|
273
|
+
average_angle_degrees = abs((angle1_degrees + angle2_degrees) / 2)
|
274
|
+
# logger.info(f"average_angle_degrees: {average_angle_degrees}")
|
275
|
+
return average_angle_degrees
|
276
|
+
|
277
|
+
|
278
|
+
def calculate_is_angle(poly):
|
279
|
+
p1, p2, p3, p4 = poly
|
280
|
+
height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
|
281
|
+
if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
|
282
|
+
return False
|
283
|
+
else:
|
284
|
+
# logger.info((p3[1] - p1[1])/height)
|
285
|
+
return True
|