magic-pdf 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. magic_pdf/config/__init__.py +0 -0
  2. magic_pdf/config/enums.py +7 -0
  3. magic_pdf/config/exceptions.py +32 -0
  4. magic_pdf/data/__init__.py +0 -0
  5. magic_pdf/data/data_reader_writer/__init__.py +12 -0
  6. magic_pdf/data/data_reader_writer/base.py +51 -0
  7. magic_pdf/data/data_reader_writer/filebase.py +59 -0
  8. magic_pdf/data/data_reader_writer/multi_bucket_s3.py +137 -0
  9. magic_pdf/data/data_reader_writer/s3.py +69 -0
  10. magic_pdf/data/dataset.py +194 -0
  11. magic_pdf/data/io/__init__.py +0 -0
  12. magic_pdf/data/io/base.py +42 -0
  13. magic_pdf/data/io/http.py +37 -0
  14. magic_pdf/data/io/s3.py +114 -0
  15. magic_pdf/data/read_api.py +95 -0
  16. magic_pdf/data/schemas.py +15 -0
  17. magic_pdf/data/utils.py +32 -0
  18. magic_pdf/dict2md/ocr_mkcontent.py +74 -234
  19. magic_pdf/libs/Constants.py +21 -8
  20. magic_pdf/libs/MakeContentConfig.py +1 -0
  21. magic_pdf/libs/boxbase.py +35 -0
  22. magic_pdf/libs/clean_memory.py +10 -0
  23. magic_pdf/libs/config_reader.py +53 -23
  24. magic_pdf/libs/draw_bbox.py +150 -65
  25. magic_pdf/libs/ocr_content_type.py +2 -0
  26. magic_pdf/libs/version.py +1 -1
  27. magic_pdf/model/doc_analyze_by_custom_model.py +77 -32
  28. magic_pdf/model/magic_model.py +331 -15
  29. magic_pdf/model/pdf_extract_kit.py +164 -80
  30. magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +8 -1
  31. magic_pdf/model/ppTableModel.py +2 -2
  32. magic_pdf/model/pp_structure_v2.py +5 -2
  33. magic_pdf/model/v3/__init__.py +0 -0
  34. magic_pdf/model/v3/helpers.py +125 -0
  35. magic_pdf/para/para_split_v3.py +296 -0
  36. magic_pdf/pdf_parse_by_ocr.py +6 -3
  37. magic_pdf/pdf_parse_by_txt.py +6 -3
  38. magic_pdf/pdf_parse_union_core_v2.py +644 -0
  39. magic_pdf/pipe/AbsPipe.py +5 -1
  40. magic_pdf/pipe/OCRPipe.py +10 -4
  41. magic_pdf/pipe/TXTPipe.py +10 -4
  42. magic_pdf/pipe/UNIPipe.py +16 -7
  43. magic_pdf/pre_proc/ocr_detect_all_bboxes.py +83 -1
  44. magic_pdf/pre_proc/ocr_dict_merge.py +27 -2
  45. magic_pdf/resources/model_config/UniMERNet/demo.yaml +7 -7
  46. magic_pdf/resources/model_config/model_configs.yaml +5 -13
  47. magic_pdf/tools/cli.py +14 -1
  48. magic_pdf/tools/common.py +18 -8
  49. magic_pdf/user_api.py +25 -6
  50. magic_pdf/utils/__init__.py +0 -0
  51. magic_pdf/utils/annotations.py +11 -0
  52. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.0.dist-info}/LICENSE.md +1 -0
  53. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.0.dist-info}/METADATA +120 -75
  54. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.0.dist-info}/RECORD +57 -33
  55. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.0.dist-info}/WHEEL +0 -0
  56. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.0.dist-info}/entry_points.txt +0 -0
  57. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,14 @@
1
1
  from loguru import logger
2
2
  import os
3
3
  import time
4
-
4
+ from pathlib import Path
5
+ import shutil
5
6
  from magic_pdf.libs.Constants import *
7
+ from magic_pdf.libs.clean_memory import clean_memory
6
8
  from magic_pdf.model.model_list import AtomicModel
7
9
 
8
10
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
11
+ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
9
12
  try:
10
13
  import cv2
11
14
  import yaml
@@ -23,6 +26,7 @@ try:
23
26
  from unimernet.common.config import Config
24
27
  import unimernet.tasks as tasks
25
28
  from unimernet.processors import load_processor
29
+ from doclayout_yolo import YOLOv10
26
30
 
27
31
  except ImportError as e:
28
32
  logger.exception(e)
@@ -32,21 +36,26 @@ except ImportError as e:
32
36
  exit(1)
33
37
 
34
38
  from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
35
- from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
39
+ from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
36
40
  from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
37
- from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
41
+ # from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
38
42
  from magic_pdf.model.ppTableModel import ppTableModel
39
43
 
40
44
 
41
45
  def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
42
- if table_model_type == STRUCT_EQTABLE:
43
- table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
44
- else:
46
+ if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
47
+ # table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
48
+ logger.error("StructEqTable is under upgrade, the current version does not support it.")
49
+ exit(1)
50
+ elif table_model_type == MODEL_NAME.TABLE_MASTER:
45
51
  config = {
46
52
  "model_dir": model_path,
47
53
  "device": _device_
48
54
  }
49
55
  table_model = ppTableModel(config)
56
+ else:
57
+ logger.error("table model type not allow")
58
+ exit(1)
50
59
  return table_model
51
60
 
52
61
 
@@ -58,12 +67,13 @@ def mfd_model_init(weight):
58
67
  def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
59
68
  args = argparse.Namespace(cfg_path=cfg_path, options=None)
60
69
  cfg = Config(args)
61
- cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
70
+ cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
62
71
  cfg.config.model.model_config.model_name = weight_dir
63
72
  cfg.config.model.tokenizer_config.path = weight_dir
64
73
  task = tasks.setup_task(cfg)
65
74
  model = task.build_model(cfg)
66
- model = model.to(_device_)
75
+ model.to(_device_)
76
+ model.eval()
67
77
  vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
68
78
  mfr_transform = transforms.Compose([vis_processor, ])
69
79
  return [model, mfr_transform]
@@ -74,8 +84,16 @@ def layout_model_init(weight, config_file, device):
74
84
  return model
75
85
 
76
86
 
77
- def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
78
- model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
87
+ def doclayout_yolo_model_init(weight):
88
+ model = YOLOv10(weight)
89
+ return model
90
+
91
+
92
+ def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
93
+ if lang is not None:
94
+ model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
95
+ else:
96
+ model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
79
97
  return model
80
98
 
81
99
 
@@ -108,19 +126,27 @@ class AtomModelSingleton:
108
126
  return cls._instance
109
127
 
110
128
  def get_atom_model(self, atom_model_name: str, **kwargs):
111
- if atom_model_name not in self._models:
112
- self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
113
- return self._models[atom_model_name]
129
+ lang = kwargs.get("lang", None)
130
+ layout_model_name = kwargs.get("layout_model_name", None)
131
+ key = (atom_model_name, layout_model_name, lang)
132
+ if key not in self._models:
133
+ self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
134
+ return self._models[key]
114
135
 
115
136
 
116
137
  def atom_model_init(model_name: str, **kwargs):
117
138
 
118
139
  if model_name == AtomicModel.Layout:
119
- atom_model = layout_model_init(
120
- kwargs.get("layout_weights"),
121
- kwargs.get("layout_config_file"),
122
- kwargs.get("device")
123
- )
140
+ if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
141
+ atom_model = layout_model_init(
142
+ kwargs.get("layout_weights"),
143
+ kwargs.get("layout_config_file"),
144
+ kwargs.get("device")
145
+ )
146
+ elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
147
+ atom_model = doclayout_yolo_model_init(
148
+ kwargs.get("doclayout_yolo_weights"),
149
+ )
124
150
  elif model_name == AtomicModel.MFD:
125
151
  atom_model = mfd_model_init(
126
152
  kwargs.get("mfd_weights")
@@ -134,11 +160,12 @@ def atom_model_init(model_name: str, **kwargs):
134
160
  elif model_name == AtomicModel.OCR:
135
161
  atom_model = ocr_model_init(
136
162
  kwargs.get("ocr_show_log"),
137
- kwargs.get("det_db_box_thresh")
163
+ kwargs.get("det_db_box_thresh"),
164
+ kwargs.get("lang")
138
165
  )
139
166
  elif model_name == AtomicModel.Table:
140
167
  atom_model = table_model_init(
141
- kwargs.get("table_model_type"),
168
+ kwargs.get("table_model_name"),
142
169
  kwargs.get("table_model_path"),
143
170
  kwargs.get("table_max_time"),
144
171
  kwargs.get("device")
@@ -150,6 +177,23 @@ def atom_model_init(model_name: str, **kwargs):
150
177
  return atom_model
151
178
 
152
179
 
180
+ # Unified crop img logic
181
+ def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
182
+ crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
183
+ crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
184
+ # Create a white background with an additional width and height of 50
185
+ crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
186
+ crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
187
+ return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
188
+
189
+ # Crop image
190
+ crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
191
+ cropped_img = input_pil_img.crop(crop_box)
192
+ return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
193
+ return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
194
+ return return_image, return_list
195
+
196
+
153
197
  class CustomPEKModel:
154
198
 
155
199
  def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
@@ -169,22 +213,35 @@ class CustomPEKModel:
169
213
  with open(config_path, "r", encoding='utf-8') as f:
170
214
  self.configs = yaml.load(f, Loader=yaml.FullLoader)
171
215
  # 初始化解析配置
172
- self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
173
- self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
216
+
217
+ # layout config
218
+ self.layout_config = kwargs.get("layout_config")
219
+ self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
220
+
221
+ # formula config
222
+ self.formula_config = kwargs.get("formula_config")
223
+ self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
224
+ self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
225
+ self.apply_formula = self.formula_config.get("enable", True)
226
+
174
227
  # table config
175
- self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
176
- self.apply_table = self.table_config.get("is_table_recog_enable", False)
228
+ self.table_config = kwargs.get("table_config")
229
+ self.apply_table = self.table_config.get("enable", False)
177
230
  self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
178
- self.table_model_type = self.table_config.get("model", TABLE_MASTER)
231
+ self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
232
+
233
+ # ocr config
179
234
  self.apply_ocr = ocr
235
+ self.lang = kwargs.get("lang", None)
236
+
180
237
  logger.info(
181
- "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
182
- self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
238
+ "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
239
+ "apply_table: {}, table_model: {}, lang: {}".format(
240
+ self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
183
241
  )
184
242
  )
185
- assert self.apply_layout, "DocAnalysis must contain layout model."
186
243
  # 初始化解析方案
187
- self.device = kwargs.get("device", self.configs["config"]["device"])
244
+ self.device = kwargs.get("device", "cpu")
188
245
  logger.info("using device: {}".format(self.device))
189
246
  models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
190
247
  logger.info("using models_dir: {}".format(models_dir))
@@ -193,17 +250,16 @@ class CustomPEKModel:
193
250
 
194
251
  # 初始化公式识别
195
252
  if self.apply_formula:
253
+
196
254
  # 初始化公式检测模型
197
- # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
198
255
  self.mfd_model = atom_model_manager.get_atom_model(
199
256
  atom_model_name=AtomicModel.MFD,
200
- mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
257
+ mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
201
258
  )
259
+
202
260
  # 初始化公式解析模型
203
- mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
261
+ mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
204
262
  mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
205
- # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
206
- # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
207
263
  self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
208
264
  atom_model_name=AtomicModel.MFR,
209
265
  mfr_weight_dir=mfr_weight_dir,
@@ -212,17 +268,20 @@ class CustomPEKModel:
212
268
  )
213
269
 
214
270
  # 初始化layout模型
215
- # self.layout_model = Layoutlmv3_Predictor(
216
- # str(os.path.join(models_dir, self.configs['weights']['layout'])),
217
- # str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
218
- # device=self.device
219
- # )
220
- self.layout_model = atom_model_manager.get_atom_model(
221
- atom_model_name=AtomicModel.Layout,
222
- layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
223
- layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
224
- device=self.device
225
- )
271
+ if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
272
+ self.layout_model = atom_model_manager.get_atom_model(
273
+ atom_model_name=AtomicModel.Layout,
274
+ layout_model_name=MODEL_NAME.LAYOUTLMv3,
275
+ layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
276
+ layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
277
+ device=self.device
278
+ )
279
+ elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
280
+ self.layout_model = atom_model_manager.get_atom_model(
281
+ atom_model_name=AtomicModel.Layout,
282
+ layout_model_name=MODEL_NAME.DocLayout_YOLO,
283
+ doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
284
+ )
226
285
  # 初始化ocr
227
286
  if self.apply_ocr:
228
287
 
@@ -230,37 +289,67 @@ class CustomPEKModel:
230
289
  self.ocr_model = atom_model_manager.get_atom_model(
231
290
  atom_model_name=AtomicModel.OCR,
232
291
  ocr_show_log=show_log,
233
- det_db_box_thresh=0.3
292
+ det_db_box_thresh=0.3,
293
+ lang=self.lang
234
294
  )
235
295
  # init table model
236
296
  if self.apply_table:
237
- table_model_dir = self.configs["weights"][self.table_model_type]
238
- # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
239
- # max_time=self.table_max_time, _device_=self.device)
297
+ table_model_dir = self.configs["weights"][self.table_model_name]
240
298
  self.table_model = atom_model_manager.get_atom_model(
241
299
  atom_model_name=AtomicModel.Table,
242
- table_model_type=self.table_model_type,
300
+ table_model_name=self.table_model_name,
243
301
  table_model_path=str(os.path.join(models_dir, table_model_dir)),
244
302
  table_max_time=self.table_max_time,
245
303
  device=self.device
246
304
  )
247
305
 
306
+ home_directory = Path.home()
307
+ det_source = os.path.join(models_dir, table_model_dir, DETECT_MODEL_DIR)
308
+ rec_source = os.path.join(models_dir, table_model_dir, REC_MODEL_DIR)
309
+ det_dest_dir = os.path.join(home_directory, PP_DET_DIRECTORY)
310
+ rec_dest_dir = os.path.join(home_directory, PP_REC_DIRECTORY)
311
+
312
+ if not os.path.exists(det_dest_dir):
313
+ shutil.copytree(det_source, det_dest_dir)
314
+ if not os.path.exists(rec_dest_dir):
315
+ shutil.copytree(rec_source, rec_dest_dir)
316
+
248
317
  logger.info('DocAnalysis init done!')
249
318
 
250
319
  def __call__(self, image):
251
320
 
321
+ page_start = time.time()
322
+
252
323
  latex_filling_list = []
253
324
  mf_image_list = []
254
325
 
255
326
  # layout检测
256
327
  layout_start = time.time()
257
- layout_res = self.layout_model(image, ignore_catids=[])
328
+ if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
329
+ # layoutlmv3
330
+ layout_res = self.layout_model(image, ignore_catids=[])
331
+ elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
332
+ # doclayout_yolo
333
+ layout_res = []
334
+ doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
335
+ for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
336
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
337
+ new_item = {
338
+ 'category_id': int(cla.item()),
339
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
340
+ 'score': round(float(conf.item()), 3),
341
+ }
342
+ layout_res.append(new_item)
258
343
  layout_cost = round(time.time() - layout_start, 2)
259
- logger.info(f"layout detection cost: {layout_cost}")
344
+ logger.info(f"layout detection time: {layout_cost}")
345
+
346
+ pil_img = Image.fromarray(image)
260
347
 
261
348
  if self.apply_formula:
262
349
  # 公式检测
263
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
350
+ mfd_start = time.time()
351
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
352
+ logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
264
353
  for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
265
354
  xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
266
355
  new_item = {
@@ -271,7 +360,7 @@ class CustomPEKModel:
271
360
  }
272
361
  layout_res.append(new_item)
273
362
  latex_filling_list.append(new_item)
274
- bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
363
+ bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
275
364
  mf_image_list.append(bbox_img)
276
365
 
277
366
  # 公式识别
@@ -281,7 +370,8 @@ class CustomPEKModel:
281
370
  mfr_res = []
282
371
  for mf_img in dataloader:
283
372
  mf_img = mf_img.to(self.device)
284
- output = self.mfr_model.generate({'image': mf_img})
373
+ with torch.no_grad():
374
+ output = self.mfr_model.generate({'image': mf_img})
285
375
  mfr_res.extend(output['pred_str'])
286
376
  for res, latex in zip(latex_filling_list, mfr_res):
287
377
  res['latex'] = latex_rm_whitespace(latex)
@@ -303,23 +393,14 @@ class CustomPEKModel:
303
393
  elif int(res['category_id']) in [5]:
304
394
  table_res_list.append(res)
305
395
 
306
- # Unified crop img logic
307
- def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
308
- crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
309
- crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
310
- # Create a white background with an additional width and height of 50
311
- crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
312
- crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
313
- return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
314
-
315
- # Crop image
316
- crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
317
- cropped_img = input_pil_img.crop(crop_box)
318
- return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
319
- return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
320
- return return_image, return_list
321
-
322
- pil_img = Image.fromarray(image)
396
+ if torch.cuda.is_available():
397
+ properties = torch.cuda.get_device_properties(self.device)
398
+ total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
399
+ if total_memory <= 10:
400
+ gc_start = time.time()
401
+ clean_memory()
402
+ gc_time = round(time.time() - gc_start, 2)
403
+ logger.info(f"gc time: {gc_time}")
323
404
 
324
405
  # ocr识别
325
406
  if self.apply_ocr:
@@ -369,7 +450,7 @@ class CustomPEKModel:
369
450
  })
370
451
 
371
452
  ocr_cost = round(time.time() - ocr_start, 2)
372
- logger.info(f"ocr cost: {ocr_cost}")
453
+ logger.info(f"ocr time: {ocr_cost}")
373
454
 
374
455
  # 表格识别 table recognition
375
456
  if self.apply_table:
@@ -377,17 +458,17 @@ class CustomPEKModel:
377
458
  for res in table_res_list:
378
459
  new_image, _ = crop_img(res, pil_img)
379
460
  single_table_start_time = time.time()
380
- logger.info("------------------table recognition processing begins-----------------")
461
+ # logger.info("------------------table recognition processing begins-----------------")
381
462
  latex_code = None
382
463
  html_code = None
383
- if self.table_model_type == STRUCT_EQTABLE:
464
+ if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
384
465
  with torch.no_grad():
385
466
  latex_code = self.table_model.image2latex(new_image)[0]
386
467
  else:
387
468
  html_code = self.table_model.img2html(new_image)
388
469
 
389
470
  run_time = time.time() - single_table_start_time
390
- logger.info(f"------------table recognition processing ends within {run_time}s-----")
471
+ # logger.info(f"------------table recognition processing ends within {run_time}s-----")
391
472
  if run_time > self.table_max_time:
392
473
  logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
393
474
  # 判断是否返回正常
@@ -398,12 +479,15 @@ class CustomPEKModel:
398
479
  if expected_ending:
399
480
  res["latex"] = latex_code
400
481
  else:
401
- logger.warning(f"------------table recognition processing fails----------")
482
+ logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
402
483
  elif html_code:
403
484
  res["html"] = html_code
404
485
  else:
405
- logger.warning(f"------------table recognition processing fails----------")
406
- table_cost = round(time.time() - table_start, 2)
407
- logger.info(f"table cost: {table_cost}")
486
+ logger.warning(f"table recognition processing fails, not get latex or html return")
487
+ logger.info(f"table time: {round(time.time() - table_start, 2)}")
488
+
489
+ logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
408
490
 
409
491
  return layout_res
492
+
493
+
@@ -1,5 +1,12 @@
1
- from struct_eqtable.model import StructTable
1
+ from loguru import logger
2
+
3
+ try:
4
+ from struct_eqtable.model import StructTable
5
+ except ImportError:
6
+ logger.error("StructEqTable is under upgrade, the current version does not support it.")
2
7
  from pypandoc import convert_text
8
+
9
+
3
10
  class StructTableModel:
4
11
  def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
5
12
  # init
@@ -52,11 +52,11 @@ class ppTableModel(object):
52
52
  rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
53
53
  rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
54
54
  device = kwargs.get("device", "cpu")
55
- use_gpu = True if device == "cuda" else False
55
+ use_gpu = True if device.startswith("cuda") else False
56
56
  config = {
57
57
  "use_gpu": use_gpu,
58
58
  "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
59
- "table_algorithm": TABLE_MASTER,
59
+ "table_algorithm": "TableMaster",
60
60
  "table_model_dir": table_model_dir,
61
61
  "table_char_dict_path": table_char_dict_path,
62
62
  "det_model_dir": det_model_dir,
@@ -18,8 +18,11 @@ def region_to_bbox(region):
18
18
 
19
19
 
20
20
  class CustomPaddleModel:
21
- def __init__(self, ocr: bool = False, show_log: bool = False):
22
- self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
21
+ def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
22
+ if lang is not None:
23
+ self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
24
+ else:
25
+ self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
23
26
 
24
27
  def __call__(self, img):
25
28
  try:
File without changes
@@ -0,0 +1,125 @@
1
+ from collections import defaultdict
2
+ from typing import List, Dict
3
+
4
+ import torch
5
+ from transformers import LayoutLMv3ForTokenClassification
6
+
7
+ MAX_LEN = 510
8
+ CLS_TOKEN_ID = 0
9
+ UNK_TOKEN_ID = 3
10
+ EOS_TOKEN_ID = 2
11
+
12
+
13
+ class DataCollator:
14
+ def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
15
+ bbox = []
16
+ labels = []
17
+ input_ids = []
18
+ attention_mask = []
19
+
20
+ # clip bbox and labels to max length, build input_ids and attention_mask
21
+ for feature in features:
22
+ _bbox = feature["source_boxes"]
23
+ if len(_bbox) > MAX_LEN:
24
+ _bbox = _bbox[:MAX_LEN]
25
+ _labels = feature["target_index"]
26
+ if len(_labels) > MAX_LEN:
27
+ _labels = _labels[:MAX_LEN]
28
+ _input_ids = [UNK_TOKEN_ID] * len(_bbox)
29
+ _attention_mask = [1] * len(_bbox)
30
+ assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
31
+ bbox.append(_bbox)
32
+ labels.append(_labels)
33
+ input_ids.append(_input_ids)
34
+ attention_mask.append(_attention_mask)
35
+
36
+ # add CLS and EOS tokens
37
+ for i in range(len(bbox)):
38
+ bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
39
+ labels[i] = [-100] + labels[i] + [-100]
40
+ input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
41
+ attention_mask[i] = [1] + attention_mask[i] + [1]
42
+
43
+ # padding to max length
44
+ max_len = max(len(x) for x in bbox)
45
+ for i in range(len(bbox)):
46
+ bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
47
+ labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
48
+ input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
49
+ attention_mask[i] = attention_mask[i] + [0] * (
50
+ max_len - len(attention_mask[i])
51
+ )
52
+
53
+ ret = {
54
+ "bbox": torch.tensor(bbox),
55
+ "attention_mask": torch.tensor(attention_mask),
56
+ "labels": torch.tensor(labels),
57
+ "input_ids": torch.tensor(input_ids),
58
+ }
59
+ # set label > MAX_LEN to -100, because original labels may be > MAX_LEN
60
+ ret["labels"][ret["labels"] > MAX_LEN] = -100
61
+ # set label > 0 to label-1, because original labels are 1-indexed
62
+ ret["labels"][ret["labels"] > 0] -= 1
63
+ return ret
64
+
65
+
66
+ def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
67
+ bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
68
+ input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
69
+ attention_mask = [1] + [1] * len(boxes) + [1]
70
+ return {
71
+ "bbox": torch.tensor([bbox]),
72
+ "attention_mask": torch.tensor([attention_mask]),
73
+ "input_ids": torch.tensor([input_ids]),
74
+ }
75
+
76
+
77
+ def prepare_inputs(
78
+ inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
79
+ ) -> Dict[str, torch.Tensor]:
80
+ ret = {}
81
+ for k, v in inputs.items():
82
+ v = v.to(model.device)
83
+ if torch.is_floating_point(v):
84
+ v = v.to(model.dtype)
85
+ ret[k] = v
86
+ return ret
87
+
88
+
89
+ def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
90
+ """
91
+ parse logits to orders
92
+
93
+ :param logits: logits from model
94
+ :param length: input length
95
+ :return: orders
96
+ """
97
+ logits = logits[1 : length + 1, :length]
98
+ orders = logits.argsort(descending=False).tolist()
99
+ ret = [o.pop() for o in orders]
100
+ while True:
101
+ order_to_idxes = defaultdict(list)
102
+ for idx, order in enumerate(ret):
103
+ order_to_idxes[order].append(idx)
104
+ # filter idxes len > 1
105
+ order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
106
+ if not order_to_idxes:
107
+ break
108
+ # filter
109
+ for order, idxes in order_to_idxes.items():
110
+ # find original logits of idxes
111
+ idxes_to_logit = {}
112
+ for idx in idxes:
113
+ idxes_to_logit[idx] = logits[idx, order]
114
+ idxes_to_logit = sorted(
115
+ idxes_to_logit.items(), key=lambda x: x[1], reverse=True
116
+ )
117
+ # keep the highest logit as order, set others to next candidate
118
+ for idx, _ in idxes_to_logit[1:]:
119
+ ret[idx] = orders[idx].pop()
120
+
121
+ return ret
122
+
123
+
124
+ def check_duplicate(a: List[int]) -> bool:
125
+ return len(a) != len(set(a))