magic-pdf 0.9.2__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. magic_pdf/config/constants.py +53 -0
  2. magic_pdf/config/drop_reason.py +35 -0
  3. magic_pdf/config/drop_tag.py +19 -0
  4. magic_pdf/config/make_content_config.py +11 -0
  5. magic_pdf/{libs/ModelBlockTypeEnum.py → config/model_block_type.py} +2 -1
  6. magic_pdf/data/read_api.py +1 -1
  7. magic_pdf/dict2md/mkcontent.py +226 -185
  8. magic_pdf/dict2md/ocr_mkcontent.py +12 -12
  9. magic_pdf/filter/pdf_meta_scan.py +101 -79
  10. magic_pdf/integrations/rag/utils.py +4 -5
  11. magic_pdf/libs/config_reader.py +6 -6
  12. magic_pdf/libs/draw_bbox.py +13 -6
  13. magic_pdf/libs/pdf_image_tools.py +36 -12
  14. magic_pdf/libs/version.py +1 -1
  15. magic_pdf/model/doc_analyze_by_custom_model.py +2 -0
  16. magic_pdf/model/magic_model.py +13 -13
  17. magic_pdf/model/pdf_extract_kit.py +142 -351
  18. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
  19. magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
  20. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
  21. magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
  22. magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
  25. magic_pdf/model/sub_modules/model_init.py +149 -0
  26. magic_pdf/model/sub_modules/model_utils.py +51 -0
  27. magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
  28. magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
  29. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +285 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +176 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
  32. magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
  33. magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
  34. magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
  35. magic_pdf/model/sub_modules/table/__init__.py +0 -0
  36. magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
  37. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +16 -0
  38. magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
  39. magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
  40. magic_pdf/model/sub_modules/table/table_utils.py +11 -0
  41. magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
  42. magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +31 -29
  43. magic_pdf/para/para_split.py +411 -248
  44. magic_pdf/para/para_split_v2.py +352 -182
  45. magic_pdf/para/para_split_v3.py +121 -66
  46. magic_pdf/pdf_parse_by_ocr.py +2 -0
  47. magic_pdf/pdf_parse_by_txt.py +2 -0
  48. magic_pdf/pdf_parse_union_core.py +174 -100
  49. magic_pdf/pdf_parse_union_core_v2.py +253 -50
  50. magic_pdf/pipe/AbsPipe.py +28 -44
  51. magic_pdf/pipe/OCRPipe.py +5 -5
  52. magic_pdf/pipe/TXTPipe.py +5 -6
  53. magic_pdf/pipe/UNIPipe.py +24 -25
  54. magic_pdf/post_proc/pdf_post_filter.py +7 -14
  55. magic_pdf/pre_proc/cut_image.py +9 -11
  56. magic_pdf/pre_proc/equations_replace.py +203 -212
  57. magic_pdf/pre_proc/ocr_detect_all_bboxes.py +235 -49
  58. magic_pdf/pre_proc/ocr_dict_merge.py +5 -5
  59. magic_pdf/pre_proc/ocr_span_list_modify.py +122 -63
  60. magic_pdf/pre_proc/pdf_pre_filter.py +37 -33
  61. magic_pdf/pre_proc/remove_bbox_overlap.py +20 -18
  62. magic_pdf/pre_proc/remove_colored_strip_bbox.py +36 -14
  63. magic_pdf/pre_proc/remove_footer_header.py +2 -5
  64. magic_pdf/pre_proc/remove_rotate_bbox.py +111 -63
  65. magic_pdf/pre_proc/resolve_bbox_conflict.py +10 -17
  66. magic_pdf/resources/model_config/model_configs.yaml +2 -1
  67. magic_pdf/spark/spark_api.py +15 -17
  68. magic_pdf/tools/cli.py +3 -4
  69. magic_pdf/tools/cli_dev.py +6 -9
  70. magic_pdf/tools/common.py +70 -36
  71. magic_pdf/user_api.py +29 -38
  72. {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/METADATA +18 -13
  73. magic_pdf-0.10.0.dist-info/RECORD +198 -0
  74. {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/WHEEL +1 -1
  75. magic_pdf/libs/Constants.py +0 -53
  76. magic_pdf/libs/MakeContentConfig.py +0 -11
  77. magic_pdf/libs/drop_reason.py +0 -27
  78. magic_pdf/libs/drop_tag.py +0 -19
  79. magic_pdf/model/pek_sub_modules/post_process.py +0 -36
  80. magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
  81. magic_pdf/para/para_pipeline.py +0 -297
  82. magic_pdf-0.9.2.dist-info/RECORD +0 -178
  83. /magic_pdf/{libs → config}/ocr_content_type.py +0 -0
  84. /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
  85. /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
  86. /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
  87. /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
  88. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
  89. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
  90. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
  91. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
  92. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
  93. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
  94. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
  95. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
  96. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
  97. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
  98. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
  99. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
  100. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
  101. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
  102. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
  103. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
  104. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
  105. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
  106. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
  107. /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
  108. {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/LICENSE.md +0 -0
  109. {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/entry_points.txt +0 -0
  110. {magic_pdf-0.9.2.dist-info → magic_pdf-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,98 @@
1
+ import os
2
+ import argparse
3
+ import re
4
+
5
+ from PIL import Image
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torchvision import transforms
9
+ from unimernet.common.config import Config
10
+ import unimernet.tasks as tasks
11
+ from unimernet.processors import load_processor
12
+
13
+
14
+ class MathDataset(Dataset):
15
+ def __init__(self, image_paths, transform=None):
16
+ self.image_paths = image_paths
17
+ self.transform = transform
18
+
19
+ def __len__(self):
20
+ return len(self.image_paths)
21
+
22
+ def __getitem__(self, idx):
23
+ # if not pil image, then convert to pil image
24
+ if isinstance(self.image_paths[idx], str):
25
+ raw_image = Image.open(self.image_paths[idx])
26
+ else:
27
+ raw_image = self.image_paths[idx]
28
+ if self.transform:
29
+ image = self.transform(raw_image)
30
+ return image
31
+
32
+
33
+ def latex_rm_whitespace(s: str):
34
+ """Remove unnecessary whitespace from LaTeX code.
35
+ """
36
+ text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
37
+ letter = '[a-zA-Z]'
38
+ noletter = '[\W_^\d]'
39
+ names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
40
+ s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
41
+ news = s
42
+ while True:
43
+ s = news
44
+ news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
45
+ news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
46
+ news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
47
+ if news == s:
48
+ break
49
+ return s
50
+
51
+
52
+ class UnimernetModel(object):
53
+ def __init__(self, weight_dir, cfg_path, _device_='cpu'):
54
+
55
+ args = argparse.Namespace(cfg_path=cfg_path, options=None)
56
+ cfg = Config(args)
57
+ cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
58
+ cfg.config.model.model_config.model_name = weight_dir
59
+ cfg.config.model.tokenizer_config.path = weight_dir
60
+ task = tasks.setup_task(cfg)
61
+ self.model = task.build_model(cfg)
62
+ self.device = _device_
63
+ self.model.to(_device_)
64
+ self.model.eval()
65
+ vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
66
+ self.mfr_transform = transforms.Compose([vis_processor, ])
67
+
68
+ def predict(self, mfd_res, image):
69
+
70
+ formula_list = []
71
+ mf_image_list = []
72
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
73
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
74
+ new_item = {
75
+ 'category_id': 13 + int(cla.item()),
76
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
77
+ 'score': round(float(conf.item()), 2),
78
+ 'latex': '',
79
+ }
80
+ formula_list.append(new_item)
81
+ pil_img = Image.fromarray(image)
82
+ bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
83
+ mf_image_list.append(bbox_img)
84
+
85
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
86
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
87
+ mfr_res = []
88
+ for mf_img in dataloader:
89
+ mf_img = mf_img.to(self.device)
90
+ with torch.no_grad():
91
+ output = self.model.generate({'image': mf_img})
92
+ mfr_res.extend(output['pred_str'])
93
+ for res, latex in zip(formula_list, mfr_res):
94
+ res['latex'] = latex_rm_whitespace(latex)
95
+ return formula_list
96
+
97
+
98
+
File without changes
@@ -0,0 +1,149 @@
1
+ from loguru import logger
2
+
3
+ from magic_pdf.config.constants import MODEL_NAME
4
+ from magic_pdf.model.model_list import AtomicModel
5
+ from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
6
+ DocLayoutYOLOModel
7
+ from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
8
+ Layoutlmv3_Predictor
9
+ from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
10
+ from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
11
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
12
+ ModifiedPaddleOCR
13
+ from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
14
+ RapidTableModel
15
+ # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
16
+ from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
17
+ StructTableModel
18
+ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
19
+ TableMasterPaddleModel
20
+
21
+
22
+ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
23
+ if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
24
+ table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
25
+ elif table_model_type == MODEL_NAME.TABLE_MASTER:
26
+ config = {
27
+ 'model_dir': model_path,
28
+ 'device': _device_
29
+ }
30
+ table_model = TableMasterPaddleModel(config)
31
+ elif table_model_type == MODEL_NAME.RAPID_TABLE:
32
+ table_model = RapidTableModel()
33
+ else:
34
+ logger.error('table model type not allow')
35
+ exit(1)
36
+
37
+ return table_model
38
+
39
+
40
+ def mfd_model_init(weight, device='cpu'):
41
+ mfd_model = YOLOv8MFDModel(weight, device)
42
+ return mfd_model
43
+
44
+
45
+ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
46
+ mfr_model = UnimernetModel(weight_dir, cfg_path, device)
47
+ return mfr_model
48
+
49
+
50
+ def layout_model_init(weight, config_file, device):
51
+ model = Layoutlmv3_Predictor(weight, config_file, device)
52
+ return model
53
+
54
+
55
+ def doclayout_yolo_model_init(weight, device='cpu'):
56
+ model = DocLayoutYOLOModel(weight, device)
57
+ return model
58
+
59
+
60
+ def ocr_model_init(show_log: bool = False,
61
+ det_db_box_thresh=0.3,
62
+ lang=None,
63
+ use_dilation=True,
64
+ det_db_unclip_ratio=1.8,
65
+ ):
66
+ if lang is not None and lang != '':
67
+ model = ModifiedPaddleOCR(
68
+ show_log=show_log,
69
+ det_db_box_thresh=det_db_box_thresh,
70
+ lang=lang,
71
+ use_dilation=use_dilation,
72
+ det_db_unclip_ratio=det_db_unclip_ratio,
73
+ )
74
+ else:
75
+ model = ModifiedPaddleOCR(
76
+ show_log=show_log,
77
+ det_db_box_thresh=det_db_box_thresh,
78
+ use_dilation=use_dilation,
79
+ det_db_unclip_ratio=det_db_unclip_ratio,
80
+ # use_angle_cls=True,
81
+ )
82
+ return model
83
+
84
+
85
+ class AtomModelSingleton:
86
+ _instance = None
87
+ _models = {}
88
+
89
+ def __new__(cls, *args, **kwargs):
90
+ if cls._instance is None:
91
+ cls._instance = super().__new__(cls)
92
+ return cls._instance
93
+
94
+ def get_atom_model(self, atom_model_name: str, **kwargs):
95
+ lang = kwargs.get('lang', None)
96
+ layout_model_name = kwargs.get('layout_model_name', None)
97
+ key = (atom_model_name, layout_model_name, lang)
98
+ if key not in self._models:
99
+ self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
100
+ return self._models[key]
101
+
102
+
103
+ def atom_model_init(model_name: str, **kwargs):
104
+ atom_model = None
105
+ if model_name == AtomicModel.Layout:
106
+ if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
107
+ atom_model = layout_model_init(
108
+ kwargs.get('layout_weights'),
109
+ kwargs.get('layout_config_file'),
110
+ kwargs.get('device')
111
+ )
112
+ elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
113
+ atom_model = doclayout_yolo_model_init(
114
+ kwargs.get('doclayout_yolo_weights'),
115
+ kwargs.get('device')
116
+ )
117
+ elif model_name == AtomicModel.MFD:
118
+ atom_model = mfd_model_init(
119
+ kwargs.get('mfd_weights'),
120
+ kwargs.get('device')
121
+ )
122
+ elif model_name == AtomicModel.MFR:
123
+ atom_model = mfr_model_init(
124
+ kwargs.get('mfr_weight_dir'),
125
+ kwargs.get('mfr_cfg_path'),
126
+ kwargs.get('device')
127
+ )
128
+ elif model_name == AtomicModel.OCR:
129
+ atom_model = ocr_model_init(
130
+ kwargs.get('ocr_show_log'),
131
+ kwargs.get('det_db_box_thresh'),
132
+ kwargs.get('lang')
133
+ )
134
+ elif model_name == AtomicModel.Table:
135
+ atom_model = table_model_init(
136
+ kwargs.get('table_model_name'),
137
+ kwargs.get('table_model_path'),
138
+ kwargs.get('table_max_time'),
139
+ kwargs.get('device')
140
+ )
141
+ else:
142
+ logger.error('model name not allow')
143
+ exit(1)
144
+
145
+ if atom_model is None:
146
+ logger.error('model init failed')
147
+ exit(1)
148
+ else:
149
+ return atom_model
@@ -0,0 +1,51 @@
1
+ import time
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from loguru import logger
6
+
7
+ from magic_pdf.libs.clean_memory import clean_memory
8
+
9
+
10
+ def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
11
+ crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
12
+ crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
13
+ # Create a white background with an additional width and height of 50
14
+ crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
15
+ crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
16
+ return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
17
+
18
+ # Crop image
19
+ crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
20
+ cropped_img = input_pil_img.crop(crop_box)
21
+ return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
22
+ return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
23
+ return return_image, return_list
24
+
25
+
26
+ # Select regions for OCR / formula regions / table regions
27
+ def get_res_list_from_layout_res(layout_res):
28
+ ocr_res_list = []
29
+ table_res_list = []
30
+ single_page_mfdetrec_res = []
31
+ for res in layout_res:
32
+ if int(res['category_id']) in [13, 14]:
33
+ single_page_mfdetrec_res.append({
34
+ "bbox": [int(res['poly'][0]), int(res['poly'][1]),
35
+ int(res['poly'][4]), int(res['poly'][5])],
36
+ })
37
+ elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
38
+ ocr_res_list.append(res)
39
+ elif int(res['category_id']) in [5]:
40
+ table_res_list.append(res)
41
+ return ocr_res_list, table_res_list, single_page_mfdetrec_res
42
+
43
+
44
+ def clean_vram(device, vram_threshold=8):
45
+ if torch.cuda.is_available() and device != 'cpu':
46
+ total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
47
+ if total_memory <= vram_threshold:
48
+ gc_start = time.time()
49
+ clean_memory()
50
+ gc_time = round(time.time() - gc_start, 2)
51
+ logger.info(f"gc time: {gc_time}")
File without changes
File without changes
@@ -0,0 +1,285 @@
1
+ import math
2
+
3
+ import numpy as np
4
+ from loguru import logger
5
+
6
+ from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
7
+ from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
8
+
9
+
10
+ def bbox_to_points(bbox):
11
+ """ 将bbox格式转换为四个顶点的数组 """
12
+ x0, y0, x1, y1 = bbox
13
+ return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
14
+
15
+
16
+ def points_to_bbox(points):
17
+ """ 将四个顶点的数组转换为bbox格式 """
18
+ x0, y0 = points[0]
19
+ x1, _ = points[1]
20
+ _, y1 = points[2]
21
+ return [x0, y0, x1, y1]
22
+
23
+
24
+ def merge_intervals(intervals):
25
+ # Sort the intervals based on the start value
26
+ intervals.sort(key=lambda x: x[0])
27
+
28
+ merged = []
29
+ for interval in intervals:
30
+ # If the list of merged intervals is empty or if the current
31
+ # interval does not overlap with the previous, simply append it.
32
+ if not merged or merged[-1][1] < interval[0]:
33
+ merged.append(interval)
34
+ else:
35
+ # Otherwise, there is overlap, so we merge the current and previous intervals.
36
+ merged[-1][1] = max(merged[-1][1], interval[1])
37
+
38
+ return merged
39
+
40
+
41
+ def remove_intervals(original, masks):
42
+ # Merge all mask intervals
43
+ merged_masks = merge_intervals(masks)
44
+
45
+ result = []
46
+ original_start, original_end = original
47
+
48
+ for mask in merged_masks:
49
+ mask_start, mask_end = mask
50
+
51
+ # If the mask starts after the original range, ignore it
52
+ if mask_start > original_end:
53
+ continue
54
+
55
+ # If the mask ends before the original range starts, ignore it
56
+ if mask_end < original_start:
57
+ continue
58
+
59
+ # Remove the masked part from the original range
60
+ if original_start < mask_start:
61
+ result.append([original_start, mask_start - 1])
62
+
63
+ original_start = max(mask_end + 1, original_start)
64
+
65
+ # Add the remaining part of the original range, if any
66
+ if original_start <= original_end:
67
+ result.append([original_start, original_end])
68
+
69
+ return result
70
+
71
+
72
+ def update_det_boxes(dt_boxes, mfd_res):
73
+ new_dt_boxes = []
74
+ angle_boxes_list = []
75
+ for text_box in dt_boxes:
76
+
77
+ if calculate_is_angle(text_box):
78
+ angle_boxes_list.append(text_box)
79
+ continue
80
+
81
+ text_bbox = points_to_bbox(text_box)
82
+ masks_list = []
83
+ for mf_box in mfd_res:
84
+ mf_bbox = mf_box['bbox']
85
+ if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
86
+ masks_list.append([mf_bbox[0], mf_bbox[2]])
87
+ text_x_range = [text_bbox[0], text_bbox[2]]
88
+ text_remove_mask_range = remove_intervals(text_x_range, masks_list)
89
+ temp_dt_box = []
90
+ for text_remove_mask in text_remove_mask_range:
91
+ temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
92
+ if len(temp_dt_box) > 0:
93
+ new_dt_boxes.extend(temp_dt_box)
94
+
95
+ new_dt_boxes.extend(angle_boxes_list)
96
+
97
+ return new_dt_boxes
98
+
99
+
100
+ def merge_overlapping_spans(spans):
101
+ """
102
+ Merges overlapping spans on the same line.
103
+
104
+ :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
105
+ :return: A list of merged spans
106
+ """
107
+ # Return an empty list if the input spans list is empty
108
+ if not spans:
109
+ return []
110
+
111
+ # Sort spans by their starting x-coordinate
112
+ spans.sort(key=lambda x: x[0])
113
+
114
+ # Initialize the list of merged spans
115
+ merged = []
116
+ for span in spans:
117
+ # Unpack span coordinates
118
+ x1, y1, x2, y2 = span
119
+ # If the merged list is empty or there's no horizontal overlap, add the span directly
120
+ if not merged or merged[-1][2] < x1:
121
+ merged.append(span)
122
+ else:
123
+ # If there is horizontal overlap, merge the current span with the previous one
124
+ last_span = merged.pop()
125
+ # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
126
+ x1 = min(last_span[0], x1)
127
+ y1 = min(last_span[1], y1)
128
+ x2 = max(last_span[2], x2)
129
+ y2 = max(last_span[3], y2)
130
+ # Add the merged span back to the list
131
+ merged.append((x1, y1, x2, y2))
132
+
133
+ # Return the list of merged spans
134
+ return merged
135
+
136
+
137
+ def merge_det_boxes(dt_boxes):
138
+ """
139
+ Merge detection boxes.
140
+
141
+ This function takes a list of detected bounding boxes, each represented by four corner points.
142
+ The goal is to merge these bounding boxes into larger text regions.
143
+
144
+ Parameters:
145
+ dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
146
+
147
+ Returns:
148
+ list: A list containing the merged text regions, where each region is represented by four corner points.
149
+ """
150
+ # Convert the detection boxes into a dictionary format with bounding boxes and type
151
+ dt_boxes_dict_list = []
152
+ angle_boxes_list = []
153
+ for text_box in dt_boxes:
154
+ text_bbox = points_to_bbox(text_box)
155
+
156
+ if calculate_is_angle(text_box):
157
+ angle_boxes_list.append(text_box)
158
+ continue
159
+
160
+ text_box_dict = {
161
+ 'bbox': text_bbox,
162
+ 'type': 'text',
163
+ }
164
+ dt_boxes_dict_list.append(text_box_dict)
165
+
166
+ # Merge adjacent text regions into lines
167
+ lines = merge_spans_to_line(dt_boxes_dict_list)
168
+
169
+ # Initialize a new list for storing the merged text regions
170
+ new_dt_boxes = []
171
+ for line in lines:
172
+ line_bbox_list = []
173
+ for span in line:
174
+ line_bbox_list.append(span['bbox'])
175
+
176
+ # Merge overlapping text regions within the same line
177
+ merged_spans = merge_overlapping_spans(line_bbox_list)
178
+
179
+ # Convert the merged text regions back to point format and add them to the new detection box list
180
+ for span in merged_spans:
181
+ new_dt_boxes.append(bbox_to_points(span))
182
+
183
+ new_dt_boxes.extend(angle_boxes_list)
184
+
185
+ return new_dt_boxes
186
+
187
+
188
+ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
189
+ paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
190
+ # Adjust the coordinates of the formula area
191
+ adjusted_mfdetrec_res = []
192
+ for mf_res in single_page_mfdetrec_res:
193
+ mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
194
+ # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
195
+ x0 = mf_xmin - xmin + paste_x
196
+ y0 = mf_ymin - ymin + paste_y
197
+ x1 = mf_xmax - xmin + paste_x
198
+ y1 = mf_ymax - ymin + paste_y
199
+ # Filter formula blocks outside the graph
200
+ if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
201
+ continue
202
+ else:
203
+ adjusted_mfdetrec_res.append({
204
+ "bbox": [x0, y0, x1, y1],
205
+ })
206
+ return adjusted_mfdetrec_res
207
+
208
+
209
+ def get_ocr_result_list(ocr_res, useful_list):
210
+ paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
211
+ ocr_result_list = []
212
+ for box_ocr_res in ocr_res:
213
+
214
+ if len(box_ocr_res) == 2:
215
+ p1, p2, p3, p4 = box_ocr_res[0]
216
+ text, score = box_ocr_res[1]
217
+ else:
218
+ p1, p2, p3, p4 = box_ocr_res
219
+ text, score = "", 1
220
+ # average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
221
+ # if average_angle_degrees > 0.5:
222
+ poly = [p1, p2, p3, p4]
223
+ if calculate_is_angle(poly):
224
+ # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
225
+ # 与x轴的夹角超过0.5度,对边界做一下矫正
226
+ # 计算几何中心
227
+ x_center = sum(point[0] for point in poly) / 4
228
+ y_center = sum(point[1] for point in poly) / 4
229
+ new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
230
+ new_width = p3[0] - p1[0]
231
+ p1 = [x_center - new_width / 2, y_center - new_height / 2]
232
+ p2 = [x_center + new_width / 2, y_center - new_height / 2]
233
+ p3 = [x_center + new_width / 2, y_center + new_height / 2]
234
+ p4 = [x_center - new_width / 2, y_center + new_height / 2]
235
+
236
+ # Convert the coordinates back to the original coordinate system
237
+ p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
238
+ p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
239
+ p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
240
+ p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
241
+
242
+ ocr_result_list.append({
243
+ 'category_id': 15,
244
+ 'poly': p1 + p2 + p3 + p4,
245
+ 'score': float(round(score, 2)),
246
+ 'text': text,
247
+ })
248
+
249
+ return ocr_result_list
250
+
251
+
252
+ def calculate_angle_degrees(poly):
253
+ # 定义对角线的顶点
254
+ diagonal1 = (poly[0], poly[2])
255
+ diagonal2 = (poly[1], poly[3])
256
+
257
+ # 计算对角线的斜率
258
+ def slope(p1, p2):
259
+ return (p2[1] - p1[1]) / (p2[0] - p1[0]) if p2[0] != p1[0] else float('inf')
260
+
261
+ slope1 = slope(diagonal1[0], diagonal1[1])
262
+ slope2 = slope(diagonal2[0], diagonal2[1])
263
+
264
+ # 计算对角线与x轴的夹角(以弧度为单位)
265
+ angle1_radians = math.atan(slope1)
266
+ angle2_radians = math.atan(slope2)
267
+
268
+ # 将弧度转换为角度
269
+ angle1_degrees = math.degrees(angle1_radians)
270
+ angle2_degrees = math.degrees(angle2_radians)
271
+
272
+ # 取两条对角线与x轴夹角的平均值
273
+ average_angle_degrees = abs((angle1_degrees + angle2_degrees) / 2)
274
+ # logger.info(f"average_angle_degrees: {average_angle_degrees}")
275
+ return average_angle_degrees
276
+
277
+
278
+ def calculate_is_angle(poly):
279
+ p1, p2, p3, p4 = poly
280
+ height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
281
+ if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
282
+ return False
283
+ else:
284
+ # logger.info((p3[1] - p1[1])/height)
285
+ return True