magic-pdf 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/dict2md/ocr_mkcontent.py +21 -0
- magic_pdf/filter/__init__.py +1 -1
- magic_pdf/filter/pdf_classify_by_type.py +6 -4
- magic_pdf/filter/pdf_meta_scan.py +4 -4
- magic_pdf/libs/pdf_check.py +11 -1
- magic_pdf/libs/performance_stats.py +54 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/doc_analyze_by_custom_model.py +27 -39
- magic_pdf/model/magic_model.py +160 -4
- magic_pdf/model/pdf_extract_kit.py +0 -7
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +4 -3
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +74 -9
- magic_pdf/model/sub_modules/model_init.py +28 -14
- magic_pdf/pdf_parse_union_core_v2.py +51 -34
- magic_pdf/post_proc/llm_aided.py +14 -16
- magic_pdf/pre_proc/ocr_dict_merge.py +14 -2
- {magic_pdf-1.1.0.dist-info → magic_pdf-1.2.1.dist-info}/METADATA +53 -41
- {magic_pdf-1.1.0.dist-info → magic_pdf-1.2.1.dist-info}/RECORD +22 -22
- {magic_pdf-1.1.0.dist-info → magic_pdf-1.2.1.dist-info}/WHEEL +1 -1
- magic_pdf/post_proc/llm_aided_ocr.py +0 -689
- {magic_pdf-1.1.0.dist-info → magic_pdf-1.2.1.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.1.0.dist-info → magic_pdf-1.2.1.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.1.0.dist-info → magic_pdf-1.2.1.dist-info}/top_level.txt +0 -0
@@ -126,11 +126,32 @@ def detect_language(text):
|
|
126
126
|
return 'empty'
|
127
127
|
|
128
128
|
|
129
|
+
def full_to_half(text: str) -> str:
|
130
|
+
"""Convert full-width characters to half-width characters using code point manipulation.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
text: String containing full-width characters
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
String with full-width characters converted to half-width
|
137
|
+
"""
|
138
|
+
result = []
|
139
|
+
for char in text:
|
140
|
+
code = ord(char)
|
141
|
+
# Full-width letters and numbers (FF21-FF3A for A-Z, FF41-FF5A for a-z, FF10-FF19 for 0-9)
|
142
|
+
if (0xFF21 <= code <= 0xFF3A) or (0xFF41 <= code <= 0xFF5A) or (0xFF10 <= code <= 0xFF19):
|
143
|
+
result.append(chr(code - 0xFEE0)) # Shift to ASCII range
|
144
|
+
else:
|
145
|
+
result.append(char)
|
146
|
+
return ''.join(result)
|
147
|
+
|
148
|
+
|
129
149
|
def merge_para_with_text(para_block):
|
130
150
|
block_text = ''
|
131
151
|
for line in para_block['lines']:
|
132
152
|
for span in line['spans']:
|
133
153
|
if span['type'] in [ContentType.Text]:
|
154
|
+
span['content'] = full_to_half(span['content'])
|
134
155
|
block_text += span['content']
|
135
156
|
block_lang = detect_lang(block_text)
|
136
157
|
|
magic_pdf/filter/__init__.py
CHANGED
@@ -23,7 +23,7 @@ def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
|
|
23
23
|
pdf_meta['image_info_per_page'],
|
24
24
|
pdf_meta['text_len_per_page'],
|
25
25
|
pdf_meta['imgs_per_page'],
|
26
|
-
pdf_meta['text_layout_per_page'],
|
26
|
+
# pdf_meta['text_layout_per_page'],
|
27
27
|
pdf_meta['invalid_chars'],
|
28
28
|
)
|
29
29
|
if is_text_pdf:
|
@@ -305,7 +305,8 @@ def classify_by_img_narrow_strips(page_width, page_height, img_sz_list):
|
|
305
305
|
|
306
306
|
|
307
307
|
def classify(total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list,
|
308
|
-
text_layout_list: list,
|
308
|
+
# text_layout_list: list,
|
309
|
+
invalid_chars: bool):
|
309
310
|
"""
|
310
311
|
这里的图片和页面长度单位是pts
|
311
312
|
:param total_page:
|
@@ -321,7 +322,7 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
|
|
321
322
|
'by_text_len': classify_by_text_len(text_len_list, total_page),
|
322
323
|
'by_avg_words': classify_by_avg_words(text_len_list),
|
323
324
|
'by_img_num': classify_by_img_num(img_sz_list, img_num_list),
|
324
|
-
'by_text_layout': classify_by_text_layout(text_layout_list),
|
325
|
+
# 'by_text_layout': classify_by_text_layout(text_layout_list),
|
325
326
|
'by_img_narrow_strips': classify_by_img_narrow_strips(page_width, page_height, img_sz_list),
|
326
327
|
'by_invalid_chars': invalid_chars,
|
327
328
|
}
|
@@ -332,9 +333,10 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
|
|
332
333
|
return False, results
|
333
334
|
else:
|
334
335
|
logger.warning(
|
335
|
-
f"
|
336
|
+
f"OCR needed based on classification result, by_image_area: {results['by_image_area']},"
|
336
337
|
f" by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']},"
|
337
|
-
f" by_text_layout: {results['by_text_layout']},
|
338
|
+
# f" by_text_layout: {results['by_text_layout']},"
|
339
|
+
f" by_img_narrow_strips: {results['by_img_narrow_strips']},"
|
338
340
|
f" by_invalid_chars: {results['by_invalid_chars']}",
|
339
341
|
file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法
|
340
342
|
return False, results
|
@@ -356,9 +356,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
|
|
356
356
|
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
|
357
357
|
text_len_per_page = get_pdf_textlen_per_page(doc)
|
358
358
|
# logger.info(f"text_len_per_page: {text_len_per_page}")
|
359
|
-
text_layout_per_page = get_pdf_text_layout_per_page(doc)
|
359
|
+
# text_layout_per_page = get_pdf_text_layout_per_page(doc)
|
360
360
|
# logger.info(f"text_layout_per_page: {text_layout_per_page}")
|
361
|
-
text_language = get_language(doc)
|
361
|
+
# text_language = get_language(doc)
|
362
362
|
# logger.info(f"text_language: {text_language}")
|
363
363
|
invalid_chars = check_invalid_chars(pdf_bytes)
|
364
364
|
# logger.info(f"invalid_chars: {invalid_chars}")
|
@@ -372,8 +372,8 @@ def pdf_meta_scan(pdf_bytes: bytes):
|
|
372
372
|
'page_height_pts': int(page_height_pts),
|
373
373
|
'image_info_per_page': image_info_per_page,
|
374
374
|
'text_len_per_page': text_len_per_page,
|
375
|
-
'text_layout_per_page': text_layout_per_page,
|
376
|
-
'text_language': text_language,
|
375
|
+
# 'text_layout_per_page': text_layout_per_page,
|
376
|
+
# 'text_language': text_language,
|
377
377
|
# "svgs_per_page": svgs_per_page,
|
378
378
|
'imgs_per_page': imgs_per_page, # 增加每页img数量list
|
379
379
|
'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list
|
magic_pdf/libs/pdf_check.py
CHANGED
@@ -4,6 +4,7 @@ from loguru import logger
|
|
4
4
|
import re
|
5
5
|
from io import BytesIO
|
6
6
|
from pdfminer.high_level import extract_text
|
7
|
+
from pdfminer.layout import LAParams
|
7
8
|
|
8
9
|
|
9
10
|
def calculate_sample_count(total_page: int):
|
@@ -41,7 +42,16 @@ def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
|
|
41
42
|
sample_docs = extract_pages(src_pdf_bytes)
|
42
43
|
sample_pdf_bytes = sample_docs.tobytes()
|
43
44
|
sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
|
44
|
-
|
45
|
+
laparams = LAParams(
|
46
|
+
line_overlap=0.5,
|
47
|
+
char_margin=2.0,
|
48
|
+
line_margin=0.5,
|
49
|
+
word_margin=0.1,
|
50
|
+
boxes_flow=None,
|
51
|
+
detect_vertical=False,
|
52
|
+
all_texts=False,
|
53
|
+
)
|
54
|
+
text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
|
45
55
|
text = text.replace("\n", "")
|
46
56
|
# logger.info(text)
|
47
57
|
'''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import time
|
2
|
+
import functools
|
3
|
+
from collections import defaultdict
|
4
|
+
from typing import Dict, List
|
5
|
+
|
6
|
+
|
7
|
+
class PerformanceStats:
|
8
|
+
"""性能统计类,用于收集和展示方法执行时间"""
|
9
|
+
|
10
|
+
_stats: Dict[str, List[float]] = defaultdict(list)
|
11
|
+
|
12
|
+
@classmethod
|
13
|
+
def add_execution_time(cls, func_name: str, execution_time: float):
|
14
|
+
"""添加执行时间记录"""
|
15
|
+
cls._stats[func_name].append(execution_time)
|
16
|
+
|
17
|
+
@classmethod
|
18
|
+
def get_stats(cls) -> Dict[str, dict]:
|
19
|
+
"""获取统计结果"""
|
20
|
+
results = {}
|
21
|
+
for func_name, times in cls._stats.items():
|
22
|
+
results[func_name] = {
|
23
|
+
'count': len(times),
|
24
|
+
'total_time': sum(times),
|
25
|
+
'avg_time': sum(times) / len(times),
|
26
|
+
'min_time': min(times),
|
27
|
+
'max_time': max(times)
|
28
|
+
}
|
29
|
+
return results
|
30
|
+
|
31
|
+
@classmethod
|
32
|
+
def print_stats(cls):
|
33
|
+
"""打印统计结果"""
|
34
|
+
stats = cls.get_stats()
|
35
|
+
print("\n性能统计结果:")
|
36
|
+
print("-" * 80)
|
37
|
+
print(f"{'方法名':<40} {'调用次数':>8} {'总时间(s)':>12} {'平均时间(s)':>12}")
|
38
|
+
print("-" * 80)
|
39
|
+
for func_name, data in stats.items():
|
40
|
+
print(f"{func_name:<40} {data['count']:8d} {data['total_time']:12.6f} {data['avg_time']:12.6f}")
|
41
|
+
|
42
|
+
|
43
|
+
def measure_time(func):
|
44
|
+
"""测量方法执行时间的装饰器"""
|
45
|
+
|
46
|
+
@functools.wraps(func)
|
47
|
+
def wrapper(*args, **kwargs):
|
48
|
+
start_time = time.time()
|
49
|
+
result = func(*args, **kwargs)
|
50
|
+
execution_time = time.time() - start_time
|
51
|
+
PerformanceStats.add_execution_time(func.__name__, execution_time)
|
52
|
+
return result
|
53
|
+
|
54
|
+
return wrapper
|
magic_pdf/libs/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "1.1
|
1
|
+
__version__ = "1.2.1"
|
@@ -1,21 +1,22 @@
|
|
1
1
|
import os
|
2
2
|
import time
|
3
|
+
import torch
|
3
4
|
|
5
|
+
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
|
6
|
+
os.environ['FLAGS_use_stride_kernel'] = '0'
|
7
|
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
|
8
|
+
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
4
9
|
# 关闭paddle的信号处理
|
5
10
|
import paddle
|
6
|
-
|
11
|
+
paddle.disable_signal_handler()
|
12
|
+
|
7
13
|
from loguru import logger
|
8
14
|
|
9
15
|
from magic_pdf.model.batch_analyze import BatchAnalyze
|
10
16
|
from magic_pdf.model.sub_modules.model_utils import get_vram
|
11
17
|
|
12
|
-
paddle.disable_signal_handler()
|
13
|
-
|
14
|
-
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
15
|
-
|
16
18
|
try:
|
17
19
|
import torchtext
|
18
|
-
|
19
20
|
if torchtext.__version__ >= '0.18.0':
|
20
21
|
torchtext.disable_torchtext_deprecation_warning()
|
21
22
|
except ImportError:
|
@@ -32,20 +33,6 @@ from magic_pdf.model.model_list import MODEL
|
|
32
33
|
from magic_pdf.operators.models import InferenceResult
|
33
34
|
|
34
35
|
|
35
|
-
def dict_compare(d1, d2):
|
36
|
-
return d1.items() == d2.items()
|
37
|
-
|
38
|
-
|
39
|
-
def remove_duplicates_dicts(lst):
|
40
|
-
unique_dicts = []
|
41
|
-
for dict_item in lst:
|
42
|
-
if not any(
|
43
|
-
dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
|
44
|
-
):
|
45
|
-
unique_dicts.append(dict_item)
|
46
|
-
return unique_dicts
|
47
|
-
|
48
|
-
|
49
36
|
class ModelSingleton:
|
50
37
|
_instance = None
|
51
38
|
_models = {}
|
@@ -158,7 +145,11 @@ def doc_analyze(
|
|
158
145
|
table_enable=None,
|
159
146
|
) -> InferenceResult:
|
160
147
|
|
161
|
-
end_page_id =
|
148
|
+
end_page_id = (
|
149
|
+
end_page_id
|
150
|
+
if end_page_id is not None and end_page_id >= 0
|
151
|
+
else len(dataset) - 1
|
152
|
+
)
|
162
153
|
|
163
154
|
model_manager = ModelSingleton()
|
164
155
|
custom_model = model_manager.get_model(
|
@@ -166,6 +157,7 @@ def doc_analyze(
|
|
166
157
|
)
|
167
158
|
|
168
159
|
batch_analyze = False
|
160
|
+
batch_ratio = 1
|
169
161
|
device = get_device()
|
170
162
|
|
171
163
|
npu_support = False
|
@@ -178,21 +170,15 @@ def doc_analyze(
|
|
178
170
|
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
|
179
171
|
if gpu_memory is not None and gpu_memory >= 8:
|
180
172
|
|
181
|
-
if
|
182
|
-
batch_ratio = 2
|
183
|
-
elif 10 <= gpu_memory <= 12:
|
184
|
-
batch_ratio = 4
|
185
|
-
elif 12 < gpu_memory <= 16:
|
173
|
+
if gpu_memory >= 16:
|
186
174
|
batch_ratio = 8
|
187
|
-
elif
|
188
|
-
batch_ratio =
|
175
|
+
elif gpu_memory >= 10:
|
176
|
+
batch_ratio = 4
|
189
177
|
else:
|
190
|
-
batch_ratio =
|
178
|
+
batch_ratio = 2
|
191
179
|
|
192
|
-
|
193
|
-
|
194
|
-
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
195
|
-
batch_analyze = True
|
180
|
+
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
|
181
|
+
batch_analyze = True
|
196
182
|
|
197
183
|
model_json = []
|
198
184
|
doc_analyze_start = time.time()
|
@@ -200,24 +186,26 @@ def doc_analyze(
|
|
200
186
|
if batch_analyze:
|
201
187
|
# batch analyze
|
202
188
|
images = []
|
189
|
+
page_wh_list = []
|
203
190
|
for index in range(len(dataset)):
|
204
191
|
if start_page_id <= index <= end_page_id:
|
205
192
|
page_data = dataset.get_page(index)
|
206
193
|
img_dict = page_data.get_image()
|
207
194
|
images.append(img_dict['img'])
|
195
|
+
page_wh_list.append((img_dict['width'], img_dict['height']))
|
196
|
+
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
|
208
197
|
analyze_result = batch_model(images)
|
209
198
|
|
210
199
|
for index in range(len(dataset)):
|
211
|
-
page_data = dataset.get_page(index)
|
212
|
-
img_dict = page_data.get_image()
|
213
|
-
page_width = img_dict['width']
|
214
|
-
page_height = img_dict['height']
|
215
200
|
if start_page_id <= index <= end_page_id:
|
216
201
|
result = analyze_result.pop(0)
|
202
|
+
page_width, page_height = page_wh_list.pop(0)
|
217
203
|
else:
|
218
204
|
result = []
|
205
|
+
page_height = 0
|
206
|
+
page_width = 0
|
219
207
|
|
220
|
-
page_info = {'page_no': index, '
|
208
|
+
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
|
221
209
|
page_dict = {'layout_dets': result, 'page_info': page_info}
|
222
210
|
model_json.append(page_dict)
|
223
211
|
|
@@ -237,7 +225,7 @@ def doc_analyze(
|
|
237
225
|
else:
|
238
226
|
result = []
|
239
227
|
|
240
|
-
page_info = {'page_no': index, '
|
228
|
+
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
|
241
229
|
page_dict = {'layout_dets': result, 'page_info': page_info}
|
242
230
|
model_json.append(page_dict)
|
243
231
|
|
magic_pdf/model/magic_model.py
CHANGED
@@ -450,11 +450,167 @@ class MagicModel:
|
|
450
450
|
)
|
451
451
|
return ret
|
452
452
|
|
453
|
+
|
454
|
+
def __tie_up_category_by_distance_v3(
|
455
|
+
self,
|
456
|
+
page_no: int,
|
457
|
+
subject_category_id: int,
|
458
|
+
object_category_id: int,
|
459
|
+
priority_pos: PosRelationEnum,
|
460
|
+
):
|
461
|
+
subjects = self.__reduct_overlap(
|
462
|
+
list(
|
463
|
+
map(
|
464
|
+
lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
465
|
+
filter(
|
466
|
+
lambda x: x['category_id'] == subject_category_id,
|
467
|
+
self.__model_list[page_no]['layout_dets'],
|
468
|
+
),
|
469
|
+
)
|
470
|
+
)
|
471
|
+
)
|
472
|
+
objects = self.__reduct_overlap(
|
473
|
+
list(
|
474
|
+
map(
|
475
|
+
lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
476
|
+
filter(
|
477
|
+
lambda x: x['category_id'] == object_category_id,
|
478
|
+
self.__model_list[page_no]['layout_dets'],
|
479
|
+
),
|
480
|
+
)
|
481
|
+
)
|
482
|
+
)
|
483
|
+
|
484
|
+
ret = []
|
485
|
+
N, M = len(subjects), len(objects)
|
486
|
+
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
|
487
|
+
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
|
488
|
+
|
489
|
+
OBJ_IDX_OFFSET = 10000
|
490
|
+
SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
|
491
|
+
|
492
|
+
all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
|
493
|
+
seen_idx = set()
|
494
|
+
seen_sub_idx = set()
|
495
|
+
|
496
|
+
while N > len(seen_sub_idx):
|
497
|
+
candidates = []
|
498
|
+
for idx, kind, x0, y0 in all_boxes_with_idx:
|
499
|
+
if idx in seen_idx:
|
500
|
+
continue
|
501
|
+
candidates.append((idx, kind, x0, y0))
|
502
|
+
|
503
|
+
if len(candidates) == 0:
|
504
|
+
break
|
505
|
+
left_x = min([v[2] for v in candidates])
|
506
|
+
top_y = min([v[3] for v in candidates])
|
507
|
+
|
508
|
+
candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
|
509
|
+
|
510
|
+
|
511
|
+
fst_idx, fst_kind, left_x, top_y = candidates[0]
|
512
|
+
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
|
513
|
+
nxt = None
|
514
|
+
|
515
|
+
for i in range(1, len(candidates)):
|
516
|
+
if candidates[i][1] ^ fst_kind == 1:
|
517
|
+
nxt = candidates[i]
|
518
|
+
break
|
519
|
+
if nxt is None:
|
520
|
+
break
|
521
|
+
|
522
|
+
if fst_kind == SUB_BIT_KIND:
|
523
|
+
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
|
524
|
+
|
525
|
+
else:
|
526
|
+
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
|
527
|
+
|
528
|
+
pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
|
529
|
+
nearest_dis = float('inf')
|
530
|
+
for i in range(N):
|
531
|
+
if i in seen_idx or i == sub_idx:continue
|
532
|
+
nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
|
533
|
+
|
534
|
+
if pair_dis >= 3*nearest_dis:
|
535
|
+
seen_idx.add(sub_idx)
|
536
|
+
continue
|
537
|
+
|
538
|
+
seen_idx.add(sub_idx)
|
539
|
+
seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
|
540
|
+
seen_sub_idx.add(sub_idx)
|
541
|
+
|
542
|
+
ret.append(
|
543
|
+
{
|
544
|
+
'sub_bbox': {
|
545
|
+
'bbox': subjects[sub_idx]['bbox'],
|
546
|
+
'score': subjects[sub_idx]['score'],
|
547
|
+
},
|
548
|
+
'obj_bboxes': [
|
549
|
+
{'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
|
550
|
+
],
|
551
|
+
'sub_idx': sub_idx,
|
552
|
+
}
|
553
|
+
)
|
554
|
+
|
555
|
+
for i in range(len(objects)):
|
556
|
+
j = i + OBJ_IDX_OFFSET
|
557
|
+
if j in seen_idx:
|
558
|
+
continue
|
559
|
+
seen_idx.add(j)
|
560
|
+
nearest_dis, nearest_sub_idx = float('inf'), -1
|
561
|
+
for k in range(len(subjects)):
|
562
|
+
dis = bbox_distance(objects[i]['bbox'], subjects[k]['bbox'])
|
563
|
+
if dis < nearest_dis:
|
564
|
+
nearest_dis = dis
|
565
|
+
nearest_sub_idx = k
|
566
|
+
|
567
|
+
for k in range(len(subjects)):
|
568
|
+
if k != nearest_sub_idx: continue
|
569
|
+
if k in seen_sub_idx:
|
570
|
+
for kk in range(len(ret)):
|
571
|
+
if ret[kk]['sub_idx'] == k:
|
572
|
+
ret[kk]['obj_bboxes'].append({'score': objects[i]['score'], 'bbox': objects[i]['bbox']})
|
573
|
+
break
|
574
|
+
else:
|
575
|
+
ret.append(
|
576
|
+
{
|
577
|
+
'sub_bbox': {
|
578
|
+
'bbox': subjects[k]['bbox'],
|
579
|
+
'score': subjects[k]['score'],
|
580
|
+
},
|
581
|
+
'obj_bboxes': [
|
582
|
+
{'score': objects[i]['score'], 'bbox': objects[i]['bbox']}
|
583
|
+
],
|
584
|
+
'sub_idx': k,
|
585
|
+
}
|
586
|
+
)
|
587
|
+
seen_sub_idx.add(k)
|
588
|
+
seen_idx.add(k)
|
589
|
+
|
590
|
+
|
591
|
+
for i in range(len(subjects)):
|
592
|
+
if i in seen_sub_idx:
|
593
|
+
continue
|
594
|
+
ret.append(
|
595
|
+
{
|
596
|
+
'sub_bbox': {
|
597
|
+
'bbox': subjects[i]['bbox'],
|
598
|
+
'score': subjects[i]['score'],
|
599
|
+
},
|
600
|
+
'obj_bboxes': [],
|
601
|
+
'sub_idx': i,
|
602
|
+
}
|
603
|
+
)
|
604
|
+
|
605
|
+
|
606
|
+
return ret
|
607
|
+
|
608
|
+
|
453
609
|
def get_imgs_v2(self, page_no: int):
|
454
|
-
with_captions = self.
|
610
|
+
with_captions = self.__tie_up_category_by_distance_v3(
|
455
611
|
page_no, 3, 4, PosRelationEnum.BOTTOM
|
456
612
|
)
|
457
|
-
with_footnotes = self.
|
613
|
+
with_footnotes = self.__tie_up_category_by_distance_v3(
|
458
614
|
page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
|
459
615
|
)
|
460
616
|
ret = []
|
@@ -470,10 +626,10 @@ class MagicModel:
|
|
470
626
|
return ret
|
471
627
|
|
472
628
|
def get_tables_v2(self, page_no: int) -> list:
|
473
|
-
with_captions = self.
|
629
|
+
with_captions = self.__tie_up_category_by_distance_v3(
|
474
630
|
page_no, 5, 6, PosRelationEnum.UP
|
475
631
|
)
|
476
|
-
with_footnotes = self.
|
632
|
+
with_footnotes = self.__tie_up_category_by_distance_v3(
|
477
633
|
page_no, 5, 7, PosRelationEnum.ALL
|
478
634
|
)
|
479
635
|
ret = []
|
@@ -89,13 +89,6 @@ class CustomPEKModel:
|
|
89
89
|
# 初始化解析方案
|
90
90
|
self.device = kwargs.get('device', 'cpu')
|
91
91
|
|
92
|
-
if str(self.device).startswith("npu"):
|
93
|
-
import torch_npu
|
94
|
-
os.environ['FLAGS_npu_jit_compile'] = '0'
|
95
|
-
os.environ['FLAGS_use_stride_kernel'] = '0'
|
96
|
-
elif str(self.device).startswith("mps"):
|
97
|
-
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
98
|
-
|
99
92
|
logger.info('using device: {}'.format(self.device))
|
100
93
|
models_dir = kwargs.get(
|
101
94
|
'models_dir', os.path.join(root_dir, 'resources', 'models')
|
@@ -1,4 +1,5 @@
|
|
1
1
|
# Copyright (c) Opendatalab. All rights reserved.
|
2
|
+
import time
|
2
3
|
from collections import Counter
|
3
4
|
from uuid import uuid4
|
4
5
|
|
@@ -102,9 +103,9 @@ class YOLOv11LangDetModel(object):
|
|
102
103
|
temp_images = split_images(image)
|
103
104
|
for temp_image in temp_images:
|
104
105
|
all_images.append(resize_images_to_224(temp_image))
|
105
|
-
|
106
|
-
images_lang_res = self.batch_predict(all_images, batch_size=
|
107
|
-
# logger.info(f"
|
106
|
+
# langdetect_start = time.time()
|
107
|
+
images_lang_res = self.batch_predict(all_images, batch_size=256)
|
108
|
+
# logger.info(f"image number of langdetect: {len(images_lang_res)}, langdetect time: {round(time.time() - langdetect_start, 2)}")
|
108
109
|
if len(images_lang_res) > 0:
|
109
110
|
count_dict = Counter(images_lang_res)
|
110
111
|
language = max(count_dict, key=count_dict.get)
|
@@ -100,20 +100,61 @@ class UnimernetModel(object):
|
|
100
100
|
res["latex"] = latex_rm_whitespace(latex)
|
101
101
|
return formula_list
|
102
102
|
|
103
|
-
def batch_predict(
|
104
|
-
|
105
|
-
) -> list:
|
103
|
+
# def batch_predict(
|
104
|
+
# self, images_mfd_res: list, images: list, batch_size: int = 64
|
105
|
+
# ) -> list:
|
106
|
+
# images_formula_list = []
|
107
|
+
# mf_image_list = []
|
108
|
+
# backfill_list = []
|
109
|
+
# for image_index in range(len(images_mfd_res)):
|
110
|
+
# mfd_res = images_mfd_res[image_index]
|
111
|
+
# pil_img = Image.fromarray(images[image_index])
|
112
|
+
# formula_list = []
|
113
|
+
#
|
114
|
+
# for xyxy, conf, cla in zip(
|
115
|
+
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
116
|
+
# ):
|
117
|
+
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
118
|
+
# new_item = {
|
119
|
+
# "category_id": 13 + int(cla.item()),
|
120
|
+
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
121
|
+
# "score": round(float(conf.item()), 2),
|
122
|
+
# "latex": "",
|
123
|
+
# }
|
124
|
+
# formula_list.append(new_item)
|
125
|
+
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
126
|
+
# mf_image_list.append(bbox_img)
|
127
|
+
#
|
128
|
+
# images_formula_list.append(formula_list)
|
129
|
+
# backfill_list += formula_list
|
130
|
+
#
|
131
|
+
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
132
|
+
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
133
|
+
# mfr_res = []
|
134
|
+
# for mf_img in dataloader:
|
135
|
+
# mf_img = mf_img.to(self.device)
|
136
|
+
# with torch.no_grad():
|
137
|
+
# output = self.model.generate({"image": mf_img})
|
138
|
+
# mfr_res.extend(output["pred_str"])
|
139
|
+
# for res, latex in zip(backfill_list, mfr_res):
|
140
|
+
# res["latex"] = latex_rm_whitespace(latex)
|
141
|
+
# return images_formula_list
|
142
|
+
|
143
|
+
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
|
106
144
|
images_formula_list = []
|
107
145
|
mf_image_list = []
|
108
146
|
backfill_list = []
|
147
|
+
image_info = [] # Store (area, original_index, image) tuples
|
148
|
+
|
149
|
+
# Collect images with their original indices
|
109
150
|
for image_index in range(len(images_mfd_res)):
|
110
151
|
mfd_res = images_mfd_res[image_index]
|
111
152
|
pil_img = Image.fromarray(images[image_index])
|
112
153
|
formula_list = []
|
113
154
|
|
114
|
-
for xyxy, conf, cla in zip(
|
115
|
-
|
116
|
-
):
|
155
|
+
for idx, (xyxy, conf, cla) in enumerate(zip(
|
156
|
+
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
157
|
+
)):
|
117
158
|
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
118
159
|
new_item = {
|
119
160
|
"category_id": 13 + int(cla.item()),
|
@@ -123,19 +164,43 @@ class UnimernetModel(object):
|
|
123
164
|
}
|
124
165
|
formula_list.append(new_item)
|
125
166
|
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
167
|
+
area = (xmax - xmin) * (ymax - ymin)
|
168
|
+
|
169
|
+
curr_idx = len(mf_image_list)
|
170
|
+
image_info.append((area, curr_idx, bbox_img))
|
126
171
|
mf_image_list.append(bbox_img)
|
127
172
|
|
128
173
|
images_formula_list.append(formula_list)
|
129
174
|
backfill_list += formula_list
|
130
175
|
|
131
|
-
|
176
|
+
# Stable sort by area
|
177
|
+
image_info.sort(key=lambda x: x[0]) # sort by area
|
178
|
+
sorted_indices = [x[1] for x in image_info]
|
179
|
+
sorted_images = [x[2] for x in image_info]
|
180
|
+
|
181
|
+
# Create mapping for results
|
182
|
+
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
|
183
|
+
|
184
|
+
# Create dataset with sorted images
|
185
|
+
dataset = MathDataset(sorted_images, transform=self.mfr_transform)
|
132
186
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
187
|
+
|
188
|
+
# Process batches and store results
|
133
189
|
mfr_res = []
|
134
190
|
for mf_img in dataloader:
|
135
191
|
mf_img = mf_img.to(self.device)
|
136
192
|
with torch.no_grad():
|
137
193
|
output = self.model.generate({"image": mf_img})
|
138
194
|
mfr_res.extend(output["pred_str"])
|
139
|
-
|
140
|
-
|
195
|
+
|
196
|
+
# Restore original order
|
197
|
+
unsorted_results = [""] * len(mfr_res)
|
198
|
+
for new_idx, latex in enumerate(mfr_res):
|
199
|
+
original_idx = index_mapping[new_idx]
|
200
|
+
unsorted_results[original_idx] = latex_rm_whitespace(latex)
|
201
|
+
|
202
|
+
# Fill results back
|
203
|
+
for res, latex in zip(backfill_list, unsorted_results):
|
204
|
+
res["latex"] = latex
|
205
|
+
|
141
206
|
return images_formula_list
|