magic-pdf 0.9.1__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +1 -1
  2. magic_pdf/libs/Constants.py +3 -1
  3. magic_pdf/libs/config_reader.py +1 -1
  4. magic_pdf/libs/draw_bbox.py +10 -4
  5. magic_pdf/libs/version.py +1 -1
  6. magic_pdf/model/pdf_extract_kit.py +42 -310
  7. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
  8. magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
  9. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
  10. magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
  11. magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
  12. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
  13. magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
  14. magic_pdf/model/sub_modules/model_init.py +144 -0
  15. magic_pdf/model/sub_modules/model_utils.py +51 -0
  16. magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
  17. magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
  18. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +259 -0
  19. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +168 -0
  20. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
  21. magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
  22. magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
  23. magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
  24. magic_pdf/model/sub_modules/table/__init__.py +0 -0
  25. magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
  26. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +14 -0
  27. magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
  28. magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
  29. magic_pdf/model/sub_modules/table/table_utils.py +11 -0
  30. magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
  31. magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +1 -1
  32. magic_pdf/para/para_split_v3.py +13 -15
  33. magic_pdf/pdf_parse_union_core_v2.py +56 -19
  34. magic_pdf/resources/model_config/model_configs.yaml +2 -1
  35. magic_pdf/tools/common.py +47 -3
  36. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/METADATA +35 -25
  37. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/RECORD +65 -44
  38. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/WHEEL +1 -1
  39. magic_pdf/model/pek_sub_modules/post_process.py +0 -36
  40. magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
  41. /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
  42. /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
  43. /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
  44. /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
  45. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
  46. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
  47. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
  48. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
  49. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
  50. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
  51. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
  52. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
  53. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
  54. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
  55. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
  56. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
  57. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
  58. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
  59. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
  60. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
  61. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
  62. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
  63. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
  64. /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
  65. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/LICENSE.md +0 -0
  66. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/entry_points.txt +0 -0
  67. {magic_pdf-0.9.1.dist-info → magic_pdf-0.9.3.dist-info}/top_level.txt +0 -0
@@ -168,7 +168,7 @@ def merge_para_with_text(para_block):
168
168
  # 如果是前一行带有-连字符,那么末尾不应该加空格
169
169
  if __is_hyphen_at_line_end(content):
170
170
  para_text += content[:-1]
171
- elif len(content) == 1 and content not in ['A', 'I', 'a', 'i']:
171
+ elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
172
172
  para_text += content
173
173
  else: # 西方文本语境下 content间需要空格分隔
174
174
  para_text += f"{content} "
@@ -50,4 +50,6 @@ class MODEL_NAME:
50
50
 
51
51
  YOLO_V8_MFD = "yolo_v8_mfd"
52
52
 
53
- UniMerNet_v2_Small = "unimernet_small"
53
+ UniMerNet_v2_Small = "unimernet_small"
54
+
55
+ RAPID_TABLE = "rapid_table"
@@ -92,7 +92,7 @@ def get_table_recog_config():
92
92
  table_config = config.get('table-config')
93
93
  if table_config is None:
94
94
  logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
95
- return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
95
+ return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
96
96
  else:
97
97
  return table_config
98
98
 
@@ -369,10 +369,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
369
369
  if block['type'] in [BlockType.Image, BlockType.Table]:
370
370
  for sub_block in block['blocks']:
371
371
  if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
372
- for line in sub_block['virtual_lines']:
373
- bbox = line['bbox']
374
- index = line['index']
375
- page_line_list.append({'index': index, 'bbox': bbox})
372
+ if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
373
+ for line in sub_block['virtual_lines']:
374
+ bbox = line['bbox']
375
+ index = line['index']
376
+ page_line_list.append({'index': index, 'bbox': bbox})
377
+ else:
378
+ for line in sub_block['lines']:
379
+ bbox = line['bbox']
380
+ index = line['index']
381
+ page_line_list.append({'index': index, 'bbox': bbox})
376
382
  elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
377
383
  for line in sub_block['lines']:
378
384
  bbox = line['bbox']
magic_pdf/libs/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.9.1"
1
+ __version__ = "0.9.3"
@@ -1,195 +1,28 @@
1
+ import numpy as np
2
+ import torch
1
3
  from loguru import logger
2
4
  import os
3
5
  import time
4
- from pathlib import Path
5
- import shutil
6
- from magic_pdf.libs.Constants import *
7
- from magic_pdf.libs.clean_memory import clean_memory
8
- from magic_pdf.model.model_list import AtomicModel
6
+ import cv2
7
+ import yaml
8
+ from PIL import Image
9
9
 
10
10
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
11
11
  os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
12
+
12
13
  try:
13
- import cv2
14
- import yaml
15
- import argparse
16
- import numpy as np
17
- import torch
18
14
  import torchtext
19
15
 
20
16
  if torchtext.__version__ >= "0.18.0":
21
17
  torchtext.disable_torchtext_deprecation_warning()
22
- from PIL import Image
23
- from torchvision import transforms
24
- from torch.utils.data import Dataset, DataLoader
25
- from ultralytics import YOLO
26
- from unimernet.common.config import Config
27
- import unimernet.tasks as tasks
28
- from unimernet.processors import load_processor
29
- from doclayout_yolo import YOLOv10
30
-
31
- except ImportError as e:
32
- logger.exception(e)
33
- logger.error(
34
- 'Required dependency not installed, please install by \n'
35
- '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
36
- exit(1)
37
-
38
- from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
39
- from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
40
- from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
41
- from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
42
- from magic_pdf.model.ppTableModel import ppTableModel
43
-
44
-
45
- def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
46
- if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
47
- table_model = StructTableModel(model_path, max_time=max_time)
48
- elif table_model_type == MODEL_NAME.TABLE_MASTER:
49
- config = {
50
- "model_dir": model_path,
51
- "device": _device_
52
- }
53
- table_model = ppTableModel(config)
54
- else:
55
- logger.error("table model type not allow")
56
- exit(1)
57
- return table_model
58
-
59
-
60
- def mfd_model_init(weight):
61
- mfd_model = YOLO(weight)
62
- return mfd_model
63
-
64
-
65
- def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
66
- args = argparse.Namespace(cfg_path=cfg_path, options=None)
67
- cfg = Config(args)
68
- cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
69
- cfg.config.model.model_config.model_name = weight_dir
70
- cfg.config.model.tokenizer_config.path = weight_dir
71
- task = tasks.setup_task(cfg)
72
- model = task.build_model(cfg)
73
- model.to(_device_)
74
- model.eval()
75
- vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
76
- mfr_transform = transforms.Compose([vis_processor, ])
77
- return [model, mfr_transform]
78
-
79
-
80
- def layout_model_init(weight, config_file, device):
81
- model = Layoutlmv3_Predictor(weight, config_file, device)
82
- return model
83
-
84
-
85
- def doclayout_yolo_model_init(weight):
86
- model = YOLOv10(weight)
87
- return model
88
-
89
-
90
- def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
91
- if lang is not None:
92
- model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
93
- else:
94
- model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
95
- return model
96
-
97
-
98
- class MathDataset(Dataset):
99
- def __init__(self, image_paths, transform=None):
100
- self.image_paths = image_paths
101
- self.transform = transform
102
-
103
- def __len__(self):
104
- return len(self.image_paths)
105
-
106
- def __getitem__(self, idx):
107
- # if not pil image, then convert to pil image
108
- if isinstance(self.image_paths[idx], str):
109
- raw_image = Image.open(self.image_paths[idx])
110
- else:
111
- raw_image = self.image_paths[idx]
112
- if self.transform:
113
- image = self.transform(raw_image)
114
- return image
115
-
116
-
117
- class AtomModelSingleton:
118
- _instance = None
119
- _models = {}
120
-
121
- def __new__(cls, *args, **kwargs):
122
- if cls._instance is None:
123
- cls._instance = super().__new__(cls)
124
- return cls._instance
125
-
126
- def get_atom_model(self, atom_model_name: str, **kwargs):
127
- lang = kwargs.get("lang", None)
128
- layout_model_name = kwargs.get("layout_model_name", None)
129
- key = (atom_model_name, layout_model_name, lang)
130
- if key not in self._models:
131
- self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
132
- return self._models[key]
133
-
134
-
135
- def atom_model_init(model_name: str, **kwargs):
136
-
137
- if model_name == AtomicModel.Layout:
138
- if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
139
- atom_model = layout_model_init(
140
- kwargs.get("layout_weights"),
141
- kwargs.get("layout_config_file"),
142
- kwargs.get("device")
143
- )
144
- elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
145
- atom_model = doclayout_yolo_model_init(
146
- kwargs.get("doclayout_yolo_weights"),
147
- )
148
- elif model_name == AtomicModel.MFD:
149
- atom_model = mfd_model_init(
150
- kwargs.get("mfd_weights")
151
- )
152
- elif model_name == AtomicModel.MFR:
153
- atom_model = mfr_model_init(
154
- kwargs.get("mfr_weight_dir"),
155
- kwargs.get("mfr_cfg_path"),
156
- kwargs.get("device")
157
- )
158
- elif model_name == AtomicModel.OCR:
159
- atom_model = ocr_model_init(
160
- kwargs.get("ocr_show_log"),
161
- kwargs.get("det_db_box_thresh"),
162
- kwargs.get("lang")
163
- )
164
- elif model_name == AtomicModel.Table:
165
- atom_model = table_model_init(
166
- kwargs.get("table_model_name"),
167
- kwargs.get("table_model_path"),
168
- kwargs.get("table_max_time"),
169
- kwargs.get("device")
170
- )
171
- else:
172
- logger.error("model name not allow")
173
- exit(1)
174
-
175
- return atom_model
176
-
177
-
178
- # Unified crop img logic
179
- def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
180
- crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
181
- crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
182
- # Create a white background with an additional width and height of 50
183
- crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
184
- crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
185
- return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
18
+ except ImportError:
19
+ pass
186
20
 
187
- # Crop image
188
- crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
189
- cropped_img = input_pil_img.crop(crop_box)
190
- return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
191
- return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
192
- return return_image, return_list
21
+ from magic_pdf.libs.Constants import *
22
+ from magic_pdf.model.model_list import AtomicModel
23
+ from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
24
+ from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
25
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
193
26
 
194
27
 
195
28
  class CustomPEKModel:
@@ -226,7 +59,7 @@ class CustomPEKModel:
226
59
  self.table_config = kwargs.get("table_config")
227
60
  self.apply_table = self.table_config.get("enable", False)
228
61
  self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
229
- self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
62
+ self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
230
63
 
231
64
  # ocr config
232
65
  self.apply_ocr = ocr
@@ -235,7 +68,8 @@ class CustomPEKModel:
235
68
  logger.info(
236
69
  "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
237
70
  "apply_table: {}, table_model: {}, lang: {}".format(
238
- self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
71
+ self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
72
+ self.lang
239
73
  )
240
74
  )
241
75
  # 初始化解析方案
@@ -248,17 +82,17 @@ class CustomPEKModel:
248
82
 
249
83
  # 初始化公式识别
250
84
  if self.apply_formula:
251
-
252
85
  # 初始化公式检测模型
253
86
  self.mfd_model = atom_model_manager.get_atom_model(
254
87
  atom_model_name=AtomicModel.MFD,
255
- mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
88
+ mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
89
+ device=self.device
256
90
  )
257
91
 
258
92
  # 初始化公式解析模型
259
93
  mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
260
94
  mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
261
- self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
95
+ self.mfr_model = atom_model_manager.get_atom_model(
262
96
  atom_model_name=AtomicModel.MFR,
263
97
  mfr_weight_dir=mfr_weight_dir,
264
98
  mfr_cfg_path=mfr_cfg_path,
@@ -278,12 +112,11 @@ class CustomPEKModel:
278
112
  self.layout_model = atom_model_manager.get_atom_model(
279
113
  atom_model_name=AtomicModel.Layout,
280
114
  layout_model_name=MODEL_NAME.DocLayout_YOLO,
281
- doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
115
+ doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
116
+ device=self.device
282
117
  )
283
118
  # 初始化ocr
284
119
  if self.apply_ocr:
285
-
286
- # self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
287
120
  self.ocr_model = atom_model_manager.get_atom_model(
288
121
  atom_model_name=AtomicModel.OCR,
289
122
  ocr_show_log=show_log,
@@ -301,43 +134,21 @@ class CustomPEKModel:
301
134
  device=self.device
302
135
  )
303
136
 
304
- home_directory = Path.home()
305
- det_source = os.path.join(models_dir, table_model_dir, DETECT_MODEL_DIR)
306
- rec_source = os.path.join(models_dir, table_model_dir, REC_MODEL_DIR)
307
- det_dest_dir = os.path.join(home_directory, PP_DET_DIRECTORY)
308
- rec_dest_dir = os.path.join(home_directory, PP_REC_DIRECTORY)
309
-
310
- if not os.path.exists(det_dest_dir):
311
- shutil.copytree(det_source, det_dest_dir)
312
- if not os.path.exists(rec_dest_dir):
313
- shutil.copytree(rec_source, rec_dest_dir)
314
-
315
137
  logger.info('DocAnalysis init done!')
316
138
 
317
139
  def __call__(self, image):
318
140
 
319
141
  page_start = time.time()
320
142
 
321
- latex_filling_list = []
322
- mf_image_list = []
323
-
324
143
  # layout检测
325
144
  layout_start = time.time()
145
+ layout_res = []
326
146
  if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
327
147
  # layoutlmv3
328
148
  layout_res = self.layout_model(image, ignore_catids=[])
329
149
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
330
150
  # doclayout_yolo
331
- layout_res = []
332
- doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
333
- for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
334
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
335
- new_item = {
336
- 'category_id': int(cla.item()),
337
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
338
- 'score': round(float(conf.item()), 3),
339
- }
340
- layout_res.append(new_item)
151
+ layout_res = self.layout_model.predict(image)
341
152
  layout_cost = round(time.time() - layout_start, 2)
342
153
  logger.info(f"layout detection time: {layout_cost}")
343
154
 
@@ -346,59 +157,21 @@ class CustomPEKModel:
346
157
  if self.apply_formula:
347
158
  # 公式检测
348
159
  mfd_start = time.time()
349
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
160
+ mfd_res = self.mfd_model.predict(image)
350
161
  logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
351
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
352
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
353
- new_item = {
354
- 'category_id': 13 + int(cla.item()),
355
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
356
- 'score': round(float(conf.item()), 2),
357
- 'latex': '',
358
- }
359
- layout_res.append(new_item)
360
- latex_filling_list.append(new_item)
361
- bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
362
- mf_image_list.append(bbox_img)
363
162
 
364
163
  # 公式识别
365
164
  mfr_start = time.time()
366
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
367
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
368
- mfr_res = []
369
- for mf_img in dataloader:
370
- mf_img = mf_img.to(self.device)
371
- with torch.no_grad():
372
- output = self.mfr_model.generate({'image': mf_img})
373
- mfr_res.extend(output['pred_str'])
374
- for res, latex in zip(latex_filling_list, mfr_res):
375
- res['latex'] = latex_rm_whitespace(latex)
165
+ formula_list = self.mfr_model.predict(mfd_res, image)
166
+ layout_res.extend(formula_list)
376
167
  mfr_cost = round(time.time() - mfr_start, 2)
377
- logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
378
-
379
- # Select regions for OCR / formula regions / table regions
380
- ocr_res_list = []
381
- table_res_list = []
382
- single_page_mfdetrec_res = []
383
- for res in layout_res:
384
- if int(res['category_id']) in [13, 14]:
385
- single_page_mfdetrec_res.append({
386
- "bbox": [int(res['poly'][0]), int(res['poly'][1]),
387
- int(res['poly'][4]), int(res['poly'][5])],
388
- })
389
- elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
390
- ocr_res_list.append(res)
391
- elif int(res['category_id']) in [5]:
392
- table_res_list.append(res)
393
-
394
- if torch.cuda.is_available() and self.device != 'cpu':
395
- properties = torch.cuda.get_device_properties(self.device)
396
- total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
397
- if total_memory <= 10:
398
- gc_start = time.time()
399
- clean_memory()
400
- gc_time = round(time.time() - gc_start, 2)
401
- logger.info(f"gc time: {gc_time}")
168
+ logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
169
+
170
+ # 清理显存
171
+ clean_vram(self.device, vram_threshold=8)
172
+
173
+ # 从layout_res中获取ocr区域、表格区域、公式区域
174
+ ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
402
175
 
403
176
  # ocr识别
404
177
  if self.apply_ocr:
@@ -406,23 +179,7 @@ class CustomPEKModel:
406
179
  # Process each area that requires OCR processing
407
180
  for res in ocr_res_list:
408
181
  new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
409
- paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
410
- # Adjust the coordinates of the formula area
411
- adjusted_mfdetrec_res = []
412
- for mf_res in single_page_mfdetrec_res:
413
- mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
414
- # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
415
- x0 = mf_xmin - xmin + paste_x
416
- y0 = mf_ymin - ymin + paste_y
417
- x1 = mf_xmax - xmin + paste_x
418
- y1 = mf_ymax - ymin + paste_y
419
- # Filter formula blocks outside the graph
420
- if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
421
- continue
422
- else:
423
- adjusted_mfdetrec_res.append({
424
- "bbox": [x0, y0, x1, y1],
425
- })
182
+ adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
426
183
 
427
184
  # OCR recognition
428
185
  new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
@@ -430,22 +187,8 @@ class CustomPEKModel:
430
187
 
431
188
  # Integration results
432
189
  if ocr_res:
433
- for box_ocr_res in ocr_res:
434
- p1, p2, p3, p4 = box_ocr_res[0]
435
- text, score = box_ocr_res[1]
436
-
437
- # Convert the coordinates back to the original coordinate system
438
- p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
439
- p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
440
- p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
441
- p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
442
-
443
- layout_res.append({
444
- 'category_id': 15,
445
- 'poly': p1 + p2 + p3 + p4,
446
- 'score': round(score, 2),
447
- 'text': text,
448
- })
190
+ ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
191
+ layout_res.extend(ocr_result_list)
449
192
 
450
193
  ocr_cost = round(time.time() - ocr_start, 2)
451
194
  logger.info(f"ocr time: {ocr_cost}")
@@ -456,41 +199,30 @@ class CustomPEKModel:
456
199
  for res in table_res_list:
457
200
  new_image, _ = crop_img(res, pil_img)
458
201
  single_table_start_time = time.time()
459
- # logger.info("------------------table recognition processing begins-----------------")
460
- latex_code = None
461
202
  html_code = None
462
203
  if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
463
204
  with torch.no_grad():
464
205
  table_result = self.table_model.predict(new_image, "html")
465
206
  if len(table_result) > 0:
466
207
  html_code = table_result[0]
467
- else:
208
+ elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
468
209
  html_code = self.table_model.img2html(new_image)
469
-
210
+ elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
211
+ html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
470
212
  run_time = time.time() - single_table_start_time
471
- # logger.info(f"------------table recognition processing ends within {run_time}s-----")
472
213
  if run_time > self.table_max_time:
473
- logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
214
+ logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
474
215
  # 判断是否返回正常
475
-
476
- if latex_code:
477
- expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
478
- if expected_ending:
479
- res["latex"] = latex_code
480
- else:
481
- logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
482
- elif html_code:
216
+ if html_code:
483
217
  expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
484
218
  if expected_ending:
485
219
  res["html"] = html_code
486
220
  else:
487
221
  logger.warning(f"table recognition processing fails, not found expected HTML table end")
488
222
  else:
489
- logger.warning(f"table recognition processing fails, not get latex or html return")
223
+ logger.warning(f"table recognition processing fails, not get html return")
490
224
  logger.info(f"table time: {round(time.time() - table_start, 2)}")
491
225
 
492
226
  logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
493
227
 
494
228
  return layout_res
495
-
496
-
@@ -0,0 +1,21 @@
1
+ from doclayout_yolo import YOLOv10
2
+
3
+
4
+ class DocLayoutYOLOModel(object):
5
+ def __init__(self, weight, device):
6
+ self.model = YOLOv10(weight)
7
+ self.device = device
8
+
9
+ def predict(self, image):
10
+ layout_res = []
11
+ doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
12
+ for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
13
+ doclayout_yolo_res.boxes.cls.cpu()):
14
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
15
+ new_item = {
16
+ 'category_id': int(cla.item()),
17
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
18
+ 'score': round(float(conf.item()), 3),
19
+ }
20
+ layout_res.append(new_item)
21
+ return layout_res
File without changes
@@ -0,0 +1,12 @@
1
+ from ultralytics import YOLO
2
+
3
+
4
+ class YOLOv8MFDModel(object):
5
+ def __init__(self, weight, device='cpu'):
6
+ self.mfd_model = YOLO(weight)
7
+ self.device = device
8
+
9
+ def predict(self, image):
10
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
11
+ return mfd_res
12
+
File without changes
File without changes
@@ -0,0 +1,98 @@
1
+ import os
2
+ import argparse
3
+ import re
4
+
5
+ from PIL import Image
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torchvision import transforms
9
+ from unimernet.common.config import Config
10
+ import unimernet.tasks as tasks
11
+ from unimernet.processors import load_processor
12
+
13
+
14
+ class MathDataset(Dataset):
15
+ def __init__(self, image_paths, transform=None):
16
+ self.image_paths = image_paths
17
+ self.transform = transform
18
+
19
+ def __len__(self):
20
+ return len(self.image_paths)
21
+
22
+ def __getitem__(self, idx):
23
+ # if not pil image, then convert to pil image
24
+ if isinstance(self.image_paths[idx], str):
25
+ raw_image = Image.open(self.image_paths[idx])
26
+ else:
27
+ raw_image = self.image_paths[idx]
28
+ if self.transform:
29
+ image = self.transform(raw_image)
30
+ return image
31
+
32
+
33
+ def latex_rm_whitespace(s: str):
34
+ """Remove unnecessary whitespace from LaTeX code.
35
+ """
36
+ text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
37
+ letter = '[a-zA-Z]'
38
+ noletter = '[\W_^\d]'
39
+ names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
40
+ s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
41
+ news = s
42
+ while True:
43
+ s = news
44
+ news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
45
+ news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
46
+ news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
47
+ if news == s:
48
+ break
49
+ return s
50
+
51
+
52
+ class UnimernetModel(object):
53
+ def __init__(self, weight_dir, cfg_path, _device_='cpu'):
54
+
55
+ args = argparse.Namespace(cfg_path=cfg_path, options=None)
56
+ cfg = Config(args)
57
+ cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
58
+ cfg.config.model.model_config.model_name = weight_dir
59
+ cfg.config.model.tokenizer_config.path = weight_dir
60
+ task = tasks.setup_task(cfg)
61
+ self.model = task.build_model(cfg)
62
+ self.device = _device_
63
+ self.model.to(_device_)
64
+ self.model.eval()
65
+ vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
66
+ self.mfr_transform = transforms.Compose([vis_processor, ])
67
+
68
+ def predict(self, mfd_res, image):
69
+
70
+ formula_list = []
71
+ mf_image_list = []
72
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
73
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
74
+ new_item = {
75
+ 'category_id': 13 + int(cla.item()),
76
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
77
+ 'score': round(float(conf.item()), 2),
78
+ 'latex': '',
79
+ }
80
+ formula_list.append(new_item)
81
+ pil_img = Image.fromarray(image)
82
+ bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
83
+ mf_image_list.append(bbox_img)
84
+
85
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
86
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
87
+ mfr_res = []
88
+ for mf_img in dataloader:
89
+ mf_img = mf_img.to(self.device)
90
+ with torch.no_grad():
91
+ output = self.model.generate({'image': mf_img})
92
+ mfr_res.extend(output['pred_str'])
93
+ for res, latex in zip(formula_list, mfr_res):
94
+ res['latex'] = latex_rm_whitespace(latex)
95
+ return formula_list
96
+
97
+
98
+