magic-pdf 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +44 -24
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +17 -11
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/post_proc/para_split_v3.py +16 -13
  82. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  83. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  84. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  85. magic_pdf/tools/cli.py +30 -12
  86. magic_pdf/tools/common.py +90 -12
  87. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +51 -41
  88. magic_pdf-1.3.0.dist-info/RECORD +202 -0
  89. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  90. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  91. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  92. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  93. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  94. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  95. magic_pdf-1.2.1.dist-info/RECORD +0 -147
  96. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  97. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  98. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  99. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
  100. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
  101. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
  102. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,14 @@
1
1
  import time
2
-
3
2
  import cv2
4
- import numpy as np
5
- import torch
6
3
  from loguru import logger
7
- from PIL import Image
4
+ from tqdm import tqdm
8
5
 
9
6
  from magic_pdf.config.constants import MODEL_NAME
10
- # from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
11
- # from magic_pdf.data.dataset import Dataset
12
- # from magic_pdf.libs.clean_memory import clean_memory
13
- # from magic_pdf.libs.config_reader import get_device
14
- # from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
15
- from magic_pdf.model.pdf_extract_kit import CustomPEKModel
7
+ from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
16
8
  from magic_pdf.model.sub_modules.model_utils import (
17
9
  clean_vram, crop_img, get_res_list_from_layout_res)
18
- from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
10
+ from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
19
11
  get_adjusted_mfdetrec_res, get_ocr_result_list)
20
- # from magic_pdf.operators.models import InferenceResult
21
12
 
22
13
  YOLO_LAYOUT_BASE_BATCH_SIZE = 1
23
14
  MFD_BASE_BATCH_SIZE = 1
@@ -25,14 +16,25 @@ MFR_BASE_BATCH_SIZE = 16
25
16
 
26
17
 
27
18
  class BatchAnalyze:
28
- def __init__(self, model: CustomPEKModel, batch_ratio: int):
29
- self.model = model
19
+ def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable):
20
+ self.model_manager = model_manager
30
21
  self.batch_ratio = batch_ratio
31
-
32
- def __call__(self, images: list) -> list:
22
+ self.show_log = show_log
23
+ self.layout_model = layout_model
24
+ self.formula_enable = formula_enable
25
+ self.table_enable = table_enable
26
+
27
+ def __call__(self, images_with_extra_info: list) -> list:
28
+ if len(images_with_extra_info) == 0:
29
+ return []
30
+
33
31
  images_layout_res = []
34
-
35
32
  layout_start_time = time.time()
33
+ _, fst_ocr, fst_lang = images_with_extra_info[0]
34
+ self.model = self.model_manager.get_model(fst_ocr, self.show_log, fst_lang, self.layout_model, self.formula_enable, self.table_enable)
35
+
36
+ images = [image for image, _, _ in images_with_extra_info]
37
+
36
38
  if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
37
39
  # layoutlmv3
38
40
  for image in images:
@@ -41,39 +43,17 @@ class BatchAnalyze:
41
43
  elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
42
44
  # doclayout_yolo
43
45
  layout_images = []
44
- modified_images = []
45
46
  for image_index, image in enumerate(images):
46
- pil_img = Image.fromarray(image)
47
- # width, height = pil_img.size
48
- # if height > width:
49
- # input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
50
- # new_image, useful_list = crop_img(
51
- # input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
52
- # )
53
- # layout_images.append(new_image)
54
- # modified_images.append([image_index, useful_list])
55
- # else:
56
- layout_images.append(pil_img)
47
+ layout_images.append(image)
57
48
 
58
49
  images_layout_res += self.model.layout_model.batch_predict(
59
50
  # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
60
51
  layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
61
52
  )
62
53
 
63
- for image_index, useful_list in modified_images:
64
- for res in images_layout_res[image_index]:
65
- for i in range(len(res['poly'])):
66
- if i % 2 == 0:
67
- res['poly'][i] = (
68
- res['poly'][i] - useful_list[0] + useful_list[2]
69
- )
70
- else:
71
- res['poly'][i] = (
72
- res['poly'][i] - useful_list[1] + useful_list[3]
73
- )
74
- logger.info(
75
- f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
76
- )
54
+ # logger.info(
55
+ # f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
56
+ # )
77
57
 
78
58
  if self.model.apply_formula:
79
59
  # 公式检测
@@ -82,9 +62,9 @@ class BatchAnalyze:
82
62
  # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
83
63
  images, MFD_BASE_BATCH_SIZE
84
64
  )
85
- logger.info(
86
- f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
87
- )
65
+ # logger.info(
66
+ # f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
67
+ # )
88
68
 
89
69
  # 公式识别
90
70
  mfr_start_time = time.time()
@@ -97,183 +77,177 @@ class BatchAnalyze:
97
77
  for image_index in range(len(images)):
98
78
  images_layout_res[image_index] += images_formula_list[image_index]
99
79
  mfr_count += len(images_formula_list[image_index])
100
- logger.info(
101
- f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
102
- )
80
+ # logger.info(
81
+ # f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
82
+ # )
103
83
 
104
84
  # 清理显存
105
- clean_vram(self.model.device, vram_threshold=8)
85
+ # clean_vram(self.model.device, vram_threshold=8)
106
86
 
107
- ocr_time = 0
108
- ocr_count = 0
109
- table_time = 0
110
- table_count = 0
111
- # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
87
+ ocr_res_list_all_page = []
88
+ table_res_list_all_page = []
112
89
  for index in range(len(images)):
90
+ _, ocr_enable, _lang = images_with_extra_info[index]
113
91
  layout_res = images_layout_res[index]
114
- pil_img = Image.fromarray(images[index])
92
+ np_array_img = images[index]
115
93
 
116
94
  ocr_res_list, table_res_list, single_page_mfdetrec_res = (
117
95
  get_res_list_from_layout_res(layout_res)
118
96
  )
119
- # ocr识别
120
- ocr_start = time.time()
97
+
98
+ ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
99
+ 'lang':_lang,
100
+ 'ocr_enable':ocr_enable,
101
+ 'np_array_img':np_array_img,
102
+ 'single_page_mfdetrec_res':single_page_mfdetrec_res,
103
+ 'layout_res':layout_res,
104
+ })
105
+
106
+ for table_res in table_res_list:
107
+ table_img, _ = crop_img(table_res, np_array_img)
108
+ table_res_list_all_page.append({'table_res':table_res,
109
+ 'lang':_lang,
110
+ 'table_img':table_img,
111
+ })
112
+
113
+ # 文本框检测
114
+ det_start = time.time()
115
+ det_count = 0
116
+ # for ocr_res_list_dict in ocr_res_list_all_page:
117
+ for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
121
118
  # Process each area that requires OCR processing
122
- for res in ocr_res_list:
119
+ _lang = ocr_res_list_dict['lang']
120
+ # Get OCR results for this language's images
121
+ atom_model_manager = AtomModelSingleton()
122
+ ocr_model = atom_model_manager.get_atom_model(
123
+ atom_model_name='ocr',
124
+ ocr_show_log=False,
125
+ det_db_box_thresh=0.3,
126
+ lang=_lang
127
+ )
128
+ for res in ocr_res_list_dict['ocr_res_list']:
123
129
  new_image, useful_list = crop_img(
124
- res, pil_img, crop_paste_x=50, crop_paste_y=50
130
+ res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
125
131
  )
126
132
  adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
127
- single_page_mfdetrec_res, useful_list
133
+ ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
128
134
  )
129
135
 
130
- # OCR recognition
131
- new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
132
-
133
- if self.model.apply_ocr:
134
- ocr_res = self.model.ocr_model.ocr(
135
- new_image, mfd_res=adjusted_mfdetrec_res
136
- )[0]
137
- else:
138
- ocr_res = self.model.ocr_model.ocr(
139
- new_image, mfd_res=adjusted_mfdetrec_res, rec=False
140
- )[0]
136
+ # OCR-det
137
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
138
+ ocr_res = ocr_model.ocr(
139
+ new_image, mfd_res=adjusted_mfdetrec_res, rec=False
140
+ )[0]
141
141
 
142
142
  # Integration results
143
143
  if ocr_res:
144
- ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
145
- layout_res.extend(ocr_result_list)
146
- ocr_time += time.time() - ocr_start
147
- ocr_count += len(ocr_res_list)
148
-
149
- # 表格识别 table recognition
150
- if self.model.apply_table:
151
- table_start = time.time()
152
- for res in table_res_list:
153
- new_image, _ = crop_img(res, pil_img)
154
- single_table_start_time = time.time()
155
- html_code = None
156
- if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
157
- with torch.no_grad():
158
- table_result = self.model.table_model.predict(
159
- new_image, 'html'
160
- )
161
- if len(table_result) > 0:
162
- html_code = table_result[0]
163
- elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
164
- html_code = self.model.table_model.img2html(new_image)
165
- elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
166
- html_code, table_cell_bboxes, logic_points, elapse = (
167
- self.model.table_model.predict(new_image)
168
- )
169
- run_time = time.time() - single_table_start_time
170
- if run_time > self.model.table_max_time:
171
- logger.warning(
172
- f'table recognition processing exceeds max time {self.model.table_max_time}s'
173
- )
174
- # 判断是否返回正常
175
- if html_code:
176
- expected_ending = html_code.strip().endswith(
177
- '</html>'
178
- ) or html_code.strip().endswith('</table>')
179
- if expected_ending:
180
- res['html'] = html_code
181
- else:
182
- logger.warning(
183
- 'table recognition processing fails, not found expected HTML table end'
184
- )
144
+ ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)
145
+ ocr_res_list_dict['layout_res'].extend(ocr_result_list)
146
+ det_count += len(ocr_res_list_dict['ocr_res_list'])
147
+ # logger.info(f'ocr-det time: {round(time.time()-det_start, 2)}, image num: {det_count}')
148
+
149
+
150
+ # 表格识别 table recognition
151
+ if self.model.apply_table:
152
+ table_start = time.time()
153
+ table_count = 0
154
+ # for table_res_list_dict in table_res_list_all_page:
155
+ for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
156
+ _lang = table_res_dict['lang']
157
+ atom_model_manager = AtomModelSingleton()
158
+ ocr_engine = atom_model_manager.get_atom_model(
159
+ atom_model_name='ocr',
160
+ ocr_show_log=False,
161
+ det_db_box_thresh=0.5,
162
+ det_db_unclip_ratio=1.6,
163
+ lang=_lang
164
+ )
165
+ table_model = atom_model_manager.get_atom_model(
166
+ atom_model_name='table',
167
+ table_model_name='rapid_table',
168
+ table_model_path='',
169
+ table_max_time=400,
170
+ device='cpu',
171
+ ocr_engine=ocr_engine,
172
+ table_sub_model_name='slanet_plus'
173
+ )
174
+ html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
175
+ # 判断是否返回正常
176
+ if html_code:
177
+ expected_ending = html_code.strip().endswith(
178
+ '</html>'
179
+ ) or html_code.strip().endswith('</table>')
180
+ if expected_ending:
181
+ table_res_dict['table_res']['html'] = html_code
185
182
  else:
186
183
  logger.warning(
187
- 'table recognition processing fails, not get html return'
184
+ 'table recognition processing fails, not found expected HTML table end'
188
185
  )
189
- table_time += time.time() - table_start
190
- table_count += len(table_res_list)
186
+ else:
187
+ logger.warning(
188
+ 'table recognition processing fails, not get html return'
189
+ )
190
+ # logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {len(table_res_list_all_page)}')
191
191
 
192
- if self.model.apply_ocr:
193
- logger.info(f'ocr time: {round(ocr_time, 2)}, image num: {ocr_count}')
194
- else:
195
- logger.info(f'det time: {round(ocr_time, 2)}, image num: {ocr_count}')
196
- if self.model.apply_table:
197
- logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
192
+ # Create dictionaries to store items by language
193
+ need_ocr_lists_by_lang = {} # Dict of lists for each language
194
+ img_crop_lists_by_lang = {} # Dict of lists for each language
198
195
 
199
- return images_layout_res
196
+ for layout_res in images_layout_res:
197
+ for layout_res_item in layout_res:
198
+ if layout_res_item['category_id'] in [15]:
199
+ if 'np_img' in layout_res_item and 'lang' in layout_res_item:
200
+ lang = layout_res_item['lang']
201
+
202
+ # Initialize lists for this language if not exist
203
+ if lang not in need_ocr_lists_by_lang:
204
+ need_ocr_lists_by_lang[lang] = []
205
+ img_crop_lists_by_lang[lang] = []
206
+
207
+ # Add to the appropriate language-specific lists
208
+ need_ocr_lists_by_lang[lang].append(layout_res_item)
209
+ img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
210
+
211
+ # Remove the fields after adding to lists
212
+ layout_res_item.pop('np_img')
213
+ layout_res_item.pop('lang')
214
+
215
+
216
+ if len(img_crop_lists_by_lang) > 0:
200
217
 
218
+ # Process OCR by language
219
+ rec_time = 0
220
+ rec_start = time.time()
221
+ total_processed = 0
201
222
 
202
- # def doc_batch_analyze(
203
- # dataset: Dataset,
204
- # ocr: bool = False,
205
- # show_log: bool = False,
206
- # start_page_id=0,
207
- # end_page_id=None,
208
- # lang=None,
209
- # layout_model=None,
210
- # formula_enable=None,
211
- # table_enable=None,
212
- # batch_ratio: int | None = None,
213
- # ) -> InferenceResult:
214
- # """Perform batch analysis on a document dataset.
215
- #
216
- # Args:
217
- # dataset (Dataset): The dataset containing document pages to be analyzed.
218
- # ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
219
- # show_log (bool, optional): Flag to enable logging. Defaults to False.
220
- # start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
221
- # end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
222
- # lang (str, optional): Language for OCR. Defaults to None.
223
- # layout_model (optional): Layout model to be used for analysis. Defaults to None.
224
- # formula_enable (optional): Flag to enable formula detection. Defaults to None.
225
- # table_enable (optional): Flag to enable table detection. Defaults to None.
226
- # batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
227
- #
228
- # Raises:
229
- # CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
230
- #
231
- # Returns:
232
- # InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
233
- # """
234
- #
235
- # if not torch.cuda.is_available():
236
- # raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
237
- #
238
- # lang = None if lang == '' else lang
239
- # # TODO: auto detect batch size
240
- # batch_ratio = 1 if batch_ratio is None else batch_ratio
241
- # end_page_id = end_page_id if end_page_id else len(dataset)
242
- #
243
- # model_manager = ModelSingleton()
244
- # custom_model: CustomPEKModel = model_manager.get_model(
245
- # ocr, show_log, lang, layout_model, formula_enable, table_enable
246
- # )
247
- # batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
248
- #
249
- # model_json = []
250
- #
251
- # # batch analyze
252
- # images = []
253
- # for index in range(len(dataset)):
254
- # if start_page_id <= index <= end_page_id:
255
- # page_data = dataset.get_page(index)
256
- # img_dict = page_data.get_image()
257
- # images.append(img_dict['img'])
258
- # analyze_result = batch_model(images)
259
- #
260
- # for index in range(len(dataset)):
261
- # page_data = dataset.get_page(index)
262
- # img_dict = page_data.get_image()
263
- # page_width = img_dict['width']
264
- # page_height = img_dict['height']
265
- # if start_page_id <= index <= end_page_id:
266
- # result = analyze_result.pop(0)
267
- # else:
268
- # result = []
269
- #
270
- # page_info = {'page_no': index, 'height': page_height, 'width': page_width}
271
- # page_dict = {'layout_dets': result, 'page_info': page_info}
272
- # model_json.append(page_dict)
273
- #
274
- # # TODO: clean memory when gpu memory is not enough
275
- # clean_memory_start_time = time.time()
276
- # clean_memory(get_device())
277
- # logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
278
- #
279
- # return InferenceResult(model_json, dataset)
223
+ # Process each language separately
224
+ for lang, img_crop_list in img_crop_lists_by_lang.items():
225
+ if len(img_crop_list) > 0:
226
+ # Get OCR results for this language's images
227
+ atom_model_manager = AtomModelSingleton()
228
+ ocr_model = atom_model_manager.get_atom_model(
229
+ atom_model_name='ocr',
230
+ ocr_show_log=False,
231
+ det_db_box_thresh=0.3,
232
+ lang=lang
233
+ )
234
+ ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
235
+
236
+ # Verify we have matching counts
237
+ assert len(ocr_res_list) == len(
238
+ need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
239
+
240
+ # Process OCR results for this language
241
+ for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
242
+ ocr_text, ocr_score = ocr_res_list[index]
243
+ layout_res_item['text'] = ocr_text
244
+ layout_res_item['score'] = float(round(ocr_score, 2))
245
+
246
+ total_processed += len(img_crop_list)
247
+
248
+ rec_time += time.time() - rec_start
249
+ # logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
250
+
251
+
252
+
253
+ return images_layout_res