magic-pdf 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -126,11 +126,32 @@ def detect_language(text):
126
126
  return 'empty'
127
127
 
128
128
 
129
+ def full_to_half(text: str) -> str:
130
+ """Convert full-width characters to half-width characters using code point manipulation.
131
+
132
+ Args:
133
+ text: String containing full-width characters
134
+
135
+ Returns:
136
+ String with full-width characters converted to half-width
137
+ """
138
+ result = []
139
+ for char in text:
140
+ code = ord(char)
141
+ # Full-width letters and numbers (FF21-FF3A for A-Z, FF41-FF5A for a-z, FF10-FF19 for 0-9)
142
+ if (0xFF21 <= code <= 0xFF3A) or (0xFF41 <= code <= 0xFF5A) or (0xFF10 <= code <= 0xFF19):
143
+ result.append(chr(code - 0xFEE0)) # Shift to ASCII range
144
+ else:
145
+ result.append(char)
146
+ return ''.join(result)
147
+
148
+
129
149
  def merge_para_with_text(para_block):
130
150
  block_text = ''
131
151
  for line in para_block['lines']:
132
152
  for span in line['spans']:
133
153
  if span['type'] in [ContentType.Text]:
154
+ span['content'] = full_to_half(span['content'])
134
155
  block_text += span['content']
135
156
  block_lang = detect_lang(block_text)
136
157
 
@@ -23,7 +23,7 @@ def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
23
23
  pdf_meta['image_info_per_page'],
24
24
  pdf_meta['text_len_per_page'],
25
25
  pdf_meta['imgs_per_page'],
26
- pdf_meta['text_layout_per_page'],
26
+ # pdf_meta['text_layout_per_page'],
27
27
  pdf_meta['invalid_chars'],
28
28
  )
29
29
  if is_text_pdf:
@@ -305,7 +305,8 @@ def classify_by_img_narrow_strips(page_width, page_height, img_sz_list):
305
305
 
306
306
 
307
307
  def classify(total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list,
308
- text_layout_list: list, invalid_chars: bool):
308
+ # text_layout_list: list,
309
+ invalid_chars: bool):
309
310
  """
310
311
  这里的图片和页面长度单位是pts
311
312
  :param total_page:
@@ -321,7 +322,7 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
321
322
  'by_text_len': classify_by_text_len(text_len_list, total_page),
322
323
  'by_avg_words': classify_by_avg_words(text_len_list),
323
324
  'by_img_num': classify_by_img_num(img_sz_list, img_num_list),
324
- 'by_text_layout': classify_by_text_layout(text_layout_list),
325
+ # 'by_text_layout': classify_by_text_layout(text_layout_list),
325
326
  'by_img_narrow_strips': classify_by_img_narrow_strips(page_width, page_height, img_sz_list),
326
327
  'by_invalid_chars': invalid_chars,
327
328
  }
@@ -332,9 +333,10 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
332
333
  return False, results
333
334
  else:
334
335
  logger.warning(
335
- f"pdf is not classified by area and text_len, by_image_area: {results['by_image_area']},"
336
+ f"OCR needed based on classification result, by_image_area: {results['by_image_area']},"
336
337
  f" by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']},"
337
- f" by_text_layout: {results['by_text_layout']}, by_img_narrow_strips: {results['by_img_narrow_strips']},"
338
+ # f" by_text_layout: {results['by_text_layout']},"
339
+ f" by_img_narrow_strips: {results['by_img_narrow_strips']},"
338
340
  f" by_invalid_chars: {results['by_invalid_chars']}",
339
341
  file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法
340
342
  return False, results
@@ -356,9 +356,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
356
356
  # logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
357
357
  text_len_per_page = get_pdf_textlen_per_page(doc)
358
358
  # logger.info(f"text_len_per_page: {text_len_per_page}")
359
- text_layout_per_page = get_pdf_text_layout_per_page(doc)
359
+ # text_layout_per_page = get_pdf_text_layout_per_page(doc)
360
360
  # logger.info(f"text_layout_per_page: {text_layout_per_page}")
361
- text_language = get_language(doc)
361
+ # text_language = get_language(doc)
362
362
  # logger.info(f"text_language: {text_language}")
363
363
  invalid_chars = check_invalid_chars(pdf_bytes)
364
364
  # logger.info(f"invalid_chars: {invalid_chars}")
@@ -372,8 +372,8 @@ def pdf_meta_scan(pdf_bytes: bytes):
372
372
  'page_height_pts': int(page_height_pts),
373
373
  'image_info_per_page': image_info_per_page,
374
374
  'text_len_per_page': text_len_per_page,
375
- 'text_layout_per_page': text_layout_per_page,
376
- 'text_language': text_language,
375
+ # 'text_layout_per_page': text_layout_per_page,
376
+ # 'text_language': text_language,
377
377
  # "svgs_per_page": svgs_per_page,
378
378
  'imgs_per_page': imgs_per_page, # 增加每页img数量list
379
379
  'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list
@@ -4,6 +4,7 @@ from loguru import logger
4
4
  import re
5
5
  from io import BytesIO
6
6
  from pdfminer.high_level import extract_text
7
+ from pdfminer.layout import LAParams
7
8
 
8
9
 
9
10
  def calculate_sample_count(total_page: int):
@@ -41,7 +42,16 @@ def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
41
42
  sample_docs = extract_pages(src_pdf_bytes)
42
43
  sample_pdf_bytes = sample_docs.tobytes()
43
44
  sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
44
- text = extract_text(sample_pdf_file_like_object)
45
+ laparams = LAParams(
46
+ line_overlap=0.5,
47
+ char_margin=2.0,
48
+ line_margin=0.5,
49
+ word_margin=0.1,
50
+ boxes_flow=None,
51
+ detect_vertical=False,
52
+ all_texts=False,
53
+ )
54
+ text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
45
55
  text = text.replace("\n", "")
46
56
  # logger.info(text)
47
57
  '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
@@ -0,0 +1,54 @@
1
+ import time
2
+ import functools
3
+ from collections import defaultdict
4
+ from typing import Dict, List
5
+
6
+
7
+ class PerformanceStats:
8
+ """性能统计类,用于收集和展示方法执行时间"""
9
+
10
+ _stats: Dict[str, List[float]] = defaultdict(list)
11
+
12
+ @classmethod
13
+ def add_execution_time(cls, func_name: str, execution_time: float):
14
+ """添加执行时间记录"""
15
+ cls._stats[func_name].append(execution_time)
16
+
17
+ @classmethod
18
+ def get_stats(cls) -> Dict[str, dict]:
19
+ """获取统计结果"""
20
+ results = {}
21
+ for func_name, times in cls._stats.items():
22
+ results[func_name] = {
23
+ 'count': len(times),
24
+ 'total_time': sum(times),
25
+ 'avg_time': sum(times) / len(times),
26
+ 'min_time': min(times),
27
+ 'max_time': max(times)
28
+ }
29
+ return results
30
+
31
+ @classmethod
32
+ def print_stats(cls):
33
+ """打印统计结果"""
34
+ stats = cls.get_stats()
35
+ print("\n性能统计结果:")
36
+ print("-" * 80)
37
+ print(f"{'方法名':<40} {'调用次数':>8} {'总时间(s)':>12} {'平均时间(s)':>12}")
38
+ print("-" * 80)
39
+ for func_name, data in stats.items():
40
+ print(f"{func_name:<40} {data['count']:8d} {data['total_time']:12.6f} {data['avg_time']:12.6f}")
41
+
42
+
43
+ def measure_time(func):
44
+ """测量方法执行时间的装饰器"""
45
+
46
+ @functools.wraps(func)
47
+ def wrapper(*args, **kwargs):
48
+ start_time = time.time()
49
+ result = func(*args, **kwargs)
50
+ execution_time = time.time() - start_time
51
+ PerformanceStats.add_execution_time(func.__name__, execution_time)
52
+ return result
53
+
54
+ return wrapper
magic_pdf/libs/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.1.0"
1
+ __version__ = "1.2.1"
@@ -1,21 +1,22 @@
1
1
  import os
2
2
  import time
3
+ import torch
3
4
 
5
+ os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
6
+ os.environ['FLAGS_use_stride_kernel'] = '0'
7
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
8
+ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
4
9
  # 关闭paddle的信号处理
5
10
  import paddle
6
- import torch
11
+ paddle.disable_signal_handler()
12
+
7
13
  from loguru import logger
8
14
 
9
15
  from magic_pdf.model.batch_analyze import BatchAnalyze
10
16
  from magic_pdf.model.sub_modules.model_utils import get_vram
11
17
 
12
- paddle.disable_signal_handler()
13
-
14
- os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
15
-
16
18
  try:
17
19
  import torchtext
18
-
19
20
  if torchtext.__version__ >= '0.18.0':
20
21
  torchtext.disable_torchtext_deprecation_warning()
21
22
  except ImportError:
@@ -32,20 +33,6 @@ from magic_pdf.model.model_list import MODEL
32
33
  from magic_pdf.operators.models import InferenceResult
33
34
 
34
35
 
35
- def dict_compare(d1, d2):
36
- return d1.items() == d2.items()
37
-
38
-
39
- def remove_duplicates_dicts(lst):
40
- unique_dicts = []
41
- for dict_item in lst:
42
- if not any(
43
- dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
44
- ):
45
- unique_dicts.append(dict_item)
46
- return unique_dicts
47
-
48
-
49
36
  class ModelSingleton:
50
37
  _instance = None
51
38
  _models = {}
@@ -158,7 +145,11 @@ def doc_analyze(
158
145
  table_enable=None,
159
146
  ) -> InferenceResult:
160
147
 
161
- end_page_id = end_page_id if end_page_id else len(dataset) - 1
148
+ end_page_id = (
149
+ end_page_id
150
+ if end_page_id is not None and end_page_id >= 0
151
+ else len(dataset) - 1
152
+ )
162
153
 
163
154
  model_manager = ModelSingleton()
164
155
  custom_model = model_manager.get_model(
@@ -166,6 +157,7 @@ def doc_analyze(
166
157
  )
167
158
 
168
159
  batch_analyze = False
160
+ batch_ratio = 1
169
161
  device = get_device()
170
162
 
171
163
  npu_support = False
@@ -178,21 +170,15 @@ def doc_analyze(
178
170
  gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
179
171
  if gpu_memory is not None and gpu_memory >= 8:
180
172
 
181
- if 8 <= gpu_memory < 10:
182
- batch_ratio = 2
183
- elif 10 <= gpu_memory <= 12:
184
- batch_ratio = 4
185
- elif 12 < gpu_memory <= 16:
173
+ if gpu_memory >= 16:
186
174
  batch_ratio = 8
187
- elif 16 < gpu_memory <= 24:
188
- batch_ratio = 16
175
+ elif gpu_memory >= 10:
176
+ batch_ratio = 4
189
177
  else:
190
- batch_ratio = 32
178
+ batch_ratio = 2
191
179
 
192
- if batch_ratio >= 1:
193
- logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
194
- batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
195
- batch_analyze = True
180
+ logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
181
+ batch_analyze = True
196
182
 
197
183
  model_json = []
198
184
  doc_analyze_start = time.time()
@@ -200,24 +186,26 @@ def doc_analyze(
200
186
  if batch_analyze:
201
187
  # batch analyze
202
188
  images = []
189
+ page_wh_list = []
203
190
  for index in range(len(dataset)):
204
191
  if start_page_id <= index <= end_page_id:
205
192
  page_data = dataset.get_page(index)
206
193
  img_dict = page_data.get_image()
207
194
  images.append(img_dict['img'])
195
+ page_wh_list.append((img_dict['width'], img_dict['height']))
196
+ batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
208
197
  analyze_result = batch_model(images)
209
198
 
210
199
  for index in range(len(dataset)):
211
- page_data = dataset.get_page(index)
212
- img_dict = page_data.get_image()
213
- page_width = img_dict['width']
214
- page_height = img_dict['height']
215
200
  if start_page_id <= index <= end_page_id:
216
201
  result = analyze_result.pop(0)
202
+ page_width, page_height = page_wh_list.pop(0)
217
203
  else:
218
204
  result = []
205
+ page_height = 0
206
+ page_width = 0
219
207
 
220
- page_info = {'page_no': index, 'height': page_height, 'width': page_width}
208
+ page_info = {'page_no': index, 'width': page_width, 'height': page_height}
221
209
  page_dict = {'layout_dets': result, 'page_info': page_info}
222
210
  model_json.append(page_dict)
223
211
 
@@ -237,7 +225,7 @@ def doc_analyze(
237
225
  else:
238
226
  result = []
239
227
 
240
- page_info = {'page_no': index, 'height': page_height, 'width': page_width}
228
+ page_info = {'page_no': index, 'width': page_width, 'height': page_height}
241
229
  page_dict = {'layout_dets': result, 'page_info': page_info}
242
230
  model_json.append(page_dict)
243
231
 
@@ -450,11 +450,167 @@ class MagicModel:
450
450
  )
451
451
  return ret
452
452
 
453
+
454
+ def __tie_up_category_by_distance_v3(
455
+ self,
456
+ page_no: int,
457
+ subject_category_id: int,
458
+ object_category_id: int,
459
+ priority_pos: PosRelationEnum,
460
+ ):
461
+ subjects = self.__reduct_overlap(
462
+ list(
463
+ map(
464
+ lambda x: {'bbox': x['bbox'], 'score': x['score']},
465
+ filter(
466
+ lambda x: x['category_id'] == subject_category_id,
467
+ self.__model_list[page_no]['layout_dets'],
468
+ ),
469
+ )
470
+ )
471
+ )
472
+ objects = self.__reduct_overlap(
473
+ list(
474
+ map(
475
+ lambda x: {'bbox': x['bbox'], 'score': x['score']},
476
+ filter(
477
+ lambda x: x['category_id'] == object_category_id,
478
+ self.__model_list[page_no]['layout_dets'],
479
+ ),
480
+ )
481
+ )
482
+ )
483
+
484
+ ret = []
485
+ N, M = len(subjects), len(objects)
486
+ subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
487
+ objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
488
+
489
+ OBJ_IDX_OFFSET = 10000
490
+ SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
491
+
492
+ all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
493
+ seen_idx = set()
494
+ seen_sub_idx = set()
495
+
496
+ while N > len(seen_sub_idx):
497
+ candidates = []
498
+ for idx, kind, x0, y0 in all_boxes_with_idx:
499
+ if idx in seen_idx:
500
+ continue
501
+ candidates.append((idx, kind, x0, y0))
502
+
503
+ if len(candidates) == 0:
504
+ break
505
+ left_x = min([v[2] for v in candidates])
506
+ top_y = min([v[3] for v in candidates])
507
+
508
+ candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
509
+
510
+
511
+ fst_idx, fst_kind, left_x, top_y = candidates[0]
512
+ candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
513
+ nxt = None
514
+
515
+ for i in range(1, len(candidates)):
516
+ if candidates[i][1] ^ fst_kind == 1:
517
+ nxt = candidates[i]
518
+ break
519
+ if nxt is None:
520
+ break
521
+
522
+ if fst_kind == SUB_BIT_KIND:
523
+ sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
524
+
525
+ else:
526
+ sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
527
+
528
+ pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
529
+ nearest_dis = float('inf')
530
+ for i in range(N):
531
+ if i in seen_idx or i == sub_idx:continue
532
+ nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
533
+
534
+ if pair_dis >= 3*nearest_dis:
535
+ seen_idx.add(sub_idx)
536
+ continue
537
+
538
+ seen_idx.add(sub_idx)
539
+ seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
540
+ seen_sub_idx.add(sub_idx)
541
+
542
+ ret.append(
543
+ {
544
+ 'sub_bbox': {
545
+ 'bbox': subjects[sub_idx]['bbox'],
546
+ 'score': subjects[sub_idx]['score'],
547
+ },
548
+ 'obj_bboxes': [
549
+ {'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
550
+ ],
551
+ 'sub_idx': sub_idx,
552
+ }
553
+ )
554
+
555
+ for i in range(len(objects)):
556
+ j = i + OBJ_IDX_OFFSET
557
+ if j in seen_idx:
558
+ continue
559
+ seen_idx.add(j)
560
+ nearest_dis, nearest_sub_idx = float('inf'), -1
561
+ for k in range(len(subjects)):
562
+ dis = bbox_distance(objects[i]['bbox'], subjects[k]['bbox'])
563
+ if dis < nearest_dis:
564
+ nearest_dis = dis
565
+ nearest_sub_idx = k
566
+
567
+ for k in range(len(subjects)):
568
+ if k != nearest_sub_idx: continue
569
+ if k in seen_sub_idx:
570
+ for kk in range(len(ret)):
571
+ if ret[kk]['sub_idx'] == k:
572
+ ret[kk]['obj_bboxes'].append({'score': objects[i]['score'], 'bbox': objects[i]['bbox']})
573
+ break
574
+ else:
575
+ ret.append(
576
+ {
577
+ 'sub_bbox': {
578
+ 'bbox': subjects[k]['bbox'],
579
+ 'score': subjects[k]['score'],
580
+ },
581
+ 'obj_bboxes': [
582
+ {'score': objects[i]['score'], 'bbox': objects[i]['bbox']}
583
+ ],
584
+ 'sub_idx': k,
585
+ }
586
+ )
587
+ seen_sub_idx.add(k)
588
+ seen_idx.add(k)
589
+
590
+
591
+ for i in range(len(subjects)):
592
+ if i in seen_sub_idx:
593
+ continue
594
+ ret.append(
595
+ {
596
+ 'sub_bbox': {
597
+ 'bbox': subjects[i]['bbox'],
598
+ 'score': subjects[i]['score'],
599
+ },
600
+ 'obj_bboxes': [],
601
+ 'sub_idx': i,
602
+ }
603
+ )
604
+
605
+
606
+ return ret
607
+
608
+
453
609
  def get_imgs_v2(self, page_no: int):
454
- with_captions = self.__tie_up_category_by_distance_v2(
610
+ with_captions = self.__tie_up_category_by_distance_v3(
455
611
  page_no, 3, 4, PosRelationEnum.BOTTOM
456
612
  )
457
- with_footnotes = self.__tie_up_category_by_distance_v2(
613
+ with_footnotes = self.__tie_up_category_by_distance_v3(
458
614
  page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
459
615
  )
460
616
  ret = []
@@ -470,10 +626,10 @@ class MagicModel:
470
626
  return ret
471
627
 
472
628
  def get_tables_v2(self, page_no: int) -> list:
473
- with_captions = self.__tie_up_category_by_distance_v2(
629
+ with_captions = self.__tie_up_category_by_distance_v3(
474
630
  page_no, 5, 6, PosRelationEnum.UP
475
631
  )
476
- with_footnotes = self.__tie_up_category_by_distance_v2(
632
+ with_footnotes = self.__tie_up_category_by_distance_v3(
477
633
  page_no, 5, 7, PosRelationEnum.ALL
478
634
  )
479
635
  ret = []
@@ -89,13 +89,6 @@ class CustomPEKModel:
89
89
  # 初始化解析方案
90
90
  self.device = kwargs.get('device', 'cpu')
91
91
 
92
- if str(self.device).startswith("npu"):
93
- import torch_npu
94
- os.environ['FLAGS_npu_jit_compile'] = '0'
95
- os.environ['FLAGS_use_stride_kernel'] = '0'
96
- elif str(self.device).startswith("mps"):
97
- os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
98
-
99
92
  logger.info('using device: {}'.format(self.device))
100
93
  models_dir = kwargs.get(
101
94
  'models_dir', os.path.join(root_dir, 'resources', 'models')
@@ -1,4 +1,5 @@
1
1
  # Copyright (c) Opendatalab. All rights reserved.
2
+ import time
2
3
  from collections import Counter
3
4
  from uuid import uuid4
4
5
 
@@ -102,9 +103,9 @@ class YOLOv11LangDetModel(object):
102
103
  temp_images = split_images(image)
103
104
  for temp_image in temp_images:
104
105
  all_images.append(resize_images_to_224(temp_image))
105
-
106
- images_lang_res = self.batch_predict(all_images, batch_size=8)
107
- # logger.info(f"images_lang_res: {images_lang_res}")
106
+ # langdetect_start = time.time()
107
+ images_lang_res = self.batch_predict(all_images, batch_size=256)
108
+ # logger.info(f"image number of langdetect: {len(images_lang_res)}, langdetect time: {round(time.time() - langdetect_start, 2)}")
108
109
  if len(images_lang_res) > 0:
109
110
  count_dict = Counter(images_lang_res)
110
111
  language = max(count_dict, key=count_dict.get)
@@ -100,20 +100,61 @@ class UnimernetModel(object):
100
100
  res["latex"] = latex_rm_whitespace(latex)
101
101
  return formula_list
102
102
 
103
- def batch_predict(
104
- self, images_mfd_res: list, images: list, batch_size: int = 64
105
- ) -> list:
103
+ # def batch_predict(
104
+ # self, images_mfd_res: list, images: list, batch_size: int = 64
105
+ # ) -> list:
106
+ # images_formula_list = []
107
+ # mf_image_list = []
108
+ # backfill_list = []
109
+ # for image_index in range(len(images_mfd_res)):
110
+ # mfd_res = images_mfd_res[image_index]
111
+ # pil_img = Image.fromarray(images[image_index])
112
+ # formula_list = []
113
+ #
114
+ # for xyxy, conf, cla in zip(
115
+ # mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
116
+ # ):
117
+ # xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
118
+ # new_item = {
119
+ # "category_id": 13 + int(cla.item()),
120
+ # "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
121
+ # "score": round(float(conf.item()), 2),
122
+ # "latex": "",
123
+ # }
124
+ # formula_list.append(new_item)
125
+ # bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
126
+ # mf_image_list.append(bbox_img)
127
+ #
128
+ # images_formula_list.append(formula_list)
129
+ # backfill_list += formula_list
130
+ #
131
+ # dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
132
+ # dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
133
+ # mfr_res = []
134
+ # for mf_img in dataloader:
135
+ # mf_img = mf_img.to(self.device)
136
+ # with torch.no_grad():
137
+ # output = self.model.generate({"image": mf_img})
138
+ # mfr_res.extend(output["pred_str"])
139
+ # for res, latex in zip(backfill_list, mfr_res):
140
+ # res["latex"] = latex_rm_whitespace(latex)
141
+ # return images_formula_list
142
+
143
+ def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
106
144
  images_formula_list = []
107
145
  mf_image_list = []
108
146
  backfill_list = []
147
+ image_info = [] # Store (area, original_index, image) tuples
148
+
149
+ # Collect images with their original indices
109
150
  for image_index in range(len(images_mfd_res)):
110
151
  mfd_res = images_mfd_res[image_index]
111
152
  pil_img = Image.fromarray(images[image_index])
112
153
  formula_list = []
113
154
 
114
- for xyxy, conf, cla in zip(
115
- mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
116
- ):
155
+ for idx, (xyxy, conf, cla) in enumerate(zip(
156
+ mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
157
+ )):
117
158
  xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
118
159
  new_item = {
119
160
  "category_id": 13 + int(cla.item()),
@@ -123,19 +164,43 @@ class UnimernetModel(object):
123
164
  }
124
165
  formula_list.append(new_item)
125
166
  bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
167
+ area = (xmax - xmin) * (ymax - ymin)
168
+
169
+ curr_idx = len(mf_image_list)
170
+ image_info.append((area, curr_idx, bbox_img))
126
171
  mf_image_list.append(bbox_img)
127
172
 
128
173
  images_formula_list.append(formula_list)
129
174
  backfill_list += formula_list
130
175
 
131
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
176
+ # Stable sort by area
177
+ image_info.sort(key=lambda x: x[0]) # sort by area
178
+ sorted_indices = [x[1] for x in image_info]
179
+ sorted_images = [x[2] for x in image_info]
180
+
181
+ # Create mapping for results
182
+ index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
183
+
184
+ # Create dataset with sorted images
185
+ dataset = MathDataset(sorted_images, transform=self.mfr_transform)
132
186
  dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
187
+
188
+ # Process batches and store results
133
189
  mfr_res = []
134
190
  for mf_img in dataloader:
135
191
  mf_img = mf_img.to(self.device)
136
192
  with torch.no_grad():
137
193
  output = self.model.generate({"image": mf_img})
138
194
  mfr_res.extend(output["pred_str"])
139
- for res, latex in zip(backfill_list, mfr_res):
140
- res["latex"] = latex_rm_whitespace(latex)
195
+
196
+ # Restore original order
197
+ unsorted_results = [""] * len(mfr_res)
198
+ for new_idx, latex in enumerate(mfr_res):
199
+ original_idx = index_mapping[new_idx]
200
+ unsorted_results[original_idx] = latex_rm_whitespace(latex)
201
+
202
+ # Fill results back
203
+ for res, latex in zip(backfill_list, unsorted_results):
204
+ res["latex"] = latex
205
+
141
206
  return images_formula_list