evalscope 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (100) hide show
  1. evalscope/api/benchmark/__init__.py +8 -1
  2. evalscope/api/benchmark/adapters/__init__.py +1 -0
  3. evalscope/api/benchmark/adapters/ner_adapter.py +212 -0
  4. evalscope/api/benchmark/benchmark.py +14 -0
  5. evalscope/api/dataset/dataset.py +21 -0
  6. evalscope/api/dataset/loader.py +6 -2
  7. evalscope/api/mixin/sandbox_mixin.py +32 -54
  8. evalscope/api/model/generate_config.py +6 -0
  9. evalscope/benchmarks/aa_lcr/__init__.py +0 -0
  10. evalscope/benchmarks/aa_lcr/aa_lcr_adapter.py +205 -0
  11. evalscope/benchmarks/bfcl/bfcl_adapter.py +1 -1
  12. evalscope/benchmarks/data_collection/data_collection_adapter.py +2 -1
  13. evalscope/benchmarks/general_arena/general_arena_adapter.py +1 -1
  14. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +1 -1
  15. evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
  16. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +23 -4
  17. evalscope/benchmarks/hallusion_bench/__init__.py +0 -0
  18. evalscope/benchmarks/hallusion_bench/hallusion_bench_adapter.py +158 -0
  19. evalscope/benchmarks/humaneval/humaneval_adapter.py +2 -1
  20. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +3 -1
  21. evalscope/benchmarks/math_verse/__init__.py +0 -0
  22. evalscope/benchmarks/math_verse/math_verse_adapter.py +100 -0
  23. evalscope/benchmarks/math_vision/__init__.py +0 -0
  24. evalscope/benchmarks/math_vision/math_vision_adapter.py +111 -0
  25. evalscope/benchmarks/math_vista/math_vista_adapter.py +6 -26
  26. evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +1 -1
  27. evalscope/benchmarks/ner/__init__.py +0 -0
  28. evalscope/benchmarks/ner/broad_twitter_corpus_adapter.py +52 -0
  29. evalscope/benchmarks/ner/conll2003_adapter.py +48 -0
  30. evalscope/benchmarks/ner/copious_adapter.py +85 -0
  31. evalscope/benchmarks/ner/cross_ner_adapter.py +120 -0
  32. evalscope/benchmarks/ner/cross_ner_entities/__init__.py +0 -0
  33. evalscope/benchmarks/ner/cross_ner_entities/ai.py +54 -0
  34. evalscope/benchmarks/ner/cross_ner_entities/literature.py +36 -0
  35. evalscope/benchmarks/ner/cross_ner_entities/music.py +39 -0
  36. evalscope/benchmarks/ner/cross_ner_entities/politics.py +37 -0
  37. evalscope/benchmarks/ner/cross_ner_entities/science.py +58 -0
  38. evalscope/benchmarks/ner/genia_ner_adapter.py +66 -0
  39. evalscope/benchmarks/ner/harvey_ner_adapter.py +58 -0
  40. evalscope/benchmarks/ner/mit_movie_trivia_adapter.py +74 -0
  41. evalscope/benchmarks/ner/mit_restaurant_adapter.py +66 -0
  42. evalscope/benchmarks/ner/ontonotes5_adapter.py +87 -0
  43. evalscope/benchmarks/ner/wnut2017_adapter.py +61 -0
  44. evalscope/benchmarks/ocr_bench_v2/utils.py +1 -0
  45. evalscope/benchmarks/omnidoc_bench/__init__.py +0 -0
  46. evalscope/benchmarks/omnidoc_bench/end2end_eval.py +349 -0
  47. evalscope/benchmarks/omnidoc_bench/metrics.py +547 -0
  48. evalscope/benchmarks/omnidoc_bench/omnidoc_bench_adapter.py +135 -0
  49. evalscope/benchmarks/omnidoc_bench/utils.py +1937 -0
  50. evalscope/benchmarks/poly_math/__init__.py +0 -0
  51. evalscope/benchmarks/poly_math/poly_math_adapter.py +127 -0
  52. evalscope/benchmarks/poly_math/utils/instruction.py +105 -0
  53. evalscope/benchmarks/pope/__init__.py +0 -0
  54. evalscope/benchmarks/pope/pope_adapter.py +111 -0
  55. evalscope/benchmarks/seed_bench_2_plus/__init__.py +0 -0
  56. evalscope/benchmarks/seed_bench_2_plus/seed_bench_2_plus_adapter.py +72 -0
  57. evalscope/benchmarks/simple_vqa/__init__.py +0 -0
  58. evalscope/benchmarks/simple_vqa/simple_vqa_adapter.py +169 -0
  59. evalscope/benchmarks/tau_bench/tau_bench_adapter.py +1 -1
  60. evalscope/benchmarks/tool_bench/tool_bench_adapter.py +1 -1
  61. evalscope/benchmarks/visu_logic/__init__.py +0 -0
  62. evalscope/benchmarks/visu_logic/visu_logic_adapter.py +75 -0
  63. evalscope/benchmarks/zerobench/__init__.py +0 -0
  64. evalscope/benchmarks/zerobench/zerobench_adapter.py +64 -0
  65. evalscope/constants.py +4 -0
  66. evalscope/evaluator/evaluator.py +72 -79
  67. evalscope/metrics/math_parser.py +14 -0
  68. evalscope/metrics/metric.py +1 -1
  69. evalscope/models/utils/openai.py +4 -0
  70. evalscope/perf/arguments.py +24 -4
  71. evalscope/perf/benchmark.py +74 -89
  72. evalscope/perf/http_client.py +31 -16
  73. evalscope/perf/main.py +15 -2
  74. evalscope/perf/plugin/api/base.py +9 -7
  75. evalscope/perf/plugin/api/custom_api.py +13 -58
  76. evalscope/perf/plugin/api/default_api.py +179 -79
  77. evalscope/perf/plugin/api/openai_api.py +4 -3
  78. evalscope/perf/plugin/datasets/base.py +21 -0
  79. evalscope/perf/plugin/datasets/custom.py +2 -3
  80. evalscope/perf/plugin/datasets/line_by_line.py +2 -3
  81. evalscope/perf/plugin/datasets/longalpaca.py +2 -3
  82. evalscope/perf/plugin/datasets/openqa.py +2 -4
  83. evalscope/perf/plugin/datasets/random_dataset.py +1 -3
  84. evalscope/perf/utils/benchmark_util.py +36 -22
  85. evalscope/perf/utils/db_util.py +14 -19
  86. evalscope/perf/utils/local_server.py +0 -44
  87. evalscope/perf/utils/log_utils.py +21 -6
  88. evalscope/report/__init__.py +2 -1
  89. evalscope/run.py +4 -0
  90. evalscope/utils/function_utils.py +195 -12
  91. evalscope/utils/io_utils.py +74 -0
  92. evalscope/utils/logger.py +49 -17
  93. evalscope/utils/ner.py +377 -0
  94. evalscope/version.py +2 -2
  95. {evalscope-1.1.0.dist-info → evalscope-1.1.1.dist-info}/METADATA +235 -363
  96. {evalscope-1.1.0.dist-info → evalscope-1.1.1.dist-info}/RECORD +100 -55
  97. {evalscope-1.1.0.dist-info → evalscope-1.1.1.dist-info}/WHEEL +1 -1
  98. {evalscope-1.1.0.dist-info → evalscope-1.1.1.dist-info}/entry_points.txt +0 -0
  99. {evalscope-1.1.0.dist-info → evalscope-1.1.1.dist-info/licenses}/LICENSE +0 -0
  100. {evalscope-1.1.0.dist-info → evalscope-1.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1937 @@
1
+ # flake8: noqa
2
+ import copy
3
+ import html
4
+ import json
5
+ import Levenshtein
6
+ import numpy as np
7
+ import os
8
+ import re
9
+ import shutil
10
+ import subprocess
11
+ import unicodedata
12
+ import uuid
13
+ from bs4 import BeautifulSoup
14
+ from collections import defaultdict
15
+ from pylatexenc.latex2text import LatexNodes2Text
16
+ from pylatexenc.latexwalker import (
17
+ LatexCharsNode,
18
+ LatexEnvironmentNode,
19
+ LatexGroupNode,
20
+ LatexMacroNode,
21
+ LatexSpecialsNode,
22
+ LatexWalker,
23
+ )
24
+ from scipy.optimize import linear_sum_assignment
25
+
26
+
27
+ def read_md_file(filepath):
28
+ with open(filepath, 'r', encoding='utf-8') as file:
29
+ content = file.read()
30
+
31
+ return content
32
+
33
+
34
+ def save_paired_result(preds, gts, save_path):
35
+ save_result = []
36
+ formula_id = 0
37
+ for gt, pred in zip(gts, preds):
38
+ save_result.append({'gt': gt, 'pred': pred, 'img_id': formula_id})
39
+ formula_id += 1
40
+ with open(save_path, 'w', encoding='utf-8') as f:
41
+ json.dump(save_result, f, indent=4, ensure_ascii=False)
42
+
43
+
44
+ def remove_markdown_fences(content):
45
+ content = re.sub(r'^```markdown\n?', '', content, flags=re.MULTILINE)
46
+ content = re.sub(r'```\n?$', '', content, flags=re.MULTILINE)
47
+ return content
48
+
49
+
50
+ # Standardize all consecutive characters
51
+ def replace_repeated_chars(input_str):
52
+ input_str = re.sub(r'_{4,}', '____', input_str) # Replace more than 4 consecutive underscores with 4 underscores
53
+ input_str = re.sub(r' {4,}', ' ', input_str) # Replace more than 4 consecutive spaces with 4 spaces
54
+ return re.sub(
55
+ r'([^a-zA-Z0-9])\1{10,}', r'\1\1\1\1', input_str
56
+ ) # For other consecutive symbols (except numbers and letters), replace more than 10 occurrences with 4
57
+
58
+
59
+ # Special Unicode handling
60
+ def fullwidth_to_halfwidth(s):
61
+ result = []
62
+ for char in s:
63
+ code = ord(char)
64
+ # Convert full-width space to half-width space
65
+ if code == 0x3000:
66
+ code = 0x0020
67
+ # Convert other full-width characters to half-width
68
+ elif 0xFF01 <= code <= 0xFF5E:
69
+ code -= 0xFEE0
70
+ result.append(chr(code))
71
+ return ''.join(result)
72
+
73
+
74
+ def find_special_unicode(s):
75
+ special_chars = {}
76
+ for char in s:
77
+ if ord(char) > 127: # Non-ASCII characters
78
+ # unicode_name = unicodedata.name(char, None)
79
+ unicode_name = unicodedata.category(char)
80
+ special_chars[char] = f'U+{ord(char):04X} ({unicode_name})'
81
+ return special_chars
82
+
83
+
84
+ inline_reg = re.compile(r'\$(.*?)\$|'
85
+ r'\\\((.*?)\\\)', )
86
+
87
+
88
+ def textblock2unicode(text):
89
+ inline_matches = inline_reg.finditer(text)
90
+ removal_positions = []
91
+ for match in inline_matches:
92
+ position = [match.start(), match.end()]
93
+ content = match.group(1) if match.group(1) is not None else match.group(2)
94
+ # print('-------- content-------', content)
95
+ # Remove escape characters \
96
+ clean_content = re.sub(r'\\([\\_&%^])', '', content)
97
+
98
+ try:
99
+ if any(char in clean_content for char in r'\^_'):
100
+ if clean_content.endswith('\\'):
101
+ clean_content += ' '
102
+ # inline_array.append(match.group(0))
103
+ unicode_content = LatexNodes2Text().latex_to_text(clean_content)
104
+ removal_positions.append((position[0], position[1], unicode_content))
105
+ except:
106
+ continue
107
+
108
+ # Remove inline formulas from original text
109
+ for start, end, unicode_content in sorted(removal_positions, reverse=True):
110
+ text = text[:start] + unicode_content.strip() + text[end:]
111
+
112
+ return text
113
+
114
+
115
+ def normalized_formula(text):
116
+ # Normalize math formulas before matching
117
+ filter_list = [
118
+ '\\mathbf', '\\mathrm', '\\mathnormal', '\\mathit', '\\mathbb', '\\mathcal', '\\mathscr', '\\mathfrak',
119
+ '\\mathsf', '\\mathtt', '\\textbf', '\\text', '\\boldmath', '\\boldsymbol', '\\operatorname', '\\bm',
120
+ '\\symbfit', '\\mathbfcal', '\\symbf', '\\scriptscriptstyle', '\\notag', '\\setlength', '\\coloneqq', '\\space',
121
+ '\\thickspace', '\\thinspace', '\\medspace', '\\nobreakspace', '\\negmedspace', '\\quad', '\\qquad',
122
+ '\\enspace', '\\substackw', ' '
123
+ ]
124
+ # '\\left', '\\right', '{', '}', ' ']
125
+
126
+ # delimiter_filter
127
+ pattern = re.compile(r'\\\[(.+?)(?<!\\)\\\]')
128
+ match = pattern.search(text)
129
+
130
+ if match:
131
+ text = match.group(1).strip()
132
+
133
+ tag_pattern = re.compile(r'\\tag\{.*?\}')
134
+ text = tag_pattern.sub('', text)
135
+ hspace_pattern = re.compile(r'\\hspace\{.*?\}')
136
+ text = hspace_pattern.sub('', text)
137
+ begin_pattern = re.compile(r'\\begin\{.*?\}')
138
+ text = begin_pattern.sub('', text)
139
+ end_pattern = re.compile(r'\\end\{.*?\}')
140
+ text = end_pattern.sub('', text)
141
+ col_sep = re.compile(r'\\arraycolsep.*?\}')
142
+ text = col_sep.sub('', text)
143
+ text = text.strip('.')
144
+
145
+ for filter_text in filter_list:
146
+ text = text.replace(filter_text, '')
147
+
148
+ # text = normalize_text(delimiter_filter(text))
149
+ # text = delimiter_filter(text)
150
+ text = text.lower()
151
+ return text
152
+
153
+
154
+ def normalized_html_table(text):
155
+
156
+ def process_table_html(md_i):
157
+ """
158
+ pred_md format edit
159
+ """
160
+
161
+ def process_table_html(html_content):
162
+ soup = BeautifulSoup(html_content, 'html.parser')
163
+ th_tags = soup.find_all('th')
164
+ for th in th_tags:
165
+ th.name = 'td'
166
+ thead_tags = soup.find_all('thead')
167
+ for thead in thead_tags:
168
+ thead.unwrap() # unwrap()会移除标签但保留其内容
169
+ math_tags = soup.find_all('math')
170
+ for math_tag in math_tags:
171
+ alttext = math_tag.get('alttext', '')
172
+ alttext = f'${alttext}$'
173
+ if alttext:
174
+ math_tag.replace_with(alttext)
175
+ span_tags = soup.find_all('span')
176
+ for span in span_tags:
177
+ span.unwrap()
178
+ return str(soup)
179
+
180
+ table_res = ''
181
+ table_res_no_space = ''
182
+ if '<table' in md_i.replace(' ', '').replace("'", '"'):
183
+ md_i = process_table_html(md_i)
184
+ table_res = html.unescape(md_i).replace('\n', '')
185
+ table_res = unicodedata.normalize('NFKC', table_res).strip()
186
+ pattern = r'<table\b[^>]*>(.*)</table>'
187
+ tables = re.findall(pattern, table_res, re.DOTALL | re.IGNORECASE)
188
+ table_res = ''.join(tables)
189
+ # table_res = re.sub('<table.*?>','',table_res)
190
+ table_res = re.sub('( style=".*?")', '', table_res)
191
+ table_res = re.sub('( height=".*?")', '', table_res)
192
+ table_res = re.sub('( width=".*?")', '', table_res)
193
+ table_res = re.sub('( align=".*?")', '', table_res)
194
+ table_res = re.sub('( class=".*?")', '', table_res)
195
+ table_res = re.sub('</?tbody>', '', table_res)
196
+
197
+ table_res = re.sub(r'\s+', ' ', table_res)
198
+ table_res_no_space = '<html><body><table border="1" >' + table_res.replace(
199
+ ' ', ''
200
+ ) + '</table></body></html>'
201
+ # table_res_no_space = re.sub(' (style=".*?")',"",table_res_no_space)
202
+ # table_res_no_space = re.sub(r'[ ]', " ", table_res_no_space)
203
+ table_res_no_space = re.sub('colspan="', ' colspan="', table_res_no_space)
204
+ table_res_no_space = re.sub('rowspan="', ' rowspan="', table_res_no_space)
205
+ table_res_no_space = re.sub('border="', ' border="', table_res_no_space)
206
+
207
+ table_res = '<html><body><table border="1" >' + table_res + '</table></body></html>'
208
+ # table_flow.append(table_res)
209
+ # table_flow_no_space.append(table_res_no_space)
210
+
211
+ return table_res, table_res_no_space
212
+
213
+ def clean_table(input_str, flag=True):
214
+ if flag:
215
+ input_str = input_str.replace('<sup>', '').replace('</sup>', '')
216
+ input_str = input_str.replace('<sub>', '').replace('</sub>', '')
217
+ input_str = input_str.replace('<span>', '').replace('</span>', '')
218
+ input_str = input_str.replace('<div>', '').replace('</div>', '')
219
+ input_str = input_str.replace('<p>', '').replace('</p>', '')
220
+ input_str = input_str.replace('<spandata-span-identity="">', '')
221
+ input_str = re.sub('<colgroup>.*?</colgroup>', '', input_str)
222
+ return input_str
223
+
224
+ norm_text, _ = process_table_html(text)
225
+ norm_text = clean_table(norm_text)
226
+ return norm_text
227
+
228
+
229
+ def normalized_latex_table(text):
230
+
231
+ def latex_template(latex_code):
232
+ template = r'''
233
+ \documentclass[border=20pt]{article}
234
+ \usepackage{subcaption}
235
+ \usepackage{url}
236
+ \usepackage{graphicx}
237
+ \usepackage{caption}
238
+ \usepackage{multirow}
239
+ \usepackage{booktabs}
240
+ \usepackage{color}
241
+ \usepackage{colortbl}
242
+ \usepackage{xcolor,soul,framed}
243
+ \usepackage{fontspec}
244
+ \usepackage{amsmath,amssymb,mathtools,bm,mathrsfs,textcomp}
245
+ \setlength{\parindent}{0pt}''' + \
246
+ r'''
247
+ \begin{document}
248
+ ''' + \
249
+ latex_code + \
250
+ r'''
251
+ \end{document}'''
252
+
253
+ return template
254
+
255
+ def process_table_latex(latex_code):
256
+ SPECIAL_STRINGS = [['\\\\vspace\\{.*?\\}', ''], ['\\\\hspace\\{.*?\\}', ''], ['\\\\rule\{.*?\\}\\{.*?\\}', ''],
257
+ ['\\\\addlinespace\\[.*?\\]', ''], ['\\\\addlinespace', ''],
258
+ ['\\\\renewcommand\\{\\\\arraystretch\\}\\{.*?\\}', ''], ['\\\\arraystretch\\{.*?\\}', ''],
259
+ ['\\\\(row|column)?colors?\\{[^}]*\\}(\\{[^}]*\\}){0,2}', ''], ['\\\\color\\{.*?\\}', ''],
260
+ ['\\\\textcolor\\{.*?\\}', ''], ['\\\\rowcolor(\\[.*?\\])?\\{.*?\\}', ''],
261
+ ['\\\\columncolor(\\[.*?\\])?\\{.*?\\}', ''], ['\\\\cellcolor(\\[.*?\\])?\\{.*?\\}', ''],
262
+ ['\\\\colorbox\\{.*?\\}', ''],
263
+ ['\\\\(tiny|scriptsize|footnotesize|small|normalsize|large|Large|LARGE|huge|Huge)', ''],
264
+ [r'\s+', ' '], ['\\\\centering', ''], ['\\\\begin\\{table\\}\\[.*?\\]', '\\\\begin{table}'],
265
+ ['\t', ''], ['@{}', ''], ['\\\\toprule(\\[.*?\\])?', '\\\\hline'],
266
+ ['\\\\bottomrule(\\[.*?\\])?', '\\\\hline'], ['\\\\midrule(\\[.*?\\])?', '\\\\hline'],
267
+ ['p\\{[^}]*\\}', 'l'], ['m\\{[^}]*\\}',
268
+ 'c'], ['\\\\scalebox\\{[^}]*\\}\\{([^}]*)\\}', '\\1'],
269
+ ['\\\\textbf\\{([^}]*)\\}', '\\1'], ['\\\\textit\\{([^}]*)\\}', '\\1'],
270
+ ['\\\\cmidrule(\\[.*?\\])?\\(.*?\\)\\{([0-9]-[0-9])\\}', '\\\\cline{\\2}'],
271
+ ['\\\\hline', ''], [r'\\multicolumn\{1\}\{[^}]*\}\{((?:[^{}]|(?:\{[^{}]*\}))*)\}', r'\1']]
272
+ pattern = r'\\begin\{tabular\}.*\\end\{tabular\}' # 注意这里不用 .*?
273
+ matches = re.findall(pattern, latex_code, re.DOTALL)
274
+ latex_code = ' '.join(matches)
275
+
276
+ for special_str in SPECIAL_STRINGS:
277
+ latex_code = re.sub(fr'{special_str[0]}', fr'{special_str[1]}', latex_code)
278
+
279
+ return latex_code
280
+
281
+ def convert_latex_to_html(latex_content, cache_dir='./temp'):
282
+ if not os.path.exists(cache_dir):
283
+ os.makedirs(cache_dir)
284
+
285
+ uuid_str = str(uuid.uuid1())
286
+ with open(f'{cache_dir}/{uuid_str}.tex', 'w') as f:
287
+ f.write(latex_template(latex_content))
288
+
289
+ cmd = [
290
+ 'latexmlc', '--quiet', '--nocomments', f'--log={cache_dir}/{uuid_str}.log', f'{cache_dir}/{uuid_str}.tex',
291
+ f'--dest={cache_dir}/{uuid_str}.html'
292
+ ]
293
+ try:
294
+ subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
295
+ with open(f'{cache_dir}/{uuid_str}.html', 'r') as f:
296
+ html_content = f.read()
297
+
298
+ pattern = r'<table\b[^>]*>(.*)</table>'
299
+ tables = re.findall(pattern, html_content, re.DOTALL | re.IGNORECASE)
300
+ tables = [f'<table>{table}</table>' for table in tables]
301
+ html_content = '\n'.join(tables)
302
+
303
+ except Exception as e:
304
+ html_content = ''
305
+
306
+ shutil.rmtree(cache_dir)
307
+ return html_content
308
+
309
+ html_text = convert_latex_to_html(text)
310
+ normlized_tables = normalized_html_table(html_text)
311
+ return normlized_tables
312
+
313
+
314
+ def normalized_table(text, format='html'):
315
+ if format not in ['html', 'latex']:
316
+ raise ValueError('Invalid format: {}'.format(format))
317
+ else:
318
+ return globals()['normalized_{}_table'.format(format)](text)
319
+
320
+
321
+ def textblock_with_norm_formula(text):
322
+ inline_matches = inline_reg.finditer(text)
323
+ removal_positions = []
324
+ for match in inline_matches:
325
+ position = [match.start(), match.end()]
326
+ content = match.group(1) if match.group(1) is not None else match.group(2)
327
+ # print('-------- content-------', content)
328
+
329
+ norm_content = normalized_formula(content)
330
+ removal_positions.append((position[0], position[1], norm_content))
331
+
332
+ # Remove inline formulas from original text
333
+ for start, end, norm_content in sorted(removal_positions, reverse=True):
334
+ text = text[:start] + norm_content.strip() + text[end:]
335
+
336
+ return text
337
+
338
+
339
+ def inline_filter_unicode(text):
340
+ # Ensure text is string type
341
+ if not isinstance(text, str):
342
+ text = str(text)
343
+
344
+ # Replace inline formula boundary markers
345
+ #print('--------text-------',text)
346
+ placeholder = '__INLINE_FORMULA_BOUNDARY__'
347
+ text_copy = text.replace('$', placeholder).replace('\\(', placeholder).replace('\\)', placeholder)
348
+ #print('--------text_copy-------',text_copy)
349
+ # Convert LaTeX content to Unicode representation
350
+ text_copy = LatexNodes2Text().latex_to_text(text_copy)
351
+ #print('--------text_copy---unicode----',text_copy)
352
+ # Restore boundary markers
353
+ text_copy = text_copy.replace(placeholder, '$')
354
+
355
+ inline_array = []
356
+ inline_matches = inline_reg.finditer(text_copy)
357
+ # Record positions of inline formulas to be removed
358
+ removal_positions = []
359
+
360
+ for match in inline_matches:
361
+ position = [match.start(), match.end()]
362
+ content = match.group(1) if match.group(1) is not None else match.group(2)
363
+ print('-------- content-------', content)
364
+ # Remove escape characters \
365
+ clean_content = re.sub(r'\\([\\_&%^])', '', content)
366
+
367
+ if any(char in clean_content for char in r'\^_'):
368
+ # inline_array.append(match.group(0))
369
+ inline_array.append({
370
+ 'category_type': 'equation_inline',
371
+ 'position': position,
372
+ 'content': content,
373
+ })
374
+ removal_positions.append((position[0], position[1]))
375
+
376
+ # Remove inline formulas from original text
377
+ for start, end in sorted(removal_positions, reverse=True):
378
+ text = text[:start] + text[end:]
379
+
380
+ return text, inline_array
381
+
382
+
383
+ def inline_filter(text):
384
+ # Ensure text is string type
385
+ if not isinstance(text, str):
386
+ text = str(text)
387
+
388
+ inline_array = []
389
+ inline_matches = inline_reg.finditer(text)
390
+
391
+ for match in inline_matches:
392
+ position = [match.start(), match.end()]
393
+ content = match.group(1) if match.group(1) is not None else match.group(2)
394
+ # print('inline_content: ', content)
395
+
396
+ # Remove escape characters \
397
+ clean_content = re.sub(r'\\([\\_&%^])', '', content)
398
+
399
+ if any(char in clean_content for char in r'\^_'):
400
+ # inline_array.append(match.group(0))
401
+ inline_array.append({
402
+ 'category_type': 'equation_inline',
403
+ 'position': position,
404
+ 'content': match.group(0),
405
+ })
406
+ text = text.replace(match.group(0), '')
407
+ # print('-----Found inline formula: ', match.group(0))
408
+ else:
409
+ text = text.replace(match.group(0), content)
410
+
411
+ return text, inline_array
412
+
413
+
414
+ # Text OCR quality check processing:
415
+ def clean_string(input_string):
416
+ # Use regex to keep Chinese characters, English letters and numbers
417
+ input_string = input_string.replace('\\t', '').replace('\\n',
418
+ '').replace('\t',
419
+ '').replace('\n',
420
+ '').replace('/t',
421
+ '').replace('/n', '')
422
+ cleaned_string = re.sub(r'[^\w\u4e00-\u9fff]', '', input_string)
423
+ return cleaned_string
424
+
425
+
426
+ def extract_tabular(text):
427
+ begin_pattern = r'\\begin{tabular}'
428
+ end_pattern = r'\\end{tabular}'
429
+
430
+ tabulars = []
431
+ positions = []
432
+ current_pos = 0
433
+ stack = []
434
+
435
+ while current_pos < len(text):
436
+ begin_match = re.search(begin_pattern, text[current_pos:])
437
+ end_match = re.search(end_pattern, text[current_pos:])
438
+
439
+ if not begin_match and not end_match:
440
+ break
441
+
442
+ if begin_match and (not end_match or begin_match.start() < end_match.start()):
443
+ stack.append(current_pos + begin_match.start())
444
+ current_pos += begin_match.start() + len(end_pattern)
445
+ elif end_match:
446
+ if stack:
447
+ start_pos = stack.pop()
448
+ if not stack:
449
+ end_pos = current_pos + end_match.start() + len(end_pattern)
450
+ tabular_code = text[start_pos:end_pos]
451
+ tabulars.append(tabular_code)
452
+ positions.append((start_pos, end_pos))
453
+ current_pos += end_match.start() + len(end_pattern)
454
+ else:
455
+ current_pos += 1
456
+
457
+ if stack:
458
+ new_start = stack[0] + len(begin_pattern)
459
+ new_tabulars, new_positions = extract_tabular(text[new_start:])
460
+ new_positions = [(start + new_start, end + new_start) for start, end in new_positions]
461
+ tabulars.extend(new_tabulars)
462
+ positions.extend(new_positions)
463
+
464
+ return tabulars, positions
465
+
466
+
467
+ # math reg
468
+ # r'\\begin{equation\*?}(.*?)\\end{equation\*?}|'
469
+ # r'\\begin{align\*?}(.*?)\\end{align\*?}|'
470
+ # r'\\begin{gather\*?}(.*?)\\end{gather\*?}|'
471
+ display_reg = re.compile(r'\$\$(.*?)\$\$|'
472
+ r'\\\[(.*?)\\\]|'
473
+ r'\$(.*?)\$|'
474
+ r'\\\((.*?)\\\)', re.DOTALL)
475
+
476
+ # inline_reg = re.compile(
477
+ # r'(?<!\$)\$(?!\$)(.*?)(?<!\$)\$(?!\$)|'
478
+ # r'\\\((.*?)\\\)',
479
+ # )
480
+ inline_reg = re.compile(r'\$(.*?)\$|'
481
+ r'\\\((.*?)\\\)', )
482
+
483
+ # table
484
+ table_reg = re.compile(
485
+ r'\\begin{table\*?}(.*?)\\end{table\*?}|'
486
+ r'\\begin{tabular\*?}(.*?)\\end{tabular\*?}', re.DOTALL
487
+ )
488
+ md_table_reg = re.compile(r'\|\s*.*?\s*\|\n', re.DOTALL)
489
+ html_table_reg = re.compile(r'(<table.*?</table>)', re.DOTALL)
490
+
491
+ # title
492
+ title_reg = re.compile(r'^\s*#.*$', re.MULTILINE)
493
+
494
+ # img
495
+ img_pattern = r'!\[.*?\]\(.*?\)'
496
+
497
+ # code block
498
+ code_block_reg = re.compile(r'```(\w+)\n(.*?)```', re.DOTALL)
499
+
500
+
501
+ def md_tex_filter(content):
502
+ '''
503
+ Input: 1 page md or tex content - String
504
+ Output: text, display, inline, table, title, code - list
505
+ '''
506
+ content = re.sub(img_pattern, '', content) # remove image
507
+ content = remove_markdown_fences(content) # remove markdown fences
508
+ content = replace_repeated_chars(content) # replace all consecutive characters
509
+
510
+ pred_all = []
511
+ latex_table_array, table_positions = extract_tex_table(content)
512
+ for latex_table, position in zip(latex_table_array, table_positions):
513
+ position = [position[0], position[0] + len(latex_table)] # !!!
514
+ pred_all.append({'category_type': 'latex_table', 'position': position, 'content': latex_table})
515
+ content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[
516
+ position[1]:] # replace latex table with space
517
+
518
+ # extract html table
519
+ html_table_array, table_positions = extract_html_table(content)
520
+ for html_table, position in zip(html_table_array, table_positions):
521
+ position = [position[0], position[0] + len(html_table)]
522
+ pred_all.append({'category_type': 'html_table', 'position': position, 'content': html_table})
523
+ content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[position[1]:
524
+ ] # replace html table with space
525
+
526
+ # extract interline formula
527
+ display_matches = display_reg.finditer(content)
528
+ for match in display_matches:
529
+ matched = match.group(0)
530
+ if matched:
531
+ single_line = ''.join(matched.split())
532
+ position = [match.start(), match.end()]
533
+ # replace $$ with \[\]
534
+ dollar_pattern = re.compile(r'\$\$(.*?)\$\$|\$(.*?)\$|\\\((.*?)\\\)', re.DOTALL)
535
+ sub_match = dollar_pattern.search(single_line)
536
+ if sub_match is None:
537
+ # pass
538
+ content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[position[1]:]
539
+ pred_all.append({'category_type': 'equation_isolated', 'position': position, 'content': single_line})
540
+ elif sub_match.group(1):
541
+ single_line = re.sub(dollar_pattern, r'\\[\1\\]', single_line)
542
+ content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[
543
+ position[1]:] # replace equation with space
544
+ pred_all.append({'category_type': 'equation_isolated', 'position': position, 'content': single_line})
545
+ else:
546
+ single_line = re.sub(dollar_pattern, r'\\[\2\3\\]', single_line)
547
+ pred_all.append({
548
+ 'category_type': 'equation_isolated',
549
+ 'position': position,
550
+ 'content': single_line,
551
+ 'fine_category_type': 'equation_inline'
552
+ })
553
+
554
+ # extract md table with ||
555
+ md_table_mathces = md_table_reg.findall(content + '\n')
556
+ if len(md_table_mathces) >= 2:
557
+ # print("md table found!")
558
+ # print("content:", content)
559
+ content = convert_markdown_to_html(content)
560
+ # print('----------content after converting md table to html:', content)
561
+ html_table_matches = html_table_reg.finditer(content)
562
+ if html_table_matches:
563
+ for match in html_table_matches:
564
+ matched = match.group(0)
565
+ position = [match.start(), match.end()]
566
+ # content = content.replace(match, '')
567
+ # print('content after removing the md table:', content)
568
+ content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[
569
+ position[1]:] # replace md table with space
570
+ pred_all.append({
571
+ 'category_type': 'html_table',
572
+ 'position': position,
573
+ 'content': matched.strip(),
574
+ 'fine_category_type': 'md2html_table'
575
+ })
576
+ # print('---------After md table: \n', content)
577
+
578
+ # extract code blocks
579
+ code_matches = code_block_reg.finditer(content)
580
+ if code_matches:
581
+ for match in code_matches:
582
+ position = [match.start(), match.end()]
583
+ language = match.group(1)
584
+ code = match.group(2).strip()
585
+ # content = content.replace(match.group(0), '')
586
+ content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[
587
+ position[1]:] # replace code block with space
588
+ pred_all.append({
589
+ 'category_type': 'text_all',
590
+ 'position': position,
591
+ 'content': code,
592
+ 'language': language,
593
+ 'fine_category_type': 'code'
594
+ })
595
+
596
+ # Remove latex style
597
+ content = re.sub(r'\\title\{(.*?)\}', r'\1', content)
598
+ content = re.sub(r'\\title\s*\{\s*(.*?)\s*\}', r'\1', content, flags=re.DOTALL)
599
+ content = re.sub(r'\\text\s*\{\s*(.*?)\s*\}', r'\1', content, flags=re.DOTALL)
600
+ content = re.sub(r'\\section\*?\{(.*?)\}', r'\1', content)
601
+ content = re.sub(r'\\section\*?\{\s*(.*?)\s*\}', r'\1', content, flags=re.DOTALL)
602
+
603
+ # extract texts
604
+ res = content.split('\n\n')
605
+ if len(res) == 1:
606
+ res = content.split('\n') # some models do not use double newlines, so use single newlines to split
607
+
608
+ content_position = 0
609
+ for text in res:
610
+ position = [content_position, content_position + len(text)]
611
+ content_position += len(text)
612
+ text = text.strip()
613
+ text = text.strip('\n')
614
+ # print('ori_text: ', text)
615
+ text = '\n'.join([_.strip()
616
+ for _ in text.split('\n')
617
+ if _.strip()]) # avoid some single newline content with many spaces
618
+ # print('after strip text: ', text)
619
+
620
+ if text: # Check if the stripped text is not empty
621
+ if text.startswith('<table') and text.endswith('</table>'):
622
+ pred_all.append({
623
+ 'category_type': 'html_table',
624
+ 'position': position,
625
+ 'content': text,
626
+ })
627
+
628
+ elif text.startswith('$') and text.endswith('$'):
629
+ if text.replace('$', '').strip():
630
+ pred_all.append({
631
+ 'category_type': 'equation_isolated',
632
+ 'position': position,
633
+ 'content': text.strip(),
634
+ })
635
+ else:
636
+ text = text.strip()
637
+ if text:
638
+ pred_all.append({
639
+ 'category_type': 'text_all',
640
+ 'position': position,
641
+ 'content': text,
642
+ 'fine_category_type': 'text_block'
643
+ })
644
+
645
+ pred_dataset = defaultdict(list)
646
+ pred_all = sorted(pred_all, key=lambda x: x['position'][0])
647
+ for item in pred_all:
648
+ pred_dataset[item['category_type']].append(item)
649
+ return pred_dataset
650
+
651
+
652
+ def extract_tex_table(content):
653
+ tables = []
654
+ tables_positions = []
655
+
656
+ pattern = r'\\begin{table}(.*?)\\end{table}'
657
+ for match in re.finditer(pattern, content, re.DOTALL):
658
+ start_pos = match.start()
659
+ end_pos = match.end()
660
+ table_content = match.group(0)
661
+ tables.append(table_content)
662
+ tables_positions.append((start_pos, end_pos))
663
+ content = content[:start_pos] + ' ' * (end_pos - start_pos) + content[end_pos:]
664
+
665
+ tabulars, tabular_positions = extract_tabular(content)
666
+ all_tables = tables + tabulars
667
+ all_positions = tables_positions + tabular_positions
668
+
669
+ all_result = sorted([[pos, table] for pos, table in zip(all_positions, all_tables)], key=lambda x: x[0][0])
670
+ all_tables = [x[1] for x in all_result]
671
+ all_positions = [x[0] for x in all_result]
672
+
673
+ return all_tables, all_positions
674
+
675
+
676
+ def extract_html_table(text):
677
+ begin_pattern = r'<table(?:[^>]*)>'
678
+ end_pattern = r'</table>'
679
+
680
+ tabulars = []
681
+ positions = []
682
+ current_pos = 0
683
+ stack = []
684
+
685
+ while current_pos < len(text):
686
+ begin_match = re.search(begin_pattern, text[current_pos:])
687
+ end_match = re.search(end_pattern, text[current_pos:])
688
+
689
+ if not begin_match and not end_match:
690
+ break
691
+
692
+ if begin_match and (not end_match or begin_match.start() < end_match.start()):
693
+ stack.append(current_pos + begin_match.start())
694
+ current_pos += begin_match.start() + len(end_pattern)
695
+ elif end_match:
696
+ if stack:
697
+ start_pos = stack.pop()
698
+ if not stack:
699
+ end_pos = current_pos + end_match.start() + len(end_pattern)
700
+ tabular_code = text[start_pos:end_pos]
701
+ tabulars.append(tabular_code)
702
+ positions.append((start_pos, end_pos))
703
+ current_pos += end_match.start() + len(end_pattern)
704
+ else:
705
+ current_pos += 1
706
+
707
+ if stack:
708
+ new_start = stack[0] + len(begin_pattern)
709
+ new_tabulars, new_positions = extract_html_table(text[new_start:])
710
+ new_positions = [(start + new_start, end + new_start) for start, end in new_positions]
711
+ tabulars.extend(new_tabulars)
712
+ positions.extend(new_positions)
713
+
714
+ return tabulars, positions
715
+
716
+
717
+ def extract_node_content(node):
718
+ """ Recursively extract content from LatexEnvironmentNode and rebuild LaTeX table representation """
719
+ if isinstance(node, LatexCharsNode):
720
+ return node.chars # Use chars attribute
721
+ elif isinstance(node, LatexGroupNode):
722
+ return '{' + ''.join(extract_node_content(n) for n in node.nodelist) + '}'
723
+ elif isinstance(node, LatexMacroNode):
724
+ # Extract macro command and its arguments
725
+ macro_content = '\\' + node.macroname
726
+ if node.nodeargs:
727
+ macro_content += ''.join([extract_node_content(arg) for arg in node.nodeargs])
728
+ return macro_content
729
+ elif isinstance(node, LatexEnvironmentNode):
730
+ # Extract environment, preserve environment name and arguments
731
+ content = '\\begin{' + node.environmentname + '}'
732
+ if node.nodeargd and node.nodeargd.argnlist:
733
+ # content += "".join("{" + extract_node_content(arg) + "}" for arg in node.nodeargd)
734
+ # content += "".join("{" + extract_node_content(node.nodeargd) + "}")
735
+ content += '{' + extract_node_content(node.nodeargd.argnlist[0]) + '}'
736
+ if node.nodelist:
737
+ content += ''.join(extract_node_content(n) for n in node.nodelist)
738
+ content += '\\end{' + node.environmentname + '}'
739
+ return content
740
+ elif isinstance(node, LatexSpecialsNode): # Changed to LatexSpecialsNode
741
+ return node.specials_chars
742
+ else:
743
+ return ''
744
+
745
+
746
+ def get_node_end_pos(node):
747
+ """Recursively determine the end position of a node"""
748
+ if hasattr(node, 'nodelist') and node.nodelist:
749
+ # If the node has child nodes, recursively find the end position of the last child node
750
+ return get_node_end_pos(node.nodelist[-1])
751
+ elif hasattr(node, 'pos_end'):
752
+ # If the node has pos_end attribute, return it directly
753
+ return node.pos_end
754
+ else:
755
+ # If there are no child nodes, assume the node ends at the last character of its content
756
+ return node.pos + len(str(node))
757
+
758
+
759
+ def remove_tex_table(content):
760
+ tables, positions = extract_tex_table(content)
761
+
762
+ # Delete in reverse order by position to avoid affecting unprocessed start positions
763
+ for start, end in sorted(positions, reverse=True):
764
+ content = content[:start] + content[end:] # Remove table content
765
+
766
+ return content
767
+
768
+
769
+ def get_pred_category_type(pred_idx, pred_items):
770
+ # if pred_idx:
771
+ if pred_items[pred_idx].get('fine_category_type'):
772
+ pred_pred_category_type = pred_items[pred_idx]['fine_category_type']
773
+ else:
774
+ pred_pred_category_type = pred_items[pred_idx]['category_type']
775
+ # else:
776
+ # pred_pred_category_type = ""
777
+ return pred_pred_category_type
778
+
779
+
780
+ def compute_edit_distance_matrix_new(gt_lines, matched_lines):
781
+ try:
782
+ distance_matrix = np.zeros((len(gt_lines), len(matched_lines)))
783
+ for i, gt_line in enumerate(gt_lines):
784
+ for j, matched_line in enumerate(matched_lines):
785
+ if len(gt_line) == 0 and len(matched_line) == 0:
786
+ distance_matrix[i][j] = 0
787
+ else:
788
+ distance_matrix[i][j] = Levenshtein.distance(gt_line,
789
+ matched_line) / max(len(matched_line), len(gt_line))
790
+ return distance_matrix
791
+ except ZeroDivisionError:
792
+ #print("ZeroDivisionError occurred. Outputting norm_gt_lines and norm_pred_lines:")
793
+ # print("norm_gt_lines:", gt_lines)
794
+ # print("norm_pred_lines:", matched_lines)
795
+ raise
796
+
797
+
798
+ def get_gt_pred_lines(gt_items, pred_items, line_type):
799
+ norm_html_lines = []
800
+ gt_lines = []
801
+ gt_cat_list = []
802
+ for item in gt_items:
803
+ if item.get('fine_category_type'):
804
+ gt_cat_list.append(item['fine_category_type'])
805
+ else:
806
+ gt_cat_list.append(item['category_type'])
807
+ if item.get('content'):
808
+ gt_lines.append(str(item['content']))
809
+ norm_html_lines.append(str(item['content']))
810
+ elif line_type == 'text':
811
+ gt_lines.append(str(item['text']))
812
+ elif line_type == 'html_table':
813
+ gt_lines.append(str(item['html']))
814
+ elif line_type == 'formula':
815
+ gt_lines.append(str(item['latex']))
816
+ elif line_type == 'latex_table':
817
+ gt_lines.append(str(item['latex']))
818
+ norm_html_lines.append(str(item['html']))
819
+
820
+ pred_lines = [str(item['content']) for item in pred_items]
821
+
822
+ if line_type == 'formula':
823
+ norm_gt_lines = [normalized_formula(_) for _ in gt_lines]
824
+ norm_pred_lines = [normalized_formula(_) for _ in pred_lines]
825
+ elif line_type == 'text':
826
+ # norm_gt_lines = [textblock_with_norm_formula(_) for _ in gt_lines]
827
+ # norm_pred_lines = [textblock_with_norm_formula(_) for _ in pred_lines]
828
+ norm_gt_lines = [clean_string(textblock2unicode(_)) for _ in gt_lines]
829
+ norm_pred_lines = [clean_string(textblock2unicode(_)) for _ in pred_lines]
830
+ # norm_gt_lines = get_norm_text_lines(gt_lines)
831
+ # norm_pred_lines = get_norm_text_lines(pred_lines)
832
+ else:
833
+ norm_gt_lines = gt_lines
834
+ norm_pred_lines = pred_lines
835
+
836
+ if line_type == 'latex_table':
837
+ gt_lines = norm_html_lines
838
+
839
+ filtered_lists = [(a, b, c) for a, b, c in zip(gt_lines, norm_gt_lines, gt_cat_list) if a and b]
840
+
841
+ # decompress to three lists
842
+ if filtered_lists:
843
+ gt_lines_c, norm_gt_lines_c, gt_cat_list_c = zip(*filtered_lists)
844
+
845
+ # convert to lists
846
+ gt_lines_c = list(gt_lines_c)
847
+ norm_gt_lines_c = list(norm_gt_lines_c)
848
+ gt_cat_list_c = list(gt_cat_list_c)
849
+ else:
850
+ gt_lines_c = []
851
+ norm_gt_lines_c = []
852
+ gt_cat_list_c = []
853
+
854
+ # pred's empty values
855
+ filtered_lists = [(a, b) for a, b in zip(pred_lines, norm_pred_lines) if a and b]
856
+
857
+ # decompress to two lists
858
+ if filtered_lists:
859
+ pred_lines_c, norm_pred_lines_c = zip(*filtered_lists)
860
+
861
+ # convert to lists
862
+ pred_lines_c = list(pred_lines_c)
863
+ norm_pred_lines_c = list(norm_pred_lines_c)
864
+ else:
865
+ pred_lines_c = []
866
+ norm_pred_lines_c = []
867
+
868
+ return gt_lines_c, norm_gt_lines_c, gt_cat_list_c, pred_lines_c, norm_pred_lines_c
869
+ # return gt_lines, norm_gt_lines, gt_cat_list, pred_lines, norm_pred_lines
870
+
871
+
872
+ def match_gt2pred_simple(gt_items, pred_items, line_type, img_name):
873
+
874
+ gt_lines, norm_gt_lines, gt_cat_list, pred_lines, norm_pred_lines = get_gt_pred_lines(
875
+ gt_items, pred_items, line_type
876
+ )
877
+
878
+ match_list = []
879
+ if not norm_gt_lines: # not matched pred should be concatenated
880
+ # print("One of the lists is empty. Returning an empty gt result.")
881
+ # for pred_idx in range(len(norm_pred_lines)):
882
+ pred_idx_list = range(len(norm_pred_lines))
883
+ match_list.append({
884
+ 'gt_idx': [''],
885
+ 'gt': '',
886
+ 'pred_idx': pred_idx_list,
887
+ 'pred': ''.join(pred_lines[_] for _ in pred_idx_list),
888
+ 'gt_position': [''],
889
+ 'pred_position': pred_items[pred_idx_list[0]]['position'][0], # get the first pred's position
890
+ 'norm_gt': '',
891
+ 'norm_pred': ''.join(norm_pred_lines[_] for _ in pred_idx_list),
892
+ 'gt_category_type': '',
893
+ 'pred_category_type': get_pred_category_type(pred_idx_list[0], pred_items), # get the first pred's category
894
+ 'gt_attribute': [{}],
895
+ 'edit': 1,
896
+ 'img_id': img_name
897
+ })
898
+ return match_list
899
+ elif not norm_pred_lines: # not matched gt should be separated
900
+ # print("One of the lists is empty. Returning an empty pred result.")
901
+ for gt_idx in range(len(norm_gt_lines)):
902
+ match_list.append({
903
+ 'gt_idx': [gt_idx],
904
+ 'gt':
905
+ gt_lines[gt_idx],
906
+ 'pred_idx': [''],
907
+ 'pred':
908
+ '',
909
+ 'gt_position': [
910
+ gt_items[gt_idx].get('order')
911
+ if gt_items[gt_idx].get('order') else gt_items[gt_idx].get('position', [''])[0]
912
+ ],
913
+ 'pred_position':
914
+ '',
915
+ 'norm_gt':
916
+ norm_gt_lines[gt_idx],
917
+ 'norm_pred':
918
+ '',
919
+ 'gt_category_type':
920
+ gt_cat_list[gt_idx],
921
+ 'pred_category_type':
922
+ '',
923
+ 'gt_attribute': [gt_items[gt_idx].get('attribute', {})],
924
+ 'edit':
925
+ 1,
926
+ 'img_id':
927
+ img_name
928
+ })
929
+ return match_list
930
+
931
+ cost_matrix = compute_edit_distance_matrix_new(norm_gt_lines, norm_pred_lines)
932
+
933
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
934
+
935
+ for gt_idx in range(len(norm_gt_lines)):
936
+ if gt_idx in row_ind:
937
+ row_i = list(row_ind).index(gt_idx)
938
+ pred_idx = int(col_ind[row_i])
939
+ pred_line = pred_lines[pred_idx]
940
+ norm_pred_line = norm_pred_lines[pred_idx]
941
+ edit = cost_matrix[gt_idx][pred_idx]
942
+ # print('edit_dist', edit)
943
+ # if edit > 0.7:
944
+ # print('! Not match')
945
+ else:
946
+ # print('No match pred')
947
+ pred_idx = ''
948
+ pred_line = ''
949
+ norm_pred_line = ''
950
+ edit = 1
951
+
952
+ match_list.append({
953
+ 'gt_idx': [gt_idx],
954
+ 'gt':
955
+ gt_lines[gt_idx],
956
+ 'norm_gt':
957
+ norm_gt_lines[gt_idx],
958
+ 'gt_category_type':
959
+ gt_cat_list[gt_idx],
960
+ 'gt_position': [
961
+ gt_items[gt_idx].get('order')
962
+ if gt_items[gt_idx].get('order') else gt_items[gt_idx].get('position', [''])[0]
963
+ ],
964
+ 'gt_attribute': [gt_items[gt_idx].get('attribute', {})],
965
+ 'pred_idx': [pred_idx],
966
+ 'pred':
967
+ pred_line,
968
+ 'norm_pred':
969
+ norm_pred_line,
970
+ 'pred_category_type':
971
+ get_pred_category_type(pred_idx, pred_items) if pred_idx else '',
972
+ 'pred_position':
973
+ pred_items[pred_idx]['position'][0] if pred_idx else '',
974
+ 'edit':
975
+ edit,
976
+ 'img_id':
977
+ img_name
978
+ })
979
+ # print('-'*10)
980
+ # [([0,1], 0),(2, 1), (1,2)] --> [0,2,1]/[0,1,2]
981
+
982
+ pred_idx_list = [
983
+ pred_idx for pred_idx in range(len(norm_pred_lines)) if pred_idx not in col_ind
984
+ ] # get not matched preds
985
+ if pred_idx_list: # if there are still remaining pred_idx, concatenate all preds
986
+ match_list.append({
987
+ 'gt_idx': [''],
988
+ 'gt': '',
989
+ 'pred_idx': pred_idx_list,
990
+ 'pred': ''.join(pred_lines[_] for _ in pred_idx_list),
991
+ 'gt_position': [''],
992
+ 'pred_position': pred_items[pred_idx_list[0]]['position'][0], # get the first pred's position
993
+ 'norm_gt': '',
994
+ 'norm_pred': ''.join(norm_pred_lines[_] for _ in pred_idx_list),
995
+ 'gt_category_type': '',
996
+ 'pred_category_type': get_pred_category_type(pred_idx_list[0], pred_items), # get the first pred's category
997
+ 'gt_attribute': [{}],
998
+ 'edit': 1,
999
+ 'img_id': img_name
1000
+ })
1001
+ return match_list
1002
+
1003
+
1004
+ def match_gt2pred_no_split(gt_items, pred_items, line_type, img_name):
1005
+ # directly concatenate gt and pred by position
1006
+ gt_lines, norm_gt_lines, gt_cat_list, pred_lines, norm_pred_lines = get_gt_pred_lines(
1007
+ gt_items, pred_items, line_type
1008
+ )
1009
+ gt_line_with_position = []
1010
+ for gt_line, norm_gt_line, gt_item in zip(gt_lines, norm_gt_lines, gt_items):
1011
+ gt_position = gt_item['order'] if gt_item.get('order') else gt_item.get('position', [''])[0]
1012
+ if gt_position:
1013
+ gt_line_with_position.append((gt_position, gt_line, norm_gt_line))
1014
+ sorted_gt_lines = sorted(gt_line_with_position, key=lambda x: x[0])
1015
+ gt = '\n\n'.join([_[1] for _ in sorted_gt_lines])
1016
+ norm_gt = '\n\n'.join([_[2] for _ in sorted_gt_lines])
1017
+ pred_line_with_position = [(pred_item['position'], pred_line, pred_norm_line)
1018
+ for pred_line, pred_norm_line, pred_item in zip(pred_lines, norm_pred_lines, pred_items)]
1019
+ sorted_pred_lines = sorted(pred_line_with_position, key=lambda x: x[0])
1020
+ pred = '\n\n'.join([_[1] for _ in sorted_pred_lines])
1021
+ norm_pred = '\n\n'.join([_[2] for _ in sorted_pred_lines])
1022
+ # edit = Levenshtein.distance(norm_gt, norm_pred)/max(len(norm_gt), len(norm_pred))
1023
+ if norm_gt or norm_pred:
1024
+ return [{
1025
+ 'gt_idx': [0],
1026
+ 'gt': gt,
1027
+ 'norm_gt': norm_gt,
1028
+ 'gt_category_type': 'text_merge',
1029
+ 'gt_position': [''],
1030
+ 'gt_attribute': [{}],
1031
+ 'pred_idx': [0],
1032
+ 'pred': pred,
1033
+ 'norm_pred': norm_pred,
1034
+ 'pred_category_type': 'text_merge',
1035
+ 'pred_position': '',
1036
+ # 'edit': edit,
1037
+ 'img_id': img_name
1038
+ }]
1039
+ else:
1040
+ return []
1041
+
1042
+
1043
+ import copy
1044
+ import evaluate
1045
+
1046
+ # from rapidfuzz.distance import Levenshtein
1047
+ import Levenshtein
1048
+ import numpy as np
1049
+ import pdb
1050
+ from collections import Counter, defaultdict
1051
+ from Levenshtein import distance as Levenshtein_distance
1052
+ from scipy.optimize import linear_sum_assignment
1053
+
1054
+
1055
+ def match_gt2pred_quick(gt_items, pred_items, line_type, img_name):
1056
+
1057
+ gt_lines, norm_gt_lines, gt_cat_list, pred_lines, norm_pred_lines = get_gt_pred_lines(
1058
+ gt_items, pred_items, line_type
1059
+ )
1060
+ all_gt_indices = set(range(len(norm_gt_lines)))
1061
+ all_pred_indices = set(range(len(norm_pred_lines)))
1062
+
1063
+ if not norm_gt_lines:
1064
+ match_list = []
1065
+ for pred_idx in range(len(norm_pred_lines)):
1066
+ match_list.append({
1067
+ 'gt_idx': [''],
1068
+ 'gt': '',
1069
+ 'pred_idx': [pred_idx],
1070
+ 'pred': pred_lines[pred_idx],
1071
+ 'gt_position': '',
1072
+ 'pred_position': pred_items[pred_idx]['position'][0],
1073
+ 'norm_gt': '',
1074
+ 'norm_pred': norm_pred_lines[pred_idx],
1075
+ 'gt_category_type': '',
1076
+ 'pred_category_type': get_pred_category_type(pred_idx, pred_items),
1077
+ 'gt_attribute': [{}],
1078
+ 'edit': 1,
1079
+ 'img_id': img_name
1080
+ })
1081
+ return match_list
1082
+ elif not norm_pred_lines:
1083
+ match_list = []
1084
+ for gt_idx in range(len(norm_gt_lines)):
1085
+ match_list.append({
1086
+ 'gt_idx': [gt_idx],
1087
+ 'gt':
1088
+ gt_lines[gt_idx],
1089
+ 'pred_idx': [''],
1090
+ 'pred':
1091
+ '',
1092
+ 'gt_position': [
1093
+ gt_items[gt_idx].get('order')
1094
+ if gt_items[gt_idx].get('order') else gt_items[gt_idx].get('position', [''])[0]
1095
+ ],
1096
+ 'pred_position':
1097
+ '',
1098
+ 'norm_gt':
1099
+ norm_gt_lines[gt_idx],
1100
+ 'norm_pred':
1101
+ '',
1102
+ 'gt_category_type':
1103
+ gt_cat_list[gt_idx],
1104
+ 'pred_category_type':
1105
+ '',
1106
+ 'gt_attribute': [gt_items[gt_idx].get('attribute', {})],
1107
+ 'edit':
1108
+ 1,
1109
+ 'img_id':
1110
+ img_name
1111
+ })
1112
+ return match_list
1113
+ elif len(norm_gt_lines) == 1 and len(norm_pred_lines) == 1:
1114
+ edit_distance = Levenshtein_distance(norm_gt_lines[0], norm_pred_lines[0])
1115
+ normalized_edit_distance = edit_distance / max(len(norm_gt_lines[0]), len(norm_pred_lines[0]))
1116
+ return [{
1117
+ 'gt_idx': [0],
1118
+ 'gt':
1119
+ gt_lines[0],
1120
+ 'pred_idx': [0],
1121
+ 'pred':
1122
+ pred_lines[0],
1123
+ 'gt_position':
1124
+ [gt_items[0].get('order') if gt_items[0].get('order') else gt_items[0].get('position', [''])[0]],
1125
+ 'pred_position':
1126
+ pred_items[0]['position'][0],
1127
+ 'norm_gt':
1128
+ norm_gt_lines[0],
1129
+ 'norm_pred':
1130
+ norm_pred_lines[0],
1131
+ 'gt_category_type':
1132
+ gt_cat_list[0],
1133
+ 'pred_category_type':
1134
+ get_pred_category_type(0, pred_items),
1135
+ 'gt_attribute': [gt_items[0].get('attribute', {})],
1136
+ 'edit':
1137
+ normalized_edit_distance,
1138
+ 'img_id':
1139
+ img_name
1140
+ }]
1141
+
1142
+ cost_matrix = compute_edit_distance_matrix_new(norm_gt_lines, norm_pred_lines)
1143
+
1144
+ matched_col_idx, row_ind, cost_list = cal_final_match(cost_matrix, norm_gt_lines, norm_pred_lines)
1145
+
1146
+ gt_lens_dict, pred_lens_dict = initialize_indices(norm_gt_lines, norm_pred_lines)
1147
+
1148
+ matches, unmatched_gt_indices, unmatched_pred_indices = process_matches(
1149
+ matched_col_idx, row_ind, cost_list, norm_gt_lines, norm_pred_lines, pred_lines
1150
+ )
1151
+
1152
+ matching_dict = fuzzy_match_unmatched_items(unmatched_gt_indices, norm_gt_lines, norm_pred_lines)
1153
+
1154
+ final_matches = merge_matches(matches, matching_dict)
1155
+
1156
+ recalculate_edit_distances(final_matches, gt_lens_dict, norm_gt_lines, norm_pred_lines)
1157
+
1158
+ converted_results = convert_final_matches(final_matches, norm_gt_lines, norm_pred_lines)
1159
+
1160
+ merged_results = merge_duplicates_add_unmatched(
1161
+ converted_results, norm_gt_lines, norm_pred_lines, gt_lines, pred_lines, all_gt_indices, all_pred_indices
1162
+ )
1163
+
1164
+ for entry in merged_results:
1165
+ entry['gt_idx'] = [entry['gt_idx']] if not isinstance(entry['gt_idx'], list) else entry['gt_idx']
1166
+ entry['pred_idx'] = [entry['pred_idx']] if not isinstance(entry['pred_idx'], list) else entry['pred_idx']
1167
+ entry['gt_position'] = [
1168
+ gt_items[_].get('order') if gt_items[_].get('order') else gt_items[_].get('position', [''])[0]
1169
+ for _ in entry['gt_idx']
1170
+ ] if entry['gt_idx'] != [''] else ['']
1171
+ entry['pred_position'] = pred_items[entry['pred_idx'][0]]['position'][0] if entry['pred_idx'] != [''] else ''
1172
+ entry['gt'] = ''.join([gt_lines[_] for _ in entry['gt_idx']]) if entry['gt_idx'] != [''] else ''
1173
+ entry['pred'] = ''.join([pred_lines[_] for _ in entry['pred_idx']]) if entry['pred_idx'] != [''] else ''
1174
+ entry['norm_gt'] = ''.join([norm_gt_lines[_] for _ in entry['gt_idx']]) if entry['gt_idx'] != [''] else ''
1175
+ entry['norm_pred'] = ''.join([norm_pred_lines[_]
1176
+ for _ in entry['pred_idx']]) if entry['pred_idx'] != [''] else ''
1177
+
1178
+ if entry['gt_idx'] != ['']:
1179
+ ignore_type = [
1180
+ 'figure_caption', 'figure_footnote', 'table_caption', 'table_footnote', 'code_algorithm',
1181
+ 'code_algorithm_caption', 'header', 'footer', 'page_footnote', 'page_number', 'equation_caption'
1182
+ ]
1183
+ gt_cagegory_clean = [gt_cat_list[_] for _ in entry['gt_idx'] if gt_cat_list[_] not in ignore_type]
1184
+ if gt_cagegory_clean:
1185
+ entry['gt_category_type'] = Counter(gt_cagegory_clean).most_common(1)[0][0]
1186
+ else:
1187
+ entry['gt_category_type'] = Counter([gt_cat_list[_] for _ in entry['gt_idx']]).most_common(1)[0][0]
1188
+ else:
1189
+ entry['gt_category_type'] = ''
1190
+ entry['pred_category_type'] = get_pred_category_type(entry['pred_idx'][0],
1191
+ pred_items) if entry['pred_idx'] != [''] else ''
1192
+ entry['gt_attribute'] = [gt_items[_].get('attribute', {})
1193
+ for _ in entry['gt_idx']] if entry['gt_idx'] != [''] else [{}]
1194
+ entry['img_id'] = img_name
1195
+
1196
+ return merged_results
1197
+
1198
+
1199
+ def merge_duplicates_add_unmatched(
1200
+ converted_results, norm_gt_lines, norm_pred_lines, gt_lines, pred_lines, all_gt_indices, all_pred_indices
1201
+ ):
1202
+ merged_results = []
1203
+ processed_pred = set()
1204
+ processed_gt = set()
1205
+
1206
+ for entry in converted_results:
1207
+ pred_idx = tuple(entry['pred_idx']) if isinstance(entry['pred_idx'], list) else (entry['pred_idx'], )
1208
+ if pred_idx not in processed_pred and pred_idx != ('', ):
1209
+ merged_entry = {
1210
+ 'gt_idx': [entry['gt_idx']],
1211
+ 'gt': entry['gt'],
1212
+ 'pred_idx': entry['pred_idx'],
1213
+ 'pred': entry['pred'],
1214
+ 'edit': entry['edit']
1215
+ }
1216
+ for other_entry in converted_results:
1217
+ other_pred_idx = tuple(other_entry['pred_idx']) if isinstance(other_entry['pred_idx'],
1218
+ list) else (other_entry['pred_idx'], )
1219
+ if other_pred_idx == pred_idx and other_entry is not entry:
1220
+ merged_entry['gt_idx'].append(other_entry['gt_idx'])
1221
+ merged_entry['gt'] += other_entry['gt']
1222
+ processed_gt.add(other_entry['gt_idx'])
1223
+ merged_results.append(merged_entry)
1224
+ processed_pred.add(pred_idx)
1225
+ processed_gt.add(entry['gt_idx'])
1226
+
1227
+ for entry in converted_results:
1228
+ if entry['gt_idx'] not in processed_gt:
1229
+ merged_results.append(entry)
1230
+
1231
+ for gt_idx in range(len(norm_gt_lines)):
1232
+ if gt_idx not in processed_gt:
1233
+ merged_results.append({'gt_idx': [gt_idx], 'gt': gt_lines[gt_idx], 'pred_idx': [''], 'pred': '', 'edit': 1})
1234
+ return merged_results
1235
+
1236
+
1237
+ def formula_format(formula_matches, img_name):
1238
+ return [{
1239
+ 'gt': item['gt'],
1240
+ 'pred': item['pred'],
1241
+ 'img_id': f'{img_name}_{i}'
1242
+ } for i, item in enumerate(formula_matches)]
1243
+
1244
+
1245
+ def merge_lists_with_sublists(main_list, sub_lists):
1246
+ main_list_final = list(copy.deepcopy(main_list))
1247
+ for sub_list in sub_lists:
1248
+ pop_idx = main_list_final.index(sub_list[0])
1249
+ for _ in sub_list:
1250
+ main_list_final.pop(pop_idx)
1251
+ main_list_final.insert(pop_idx, sub_list)
1252
+ return main_list_final
1253
+
1254
+
1255
+ def sub_pred_fuzzy_matching(gt, pred):
1256
+
1257
+ min_d = float('inf')
1258
+ # pos = -1
1259
+
1260
+ gt_len = len(gt)
1261
+ pred_len = len(pred)
1262
+
1263
+ if gt_len >= pred_len and pred_len > 0:
1264
+ for i in range(gt_len - pred_len + 1):
1265
+ sub = gt[i:i + pred_len]
1266
+ dist = Levenshtein_distance(sub, pred) / pred_len
1267
+ if dist < min_d:
1268
+ min_d = dist
1269
+ pos = i
1270
+
1271
+ return min_d
1272
+ else:
1273
+ return False
1274
+
1275
+
1276
+ def sub_gt_fuzzy_matching(pred, gt):
1277
+
1278
+ min_d = float('inf')
1279
+ pos = ''
1280
+ matched_sub = ''
1281
+ gt_len = len(gt)
1282
+ pred_len = len(pred)
1283
+
1284
+ if pred_len >= gt_len and gt_len > 0:
1285
+ for i in range(pred_len - gt_len + 1):
1286
+ sub = pred[i:i + gt_len]
1287
+ dist = Levenshtein.distance(sub, gt) / gt_len
1288
+ if dist < min_d:
1289
+ min_d = dist
1290
+ pos = i
1291
+ matched_sub = sub
1292
+ return min_d, pos, gt_len, matched_sub
1293
+ else:
1294
+ return 1, '', gt_len, ''
1295
+
1296
+
1297
+ def get_final_subset(subset_certain, subset_certain_cost):
1298
+ if not subset_certain or not subset_certain_cost:
1299
+ return []
1300
+
1301
+ subset_turple = sorted([(a, b) for a, b in zip(subset_certain, subset_certain_cost)], key=lambda x: x[0][0])
1302
+
1303
+ group_list = defaultdict(list)
1304
+ group_idx = 0
1305
+ group_list[group_idx].append(subset_turple[0])
1306
+
1307
+ for item in subset_turple[1:]:
1308
+ overlap_flag = False
1309
+ for subset in group_list[group_idx]:
1310
+ for idx in item[0]:
1311
+ if idx in subset[0]:
1312
+ overlap_flag = True
1313
+ break
1314
+ if overlap_flag:
1315
+ break
1316
+ if overlap_flag:
1317
+ group_list[group_idx].append(item)
1318
+ else:
1319
+ group_idx += 1
1320
+ group_list[group_idx].append(item)
1321
+
1322
+ final_subset = []
1323
+ for _, group in group_list.items():
1324
+ if len(group) == 1:
1325
+ final_subset.append(group[0][0])
1326
+ else:
1327
+ path_dict = defaultdict(list)
1328
+ path_idx = 0
1329
+ path_dict[path_idx].append(group[0])
1330
+
1331
+ for subset in group[1:]:
1332
+ new_path = True
1333
+ for path_idx_s, path_items in path_dict.items():
1334
+ is_dup = False
1335
+ is_same = False
1336
+ for path_item in path_items:
1337
+ if path_item[0] == subset[0]:
1338
+ is_dup = True
1339
+ is_same = True
1340
+ if path_item[1] > subset[1]:
1341
+ path_dict[path_idx_s].pop(path_dict[path_idx_s].index(path_item))
1342
+ path_dict[path_idx_s].append(subset)
1343
+ else:
1344
+ for num_1 in path_item[0]:
1345
+ for num_2 in subset[0]:
1346
+ if num_1 == num_2:
1347
+ is_dup = True
1348
+ if not is_dup:
1349
+ path_dict[path_idx_s].append(subset)
1350
+ new_path = False
1351
+ if is_same:
1352
+ new_path = False
1353
+ if new_path:
1354
+ path_idx = len(path_dict.keys())
1355
+ path_dict[path_idx].append(subset)
1356
+
1357
+ saved_cost = float('inf')
1358
+ saved_subset = []
1359
+ for path_idx, path in path_dict.items():
1360
+ avg_cost = sum([i[1] for i in path]) / len(path)
1361
+ if avg_cost < saved_cost:
1362
+ saved_subset = [i[0] for i in path]
1363
+ saved_cost = avg_cost
1364
+
1365
+ final_subset.extend(saved_subset)
1366
+
1367
+ return final_subset
1368
+
1369
+
1370
+ def judge_pred_merge(gt_list, pred_list, threshold=0.6):
1371
+ if len(pred_list) == 1:
1372
+ return False, False
1373
+
1374
+ cur_pred = ' '.join(pred_list[:-1])
1375
+ merged_pred = ' '.join(pred_list)
1376
+
1377
+ cur_dist = Levenshtein.distance(gt_list[0], cur_pred) / max(len(gt_list[0]), len(cur_pred))
1378
+ merged_dist = Levenshtein.distance(gt_list[0], merged_pred) / max(len(gt_list[0]), len(merged_pred))
1379
+
1380
+ if merged_dist > cur_dist:
1381
+ return False, False
1382
+
1383
+ cur_fuzzy_dists = [sub_pred_fuzzy_matching(gt_list[0], cur_pred) for cur_pred in pred_list[:-1]]
1384
+ if any(dist is False or dist > threshold for dist in cur_fuzzy_dists):
1385
+ return False, False
1386
+
1387
+ add_fuzzy_dist = sub_pred_fuzzy_matching(gt_list[0], pred_list[-1])
1388
+ if add_fuzzy_dist is False:
1389
+ return False, False
1390
+
1391
+ merged_pred_flag = add_fuzzy_dist < threshold
1392
+ continue_flag = len(merged_pred) <= len(gt_list[0])
1393
+
1394
+ return merged_pred_flag, continue_flag
1395
+
1396
+
1397
+ def deal_with_truncated(cost_matrix, norm_gt_lines, norm_pred_lines):
1398
+ matched_first = np.argwhere(cost_matrix < 0.25)
1399
+ masked_gt_idx = [i[0] for i in matched_first]
1400
+ unmasked_gt_idx = [i for i in range(cost_matrix.shape[0]) if i not in masked_gt_idx]
1401
+ masked_pred_idx = [i[1] for i in matched_first]
1402
+ unmasked_pred_idx = [i for i in range(cost_matrix.shape[1]) if i not in masked_pred_idx]
1403
+
1404
+ merges_gt_dict = {}
1405
+ merges_pred_dict = {}
1406
+ merged_gt_subsets = []
1407
+
1408
+ for gt_idx in unmasked_gt_idx:
1409
+ check_merge_subset = []
1410
+ merged_dist = []
1411
+
1412
+ for pred_idx in unmasked_pred_idx:
1413
+ step = 1
1414
+ merged_pred = [norm_pred_lines[pred_idx]]
1415
+
1416
+ while True:
1417
+ if pred_idx + step in masked_pred_idx or pred_idx + step >= len(norm_pred_lines):
1418
+ break
1419
+ else:
1420
+ merged_pred.append(norm_pred_lines[pred_idx + step])
1421
+ merged_pred_flag, continue_flag = judge_pred_merge([norm_gt_lines[gt_idx]], merged_pred)
1422
+ if not merged_pred_flag:
1423
+ break
1424
+ else:
1425
+ step += 1
1426
+ if not continue_flag:
1427
+ break
1428
+
1429
+ check_merge_subset.append(list(range(pred_idx, pred_idx + step)))
1430
+ matched_line = ' '.join([norm_pred_lines[i] for i in range(pred_idx, pred_idx + step)])
1431
+ dist = Levenshtein_distance(norm_gt_lines[gt_idx],
1432
+ matched_line) / max(len(matched_line), len(norm_gt_lines[gt_idx]))
1433
+ merged_dist.append(dist)
1434
+
1435
+ if not merged_dist:
1436
+ subset_certain = []
1437
+ min_cost_idx = ''
1438
+ min_cost = float('inf')
1439
+ else:
1440
+ min_cost = min(merged_dist)
1441
+ min_cost_idx = merged_dist.index(min_cost)
1442
+ subset_certain = check_merge_subset[min_cost_idx]
1443
+
1444
+ merges_gt_dict[gt_idx] = {
1445
+ 'merge_subset': check_merge_subset,
1446
+ 'merged_cost': merged_dist,
1447
+ 'min_cost_idx': min_cost_idx,
1448
+ 'subset_certain': subset_certain,
1449
+ 'min_cost': min_cost
1450
+ }
1451
+
1452
+ subset_certain = [
1453
+ merges_gt_dict[gt_idx]['subset_certain']
1454
+ for gt_idx in unmasked_gt_idx
1455
+ if merges_gt_dict[gt_idx]['subset_certain']
1456
+ ]
1457
+ subset_certain_cost = [
1458
+ merges_gt_dict[gt_idx]['min_cost'] for gt_idx in unmasked_gt_idx if merges_gt_dict[gt_idx]['subset_certain']
1459
+ ]
1460
+
1461
+ subset_certain_final = get_final_subset(subset_certain, subset_certain_cost)
1462
+
1463
+ if not subset_certain_final:
1464
+ return cost_matrix, norm_pred_lines, range(len(norm_pred_lines))
1465
+
1466
+ final_pred_idx_list = merge_lists_with_sublists(range(len(norm_pred_lines)), subset_certain_final)
1467
+ final_norm_pred_lines = [
1468
+ ' '.join(norm_pred_lines[idx_list[0]:idx_list[-1]
1469
+ + 1]) if isinstance(idx_list, list) else norm_pred_lines[idx_list]
1470
+ for idx_list in final_pred_idx_list
1471
+ ]
1472
+
1473
+ new_cost_matrix = compute_edit_distance_matrix_new(norm_gt_lines, final_norm_pred_lines)
1474
+
1475
+ return new_cost_matrix, final_norm_pred_lines, final_pred_idx_list
1476
+
1477
+
1478
+ def cal_move_dist(gt, pred):
1479
+ assert len(gt) == len(pred), 'Not right length'
1480
+ step = 0
1481
+ for i, gt_c in enumerate(gt):
1482
+ if gt_c != pred[i]:
1483
+ step += abs(i - pred.index(gt_c))
1484
+ pred[i], pred[pred.index(gt_c)] = pred[pred.index(gt_c)], pred[i]
1485
+ return step / len(gt)
1486
+
1487
+
1488
+ def cal_final_match(cost_matrix, norm_gt_lines, norm_pred_lines):
1489
+ min_indice = cost_matrix.argmax(axis=1)
1490
+
1491
+ new_cost_matrix, final_norm_pred_lines, final_pred_idx_list = deal_with_truncated(
1492
+ cost_matrix, norm_gt_lines, norm_pred_lines
1493
+ )
1494
+
1495
+ row_ind, col_ind = linear_sum_assignment(new_cost_matrix)
1496
+
1497
+ cost_list = [new_cost_matrix[r][c] for r, c in zip(row_ind, col_ind)]
1498
+ matched_col_idx = [final_pred_idx_list[i] for i in col_ind]
1499
+
1500
+ return matched_col_idx, row_ind, cost_list
1501
+
1502
+
1503
+ def initialize_indices(norm_gt_lines, norm_pred_lines):
1504
+ gt_lens_dict = {idx: len(gt_line) for idx, gt_line in enumerate(norm_gt_lines)}
1505
+ pred_lens_dict = {idx: len(pred_line) for idx, pred_line in enumerate(norm_pred_lines)}
1506
+ return gt_lens_dict, pred_lens_dict
1507
+
1508
+
1509
+ def process_matches(matched_col_idx, row_ind, cost_list, norm_gt_lines, norm_pred_lines, pred_lines):
1510
+ matches = {}
1511
+ unmatched_gt_indices = []
1512
+ unmatched_pred_indices = []
1513
+
1514
+ for i in range(len(norm_gt_lines)):
1515
+ if i in row_ind:
1516
+ idx = list(row_ind).index(i)
1517
+ pred_idx = matched_col_idx[idx]
1518
+
1519
+ if pred_idx is None or (isinstance(pred_idx, list) and None in pred_idx):
1520
+ unmatched_pred_indices.append(pred_idx)
1521
+ continue
1522
+
1523
+ if isinstance(pred_idx, list):
1524
+ pred_line = ' | '.join(norm_pred_lines[pred_idx[0]:pred_idx[-1] + 1])
1525
+ ori_pred_line = ' | '.join(pred_lines[pred_idx[0]:pred_idx[-1] + 1])
1526
+ matched_pred_indices_range = list(range(pred_idx[0], pred_idx[-1] + 1))
1527
+ else:
1528
+ pred_line = norm_pred_lines[pred_idx]
1529
+ ori_pred_line = pred_lines[pred_idx]
1530
+ matched_pred_indices_range = [pred_idx]
1531
+
1532
+ edit = cost_list[idx]
1533
+
1534
+ if edit > 0.7:
1535
+ unmatched_pred_indices.extend(matched_pred_indices_range)
1536
+ unmatched_gt_indices.append(i)
1537
+ else:
1538
+ matches[i] = {
1539
+ 'pred_indices': matched_pred_indices_range,
1540
+ 'edit_distance': edit,
1541
+ }
1542
+ for matched_pred_idx in matched_pred_indices_range:
1543
+ if matched_pred_idx in unmatched_pred_indices:
1544
+ unmatched_pred_indices.remove(matched_pred_idx)
1545
+ else:
1546
+ unmatched_gt_indices.append(i)
1547
+
1548
+ return matches, unmatched_gt_indices, unmatched_pred_indices
1549
+
1550
+
1551
+ def fuzzy_match_unmatched_items(unmatched_gt_indices, norm_gt_lines, norm_pred_lines):
1552
+ matching_dict = {}
1553
+
1554
+ for pred_idx, pred_content in enumerate(norm_pred_lines):
1555
+ if isinstance(pred_idx, list):
1556
+ continue
1557
+
1558
+ matching_indices = []
1559
+
1560
+ for unmatched_gt_idx in unmatched_gt_indices:
1561
+ gt_content = norm_gt_lines[unmatched_gt_idx]
1562
+ cur_fuzzy_dist_unmatch, cur_pos, gt_lens, matched_field = sub_gt_fuzzy_matching(pred_content, gt_content)
1563
+ if cur_fuzzy_dist_unmatch < 0.4:
1564
+ matching_indices.append(unmatched_gt_idx)
1565
+
1566
+ if matching_indices:
1567
+ matching_dict[pred_idx] = matching_indices
1568
+
1569
+ return matching_dict
1570
+
1571
+
1572
+ def merge_matches(matches, matching_dict):
1573
+ final_matches = {}
1574
+ processed_gt_indices = set()
1575
+
1576
+ for gt_idx, match_info in matches.items():
1577
+ pred_indices = match_info['pred_indices']
1578
+ edit_distance = match_info['edit_distance']
1579
+
1580
+ pred_key = tuple(sorted(pred_indices))
1581
+
1582
+ if pred_key in final_matches:
1583
+ if gt_idx not in processed_gt_indices:
1584
+ final_matches[pred_key]['gt_indices'].append(gt_idx)
1585
+ processed_gt_indices.add(gt_idx)
1586
+ else:
1587
+ final_matches[pred_key] = {'gt_indices': [gt_idx], 'edit_distance': edit_distance}
1588
+ processed_gt_indices.add(gt_idx)
1589
+
1590
+ for pred_idx, gt_indices in matching_dict.items():
1591
+ pred_key = (pred_idx, ) if not isinstance(pred_idx, (list, tuple)) else tuple(sorted(pred_idx))
1592
+
1593
+ if pred_key in final_matches:
1594
+ for gt_idx in gt_indices:
1595
+ if gt_idx not in processed_gt_indices:
1596
+ final_matches[pred_key]['gt_indices'].append(gt_idx)
1597
+ processed_gt_indices.add(gt_idx)
1598
+ else:
1599
+ final_matches[pred_key] = {
1600
+ 'gt_indices': [gt_idx for gt_idx in gt_indices if gt_idx not in processed_gt_indices],
1601
+ 'edit_distance': None
1602
+ }
1603
+ processed_gt_indices.update(final_matches[pred_key]['gt_indices'])
1604
+
1605
+ return final_matches
1606
+
1607
+
1608
+ def recalculate_edit_distances(final_matches, gt_lens_dict, norm_gt_lines, norm_pred_lines):
1609
+ for pred_key, info in final_matches.items():
1610
+ gt_indices = sorted(set(info['gt_indices']))
1611
+
1612
+ if not gt_indices:
1613
+ info['edit_distance'] = 1
1614
+ continue
1615
+
1616
+ if len(gt_indices) > 1:
1617
+ merged_gt_content = ''.join(norm_gt_lines[gt_idx] for gt_idx in gt_indices)
1618
+ pred_content = norm_pred_lines[pred_key[0]] if isinstance(pred_key[0], int) else ''
1619
+
1620
+ try:
1621
+ edit_distance = Levenshtein_distance(merged_gt_content, pred_content)
1622
+ normalized_edit_distance = edit_distance / max(len(merged_gt_content), len(pred_content))
1623
+ except ZeroDivisionError:
1624
+ normalized_edit_distance = 1
1625
+
1626
+ info['edit_distance'] = normalized_edit_distance
1627
+ else:
1628
+ gt_idx = gt_indices[0]
1629
+ pred_content = ' '.join(norm_pred_lines[pred_idx] for pred_idx in pred_key if isinstance(pred_idx, int))
1630
+
1631
+ try:
1632
+ edit_distance = Levenshtein_distance(norm_gt_lines[gt_idx], pred_content)
1633
+ normalized_edit_distance = edit_distance / max(len(norm_gt_lines[gt_idx]), len(pred_content))
1634
+ except ZeroDivisionError:
1635
+ normalized_edit_distance = 1
1636
+
1637
+ info['edit_distance'] = normalized_edit_distance
1638
+ info['pred_content'] = pred_content
1639
+
1640
+
1641
+ def convert_final_matches(final_matches, norm_gt_lines, norm_pred_lines):
1642
+ converted_results = []
1643
+
1644
+ all_gt_indices = set(range(len(norm_gt_lines)))
1645
+ all_pred_indices = set(range(len(norm_pred_lines)))
1646
+
1647
+ for pred_key, info in final_matches.items():
1648
+ pred_content = ' '.join(norm_pred_lines[pred_idx] for pred_idx in pred_key if isinstance(pred_idx, int))
1649
+
1650
+ for gt_idx in sorted(set(info['gt_indices'])):
1651
+ result_entry = {
1652
+ 'gt_idx': int(gt_idx),
1653
+ 'gt': norm_gt_lines[gt_idx],
1654
+ 'pred_idx': list(pred_key),
1655
+ 'pred': pred_content,
1656
+ 'edit': info['edit_distance']
1657
+ }
1658
+ converted_results.append(result_entry)
1659
+
1660
+ matched_gt_indices = set().union(*[set(info['gt_indices']) for info in final_matches.values()])
1661
+ unmatched_gt_indices = all_gt_indices - matched_gt_indices
1662
+ matched_pred_indices = set(idx for pred_key in final_matches.keys() for idx in pred_key if isinstance(idx, int))
1663
+ unmatched_pred_indices = all_pred_indices - matched_pred_indices
1664
+
1665
+ if unmatched_pred_indices:
1666
+ if unmatched_gt_indices:
1667
+ distance_matrix = [[
1668
+ Levenshtein_distance(norm_gt_lines[gt_idx], norm_pred_lines[pred_idx])
1669
+ for pred_idx in unmatched_pred_indices
1670
+ ]
1671
+ for gt_idx in unmatched_gt_indices]
1672
+
1673
+ row_ind, col_ind = linear_sum_assignment(distance_matrix)
1674
+
1675
+ for i, j in zip(row_ind, col_ind):
1676
+ gt_idx = list(unmatched_gt_indices)[i]
1677
+ pred_idx = list(unmatched_pred_indices)[j]
1678
+ result_entry = {
1679
+ 'gt_idx': int(gt_idx),
1680
+ 'gt': norm_gt_lines[gt_idx],
1681
+ 'pred_idx': [pred_idx],
1682
+ 'pred': norm_pred_lines[pred_idx],
1683
+ 'edit': 1
1684
+ }
1685
+ converted_results.append(result_entry)
1686
+
1687
+ matched_gt_indices.update(list(unmatched_gt_indices)[i] for i in row_ind)
1688
+ else:
1689
+ result_entry = {
1690
+ 'gt_idx': '',
1691
+ 'gt': '',
1692
+ 'pred_idx': list(unmatched_pred_indices),
1693
+ 'pred': ' '.join(norm_pred_lines[pred_idx] for pred_idx in unmatched_pred_indices),
1694
+ 'edit': 1
1695
+ }
1696
+ converted_results.append(result_entry)
1697
+ else:
1698
+ for gt_idx in unmatched_gt_indices:
1699
+ result_entry = {'gt_idx': int(gt_idx), 'gt': norm_gt_lines[gt_idx], 'pred_idx': '', 'pred': '', 'edit': 1}
1700
+ converted_results.append(result_entry)
1701
+
1702
+ return converted_results
1703
+
1704
+
1705
+ import json
1706
+
1707
+
1708
+ def read_md_file(filepath):
1709
+ with open(filepath, 'r', encoding='utf-8') as file:
1710
+ content = file.read()
1711
+
1712
+ return content
1713
+
1714
+
1715
+ def save_paired_result(preds, gts, save_path):
1716
+ save_result = []
1717
+ formula_id = 0
1718
+ for gt, pred in zip(gts, preds):
1719
+ save_result.append({'gt': gt, 'pred': pred, 'img_id': formula_id})
1720
+ formula_id += 1
1721
+ with open(save_path, 'w', encoding='utf-8') as f:
1722
+ json.dump(save_result, f, indent=4, ensure_ascii=False)
1723
+
1724
+
1725
+ import numpy as np
1726
+ import os
1727
+ import re
1728
+
1729
+
1730
+ def print_aligned_dict(data):
1731
+ # Find the maximum length of all keys
1732
+ max_key_length = max(len(key) for key in data['testcase1'])
1733
+
1734
+ # Print header
1735
+ print(f"{' ' * (max_key_length + 4)}", end='')
1736
+ for key in data:
1737
+ print(f'{key:>{max_key_length}}', end='')
1738
+ print()
1739
+
1740
+ # Print dictionary content
1741
+ for subkey in data['testcase1']:
1742
+ print(f'{subkey:<{max_key_length + 4}}', end='')
1743
+ for key in data:
1744
+ print(f'{data[key][subkey]:>{max_key_length}}', end='')
1745
+ print()
1746
+
1747
+
1748
+ def create_dict_from_folders(directory):
1749
+ body = {}
1750
+ for folder_name in os.listdir(directory):
1751
+ folder_path = os.path.join(directory, folder_name)
1752
+ if os.path.isdir(folder_path):
1753
+ body[folder_name] = {}
1754
+ return body
1755
+
1756
+
1757
+ # The function is from https://github.com/intsig-textin/markdown_tester
1758
+ def markdown_to_html(markdown_table):
1759
+ rows = [row.strip() for row in markdown_table.strip().split('\n')]
1760
+ num_columns = len(rows[0].split('|')) - 2
1761
+
1762
+ html_table = '<table>\n <thead>\n <tr>\n'
1763
+
1764
+ header_cells = [cell.strip() for cell in rows[0].split('|')[1:-1]]
1765
+ for cell in header_cells:
1766
+ html_table += f' <th>{cell}</th>\n'
1767
+ html_table += ' </tr>\n </thead>\n <tbody>\n'
1768
+
1769
+ for row in rows[2:]:
1770
+ cells = [cell.strip() for cell in row.split('|')[1:-1]]
1771
+ html_table += ' <tr>\n'
1772
+ for cell in cells:
1773
+ html_table += f' <td>{cell}</td>\n'
1774
+ html_table += ' </tr>\n'
1775
+
1776
+ html_table += ' </tbody>\n</table>\n'
1777
+ return html_table
1778
+
1779
+
1780
+ def convert_markdown_to_html(markdown_content, md_type):
1781
+ # Define a regex pattern to find Markdown tables with newlines
1782
+ markdown_content = markdown_content.replace('\r', '')
1783
+ pattern = re.compile(r'\|\s*.*?\s*\|\n', re.DOTALL)
1784
+
1785
+ # Find all matches in the Markdown content
1786
+ matches = pattern.findall(markdown_content)
1787
+ for match in matches:
1788
+ html_table = markdown_to_html(match)
1789
+ markdown_content = markdown_content.replace(match, html_table, 1) # Only replace the first occurrence
1790
+ res_html = convert_table(replace_table_with_placeholder(markdown_content))
1791
+
1792
+ return res_html
1793
+
1794
+
1795
+ def convert_table_str(s):
1796
+ s = re.sub(r'<table.*?>', '<table>', s)
1797
+ s = re.sub(r'<th', '<td', s)
1798
+ s = re.sub(r'</th>', '</td>', s)
1799
+ # s = re.sub(r'<td rowspan="(.)">',lambda x:f'<td colspan="1" rowspan="{x.group(1)}">',s)
1800
+ # s = re.sub(r'<td colspan="(.)">',lambda x:f'<td colspan="{x.group(1)}" rowspan="1">',s)
1801
+ res = ''
1802
+ res += '\n\n'
1803
+ temp_item = ''
1804
+ for c in s:
1805
+ temp_item += c
1806
+ if c == '>' and not re.search(r'<td.*?>\$', temp_item):
1807
+ res += temp_item + '\n'
1808
+ temp_item = ''
1809
+ return res + '\n'
1810
+
1811
+
1812
+ def merge_table(md):
1813
+ table_temp = ''
1814
+ for line in md:
1815
+ table_temp += line
1816
+ return convert_table_str(table_temp)
1817
+
1818
+
1819
+ def find_md_table_mode(line):
1820
+ if re.search(r'-*?:', line) or re.search(r'---', line) or re.search(r':-*?', line):
1821
+ return True
1822
+ return False
1823
+
1824
+
1825
+ def delete_table_and_body(input_list):
1826
+ res = []
1827
+ for line in input_list:
1828
+ if not re.search(r'</?t(able|head|body)>', line):
1829
+ res.append(line)
1830
+ return res
1831
+
1832
+
1833
+ def merge_tables(input_str):
1834
+ # Delete HTML comments
1835
+ input_str = re.sub(r'<!--[\s\S]*?-->', '', input_str)
1836
+
1837
+ # Use regex to find each <table> block
1838
+ table_blocks = re.findall(r'<table>[\s\S]*?</table>', input_str)
1839
+
1840
+ # Process each <table> block, replace <th> with <td>
1841
+ output_lines = []
1842
+ for block in table_blocks:
1843
+ block_lines = block.split('\n')
1844
+ for i, line in enumerate(block_lines):
1845
+ if '<th>' in line:
1846
+ block_lines[i] = line.replace('<th>', '<td>').replace('</th>', '</td>')
1847
+ final_tr = delete_table_and_body(block_lines)
1848
+ if len(final_tr) > 2:
1849
+ output_lines.extend(final_tr) # Ignore <table> and </table> tags, keep only table content
1850
+
1851
+ # Rejoin the processed strings
1852
+ merged_output = '<table>\n{}\n</table>'.format('\n'.join(output_lines))
1853
+
1854
+ return '\n\n' + merged_output + '\n\n'
1855
+
1856
+
1857
+ def replace_table_with_placeholder(input_string):
1858
+ lines = input_string.split('\n')
1859
+ output_lines = []
1860
+
1861
+ in_table_block = False
1862
+ temp_block = ''
1863
+ last_line = ''
1864
+
1865
+ org_table_list = []
1866
+ in_org_table = False
1867
+
1868
+ for idx, line in enumerate(lines):
1869
+ # if not in_org_table:
1870
+ # if "<table>" not in last_line and in_table_block == False and temp_block != "":
1871
+ # output_lines.append(merge_tables(temp_block))
1872
+ # temp_block = ""
1873
+ if '<table>' in line:
1874
+ # if "<table><tr" in line:
1875
+ # org_table_list.append(line)
1876
+ # in_org_table = True
1877
+ # output_lines.append(last_line)
1878
+ # continue
1879
+ # else:
1880
+ in_table_block = True
1881
+ temp_block += last_line
1882
+ elif in_table_block:
1883
+ if not find_md_table_mode(last_line) and '</thead>' not in last_line:
1884
+ temp_block += '\n' + last_line
1885
+ if '</table>' in last_line:
1886
+ if '<table>' not in line:
1887
+ in_table_block = False
1888
+ output_lines.append(merge_tables(temp_block))
1889
+ temp_block = ''
1890
+ else:
1891
+ output_lines.append(last_line)
1892
+
1893
+ last_line = line
1894
+ # else:
1895
+ # org_table_list.append(line)
1896
+ # if "</table" in line:
1897
+ # in_org_table = False
1898
+ # last_line = merge_table(org_table_list)
1899
+ # org_table_list = []
1900
+
1901
+ if last_line:
1902
+ if in_table_block or '</table>' in last_line:
1903
+ temp_block += '\n' + last_line
1904
+ output_lines.append(merge_tables(temp_block))
1905
+ else:
1906
+ output_lines.append(last_line)
1907
+ # if "</table>" in last_line:
1908
+ # output_lines.append(merge_tables(temp_block))
1909
+
1910
+ return '\n'.join(output_lines)
1911
+
1912
+
1913
+ def convert_table(input_str):
1914
+ # Replace <table>
1915
+ output_str = input_str.replace('<table>', "<table border=\"1\" >")
1916
+
1917
+ # Replace <td>
1918
+ output_str = output_str.replace('<td>', "<td colspan=\"1\" rowspan=\"1\">")
1919
+
1920
+ return output_str
1921
+
1922
+
1923
+ def convert_markdown_to_html(markdown_content):
1924
+ # Define a regex pattern to find Markdown tables with newlines
1925
+ markdown_content = markdown_content.replace('\r', '') + '\n'
1926
+ pattern = re.compile(r'\|\s*.*?\s*\|\n', re.DOTALL)
1927
+
1928
+ # Find all matches in the Markdown content
1929
+ matches = pattern.findall(markdown_content)
1930
+
1931
+ for match in matches:
1932
+ html_table = markdown_to_html(match)
1933
+ markdown_content = markdown_content.replace(match, html_table, 1) # Only replace the first occurrence
1934
+
1935
+ res_html = convert_table(replace_table_with_placeholder(markdown_content))
1936
+
1937
+ return res_html