magic-pdf 0.10.5__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. magic_pdf/config/constants.py +7 -0
  2. magic_pdf/config/exceptions.py +7 -0
  3. magic_pdf/data/data_reader_writer/base.py +13 -1
  4. magic_pdf/data/data_reader_writer/filebase.py +1 -1
  5. magic_pdf/data/data_reader_writer/multi_bucket_s3.py +8 -6
  6. magic_pdf/data/dataset.py +188 -5
  7. magic_pdf/data/read_api.py +59 -12
  8. magic_pdf/data/utils.py +35 -0
  9. magic_pdf/dict2md/ocr_mkcontent.py +16 -15
  10. magic_pdf/filter/__init__.py +32 -0
  11. magic_pdf/filter/pdf_meta_scan.py +3 -2
  12. magic_pdf/libs/clean_memory.py +11 -4
  13. magic_pdf/libs/config_reader.py +9 -0
  14. magic_pdf/libs/draw_bbox.py +19 -22
  15. magic_pdf/libs/language.py +3 -0
  16. magic_pdf/libs/pdf_check.py +30 -30
  17. magic_pdf/libs/version.py +1 -1
  18. magic_pdf/model/__init__.py +1 -1
  19. magic_pdf/model/batch_analyze.py +275 -0
  20. magic_pdf/model/doc_analyze_by_custom_model.py +104 -92
  21. magic_pdf/model/magic_model.py +4 -435
  22. magic_pdf/model/model_list.py +1 -0
  23. magic_pdf/model/pdf_extract_kit.py +35 -5
  24. magic_pdf/model/sub_modules/language_detection/__init__.py +1 -0
  25. magic_pdf/model/sub_modules/language_detection/utils.py +82 -0
  26. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +139 -0
  27. magic_pdf/model/sub_modules/language_detection/yolov11/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +44 -7
  29. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +21 -2
  30. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +70 -27
  31. magic_pdf/model/sub_modules/model_init.py +43 -7
  32. magic_pdf/model/sub_modules/model_utils.py +17 -5
  33. magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +51 -1
  34. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +32 -6
  35. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +42 -7
  36. magic_pdf/operators/__init__.py +94 -0
  37. magic_pdf/operators/models.py +154 -0
  38. magic_pdf/operators/pipes.py +191 -0
  39. magic_pdf/pdf_parse_union_core_v2.py +77 -27
  40. magic_pdf/post_proc/__init__.py +1 -0
  41. magic_pdf/post_proc/llm_aided.py +133 -0
  42. magic_pdf/pre_proc/ocr_span_list_modify.py +8 -0
  43. magic_pdf/pre_proc/remove_bbox_overlap.py +1 -1
  44. magic_pdf/resources/yolov11-langdetect/yolo_v11_ft.pt +0 -0
  45. magic_pdf/tools/cli.py +36 -11
  46. magic_pdf/tools/common.py +120 -61
  47. magic_pdf/utils/office_to_pdf.py +29 -0
  48. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/METADATA +78 -25
  49. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/RECORD +54 -55
  50. magic_pdf/para/__init__.py +0 -0
  51. magic_pdf/pdf_parse_by_ocr.py +0 -23
  52. magic_pdf/pdf_parse_by_txt.py +0 -24
  53. magic_pdf/pipe/AbsPipe.py +0 -98
  54. magic_pdf/pipe/OCRPipe.py +0 -41
  55. magic_pdf/pipe/TXTPipe.py +0 -41
  56. magic_pdf/pipe/UNIPipe.py +0 -98
  57. magic_pdf/pipe/__init__.py +0 -0
  58. magic_pdf/rw/AbsReaderWriter.py +0 -17
  59. magic_pdf/rw/DiskReaderWriter.py +0 -74
  60. magic_pdf/rw/S3ReaderWriter.py +0 -142
  61. magic_pdf/rw/__init__.py +0 -0
  62. magic_pdf/user_api.py +0 -121
  63. /magic_pdf/{para → post_proc}/para_split_v3.py +0 -0
  64. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/LICENSE.md +0 -0
  65. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/WHEEL +0 -0
  66. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/entry_points.txt +0 -0
  67. {magic_pdf-0.10.5.dist-info → magic_pdf-1.0.0.dist-info}/top_level.txt +0 -0
@@ -3,12 +3,9 @@ import enum
3
3
  from magic_pdf.config.model_block_type import ModelBlockTypeEnum
4
4
  from magic_pdf.config.ocr_content_type import CategoryId, ContentType
5
5
  from magic_pdf.data.dataset import Dataset
6
- from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
7
- bbox_relative_pos, box_area, calculate_iou,
8
- calculate_overlap_area_in_bbox1_area_ratio,
9
- get_overlap_area)
6
+ from magic_pdf.libs.boxbase import (_is_in, bbox_distance, bbox_relative_pos,
7
+ calculate_iou)
10
8
  from magic_pdf.libs.coordinate_transform import get_scale_ratio
11
- from magic_pdf.libs.local_math import float_gt
12
9
  from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
13
10
 
14
11
  CAPATION_OVERLAP_AREA_RATIO = 0.6
@@ -208,393 +205,6 @@ class MagicModel:
208
205
  keep[i] = False
209
206
  return [bboxes[i] for i in range(N) if keep[i]]
210
207
 
211
- def __tie_up_category_by_distance(
212
- self, page_no, subject_category_id, object_category_id
213
- ):
214
- """假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object
215
- 只能属于一个 subject."""
216
- ret = []
217
- MAX_DIS_OF_POINT = 10**9 + 7
218
- """
219
- subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。
220
- 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
221
- 再求出筛选出的 subjects 和 object 的最短距离
222
- """
223
-
224
- def search_overlap_between_boxes(subject_idx, object_idx):
225
- idxes = [subject_idx, object_idx]
226
- x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
227
- y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
228
- x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
229
- y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
230
-
231
- merged_bbox = [
232
- min(x0s),
233
- min(y0s),
234
- max(x1s),
235
- max(y1s),
236
- ]
237
- ratio = 0
238
-
239
- other_objects = list(
240
- map(
241
- lambda x: {'bbox': x['bbox'], 'score': x['score']},
242
- filter(
243
- lambda x: x['category_id']
244
- not in (object_category_id, subject_category_id),
245
- self.__model_list[page_no]['layout_dets'],
246
- ),
247
- )
248
- )
249
- for other_object in other_objects:
250
- ratio = max(
251
- ratio,
252
- get_overlap_area(merged_bbox, other_object['bbox'])
253
- * 1.0
254
- / box_area(all_bboxes[object_idx]['bbox']),
255
- )
256
- if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
257
- break
258
-
259
- return ratio
260
-
261
- def may_find_other_nearest_bbox(subject_idx, object_idx):
262
- ret = float('inf')
263
-
264
- x0 = min(
265
- all_bboxes[subject_idx]['bbox'][0], all_bboxes[object_idx]['bbox'][0]
266
- )
267
- y0 = min(
268
- all_bboxes[subject_idx]['bbox'][1], all_bboxes[object_idx]['bbox'][1]
269
- )
270
- x1 = max(
271
- all_bboxes[subject_idx]['bbox'][2], all_bboxes[object_idx]['bbox'][2]
272
- )
273
- y1 = max(
274
- all_bboxes[subject_idx]['bbox'][3], all_bboxes[object_idx]['bbox'][3]
275
- )
276
-
277
- object_area = abs(
278
- all_bboxes[object_idx]['bbox'][2] - all_bboxes[object_idx]['bbox'][0]
279
- ) * abs(
280
- all_bboxes[object_idx]['bbox'][3] - all_bboxes[object_idx]['bbox'][1]
281
- )
282
-
283
- for i in range(len(all_bboxes)):
284
- if (
285
- i == subject_idx
286
- or all_bboxes[i]['category_id'] != subject_category_id
287
- ):
288
- continue
289
- if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]['bbox']) or _is_in(
290
- all_bboxes[i]['bbox'], [x0, y0, x1, y1]
291
- ):
292
-
293
- i_area = abs(
294
- all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
295
- ) * abs(all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1])
296
- if i_area >= object_area:
297
- ret = min(float('inf'), dis[i][object_idx])
298
-
299
- return ret
300
-
301
- def expand_bbbox(idxes):
302
- x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
303
- y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
304
- x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
305
- y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
306
- return min(x0s), min(y0s), max(x1s), max(y1s)
307
-
308
- subjects = self.__reduct_overlap(
309
- list(
310
- map(
311
- lambda x: {'bbox': x['bbox'], 'score': x['score']},
312
- filter(
313
- lambda x: x['category_id'] == subject_category_id,
314
- self.__model_list[page_no]['layout_dets'],
315
- ),
316
- )
317
- )
318
- )
319
-
320
- objects = self.__reduct_overlap(
321
- list(
322
- map(
323
- lambda x: {'bbox': x['bbox'], 'score': x['score']},
324
- filter(
325
- lambda x: x['category_id'] == object_category_id,
326
- self.__model_list[page_no]['layout_dets'],
327
- ),
328
- )
329
- )
330
- )
331
- subject_object_relation_map = {}
332
-
333
- subjects.sort(
334
- key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2
335
- ) # get the distance !
336
-
337
- all_bboxes = []
338
-
339
- for v in subjects:
340
- all_bboxes.append(
341
- {
342
- 'category_id': subject_category_id,
343
- 'bbox': v['bbox'],
344
- 'score': v['score'],
345
- }
346
- )
347
-
348
- for v in objects:
349
- all_bboxes.append(
350
- {
351
- 'category_id': object_category_id,
352
- 'bbox': v['bbox'],
353
- 'score': v['score'],
354
- }
355
- )
356
-
357
- N = len(all_bboxes)
358
- dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]
359
-
360
- for i in range(N):
361
- for j in range(i):
362
- if (
363
- all_bboxes[i]['category_id'] == subject_category_id
364
- and all_bboxes[j]['category_id'] == subject_category_id
365
- ):
366
- continue
367
-
368
- subject_idx, object_idx = i, j
369
- if all_bboxes[j]['category_id'] == subject_category_id:
370
- subject_idx, object_idx = j, i
371
-
372
- if (
373
- search_overlap_between_boxes(subject_idx, object_idx)
374
- >= MERGE_BOX_OVERLAP_AREA_RATIO
375
- ):
376
- dis[i][j] = float('inf')
377
- dis[j][i] = dis[i][j]
378
- continue
379
-
380
- dis[i][j] = self._bbox_distance(
381
- all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
382
- )
383
- dis[j][i] = dis[i][j]
384
-
385
- used = set()
386
- for i in range(N):
387
- # 求第 i 个 subject 所关联的 object
388
- if all_bboxes[i]['category_id'] != subject_category_id:
389
- continue
390
- seen = set()
391
- candidates = []
392
- arr = []
393
- for j in range(N):
394
-
395
- pos_flag_count = sum(
396
- list(
397
- map(
398
- lambda x: 1 if x else 0,
399
- bbox_relative_pos(
400
- all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
401
- ),
402
- )
403
- )
404
- )
405
- if pos_flag_count > 1:
406
- continue
407
- if (
408
- all_bboxes[j]['category_id'] != object_category_id
409
- or j in used
410
- or dis[i][j] == MAX_DIS_OF_POINT
411
- ):
412
- continue
413
- left, right, _, _ = bbox_relative_pos(
414
- all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
415
- ) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
416
- if left or right:
417
- one_way_dis = all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
418
- else:
419
- one_way_dis = all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1]
420
- if dis[i][j] > one_way_dis:
421
- continue
422
- arr.append((dis[i][j], j))
423
-
424
- arr.sort(key=lambda x: x[0])
425
- if len(arr) > 0:
426
- """
427
- bug: 离该subject 最近的 object 可能跨越了其它的 subject。
428
- 比如 [this subect] [some sbuject] [the nearest object of subject]
429
- """
430
- if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
431
-
432
- candidates.append(arr[0][1])
433
- seen.add(arr[0][1])
434
-
435
- # 已经获取初始种子
436
- for j in set(candidates):
437
- tmp = []
438
- for k in range(i + 1, N):
439
- pos_flag_count = sum(
440
- list(
441
- map(
442
- lambda x: 1 if x else 0,
443
- bbox_relative_pos(
444
- all_bboxes[j]['bbox'], all_bboxes[k]['bbox']
445
- ),
446
- )
447
- )
448
- )
449
-
450
- if pos_flag_count > 1:
451
- continue
452
-
453
- if (
454
- all_bboxes[k]['category_id'] != object_category_id
455
- or k in used
456
- or k in seen
457
- or dis[j][k] == MAX_DIS_OF_POINT
458
- or dis[j][k] > dis[i][j]
459
- ):
460
- continue
461
-
462
- is_nearest = True
463
- for ni in range(i + 1, N):
464
- if ni in (j, k) or ni in used or ni in seen:
465
- continue
466
-
467
- if not float_gt(dis[ni][k], dis[j][k]):
468
- is_nearest = False
469
- break
470
-
471
- if is_nearest:
472
- nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
473
- n_dis = bbox_distance(
474
- all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
475
- )
476
- if float_gt(dis[i][j], n_dis):
477
- continue
478
- tmp.append(k)
479
- seen.add(k)
480
-
481
- candidates = tmp
482
- if len(candidates) == 0:
483
- break
484
-
485
- # 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
486
- # 先扩一下 bbox,
487
- ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
488
- ix0, iy0, ix1, iy1 = all_bboxes[i]['bbox']
489
-
490
- # 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
491
- caption_poses = [
492
- [ox0, oy0, ix0, oy1],
493
- [ox0, oy0, ox1, iy0],
494
- [ox0, iy1, ox1, oy1],
495
- [ix1, oy0, ox1, oy1],
496
- ]
497
-
498
- caption_areas = []
499
- for bbox in caption_poses:
500
- embed_arr = []
501
- for idx in seen:
502
- if (
503
- calculate_overlap_area_in_bbox1_area_ratio(
504
- all_bboxes[idx]['bbox'], bbox
505
- )
506
- > CAPATION_OVERLAP_AREA_RATIO
507
- ):
508
- embed_arr.append(idx)
509
-
510
- if len(embed_arr) > 0:
511
- embed_x0 = min([all_bboxes[idx]['bbox'][0] for idx in embed_arr])
512
- embed_y0 = min([all_bboxes[idx]['bbox'][1] for idx in embed_arr])
513
- embed_x1 = max([all_bboxes[idx]['bbox'][2] for idx in embed_arr])
514
- embed_y1 = max([all_bboxes[idx]['bbox'][3] for idx in embed_arr])
515
- caption_areas.append(
516
- int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0))
517
- )
518
- else:
519
- caption_areas.append(0)
520
-
521
- subject_object_relation_map[i] = []
522
- if max(caption_areas) > 0:
523
- max_area_idx = caption_areas.index(max(caption_areas))
524
- caption_bbox = caption_poses[max_area_idx]
525
-
526
- for j in seen:
527
- if (
528
- calculate_overlap_area_in_bbox1_area_ratio(
529
- all_bboxes[j]['bbox'], caption_bbox
530
- )
531
- > CAPATION_OVERLAP_AREA_RATIO
532
- ):
533
- used.add(j)
534
- subject_object_relation_map[i].append(j)
535
-
536
- for i in sorted(subject_object_relation_map.keys()):
537
- result = {
538
- 'subject_body': all_bboxes[i]['bbox'],
539
- 'all': all_bboxes[i]['bbox'],
540
- 'score': all_bboxes[i]['score'],
541
- }
542
-
543
- if len(subject_object_relation_map[i]) > 0:
544
- x0 = min(
545
- [all_bboxes[j]['bbox'][0] for j in subject_object_relation_map[i]]
546
- )
547
- y0 = min(
548
- [all_bboxes[j]['bbox'][1] for j in subject_object_relation_map[i]]
549
- )
550
- x1 = max(
551
- [all_bboxes[j]['bbox'][2] for j in subject_object_relation_map[i]]
552
- )
553
- y1 = max(
554
- [all_bboxes[j]['bbox'][3] for j in subject_object_relation_map[i]]
555
- )
556
- result['object_body'] = [x0, y0, x1, y1]
557
- result['all'] = [
558
- min(x0, all_bboxes[i]['bbox'][0]),
559
- min(y0, all_bboxes[i]['bbox'][1]),
560
- max(x1, all_bboxes[i]['bbox'][2]),
561
- max(y1, all_bboxes[i]['bbox'][3]),
562
- ]
563
- ret.append(result)
564
-
565
- total_subject_object_dis = 0
566
- # 计算已经配对的 distance 距离
567
- for i in subject_object_relation_map.keys():
568
- for j in subject_object_relation_map[i]:
569
- total_subject_object_dis += bbox_distance(
570
- all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
571
- )
572
-
573
- # 计算未匹配的 subject 和 object 的距离(非精确版)
574
- with_caption_subject = set(
575
- [
576
- key
577
- for key in subject_object_relation_map.keys()
578
- if len(subject_object_relation_map[i]) > 0
579
- ]
580
- )
581
- for i in range(N):
582
- if all_bboxes[i]['category_id'] != object_category_id or i in used:
583
- continue
584
- candidates = []
585
- for j in range(N):
586
- if (
587
- all_bboxes[j]['category_id'] != subject_category_id
588
- or j in with_caption_subject
589
- ):
590
- continue
591
- candidates.append((dis[i][j], j))
592
- if len(candidates) > 0:
593
- candidates.sort(key=lambda x: x[0])
594
- total_subject_object_dis += candidates[0][1]
595
- with_caption_subject.add(j)
596
- return ret, total_subject_object_dis
597
-
598
208
  def __tie_up_category_by_distance_v2(
599
209
  self,
600
210
  page_no: int,
@@ -879,52 +489,12 @@ class MagicModel:
879
489
  return ret
880
490
 
881
491
  def get_imgs(self, page_no: int):
882
- with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
883
- with_footnotes, _ = self.__tie_up_category_by_distance(
884
- page_no, 3, CategoryId.ImageFootnote
885
- )
886
- ret = []
887
- N, M = len(with_captions), len(with_footnotes)
888
- assert N == M
889
- for i in range(N):
890
- record = {
891
- 'score': with_captions[i]['score'],
892
- 'img_caption_bbox': with_captions[i].get('object_body', None),
893
- 'img_body_bbox': with_captions[i]['subject_body'],
894
- 'img_footnote_bbox': with_footnotes[i].get('object_body', None),
895
- }
896
-
897
- x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
898
- y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
899
- x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
900
- y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
901
- record['bbox'] = [x0, y0, x1, y1]
902
- ret.append(record)
903
- return ret
492
+ return self.get_imgs_v2(page_no)
904
493
 
905
494
  def get_tables(
906
495
  self, page_no: int
907
496
  ) -> list: # 3个坐标, caption, table主体,table-note
908
- with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
909
- with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
910
- ret = []
911
- N, M = len(with_captions), len(with_footnotes)
912
- assert N == M
913
- for i in range(N):
914
- record = {
915
- 'score': with_captions[i]['score'],
916
- 'table_caption_bbox': with_captions[i].get('object_body', None),
917
- 'table_body_bbox': with_captions[i]['subject_body'],
918
- 'table_footnote_bbox': with_footnotes[i].get('object_body', None),
919
- }
920
-
921
- x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
922
- y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
923
- x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
924
- y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
925
- record['bbox'] = [x0, y0, x1, y1]
926
- ret.append(record)
927
- return ret
497
+ return self.get_tables_v2(page_no)
928
498
 
929
499
  def get_equations(self, page_no: int) -> list: # 有坐标,也有字
930
500
  inline_equations = self.__get_blocks_by_type(
@@ -1043,4 +613,3 @@ class MagicModel:
1043
613
 
1044
614
  def get_model_list(self, page_no):
1045
615
  return self.__model_list[page_no]
1046
-
@@ -9,3 +9,4 @@ class AtomicModel:
9
9
  MFR = "mfr"
10
10
  OCR = "ocr"
11
11
  Table = "table"
12
+ LangDetect = "langdetect"
@@ -10,7 +10,6 @@ from loguru import logger
10
10
  from PIL import Image
11
11
 
12
12
  os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
13
- os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
14
13
 
15
14
  try:
16
15
  import torchtext
@@ -88,6 +87,14 @@ class CustomPEKModel:
88
87
  )
89
88
  # 初始化解析方案
90
89
  self.device = kwargs.get('device', 'cpu')
90
+
91
+ if str(self.device).startswith("npu"):
92
+ import torch_npu
93
+ os.environ['FLAGS_npu_jit_compile'] = '0'
94
+ os.environ['FLAGS_use_stride_kernel'] = '0'
95
+ elif str(self.device).startswith("mps"):
96
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
97
+
91
98
  logger.info('using device: {}'.format(self.device))
92
99
  models_dir = kwargs.get(
93
100
  'models_dir', os.path.join(root_dir, 'resources', 'models')
@@ -114,11 +121,12 @@ class CustomPEKModel:
114
121
  os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
115
122
  )
116
123
  mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
124
+
117
125
  self.mfr_model = atom_model_manager.get_atom_model(
118
126
  atom_model_name=AtomicModel.MFR,
119
127
  mfr_weight_dir=mfr_weight_dir,
120
128
  mfr_cfg_path=mfr_cfg_path,
121
- device=self.device,
129
+ device='cpu' if str(self.device).startswith("mps") else self.device,
122
130
  )
123
131
 
124
132
  # 初始化layout模型
@@ -165,12 +173,17 @@ class CustomPEKModel:
165
173
  table_model_path=str(os.path.join(models_dir, table_model_dir)),
166
174
  table_max_time=self.table_max_time,
167
175
  device=self.device,
176
+ ocr_engine=self.ocr_model,
168
177
  )
169
178
 
170
179
  logger.info('DocAnalysis init done!')
171
180
 
172
181
  def __call__(self, image):
173
182
 
183
+ pil_img = Image.fromarray(image)
184
+ width, height = pil_img.size
185
+ # logger.info(f'width: {width}, height: {height}')
186
+
174
187
  # layout检测
175
188
  layout_start = time.time()
176
189
  layout_res = []
@@ -179,12 +192,28 @@ class CustomPEKModel:
179
192
  layout_res = self.layout_model(image, ignore_catids=[])
180
193
  elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
181
194
  # doclayout_yolo
182
- layout_res = self.layout_model.predict(image)
195
+ if height > width:
196
+ input_res = {"poly":[0,0,width,0,width,height,0,height]}
197
+ new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
198
+ paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
199
+ layout_res = self.layout_model.predict(new_image)
200
+ for res in layout_res:
201
+ p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
202
+ p1 = p1 - paste_x + xmin
203
+ p2 = p2 - paste_y + ymin
204
+ p3 = p3 - paste_x + xmin
205
+ p4 = p4 - paste_y + ymin
206
+ p5 = p5 - paste_x + xmin
207
+ p6 = p6 - paste_y + ymin
208
+ p7 = p7 - paste_x + xmin
209
+ p8 = p8 - paste_y + ymin
210
+ res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
211
+ else:
212
+ layout_res = self.layout_model.predict(image)
213
+
183
214
  layout_cost = round(time.time() - layout_start, 2)
184
215
  logger.info(f'layout detection time: {layout_cost}')
185
216
 
186
- pil_img = Image.fromarray(image)
187
-
188
217
  if self.apply_formula:
189
218
  # 公式检测
190
219
  mfd_start = time.time()
@@ -215,6 +244,7 @@ class CustomPEKModel:
215
244
 
216
245
  # OCR recognition
217
246
  new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
247
+
218
248
  if self.apply_ocr:
219
249
  ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
220
250
  else:
@@ -0,0 +1 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
@@ -0,0 +1,82 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import yaml
6
+ from PIL import Image
7
+
8
+ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
9
+
10
+ from magic_pdf.config.constants import MODEL_NAME
11
+ from magic_pdf.data.utils import load_images_from_pdf
12
+ from magic_pdf.libs.config_reader import get_local_models_dir, get_device
13
+ from magic_pdf.libs.pdf_check import extract_pages
14
+ from magic_pdf.model.model_list import AtomicModel
15
+ from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
16
+
17
+
18
+ def get_model_config():
19
+ local_models_dir = get_local_models_dir()
20
+ device = get_device()
21
+ current_file_path = os.path.abspath(__file__)
22
+ root_dir = Path(current_file_path).parents[3]
23
+ model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
24
+ config_path = os.path.join(model_config_dir, 'model_configs.yaml')
25
+ with open(config_path, 'r', encoding='utf-8') as f:
26
+ configs = yaml.load(f, Loader=yaml.FullLoader)
27
+ return root_dir, local_models_dir, device, configs
28
+
29
+
30
+ def get_text_images(simple_images):
31
+ _, local_models_dir, device, configs = get_model_config()
32
+ atom_model_manager = AtomModelSingleton()
33
+ temp_layout_model = atom_model_manager.get_atom_model(
34
+ atom_model_name=AtomicModel.Layout,
35
+ layout_model_name=MODEL_NAME.DocLayout_YOLO,
36
+ doclayout_yolo_weights=str(
37
+ os.path.join(
38
+ local_models_dir, configs['weights'][MODEL_NAME.DocLayout_YOLO]
39
+ )
40
+ ),
41
+ device=device,
42
+ )
43
+ text_images = []
44
+ for simple_image in simple_images:
45
+ image = Image.fromarray(simple_image['img'])
46
+ layout_res = temp_layout_model.predict(image)
47
+ # 给textblock截图
48
+ for res in layout_res:
49
+ if res['category_id'] in [1]:
50
+ x1, y1, _, _, x2, y2, _, _ = res['poly']
51
+ # 初步清洗(宽和高都小于100)
52
+ if x2 - x1 < 100 and y2 - y1 < 100:
53
+ continue
54
+ text_images.append(image.crop((x1, y1, x2, y2)))
55
+ return text_images
56
+
57
+
58
+ def auto_detect_lang(pdf_bytes: bytes):
59
+ sample_docs = extract_pages(pdf_bytes)
60
+ sample_pdf_bytes = sample_docs.tobytes()
61
+ simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
62
+ text_images = get_text_images(simple_images)
63
+ langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
64
+ lang = langdetect_model.do_detect(text_images)
65
+ return lang
66
+
67
+
68
+ def model_init(model_name: str):
69
+ atom_model_manager = AtomModelSingleton()
70
+
71
+ if model_name == MODEL_NAME.YOLO_V11_LangDetect:
72
+ root_dir, _, device, _ = get_model_config()
73
+ model = atom_model_manager.get_atom_model(
74
+ atom_model_name=AtomicModel.LangDetect,
75
+ langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
76
+ langdetect_model_weight=str(os.path.join(root_dir, 'resources', 'yolov11-langdetect', 'yolo_v11_ft.pt')),
77
+ device=device,
78
+ )
79
+ else:
80
+ raise ValueError(f"model_name {model_name} not found")
81
+ return model
82
+