magic-pdf 0.8.1__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. magic_pdf/config/__init__.py +0 -0
  2. magic_pdf/config/enums.py +7 -0
  3. magic_pdf/config/exceptions.py +32 -0
  4. magic_pdf/data/__init__.py +0 -0
  5. magic_pdf/data/data_reader_writer/__init__.py +12 -0
  6. magic_pdf/data/data_reader_writer/base.py +51 -0
  7. magic_pdf/data/data_reader_writer/filebase.py +59 -0
  8. magic_pdf/data/data_reader_writer/multi_bucket_s3.py +143 -0
  9. magic_pdf/data/data_reader_writer/s3.py +73 -0
  10. magic_pdf/data/dataset.py +194 -0
  11. magic_pdf/data/io/__init__.py +6 -0
  12. magic_pdf/data/io/base.py +42 -0
  13. magic_pdf/data/io/http.py +37 -0
  14. magic_pdf/data/io/s3.py +114 -0
  15. magic_pdf/data/read_api.py +95 -0
  16. magic_pdf/data/schemas.py +19 -0
  17. magic_pdf/data/utils.py +32 -0
  18. magic_pdf/dict2md/ocr_mkcontent.py +106 -244
  19. magic_pdf/libs/Constants.py +21 -8
  20. magic_pdf/libs/MakeContentConfig.py +1 -0
  21. magic_pdf/libs/boxbase.py +35 -0
  22. magic_pdf/libs/clean_memory.py +10 -0
  23. magic_pdf/libs/config_reader.py +53 -23
  24. magic_pdf/libs/draw_bbox.py +150 -65
  25. magic_pdf/libs/ocr_content_type.py +2 -0
  26. magic_pdf/libs/version.py +1 -1
  27. magic_pdf/model/doc_analyze_by_custom_model.py +77 -32
  28. magic_pdf/model/magic_model.py +331 -15
  29. magic_pdf/model/pdf_extract_kit.py +170 -83
  30. magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py +40 -16
  31. magic_pdf/model/ppTableModel.py +8 -6
  32. magic_pdf/model/pp_structure_v2.py +5 -2
  33. magic_pdf/model/v3/__init__.py +0 -0
  34. magic_pdf/model/v3/helpers.py +125 -0
  35. magic_pdf/para/para_split_v3.py +322 -0
  36. magic_pdf/pdf_parse_by_ocr.py +6 -3
  37. magic_pdf/pdf_parse_by_txt.py +6 -3
  38. magic_pdf/pdf_parse_union_core_v2.py +644 -0
  39. magic_pdf/pipe/AbsPipe.py +5 -1
  40. magic_pdf/pipe/OCRPipe.py +10 -4
  41. magic_pdf/pipe/TXTPipe.py +10 -4
  42. magic_pdf/pipe/UNIPipe.py +16 -7
  43. magic_pdf/pre_proc/ocr_detect_all_bboxes.py +83 -1
  44. magic_pdf/pre_proc/ocr_dict_merge.py +27 -2
  45. magic_pdf/resources/model_config/UniMERNet/demo.yaml +7 -7
  46. magic_pdf/resources/model_config/model_configs.yaml +5 -13
  47. magic_pdf/tools/cli.py +14 -1
  48. magic_pdf/tools/common.py +18 -8
  49. magic_pdf/user_api.py +25 -6
  50. magic_pdf/utils/__init__.py +0 -0
  51. magic_pdf/utils/annotations.py +11 -0
  52. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/LICENSE.md +1 -0
  53. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/METADATA +124 -78
  54. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/RECORD +57 -33
  55. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/WHEEL +0 -0
  56. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/entry_points.txt +0 -0
  57. {magic_pdf-0.8.1.dist-info → magic_pdf-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,14 @@
1
1
  from loguru import logger
2
2
  import os
3
3
  import time
4
-
4
+ from pathlib import Path
5
+ import shutil
5
6
  from magic_pdf.libs.Constants import *
7
+ from magic_pdf.libs.clean_memory import clean_memory
6
8
  from magic_pdf.model.model_list import AtomicModel
7
9
 
8
10
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
11
+ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
9
12
  try:
10
13
  import cv2
11
14
  import yaml
@@ -23,6 +26,7 @@ try:
23
26
  from unimernet.common.config import Config
24
27
  import unimernet.tasks as tasks
25
28
  from unimernet.processors import load_processor
29
+ from doclayout_yolo import YOLOv10
26
30
 
27
31
  except ImportError as e:
28
32
  logger.exception(e)
@@ -32,21 +36,24 @@ except ImportError as e:
32
36
  exit(1)
33
37
 
34
38
  from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
35
- from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
39
+ from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
36
40
  from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
37
41
  from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
38
42
  from magic_pdf.model.ppTableModel import ppTableModel
39
43
 
40
44
 
41
45
  def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
42
- if table_model_type == STRUCT_EQTABLE:
43
- table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
44
- else:
46
+ if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
47
+ table_model = StructTableModel(model_path, max_time=max_time)
48
+ elif table_model_type == MODEL_NAME.TABLE_MASTER:
45
49
  config = {
46
50
  "model_dir": model_path,
47
51
  "device": _device_
48
52
  }
49
53
  table_model = ppTableModel(config)
54
+ else:
55
+ logger.error("table model type not allow")
56
+ exit(1)
50
57
  return table_model
51
58
 
52
59
 
@@ -58,12 +65,13 @@ def mfd_model_init(weight):
58
65
  def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
59
66
  args = argparse.Namespace(cfg_path=cfg_path, options=None)
60
67
  cfg = Config(args)
61
- cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
68
+ cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
62
69
  cfg.config.model.model_config.model_name = weight_dir
63
70
  cfg.config.model.tokenizer_config.path = weight_dir
64
71
  task = tasks.setup_task(cfg)
65
72
  model = task.build_model(cfg)
66
- model = model.to(_device_)
73
+ model.to(_device_)
74
+ model.eval()
67
75
  vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
68
76
  mfr_transform = transforms.Compose([vis_processor, ])
69
77
  return [model, mfr_transform]
@@ -74,8 +82,16 @@ def layout_model_init(weight, config_file, device):
74
82
  return model
75
83
 
76
84
 
77
- def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
78
- model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
85
+ def doclayout_yolo_model_init(weight):
86
+ model = YOLOv10(weight)
87
+ return model
88
+
89
+
90
+ def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
91
+ if lang is not None:
92
+ model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
93
+ else:
94
+ model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
79
95
  return model
80
96
 
81
97
 
@@ -108,19 +124,27 @@ class AtomModelSingleton:
108
124
  return cls._instance
109
125
 
110
126
  def get_atom_model(self, atom_model_name: str, **kwargs):
111
- if atom_model_name not in self._models:
112
- self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
113
- return self._models[atom_model_name]
127
+ lang = kwargs.get("lang", None)
128
+ layout_model_name = kwargs.get("layout_model_name", None)
129
+ key = (atom_model_name, layout_model_name, lang)
130
+ if key not in self._models:
131
+ self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
132
+ return self._models[key]
114
133
 
115
134
 
116
135
  def atom_model_init(model_name: str, **kwargs):
117
136
 
118
137
  if model_name == AtomicModel.Layout:
119
- atom_model = layout_model_init(
120
- kwargs.get("layout_weights"),
121
- kwargs.get("layout_config_file"),
122
- kwargs.get("device")
123
- )
138
+ if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
139
+ atom_model = layout_model_init(
140
+ kwargs.get("layout_weights"),
141
+ kwargs.get("layout_config_file"),
142
+ kwargs.get("device")
143
+ )
144
+ elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
145
+ atom_model = doclayout_yolo_model_init(
146
+ kwargs.get("doclayout_yolo_weights"),
147
+ )
124
148
  elif model_name == AtomicModel.MFD:
125
149
  atom_model = mfd_model_init(
126
150
  kwargs.get("mfd_weights")
@@ -134,11 +158,12 @@ def atom_model_init(model_name: str, **kwargs):
134
158
  elif model_name == AtomicModel.OCR:
135
159
  atom_model = ocr_model_init(
136
160
  kwargs.get("ocr_show_log"),
137
- kwargs.get("det_db_box_thresh")
161
+ kwargs.get("det_db_box_thresh"),
162
+ kwargs.get("lang")
138
163
  )
139
164
  elif model_name == AtomicModel.Table:
140
165
  atom_model = table_model_init(
141
- kwargs.get("table_model_type"),
166
+ kwargs.get("table_model_name"),
142
167
  kwargs.get("table_model_path"),
143
168
  kwargs.get("table_max_time"),
144
169
  kwargs.get("device")
@@ -150,6 +175,23 @@ def atom_model_init(model_name: str, **kwargs):
150
175
  return atom_model
151
176
 
152
177
 
178
+ # Unified crop img logic
179
+ def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
180
+ crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
181
+ crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
182
+ # Create a white background with an additional width and height of 50
183
+ crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
184
+ crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
185
+ return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
186
+
187
+ # Crop image
188
+ crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
189
+ cropped_img = input_pil_img.crop(crop_box)
190
+ return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
191
+ return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
192
+ return return_image, return_list
193
+
194
+
153
195
  class CustomPEKModel:
154
196
 
155
197
  def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
@@ -169,22 +211,35 @@ class CustomPEKModel:
169
211
  with open(config_path, "r", encoding='utf-8') as f:
170
212
  self.configs = yaml.load(f, Loader=yaml.FullLoader)
171
213
  # 初始化解析配置
172
- self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
173
- self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
214
+
215
+ # layout config
216
+ self.layout_config = kwargs.get("layout_config")
217
+ self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
218
+
219
+ # formula config
220
+ self.formula_config = kwargs.get("formula_config")
221
+ self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
222
+ self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
223
+ self.apply_formula = self.formula_config.get("enable", True)
224
+
174
225
  # table config
175
- self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
176
- self.apply_table = self.table_config.get("is_table_recog_enable", False)
226
+ self.table_config = kwargs.get("table_config")
227
+ self.apply_table = self.table_config.get("enable", False)
177
228
  self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
178
- self.table_model_type = self.table_config.get("model", TABLE_MASTER)
229
+ self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
230
+
231
+ # ocr config
179
232
  self.apply_ocr = ocr
233
+ self.lang = kwargs.get("lang", None)
234
+
180
235
  logger.info(
181
- "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
182
- self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
236
+ "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
237
+ "apply_table: {}, table_model: {}, lang: {}".format(
238
+ self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
183
239
  )
184
240
  )
185
- assert self.apply_layout, "DocAnalysis must contain layout model."
186
241
  # 初始化解析方案
187
- self.device = kwargs.get("device", self.configs["config"]["device"])
242
+ self.device = kwargs.get("device", "cpu")
188
243
  logger.info("using device: {}".format(self.device))
189
244
  models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
190
245
  logger.info("using models_dir: {}".format(models_dir))
@@ -193,17 +248,16 @@ class CustomPEKModel:
193
248
 
194
249
  # 初始化公式识别
195
250
  if self.apply_formula:
251
+
196
252
  # 初始化公式检测模型
197
- # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
198
253
  self.mfd_model = atom_model_manager.get_atom_model(
199
254
  atom_model_name=AtomicModel.MFD,
200
- mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
255
+ mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
201
256
  )
257
+
202
258
  # 初始化公式解析模型
203
- mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
259
+ mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
204
260
  mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
205
- # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
206
- # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
207
261
  self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
208
262
  atom_model_name=AtomicModel.MFR,
209
263
  mfr_weight_dir=mfr_weight_dir,
@@ -212,17 +266,20 @@ class CustomPEKModel:
212
266
  )
213
267
 
214
268
  # 初始化layout模型
215
- # self.layout_model = Layoutlmv3_Predictor(
216
- # str(os.path.join(models_dir, self.configs['weights']['layout'])),
217
- # str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
218
- # device=self.device
219
- # )
220
- self.layout_model = atom_model_manager.get_atom_model(
221
- atom_model_name=AtomicModel.Layout,
222
- layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
223
- layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
224
- device=self.device
225
- )
269
+ if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
270
+ self.layout_model = atom_model_manager.get_atom_model(
271
+ atom_model_name=AtomicModel.Layout,
272
+ layout_model_name=MODEL_NAME.LAYOUTLMv3,
273
+ layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
274
+ layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
275
+ device=self.device
276
+ )
277
+ elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
278
+ self.layout_model = atom_model_manager.get_atom_model(
279
+ atom_model_name=AtomicModel.Layout,
280
+ layout_model_name=MODEL_NAME.DocLayout_YOLO,
281
+ doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
282
+ )
226
283
  # 初始化ocr
227
284
  if self.apply_ocr:
228
285
 
@@ -230,37 +287,67 @@ class CustomPEKModel:
230
287
  self.ocr_model = atom_model_manager.get_atom_model(
231
288
  atom_model_name=AtomicModel.OCR,
232
289
  ocr_show_log=show_log,
233
- det_db_box_thresh=0.3
290
+ det_db_box_thresh=0.3,
291
+ lang=self.lang
234
292
  )
235
293
  # init table model
236
294
  if self.apply_table:
237
- table_model_dir = self.configs["weights"][self.table_model_type]
238
- # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
239
- # max_time=self.table_max_time, _device_=self.device)
295
+ table_model_dir = self.configs["weights"][self.table_model_name]
240
296
  self.table_model = atom_model_manager.get_atom_model(
241
297
  atom_model_name=AtomicModel.Table,
242
- table_model_type=self.table_model_type,
298
+ table_model_name=self.table_model_name,
243
299
  table_model_path=str(os.path.join(models_dir, table_model_dir)),
244
300
  table_max_time=self.table_max_time,
245
301
  device=self.device
246
302
  )
247
303
 
304
+ home_directory = Path.home()
305
+ det_source = os.path.join(models_dir, table_model_dir, DETECT_MODEL_DIR)
306
+ rec_source = os.path.join(models_dir, table_model_dir, REC_MODEL_DIR)
307
+ det_dest_dir = os.path.join(home_directory, PP_DET_DIRECTORY)
308
+ rec_dest_dir = os.path.join(home_directory, PP_REC_DIRECTORY)
309
+
310
+ if not os.path.exists(det_dest_dir):
311
+ shutil.copytree(det_source, det_dest_dir)
312
+ if not os.path.exists(rec_dest_dir):
313
+ shutil.copytree(rec_source, rec_dest_dir)
314
+
248
315
  logger.info('DocAnalysis init done!')
249
316
 
250
317
  def __call__(self, image):
251
318
 
319
+ page_start = time.time()
320
+
252
321
  latex_filling_list = []
253
322
  mf_image_list = []
254
323
 
255
324
  # layout检测
256
325
  layout_start = time.time()
257
- layout_res = self.layout_model(image, ignore_catids=[])
326
+ if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
327
+ # layoutlmv3
328
+ layout_res = self.layout_model(image, ignore_catids=[])
329
+ elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
330
+ # doclayout_yolo
331
+ layout_res = []
332
+ doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
333
+ for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
334
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
335
+ new_item = {
336
+ 'category_id': int(cla.item()),
337
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
338
+ 'score': round(float(conf.item()), 3),
339
+ }
340
+ layout_res.append(new_item)
258
341
  layout_cost = round(time.time() - layout_start, 2)
259
- logger.info(f"layout detection cost: {layout_cost}")
342
+ logger.info(f"layout detection time: {layout_cost}")
343
+
344
+ pil_img = Image.fromarray(image)
260
345
 
261
346
  if self.apply_formula:
262
347
  # 公式检测
263
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
348
+ mfd_start = time.time()
349
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
350
+ logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
264
351
  for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
265
352
  xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
266
353
  new_item = {
@@ -271,7 +358,7 @@ class CustomPEKModel:
271
358
  }
272
359
  layout_res.append(new_item)
273
360
  latex_filling_list.append(new_item)
274
- bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
361
+ bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
275
362
  mf_image_list.append(bbox_img)
276
363
 
277
364
  # 公式识别
@@ -281,7 +368,8 @@ class CustomPEKModel:
281
368
  mfr_res = []
282
369
  for mf_img in dataloader:
283
370
  mf_img = mf_img.to(self.device)
284
- output = self.mfr_model.generate({'image': mf_img})
371
+ with torch.no_grad():
372
+ output = self.mfr_model.generate({'image': mf_img})
285
373
  mfr_res.extend(output['pred_str'])
286
374
  for res, latex in zip(latex_filling_list, mfr_res):
287
375
  res['latex'] = latex_rm_whitespace(latex)
@@ -303,23 +391,14 @@ class CustomPEKModel:
303
391
  elif int(res['category_id']) in [5]:
304
392
  table_res_list.append(res)
305
393
 
306
- # Unified crop img logic
307
- def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
308
- crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
309
- crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
310
- # Create a white background with an additional width and height of 50
311
- crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
312
- crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
313
- return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
314
-
315
- # Crop image
316
- crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
317
- cropped_img = input_pil_img.crop(crop_box)
318
- return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
319
- return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
320
- return return_image, return_list
321
-
322
- pil_img = Image.fromarray(image)
394
+ if torch.cuda.is_available() and self.device != 'cpu':
395
+ properties = torch.cuda.get_device_properties(self.device)
396
+ total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
397
+ if total_memory <= 10:
398
+ gc_start = time.time()
399
+ clean_memory()
400
+ gc_time = round(time.time() - gc_start, 2)
401
+ logger.info(f"gc time: {gc_time}")
323
402
 
324
403
  # ocr识别
325
404
  if self.apply_ocr:
@@ -369,7 +448,7 @@ class CustomPEKModel:
369
448
  })
370
449
 
371
450
  ocr_cost = round(time.time() - ocr_start, 2)
372
- logger.info(f"ocr cost: {ocr_cost}")
451
+ logger.info(f"ocr time: {ocr_cost}")
373
452
 
374
453
  # 表格识别 table recognition
375
454
  if self.apply_table:
@@ -377,33 +456,41 @@ class CustomPEKModel:
377
456
  for res in table_res_list:
378
457
  new_image, _ = crop_img(res, pil_img)
379
458
  single_table_start_time = time.time()
380
- logger.info("------------------table recognition processing begins-----------------")
459
+ # logger.info("------------------table recognition processing begins-----------------")
381
460
  latex_code = None
382
461
  html_code = None
383
- if self.table_model_type == STRUCT_EQTABLE:
462
+ if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
384
463
  with torch.no_grad():
385
- latex_code = self.table_model.image2latex(new_image)[0]
464
+ table_result = self.table_model.predict(new_image, "html")
465
+ if len(table_result) > 0:
466
+ html_code = table_result[0]
386
467
  else:
387
468
  html_code = self.table_model.img2html(new_image)
388
469
 
389
470
  run_time = time.time() - single_table_start_time
390
- logger.info(f"------------table recognition processing ends within {run_time}s-----")
471
+ # logger.info(f"------------table recognition processing ends within {run_time}s-----")
391
472
  if run_time > self.table_max_time:
392
473
  logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
393
474
  # 判断是否返回正常
394
475
 
395
476
  if latex_code:
396
- expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
397
- 'end{table}')
477
+ expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
398
478
  if expected_ending:
399
479
  res["latex"] = latex_code
400
480
  else:
401
- logger.warning(f"------------table recognition processing fails----------")
481
+ logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
402
482
  elif html_code:
403
- res["html"] = html_code
483
+ expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
484
+ if expected_ending:
485
+ res["html"] = html_code
486
+ else:
487
+ logger.warning(f"table recognition processing fails, not found expected HTML table end")
404
488
  else:
405
- logger.warning(f"------------table recognition processing fails----------")
406
- table_cost = round(time.time() - table_start, 2)
407
- logger.info(f"table cost: {table_cost}")
489
+ logger.warning(f"table recognition processing fails, not get latex or html return")
490
+ logger.info(f"table time: {round(time.time() - table_start, 2)}")
491
+
492
+ logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
408
493
 
409
494
  return layout_res
495
+
496
+
@@ -1,21 +1,45 @@
1
- from struct_eqtable.model import StructTable
2
- from pypandoc import convert_text
1
+ import re
2
+
3
+ import torch
4
+ from struct_eqtable import build_model
5
+
6
+
3
7
  class StructTableModel:
4
- def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
8
+ def __init__(self, model_path, max_new_tokens=1024, max_time=60):
5
9
  # init
6
- self.model_path = model_path
7
- self.max_new_tokens = max_new_tokens # maximum output tokens length
8
- self.max_time = max_time # timeout for processing in seconds
9
- if device == 'cuda':
10
- self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
10
+ assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
11
+ self.model = build_model(
12
+ model_ckpt=model_path,
13
+ max_new_tokens=max_new_tokens,
14
+ max_time=max_time,
15
+ lmdeploy=False,
16
+ flash_attn=False,
17
+ batch_size=1,
18
+ ).cuda()
19
+ self.default_format = "html"
20
+
21
+ def predict(self, images, output_format=None, **kwargs):
22
+
23
+ if output_format is None:
24
+ output_format = self.default_format
11
25
  else:
12
- self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
26
+ if output_format not in ['latex', 'markdown', 'html']:
27
+ raise ValueError(f"Output format {output_format} is not supported.")
28
+
29
+ results = self.model(
30
+ images, output_format=output_format
31
+ )
32
+
33
+ if output_format == "html":
34
+ results = [self.minify_html(html) for html in results]
13
35
 
14
- def image2latex(self, image) -> str:
15
- table_latex = self.model.forward(image)
16
- return table_latex
36
+ return results
17
37
 
18
- def image2html(self, image) -> str:
19
- table_latex = self.image2latex(image)
20
- table_html = convert_text(table_latex, 'html', format='latex')
21
- return table_html
38
+ def minify_html(self, html):
39
+ # 移除多余的空白字符
40
+ html = re.sub(r'\s+', ' ', html)
41
+ # 移除行尾的空白字符
42
+ html = re.sub(r'\s*>\s*', '>', html)
43
+ # 移除标签前的空白字符
44
+ html = re.sub(r'\s*<\s*', '<', html)
45
+ return html.strip()
@@ -1,3 +1,4 @@
1
+ import cv2
1
2
  from paddleocr.ppstructure.table.predict_table import TableSystem
2
3
  from paddleocr.ppstructure.utility import init_args
3
4
  from magic_pdf.libs.Constants import *
@@ -36,12 +37,13 @@ class ppTableModel(object):
36
37
  - HTML (str): A string representing the HTML structure with content of the table.
37
38
  """
38
39
  if isinstance(image, Image.Image):
39
- image = np.array(image)
40
+ image = np.asarray(image)
41
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
40
42
  pred_res, _ = self.table_sys(image)
41
43
  pred_html = pred_res["html"]
42
- res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace("</table></body></html>",
43
- "") + "</table></td>\n"
44
- return res
44
+ # res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace(
45
+ # "</table></body></html>","") + "</table></td>\n"
46
+ return pred_html
45
47
 
46
48
  def parse_args(self, **kwargs):
47
49
  parser = init_args()
@@ -52,11 +54,11 @@ class ppTableModel(object):
52
54
  rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
53
55
  rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
54
56
  device = kwargs.get("device", "cpu")
55
- use_gpu = True if device == "cuda" else False
57
+ use_gpu = True if device.startswith("cuda") else False
56
58
  config = {
57
59
  "use_gpu": use_gpu,
58
60
  "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
59
- "table_algorithm": TABLE_MASTER,
61
+ "table_algorithm": "TableMaster",
60
62
  "table_model_dir": table_model_dir,
61
63
  "table_char_dict_path": table_char_dict_path,
62
64
  "det_model_dir": det_model_dir,
@@ -18,8 +18,11 @@ def region_to_bbox(region):
18
18
 
19
19
 
20
20
  class CustomPaddleModel:
21
- def __init__(self, ocr: bool = False, show_log: bool = False):
22
- self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
21
+ def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
22
+ if lang is not None:
23
+ self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
24
+ else:
25
+ self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
23
26
 
24
27
  def __call__(self, img):
25
28
  try:
File without changes