magic-pdf 0.9.2__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. magic_pdf/dict2md/ocr_mkcontent.py +1 -1
  2. magic_pdf/libs/Constants.py +3 -1
  3. magic_pdf/libs/config_reader.py +1 -1
  4. magic_pdf/libs/draw_bbox.py +10 -4
  5. magic_pdf/libs/version.py +1 -1
  6. magic_pdf/model/pdf_extract_kit.py +42 -297
  7. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +21 -0
  8. magic_pdf/model/sub_modules/mfd/__init__.py +0 -0
  9. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +12 -0
  10. magic_pdf/model/sub_modules/mfd/yolov8/__init__.py +0 -0
  11. magic_pdf/model/sub_modules/mfr/__init__.py +0 -0
  12. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +98 -0
  13. magic_pdf/model/sub_modules/mfr/unimernet/__init__.py +0 -0
  14. magic_pdf/model/sub_modules/model_init.py +144 -0
  15. magic_pdf/model/sub_modules/model_utils.py +51 -0
  16. magic_pdf/model/sub_modules/ocr/__init__.py +0 -0
  17. magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py +0 -0
  18. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +259 -0
  19. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +168 -0
  20. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +213 -0
  21. magic_pdf/model/sub_modules/reading_oreder/__init__.py +0 -0
  22. magic_pdf/model/sub_modules/reading_oreder/layoutreader/__init__.py +0 -0
  23. magic_pdf/model/sub_modules/reading_oreder/layoutreader/xycut.py +242 -0
  24. magic_pdf/model/sub_modules/table/__init__.py +0 -0
  25. magic_pdf/model/sub_modules/table/rapidtable/__init__.py +0 -0
  26. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +14 -0
  27. magic_pdf/model/sub_modules/table/structeqtable/__init__.py +0 -0
  28. magic_pdf/model/{pek_sub_modules/structeqtable/StructTableModel.py → sub_modules/table/structeqtable/struct_eqtable.py} +3 -11
  29. magic_pdf/model/sub_modules/table/table_utils.py +11 -0
  30. magic_pdf/model/sub_modules/table/tablemaster/__init__.py +0 -0
  31. magic_pdf/model/{ppTableModel.py → sub_modules/table/tablemaster/tablemaster_paddle.py} +1 -1
  32. magic_pdf/para/para_split_v3.py +13 -15
  33. magic_pdf/pdf_parse_union_core_v2.py +56 -19
  34. magic_pdf/resources/model_config/model_configs.yaml +2 -1
  35. magic_pdf/tools/common.py +47 -3
  36. {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/METADATA +9 -3
  37. {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/RECORD +65 -44
  38. {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/WHEEL +1 -1
  39. magic_pdf/model/pek_sub_modules/post_process.py +0 -36
  40. magic_pdf/model/pek_sub_modules/self_modify.py +0 -388
  41. /magic_pdf/model/{pek_sub_modules → sub_modules}/__init__.py +0 -0
  42. /magic_pdf/model/{pek_sub_modules/layoutlmv3 → sub_modules/layout}/__init__.py +0 -0
  43. /magic_pdf/model/{pek_sub_modules/structeqtable → sub_modules/layout/doclayout_yolo}/__init__.py +0 -0
  44. /magic_pdf/model/{v3 → sub_modules/layout/layoutlmv3}/__init__.py +0 -0
  45. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/backbone.py +0 -0
  46. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/beit.py +0 -0
  47. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/deit.py +0 -0
  48. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/__init__.py +0 -0
  49. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/__init__.py +0 -0
  50. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/cord.py +0 -0
  51. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/data_collator.py +0 -0
  52. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/funsd.py +0 -0
  53. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/image_utils.py +0 -0
  54. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/data/xfund.py +0 -0
  55. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/__init__.py +0 -0
  56. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +0 -0
  57. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +0 -0
  58. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +0 -0
  59. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +0 -0
  60. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +0 -0
  61. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/model_init.py +0 -0
  62. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/rcnn_vl.py +0 -0
  63. /magic_pdf/model/{pek_sub_modules → sub_modules/layout}/layoutlmv3/visualizer.py +0 -0
  64. /magic_pdf/model/{v3 → sub_modules/reading_oreder/layoutreader}/helpers.py +0 -0
  65. {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/LICENSE.md +0 -0
  66. {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/entry_points.txt +0 -0
  67. {magic_pdf-0.9.2.dist-info → magic_pdf-0.9.3.dist-info}/top_level.txt +0 -0
@@ -168,7 +168,7 @@ def merge_para_with_text(para_block):
168
168
  # 如果是前一行带有-连字符,那么末尾不应该加空格
169
169
  if __is_hyphen_at_line_end(content):
170
170
  para_text += content[:-1]
171
- elif len(content) == 1 and content not in ['A', 'I', 'a', 'i']:
171
+ elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
172
172
  para_text += content
173
173
  else: # 西方文本语境下 content间需要空格分隔
174
174
  para_text += f"{content} "
@@ -50,4 +50,6 @@ class MODEL_NAME:
50
50
 
51
51
  YOLO_V8_MFD = "yolo_v8_mfd"
52
52
 
53
- UniMerNet_v2_Small = "unimernet_small"
53
+ UniMerNet_v2_Small = "unimernet_small"
54
+
55
+ RAPID_TABLE = "rapid_table"
@@ -92,7 +92,7 @@ def get_table_recog_config():
92
92
  table_config = config.get('table-config')
93
93
  if table_config is None:
94
94
  logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
95
- return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
95
+ return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
96
96
  else:
97
97
  return table_config
98
98
 
@@ -369,10 +369,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
369
369
  if block['type'] in [BlockType.Image, BlockType.Table]:
370
370
  for sub_block in block['blocks']:
371
371
  if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
372
- for line in sub_block['virtual_lines']:
373
- bbox = line['bbox']
374
- index = line['index']
375
- page_line_list.append({'index': index, 'bbox': bbox})
372
+ if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
373
+ for line in sub_block['virtual_lines']:
374
+ bbox = line['bbox']
375
+ index = line['index']
376
+ page_line_list.append({'index': index, 'bbox': bbox})
377
+ else:
378
+ for line in sub_block['lines']:
379
+ bbox = line['bbox']
380
+ index = line['index']
381
+ page_line_list.append({'index': index, 'bbox': bbox})
376
382
  elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
377
383
  for line in sub_block['lines']:
378
384
  bbox = line['bbox']
magic_pdf/libs/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.9.2"
1
+ __version__ = "0.9.3"
@@ -1,195 +1,28 @@
1
+ import numpy as np
2
+ import torch
1
3
  from loguru import logger
2
4
  import os
3
5
  import time
4
- from pathlib import Path
5
- import shutil
6
- from magic_pdf.libs.Constants import *
7
- from magic_pdf.libs.clean_memory import clean_memory
8
- from magic_pdf.model.model_list import AtomicModel
6
+ import cv2
7
+ import yaml
8
+ from PIL import Image
9
9
 
10
10
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
11
11
  os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
12
+
12
13
  try:
13
- import cv2
14
- import yaml
15
- import argparse
16
- import numpy as np
17
- import torch
18
14
  import torchtext
19
15
 
20
16
  if torchtext.__version__ >= "0.18.0":
21
17
  torchtext.disable_torchtext_deprecation_warning()
22
- from PIL import Image
23
- from torchvision import transforms
24
- from torch.utils.data import Dataset, DataLoader
25
- from ultralytics import YOLO
26
- from unimernet.common.config import Config
27
- import unimernet.tasks as tasks
28
- from unimernet.processors import load_processor
29
- from doclayout_yolo import YOLOv10
30
-
31
- except ImportError as e:
32
- logger.exception(e)
33
- logger.error(
34
- 'Required dependency not installed, please install by \n'
35
- '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
36
- exit(1)
37
-
38
- from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
39
- from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
40
- from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
41
- from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
42
- from magic_pdf.model.ppTableModel import ppTableModel
43
-
44
-
45
- def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
46
- if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
47
- table_model = StructTableModel(model_path, max_time=max_time)
48
- elif table_model_type == MODEL_NAME.TABLE_MASTER:
49
- config = {
50
- "model_dir": model_path,
51
- "device": _device_
52
- }
53
- table_model = ppTableModel(config)
54
- else:
55
- logger.error("table model type not allow")
56
- exit(1)
57
- return table_model
58
-
59
-
60
- def mfd_model_init(weight):
61
- mfd_model = YOLO(weight)
62
- return mfd_model
63
-
64
-
65
- def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
66
- args = argparse.Namespace(cfg_path=cfg_path, options=None)
67
- cfg = Config(args)
68
- cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
69
- cfg.config.model.model_config.model_name = weight_dir
70
- cfg.config.model.tokenizer_config.path = weight_dir
71
- task = tasks.setup_task(cfg)
72
- model = task.build_model(cfg)
73
- model.to(_device_)
74
- model.eval()
75
- vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
76
- mfr_transform = transforms.Compose([vis_processor, ])
77
- return [model, mfr_transform]
78
-
79
-
80
- def layout_model_init(weight, config_file, device):
81
- model = Layoutlmv3_Predictor(weight, config_file, device)
82
- return model
83
-
84
-
85
- def doclayout_yolo_model_init(weight):
86
- model = YOLOv10(weight)
87
- return model
88
-
89
-
90
- def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
91
- if lang is not None:
92
- model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
93
- else:
94
- model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
95
- return model
96
-
97
-
98
- class MathDataset(Dataset):
99
- def __init__(self, image_paths, transform=None):
100
- self.image_paths = image_paths
101
- self.transform = transform
102
-
103
- def __len__(self):
104
- return len(self.image_paths)
105
-
106
- def __getitem__(self, idx):
107
- # if not pil image, then convert to pil image
108
- if isinstance(self.image_paths[idx], str):
109
- raw_image = Image.open(self.image_paths[idx])
110
- else:
111
- raw_image = self.image_paths[idx]
112
- if self.transform:
113
- image = self.transform(raw_image)
114
- return image
115
-
116
-
117
- class AtomModelSingleton:
118
- _instance = None
119
- _models = {}
120
-
121
- def __new__(cls, *args, **kwargs):
122
- if cls._instance is None:
123
- cls._instance = super().__new__(cls)
124
- return cls._instance
125
-
126
- def get_atom_model(self, atom_model_name: str, **kwargs):
127
- lang = kwargs.get("lang", None)
128
- layout_model_name = kwargs.get("layout_model_name", None)
129
- key = (atom_model_name, layout_model_name, lang)
130
- if key not in self._models:
131
- self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
132
- return self._models[key]
133
-
134
-
135
- def atom_model_init(model_name: str, **kwargs):
136
-
137
- if model_name == AtomicModel.Layout:
138
- if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
139
- atom_model = layout_model_init(
140
- kwargs.get("layout_weights"),
141
- kwargs.get("layout_config_file"),
142
- kwargs.get("device")
143
- )
144
- elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
145
- atom_model = doclayout_yolo_model_init(
146
- kwargs.get("doclayout_yolo_weights"),
147
- )
148
- elif model_name == AtomicModel.MFD:
149
- atom_model = mfd_model_init(
150
- kwargs.get("mfd_weights")
151
- )
152
- elif model_name == AtomicModel.MFR:
153
- atom_model = mfr_model_init(
154
- kwargs.get("mfr_weight_dir"),
155
- kwargs.get("mfr_cfg_path"),
156
- kwargs.get("device")
157
- )
158
- elif model_name == AtomicModel.OCR:
159
- atom_model = ocr_model_init(
160
- kwargs.get("ocr_show_log"),
161
- kwargs.get("det_db_box_thresh"),
162
- kwargs.get("lang")
163
- )
164
- elif model_name == AtomicModel.Table:
165
- atom_model = table_model_init(
166
- kwargs.get("table_model_name"),
167
- kwargs.get("table_model_path"),
168
- kwargs.get("table_max_time"),
169
- kwargs.get("device")
170
- )
171
- else:
172
- logger.error("model name not allow")
173
- exit(1)
174
-
175
- return atom_model
176
-
18
+ except ImportError:
19
+ pass
177
20
 
178
- # Unified crop img logic
179
- def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
180
- crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
181
- crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
182
- # Create a white background with an additional width and height of 50
183
- crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
184
- crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
185
- return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
186
-
187
- # Crop image
188
- crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
189
- cropped_img = input_pil_img.crop(crop_box)
190
- return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
191
- return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
192
- return return_image, return_list
21
+ from magic_pdf.libs.Constants import *
22
+ from magic_pdf.model.model_list import AtomicModel
23
+ from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
24
+ from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
25
+ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
193
26
 
194
27
 
195
28
  class CustomPEKModel:
@@ -226,7 +59,7 @@ class CustomPEKModel:
226
59
  self.table_config = kwargs.get("table_config")
227
60
  self.apply_table = self.table_config.get("enable", False)
228
61
  self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
229
- self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
62
+ self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
230
63
 
231
64
  # ocr config
232
65
  self.apply_ocr = ocr
@@ -235,7 +68,8 @@ class CustomPEKModel:
235
68
  logger.info(
236
69
  "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
237
70
  "apply_table: {}, table_model: {}, lang: {}".format(
238
- self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
71
+ self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
72
+ self.lang
239
73
  )
240
74
  )
241
75
  # 初始化解析方案
@@ -248,17 +82,17 @@ class CustomPEKModel:
248
82
 
249
83
  # 初始化公式识别
250
84
  if self.apply_formula:
251
-
252
85
  # 初始化公式检测模型
253
86
  self.mfd_model = atom_model_manager.get_atom_model(
254
87
  atom_model_name=AtomicModel.MFD,
255
- mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
88
+ mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
89
+ device=self.device
256
90
  )
257
91
 
258
92
  # 初始化公式解析模型
259
93
  mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
260
94
  mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
261
- self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
95
+ self.mfr_model = atom_model_manager.get_atom_model(
262
96
  atom_model_name=AtomicModel.MFR,
263
97
  mfr_weight_dir=mfr_weight_dir,
264
98
  mfr_cfg_path=mfr_cfg_path,
@@ -278,7 +112,8 @@ class CustomPEKModel:
278
112
  self.layout_model = atom_model_manager.get_atom_model(
279
113
  atom_model_name=AtomicModel.Layout,
280
114
  layout_model_name=MODEL_NAME.DocLayout_YOLO,
281
- doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
115
+ doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
116
+ device=self.device
282
117
  )
283
118
  # 初始化ocr
284
119
  if self.apply_ocr:
@@ -305,26 +140,15 @@ class CustomPEKModel:
305
140
 
306
141
  page_start = time.time()
307
142
 
308
- latex_filling_list = []
309
- mf_image_list = []
310
-
311
143
  # layout检测
312
144
  layout_start = time.time()
145
+ layout_res = []
313
146
  if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
314
147
  # layoutlmv3
315
148
  layout_res = self.layout_model(image, ignore_catids=[])
316
149
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
317
150
  # doclayout_yolo
318
- layout_res = []
319
- doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
320
- for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
321
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
322
- new_item = {
323
- 'category_id': int(cla.item()),
324
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
325
- 'score': round(float(conf.item()), 3),
326
- }
327
- layout_res.append(new_item)
151
+ layout_res = self.layout_model.predict(image)
328
152
  layout_cost = round(time.time() - layout_start, 2)
329
153
  logger.info(f"layout detection time: {layout_cost}")
330
154
 
@@ -333,59 +157,21 @@ class CustomPEKModel:
333
157
  if self.apply_formula:
334
158
  # 公式检测
335
159
  mfd_start = time.time()
336
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
160
+ mfd_res = self.mfd_model.predict(image)
337
161
  logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
338
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
339
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
340
- new_item = {
341
- 'category_id': 13 + int(cla.item()),
342
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
343
- 'score': round(float(conf.item()), 2),
344
- 'latex': '',
345
- }
346
- layout_res.append(new_item)
347
- latex_filling_list.append(new_item)
348
- bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
349
- mf_image_list.append(bbox_img)
350
162
 
351
163
  # 公式识别
352
164
  mfr_start = time.time()
353
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
354
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
355
- mfr_res = []
356
- for mf_img in dataloader:
357
- mf_img = mf_img.to(self.device)
358
- with torch.no_grad():
359
- output = self.mfr_model.generate({'image': mf_img})
360
- mfr_res.extend(output['pred_str'])
361
- for res, latex in zip(latex_filling_list, mfr_res):
362
- res['latex'] = latex_rm_whitespace(latex)
165
+ formula_list = self.mfr_model.predict(mfd_res, image)
166
+ layout_res.extend(formula_list)
363
167
  mfr_cost = round(time.time() - mfr_start, 2)
364
- logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
365
-
366
- # Select regions for OCR / formula regions / table regions
367
- ocr_res_list = []
368
- table_res_list = []
369
- single_page_mfdetrec_res = []
370
- for res in layout_res:
371
- if int(res['category_id']) in [13, 14]:
372
- single_page_mfdetrec_res.append({
373
- "bbox": [int(res['poly'][0]), int(res['poly'][1]),
374
- int(res['poly'][4]), int(res['poly'][5])],
375
- })
376
- elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
377
- ocr_res_list.append(res)
378
- elif int(res['category_id']) in [5]:
379
- table_res_list.append(res)
380
-
381
- if torch.cuda.is_available() and self.device != 'cpu':
382
- properties = torch.cuda.get_device_properties(self.device)
383
- total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
384
- if total_memory <= 10:
385
- gc_start = time.time()
386
- clean_memory()
387
- gc_time = round(time.time() - gc_start, 2)
388
- logger.info(f"gc time: {gc_time}")
168
+ logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
169
+
170
+ # 清理显存
171
+ clean_vram(self.device, vram_threshold=8)
172
+
173
+ # 从layout_res中获取ocr区域、表格区域、公式区域
174
+ ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
389
175
 
390
176
  # ocr识别
391
177
  if self.apply_ocr:
@@ -393,23 +179,7 @@ class CustomPEKModel:
393
179
  # Process each area that requires OCR processing
394
180
  for res in ocr_res_list:
395
181
  new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
396
- paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
397
- # Adjust the coordinates of the formula area
398
- adjusted_mfdetrec_res = []
399
- for mf_res in single_page_mfdetrec_res:
400
- mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
401
- # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
402
- x0 = mf_xmin - xmin + paste_x
403
- y0 = mf_ymin - ymin + paste_y
404
- x1 = mf_xmax - xmin + paste_x
405
- y1 = mf_ymax - ymin + paste_y
406
- # Filter formula blocks outside the graph
407
- if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
408
- continue
409
- else:
410
- adjusted_mfdetrec_res.append({
411
- "bbox": [x0, y0, x1, y1],
412
- })
182
+ adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
413
183
 
414
184
  # OCR recognition
415
185
  new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
@@ -417,22 +187,8 @@ class CustomPEKModel:
417
187
 
418
188
  # Integration results
419
189
  if ocr_res:
420
- for box_ocr_res in ocr_res:
421
- p1, p2, p3, p4 = box_ocr_res[0]
422
- text, score = box_ocr_res[1]
423
-
424
- # Convert the coordinates back to the original coordinate system
425
- p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
426
- p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
427
- p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
428
- p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
429
-
430
- layout_res.append({
431
- 'category_id': 15,
432
- 'poly': p1 + p2 + p3 + p4,
433
- 'score': round(score, 2),
434
- 'text': text,
435
- })
190
+ ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
191
+ layout_res.extend(ocr_result_list)
436
192
 
437
193
  ocr_cost = round(time.time() - ocr_start, 2)
438
194
  logger.info(f"ocr time: {ocr_cost}")
@@ -443,41 +199,30 @@ class CustomPEKModel:
443
199
  for res in table_res_list:
444
200
  new_image, _ = crop_img(res, pil_img)
445
201
  single_table_start_time = time.time()
446
- # logger.info("------------------table recognition processing begins-----------------")
447
- latex_code = None
448
202
  html_code = None
449
203
  if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
450
204
  with torch.no_grad():
451
205
  table_result = self.table_model.predict(new_image, "html")
452
206
  if len(table_result) > 0:
453
207
  html_code = table_result[0]
454
- else:
208
+ elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
455
209
  html_code = self.table_model.img2html(new_image)
456
-
210
+ elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
211
+ html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
457
212
  run_time = time.time() - single_table_start_time
458
- # logger.info(f"------------table recognition processing ends within {run_time}s-----")
459
213
  if run_time > self.table_max_time:
460
- logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
214
+ logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
461
215
  # 判断是否返回正常
462
-
463
- if latex_code:
464
- expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
465
- if expected_ending:
466
- res["latex"] = latex_code
467
- else:
468
- logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
469
- elif html_code:
216
+ if html_code:
470
217
  expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
471
218
  if expected_ending:
472
219
  res["html"] = html_code
473
220
  else:
474
221
  logger.warning(f"table recognition processing fails, not found expected HTML table end")
475
222
  else:
476
- logger.warning(f"table recognition processing fails, not get latex or html return")
223
+ logger.warning(f"table recognition processing fails, not get html return")
477
224
  logger.info(f"table time: {round(time.time() - table_start, 2)}")
478
225
 
479
226
  logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
480
227
 
481
228
  return layout_res
482
-
483
-
@@ -0,0 +1,21 @@
1
+ from doclayout_yolo import YOLOv10
2
+
3
+
4
+ class DocLayoutYOLOModel(object):
5
+ def __init__(self, weight, device):
6
+ self.model = YOLOv10(weight)
7
+ self.device = device
8
+
9
+ def predict(self, image):
10
+ layout_res = []
11
+ doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
12
+ for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
13
+ doclayout_yolo_res.boxes.cls.cpu()):
14
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
15
+ new_item = {
16
+ 'category_id': int(cla.item()),
17
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
18
+ 'score': round(float(conf.item()), 3),
19
+ }
20
+ layout_res.append(new_item)
21
+ return layout_res
File without changes
@@ -0,0 +1,12 @@
1
+ from ultralytics import YOLO
2
+
3
+
4
+ class YOLOv8MFDModel(object):
5
+ def __init__(self, weight, device='cpu'):
6
+ self.mfd_model = YOLO(weight)
7
+ self.device = device
8
+
9
+ def predict(self, image):
10
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
11
+ return mfd_res
12
+
File without changes
File without changes
@@ -0,0 +1,98 @@
1
+ import os
2
+ import argparse
3
+ import re
4
+
5
+ from PIL import Image
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torchvision import transforms
9
+ from unimernet.common.config import Config
10
+ import unimernet.tasks as tasks
11
+ from unimernet.processors import load_processor
12
+
13
+
14
+ class MathDataset(Dataset):
15
+ def __init__(self, image_paths, transform=None):
16
+ self.image_paths = image_paths
17
+ self.transform = transform
18
+
19
+ def __len__(self):
20
+ return len(self.image_paths)
21
+
22
+ def __getitem__(self, idx):
23
+ # if not pil image, then convert to pil image
24
+ if isinstance(self.image_paths[idx], str):
25
+ raw_image = Image.open(self.image_paths[idx])
26
+ else:
27
+ raw_image = self.image_paths[idx]
28
+ if self.transform:
29
+ image = self.transform(raw_image)
30
+ return image
31
+
32
+
33
+ def latex_rm_whitespace(s: str):
34
+ """Remove unnecessary whitespace from LaTeX code.
35
+ """
36
+ text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
37
+ letter = '[a-zA-Z]'
38
+ noletter = '[\W_^\d]'
39
+ names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
40
+ s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
41
+ news = s
42
+ while True:
43
+ s = news
44
+ news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
45
+ news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
46
+ news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
47
+ if news == s:
48
+ break
49
+ return s
50
+
51
+
52
+ class UnimernetModel(object):
53
+ def __init__(self, weight_dir, cfg_path, _device_='cpu'):
54
+
55
+ args = argparse.Namespace(cfg_path=cfg_path, options=None)
56
+ cfg = Config(args)
57
+ cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
58
+ cfg.config.model.model_config.model_name = weight_dir
59
+ cfg.config.model.tokenizer_config.path = weight_dir
60
+ task = tasks.setup_task(cfg)
61
+ self.model = task.build_model(cfg)
62
+ self.device = _device_
63
+ self.model.to(_device_)
64
+ self.model.eval()
65
+ vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
66
+ self.mfr_transform = transforms.Compose([vis_processor, ])
67
+
68
+ def predict(self, mfd_res, image):
69
+
70
+ formula_list = []
71
+ mf_image_list = []
72
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
73
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
74
+ new_item = {
75
+ 'category_id': 13 + int(cla.item()),
76
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
77
+ 'score': round(float(conf.item()), 2),
78
+ 'latex': '',
79
+ }
80
+ formula_list.append(new_item)
81
+ pil_img = Image.fromarray(image)
82
+ bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
83
+ mf_image_list.append(bbox_img)
84
+
85
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
86
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
87
+ mfr_res = []
88
+ for mf_img in dataloader:
89
+ mf_img = mf_img.to(self.device)
90
+ with torch.no_grad():
91
+ output = self.model.generate({'image': mf_img})
92
+ mfr_res.extend(output['pred_str'])
93
+ for res, latex in zip(formula_list, mfr_res):
94
+ res['latex'] = latex_rm_whitespace(latex)
95
+ return formula_list
96
+
97
+
98
+
File without changes