evalscope 1.0.2__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/api/benchmark/__init__.py +8 -1
- evalscope/api/benchmark/adapters/__init__.py +1 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +12 -0
- evalscope/api/benchmark/adapters/ner_adapter.py +212 -0
- evalscope/api/benchmark/benchmark.py +14 -0
- evalscope/api/dataset/dataset.py +21 -0
- evalscope/api/dataset/loader.py +6 -2
- evalscope/api/mixin/sandbox_mixin.py +32 -54
- evalscope/api/model/generate_config.py +6 -0
- evalscope/app/ui/multi_model.py +6 -1
- evalscope/app/ui/single_model.py +8 -2
- evalscope/app/utils/data_utils.py +3 -2
- evalscope/app/utils/visualization.py +2 -2
- evalscope/benchmarks/aa_lcr/aa_lcr_adapter.py +205 -0
- evalscope/benchmarks/ai2d/ai2d_adapter.py +3 -2
- evalscope/benchmarks/bfcl/bfcl_adapter.py +11 -46
- evalscope/benchmarks/blink/__init__.py +0 -0
- evalscope/benchmarks/blink/blink_adapter.py +61 -0
- evalscope/benchmarks/chartqa/__init__.py +0 -0
- evalscope/benchmarks/chartqa/chartqa_adapter.py +80 -0
- evalscope/benchmarks/chartqa/utils.py +38 -0
- evalscope/benchmarks/data_collection/data_collection_adapter.py +2 -1
- evalscope/benchmarks/docvqa/__init__.py +0 -0
- evalscope/benchmarks/docvqa/docvqa_adapter.py +67 -0
- evalscope/benchmarks/general_arena/general_arena_adapter.py +1 -1
- evalscope/benchmarks/general_arena/utils.py +2 -1
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +1 -1
- evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +23 -4
- evalscope/benchmarks/hallusion_bench/__init__.py +0 -0
- evalscope/benchmarks/hallusion_bench/hallusion_bench_adapter.py +158 -0
- evalscope/benchmarks/hle/hle_adapter.py +3 -2
- evalscope/benchmarks/humaneval/humaneval_adapter.py +2 -1
- evalscope/benchmarks/infovqa/__init__.py +0 -0
- evalscope/benchmarks/infovqa/infovqa_adapter.py +66 -0
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +3 -1
- evalscope/benchmarks/math_verse/__init__.py +0 -0
- evalscope/benchmarks/math_verse/math_verse_adapter.py +100 -0
- evalscope/benchmarks/math_vision/__init__.py +0 -0
- evalscope/benchmarks/math_vision/math_vision_adapter.py +111 -0
- evalscope/benchmarks/math_vista/math_vista_adapter.py +6 -26
- evalscope/benchmarks/mm_bench/mm_bench_adapter.py +2 -2
- evalscope/benchmarks/mmmu/mmmu_adapter.py +1 -1
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +1 -1
- evalscope/benchmarks/ner/__init__.py +0 -0
- evalscope/benchmarks/ner/broad_twitter_corpus_adapter.py +52 -0
- evalscope/benchmarks/ner/conll2003_adapter.py +48 -0
- evalscope/benchmarks/ner/copious_adapter.py +85 -0
- evalscope/benchmarks/ner/cross_ner_adapter.py +120 -0
- evalscope/benchmarks/ner/cross_ner_entities/__init__.py +0 -0
- evalscope/benchmarks/ner/cross_ner_entities/ai.py +54 -0
- evalscope/benchmarks/ner/cross_ner_entities/literature.py +36 -0
- evalscope/benchmarks/ner/cross_ner_entities/music.py +39 -0
- evalscope/benchmarks/ner/cross_ner_entities/politics.py +37 -0
- evalscope/benchmarks/ner/cross_ner_entities/science.py +58 -0
- evalscope/benchmarks/ner/genia_ner_adapter.py +66 -0
- evalscope/benchmarks/ner/harvey_ner_adapter.py +58 -0
- evalscope/benchmarks/ner/mit_movie_trivia_adapter.py +74 -0
- evalscope/benchmarks/ner/mit_restaurant_adapter.py +66 -0
- evalscope/benchmarks/ner/ontonotes5_adapter.py +87 -0
- evalscope/benchmarks/ner/wnut2017_adapter.py +61 -0
- evalscope/benchmarks/ocr_bench/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_adapter.py +101 -0
- evalscope/benchmarks/ocr_bench_v2/IoUscore_metric.py +87 -0
- evalscope/benchmarks/ocr_bench_v2/TEDS_metric.py +963 -0
- evalscope/benchmarks/ocr_bench_v2/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench_v2/ocr_bench_v2_adapter.py +161 -0
- evalscope/benchmarks/ocr_bench_v2/page_ocr_metric.py +50 -0
- evalscope/benchmarks/ocr_bench_v2/parallel.py +46 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/readme.txt +26 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/rrc_evaluation_funcs_1_1.py +537 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/script.py +481 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_metric.py +179 -0
- evalscope/benchmarks/ocr_bench_v2/utils.py +433 -0
- evalscope/benchmarks/ocr_bench_v2/vqa_metric.py +254 -0
- evalscope/benchmarks/omnidoc_bench/__init__.py +0 -0
- evalscope/benchmarks/omnidoc_bench/end2end_eval.py +349 -0
- evalscope/benchmarks/omnidoc_bench/metrics.py +547 -0
- evalscope/benchmarks/omnidoc_bench/omnidoc_bench_adapter.py +135 -0
- evalscope/benchmarks/omnidoc_bench/utils.py +1937 -0
- evalscope/benchmarks/poly_math/__init__.py +0 -0
- evalscope/benchmarks/poly_math/poly_math_adapter.py +127 -0
- evalscope/benchmarks/poly_math/utils/instruction.py +105 -0
- evalscope/benchmarks/pope/__init__.py +0 -0
- evalscope/benchmarks/pope/pope_adapter.py +111 -0
- evalscope/benchmarks/seed_bench_2_plus/__init__.py +0 -0
- evalscope/benchmarks/seed_bench_2_plus/seed_bench_2_plus_adapter.py +72 -0
- evalscope/benchmarks/simple_vqa/__init__.py +0 -0
- evalscope/benchmarks/simple_vqa/simple_vqa_adapter.py +169 -0
- evalscope/benchmarks/tau_bench/tau_bench_adapter.py +1 -1
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +1 -1
- evalscope/benchmarks/visu_logic/__init__.py +0 -0
- evalscope/benchmarks/visu_logic/visu_logic_adapter.py +75 -0
- evalscope/benchmarks/zerobench/__init__.py +0 -0
- evalscope/benchmarks/zerobench/zerobench_adapter.py +64 -0
- evalscope/constants.py +4 -0
- evalscope/evaluator/evaluator.py +72 -79
- evalscope/metrics/math_parser.py +14 -0
- evalscope/metrics/metric.py +52 -1
- evalscope/metrics/metrics.py +16 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +2 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +2 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +2 -6
- evalscope/models/utils/openai.py +4 -0
- evalscope/perf/arguments.py +24 -4
- evalscope/perf/benchmark.py +74 -89
- evalscope/perf/http_client.py +31 -16
- evalscope/perf/main.py +15 -2
- evalscope/perf/plugin/api/base.py +9 -7
- evalscope/perf/plugin/api/custom_api.py +13 -58
- evalscope/perf/plugin/api/default_api.py +179 -79
- evalscope/perf/plugin/api/openai_api.py +4 -3
- evalscope/perf/plugin/datasets/base.py +21 -0
- evalscope/perf/plugin/datasets/custom.py +2 -3
- evalscope/perf/plugin/datasets/line_by_line.py +2 -3
- evalscope/perf/plugin/datasets/longalpaca.py +2 -3
- evalscope/perf/plugin/datasets/openqa.py +2 -4
- evalscope/perf/plugin/datasets/random_dataset.py +1 -3
- evalscope/perf/utils/benchmark_util.py +36 -22
- evalscope/perf/utils/db_util.py +14 -19
- evalscope/perf/utils/local_server.py +0 -44
- evalscope/perf/utils/log_utils.py +21 -6
- evalscope/report/__init__.py +11 -2
- evalscope/report/combinator.py +52 -2
- evalscope/run.py +4 -0
- evalscope/utils/function_utils.py +195 -12
- evalscope/utils/io_utils.py +74 -0
- evalscope/utils/json_schema.py +8 -6
- evalscope/utils/logger.py +49 -17
- evalscope/utils/multi_choices.py +16 -1
- evalscope/utils/ner.py +377 -0
- evalscope/version.py +2 -2
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/METADATA +239 -393
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/RECORD +140 -98
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/WHEEL +1 -1
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/top_level.txt +0 -1
- tests/__init__.py +0 -1
- tests/benchmark/__init__.py +0 -1
- tests/benchmark/test_eval.py +0 -429
- tests/benchmark/test_image_edit.py +0 -65
- tests/benchmark/test_sandbox.py +0 -81
- tests/benchmark/test_t2i.py +0 -142
- tests/benchmark/test_vlm.py +0 -137
- tests/cli/__init__.py +0 -1
- tests/cli/test_all.py +0 -269
- tests/cli/test_collection.py +0 -99
- tests/cli/test_custom.py +0 -268
- tests/cli/test_reasoning.py +0 -81
- tests/common.py +0 -73
- tests/perf/__init__.py +0 -1
- tests/perf/test_perf.py +0 -206
- tests/rag/test_clip_benchmark.py +0 -87
- tests/rag/test_mteb.py +0 -213
- tests/rag/test_ragas.py +0 -128
- tests/swift/__init__.py +0 -1
- tests/swift/test_run_swift_eval.py +0 -146
- tests/swift/test_run_swift_vlm_eval.py +0 -128
- tests/swift/test_run_swift_vlm_jugde_eval.py +0 -157
- tests/test_run_all.py +0 -12
- tests/utils.py +0 -13
- tests/vlm/__init__.py +0 -1
- tests/vlm/test_vlmeval.py +0 -102
- {tests/rag → evalscope/benchmarks/aa_lcr}/__init__.py +0 -0
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/entry_points.txt +0 -0
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,1937 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
import copy
|
|
3
|
+
import html
|
|
4
|
+
import json
|
|
5
|
+
import Levenshtein
|
|
6
|
+
import numpy as np
|
|
7
|
+
import os
|
|
8
|
+
import re
|
|
9
|
+
import shutil
|
|
10
|
+
import subprocess
|
|
11
|
+
import unicodedata
|
|
12
|
+
import uuid
|
|
13
|
+
from bs4 import BeautifulSoup
|
|
14
|
+
from collections import defaultdict
|
|
15
|
+
from pylatexenc.latex2text import LatexNodes2Text
|
|
16
|
+
from pylatexenc.latexwalker import (
|
|
17
|
+
LatexCharsNode,
|
|
18
|
+
LatexEnvironmentNode,
|
|
19
|
+
LatexGroupNode,
|
|
20
|
+
LatexMacroNode,
|
|
21
|
+
LatexSpecialsNode,
|
|
22
|
+
LatexWalker,
|
|
23
|
+
)
|
|
24
|
+
from scipy.optimize import linear_sum_assignment
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def read_md_file(filepath):
|
|
28
|
+
with open(filepath, 'r', encoding='utf-8') as file:
|
|
29
|
+
content = file.read()
|
|
30
|
+
|
|
31
|
+
return content
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def save_paired_result(preds, gts, save_path):
|
|
35
|
+
save_result = []
|
|
36
|
+
formula_id = 0
|
|
37
|
+
for gt, pred in zip(gts, preds):
|
|
38
|
+
save_result.append({'gt': gt, 'pred': pred, 'img_id': formula_id})
|
|
39
|
+
formula_id += 1
|
|
40
|
+
with open(save_path, 'w', encoding='utf-8') as f:
|
|
41
|
+
json.dump(save_result, f, indent=4, ensure_ascii=False)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def remove_markdown_fences(content):
|
|
45
|
+
content = re.sub(r'^```markdown\n?', '', content, flags=re.MULTILINE)
|
|
46
|
+
content = re.sub(r'```\n?$', '', content, flags=re.MULTILINE)
|
|
47
|
+
return content
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# Standardize all consecutive characters
|
|
51
|
+
def replace_repeated_chars(input_str):
|
|
52
|
+
input_str = re.sub(r'_{4,}', '____', input_str) # Replace more than 4 consecutive underscores with 4 underscores
|
|
53
|
+
input_str = re.sub(r' {4,}', ' ', input_str) # Replace more than 4 consecutive spaces with 4 spaces
|
|
54
|
+
return re.sub(
|
|
55
|
+
r'([^a-zA-Z0-9])\1{10,}', r'\1\1\1\1', input_str
|
|
56
|
+
) # For other consecutive symbols (except numbers and letters), replace more than 10 occurrences with 4
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Special Unicode handling
|
|
60
|
+
def fullwidth_to_halfwidth(s):
|
|
61
|
+
result = []
|
|
62
|
+
for char in s:
|
|
63
|
+
code = ord(char)
|
|
64
|
+
# Convert full-width space to half-width space
|
|
65
|
+
if code == 0x3000:
|
|
66
|
+
code = 0x0020
|
|
67
|
+
# Convert other full-width characters to half-width
|
|
68
|
+
elif 0xFF01 <= code <= 0xFF5E:
|
|
69
|
+
code -= 0xFEE0
|
|
70
|
+
result.append(chr(code))
|
|
71
|
+
return ''.join(result)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def find_special_unicode(s):
|
|
75
|
+
special_chars = {}
|
|
76
|
+
for char in s:
|
|
77
|
+
if ord(char) > 127: # Non-ASCII characters
|
|
78
|
+
# unicode_name = unicodedata.name(char, None)
|
|
79
|
+
unicode_name = unicodedata.category(char)
|
|
80
|
+
special_chars[char] = f'U+{ord(char):04X} ({unicode_name})'
|
|
81
|
+
return special_chars
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
inline_reg = re.compile(r'\$(.*?)\$|'
|
|
85
|
+
r'\\\((.*?)\\\)', )
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def textblock2unicode(text):
|
|
89
|
+
inline_matches = inline_reg.finditer(text)
|
|
90
|
+
removal_positions = []
|
|
91
|
+
for match in inline_matches:
|
|
92
|
+
position = [match.start(), match.end()]
|
|
93
|
+
content = match.group(1) if match.group(1) is not None else match.group(2)
|
|
94
|
+
# print('-------- content-------', content)
|
|
95
|
+
# Remove escape characters \
|
|
96
|
+
clean_content = re.sub(r'\\([\\_&%^])', '', content)
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
if any(char in clean_content for char in r'\^_'):
|
|
100
|
+
if clean_content.endswith('\\'):
|
|
101
|
+
clean_content += ' '
|
|
102
|
+
# inline_array.append(match.group(0))
|
|
103
|
+
unicode_content = LatexNodes2Text().latex_to_text(clean_content)
|
|
104
|
+
removal_positions.append((position[0], position[1], unicode_content))
|
|
105
|
+
except:
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
# Remove inline formulas from original text
|
|
109
|
+
for start, end, unicode_content in sorted(removal_positions, reverse=True):
|
|
110
|
+
text = text[:start] + unicode_content.strip() + text[end:]
|
|
111
|
+
|
|
112
|
+
return text
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def normalized_formula(text):
|
|
116
|
+
# Normalize math formulas before matching
|
|
117
|
+
filter_list = [
|
|
118
|
+
'\\mathbf', '\\mathrm', '\\mathnormal', '\\mathit', '\\mathbb', '\\mathcal', '\\mathscr', '\\mathfrak',
|
|
119
|
+
'\\mathsf', '\\mathtt', '\\textbf', '\\text', '\\boldmath', '\\boldsymbol', '\\operatorname', '\\bm',
|
|
120
|
+
'\\symbfit', '\\mathbfcal', '\\symbf', '\\scriptscriptstyle', '\\notag', '\\setlength', '\\coloneqq', '\\space',
|
|
121
|
+
'\\thickspace', '\\thinspace', '\\medspace', '\\nobreakspace', '\\negmedspace', '\\quad', '\\qquad',
|
|
122
|
+
'\\enspace', '\\substackw', ' '
|
|
123
|
+
]
|
|
124
|
+
# '\\left', '\\right', '{', '}', ' ']
|
|
125
|
+
|
|
126
|
+
# delimiter_filter
|
|
127
|
+
pattern = re.compile(r'\\\[(.+?)(?<!\\)\\\]')
|
|
128
|
+
match = pattern.search(text)
|
|
129
|
+
|
|
130
|
+
if match:
|
|
131
|
+
text = match.group(1).strip()
|
|
132
|
+
|
|
133
|
+
tag_pattern = re.compile(r'\\tag\{.*?\}')
|
|
134
|
+
text = tag_pattern.sub('', text)
|
|
135
|
+
hspace_pattern = re.compile(r'\\hspace\{.*?\}')
|
|
136
|
+
text = hspace_pattern.sub('', text)
|
|
137
|
+
begin_pattern = re.compile(r'\\begin\{.*?\}')
|
|
138
|
+
text = begin_pattern.sub('', text)
|
|
139
|
+
end_pattern = re.compile(r'\\end\{.*?\}')
|
|
140
|
+
text = end_pattern.sub('', text)
|
|
141
|
+
col_sep = re.compile(r'\\arraycolsep.*?\}')
|
|
142
|
+
text = col_sep.sub('', text)
|
|
143
|
+
text = text.strip('.')
|
|
144
|
+
|
|
145
|
+
for filter_text in filter_list:
|
|
146
|
+
text = text.replace(filter_text, '')
|
|
147
|
+
|
|
148
|
+
# text = normalize_text(delimiter_filter(text))
|
|
149
|
+
# text = delimiter_filter(text)
|
|
150
|
+
text = text.lower()
|
|
151
|
+
return text
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def normalized_html_table(text):
|
|
155
|
+
|
|
156
|
+
def process_table_html(md_i):
|
|
157
|
+
"""
|
|
158
|
+
pred_md format edit
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
def process_table_html(html_content):
|
|
162
|
+
soup = BeautifulSoup(html_content, 'html.parser')
|
|
163
|
+
th_tags = soup.find_all('th')
|
|
164
|
+
for th in th_tags:
|
|
165
|
+
th.name = 'td'
|
|
166
|
+
thead_tags = soup.find_all('thead')
|
|
167
|
+
for thead in thead_tags:
|
|
168
|
+
thead.unwrap() # unwrap()会移除标签但保留其内容
|
|
169
|
+
math_tags = soup.find_all('math')
|
|
170
|
+
for math_tag in math_tags:
|
|
171
|
+
alttext = math_tag.get('alttext', '')
|
|
172
|
+
alttext = f'${alttext}$'
|
|
173
|
+
if alttext:
|
|
174
|
+
math_tag.replace_with(alttext)
|
|
175
|
+
span_tags = soup.find_all('span')
|
|
176
|
+
for span in span_tags:
|
|
177
|
+
span.unwrap()
|
|
178
|
+
return str(soup)
|
|
179
|
+
|
|
180
|
+
table_res = ''
|
|
181
|
+
table_res_no_space = ''
|
|
182
|
+
if '<table' in md_i.replace(' ', '').replace("'", '"'):
|
|
183
|
+
md_i = process_table_html(md_i)
|
|
184
|
+
table_res = html.unescape(md_i).replace('\n', '')
|
|
185
|
+
table_res = unicodedata.normalize('NFKC', table_res).strip()
|
|
186
|
+
pattern = r'<table\b[^>]*>(.*)</table>'
|
|
187
|
+
tables = re.findall(pattern, table_res, re.DOTALL | re.IGNORECASE)
|
|
188
|
+
table_res = ''.join(tables)
|
|
189
|
+
# table_res = re.sub('<table.*?>','',table_res)
|
|
190
|
+
table_res = re.sub('( style=".*?")', '', table_res)
|
|
191
|
+
table_res = re.sub('( height=".*?")', '', table_res)
|
|
192
|
+
table_res = re.sub('( width=".*?")', '', table_res)
|
|
193
|
+
table_res = re.sub('( align=".*?")', '', table_res)
|
|
194
|
+
table_res = re.sub('( class=".*?")', '', table_res)
|
|
195
|
+
table_res = re.sub('</?tbody>', '', table_res)
|
|
196
|
+
|
|
197
|
+
table_res = re.sub(r'\s+', ' ', table_res)
|
|
198
|
+
table_res_no_space = '<html><body><table border="1" >' + table_res.replace(
|
|
199
|
+
' ', ''
|
|
200
|
+
) + '</table></body></html>'
|
|
201
|
+
# table_res_no_space = re.sub(' (style=".*?")',"",table_res_no_space)
|
|
202
|
+
# table_res_no_space = re.sub(r'[ ]', " ", table_res_no_space)
|
|
203
|
+
table_res_no_space = re.sub('colspan="', ' colspan="', table_res_no_space)
|
|
204
|
+
table_res_no_space = re.sub('rowspan="', ' rowspan="', table_res_no_space)
|
|
205
|
+
table_res_no_space = re.sub('border="', ' border="', table_res_no_space)
|
|
206
|
+
|
|
207
|
+
table_res = '<html><body><table border="1" >' + table_res + '</table></body></html>'
|
|
208
|
+
# table_flow.append(table_res)
|
|
209
|
+
# table_flow_no_space.append(table_res_no_space)
|
|
210
|
+
|
|
211
|
+
return table_res, table_res_no_space
|
|
212
|
+
|
|
213
|
+
def clean_table(input_str, flag=True):
|
|
214
|
+
if flag:
|
|
215
|
+
input_str = input_str.replace('<sup>', '').replace('</sup>', '')
|
|
216
|
+
input_str = input_str.replace('<sub>', '').replace('</sub>', '')
|
|
217
|
+
input_str = input_str.replace('<span>', '').replace('</span>', '')
|
|
218
|
+
input_str = input_str.replace('<div>', '').replace('</div>', '')
|
|
219
|
+
input_str = input_str.replace('<p>', '').replace('</p>', '')
|
|
220
|
+
input_str = input_str.replace('<spandata-span-identity="">', '')
|
|
221
|
+
input_str = re.sub('<colgroup>.*?</colgroup>', '', input_str)
|
|
222
|
+
return input_str
|
|
223
|
+
|
|
224
|
+
norm_text, _ = process_table_html(text)
|
|
225
|
+
norm_text = clean_table(norm_text)
|
|
226
|
+
return norm_text
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def normalized_latex_table(text):
|
|
230
|
+
|
|
231
|
+
def latex_template(latex_code):
|
|
232
|
+
template = r'''
|
|
233
|
+
\documentclass[border=20pt]{article}
|
|
234
|
+
\usepackage{subcaption}
|
|
235
|
+
\usepackage{url}
|
|
236
|
+
\usepackage{graphicx}
|
|
237
|
+
\usepackage{caption}
|
|
238
|
+
\usepackage{multirow}
|
|
239
|
+
\usepackage{booktabs}
|
|
240
|
+
\usepackage{color}
|
|
241
|
+
\usepackage{colortbl}
|
|
242
|
+
\usepackage{xcolor,soul,framed}
|
|
243
|
+
\usepackage{fontspec}
|
|
244
|
+
\usepackage{amsmath,amssymb,mathtools,bm,mathrsfs,textcomp}
|
|
245
|
+
\setlength{\parindent}{0pt}''' + \
|
|
246
|
+
r'''
|
|
247
|
+
\begin{document}
|
|
248
|
+
''' + \
|
|
249
|
+
latex_code + \
|
|
250
|
+
r'''
|
|
251
|
+
\end{document}'''
|
|
252
|
+
|
|
253
|
+
return template
|
|
254
|
+
|
|
255
|
+
def process_table_latex(latex_code):
|
|
256
|
+
SPECIAL_STRINGS = [['\\\\vspace\\{.*?\\}', ''], ['\\\\hspace\\{.*?\\}', ''], ['\\\\rule\{.*?\\}\\{.*?\\}', ''],
|
|
257
|
+
['\\\\addlinespace\\[.*?\\]', ''], ['\\\\addlinespace', ''],
|
|
258
|
+
['\\\\renewcommand\\{\\\\arraystretch\\}\\{.*?\\}', ''], ['\\\\arraystretch\\{.*?\\}', ''],
|
|
259
|
+
['\\\\(row|column)?colors?\\{[^}]*\\}(\\{[^}]*\\}){0,2}', ''], ['\\\\color\\{.*?\\}', ''],
|
|
260
|
+
['\\\\textcolor\\{.*?\\}', ''], ['\\\\rowcolor(\\[.*?\\])?\\{.*?\\}', ''],
|
|
261
|
+
['\\\\columncolor(\\[.*?\\])?\\{.*?\\}', ''], ['\\\\cellcolor(\\[.*?\\])?\\{.*?\\}', ''],
|
|
262
|
+
['\\\\colorbox\\{.*?\\}', ''],
|
|
263
|
+
['\\\\(tiny|scriptsize|footnotesize|small|normalsize|large|Large|LARGE|huge|Huge)', ''],
|
|
264
|
+
[r'\s+', ' '], ['\\\\centering', ''], ['\\\\begin\\{table\\}\\[.*?\\]', '\\\\begin{table}'],
|
|
265
|
+
['\t', ''], ['@{}', ''], ['\\\\toprule(\\[.*?\\])?', '\\\\hline'],
|
|
266
|
+
['\\\\bottomrule(\\[.*?\\])?', '\\\\hline'], ['\\\\midrule(\\[.*?\\])?', '\\\\hline'],
|
|
267
|
+
['p\\{[^}]*\\}', 'l'], ['m\\{[^}]*\\}',
|
|
268
|
+
'c'], ['\\\\scalebox\\{[^}]*\\}\\{([^}]*)\\}', '\\1'],
|
|
269
|
+
['\\\\textbf\\{([^}]*)\\}', '\\1'], ['\\\\textit\\{([^}]*)\\}', '\\1'],
|
|
270
|
+
['\\\\cmidrule(\\[.*?\\])?\\(.*?\\)\\{([0-9]-[0-9])\\}', '\\\\cline{\\2}'],
|
|
271
|
+
['\\\\hline', ''], [r'\\multicolumn\{1\}\{[^}]*\}\{((?:[^{}]|(?:\{[^{}]*\}))*)\}', r'\1']]
|
|
272
|
+
pattern = r'\\begin\{tabular\}.*\\end\{tabular\}' # 注意这里不用 .*?
|
|
273
|
+
matches = re.findall(pattern, latex_code, re.DOTALL)
|
|
274
|
+
latex_code = ' '.join(matches)
|
|
275
|
+
|
|
276
|
+
for special_str in SPECIAL_STRINGS:
|
|
277
|
+
latex_code = re.sub(fr'{special_str[0]}', fr'{special_str[1]}', latex_code)
|
|
278
|
+
|
|
279
|
+
return latex_code
|
|
280
|
+
|
|
281
|
+
def convert_latex_to_html(latex_content, cache_dir='./temp'):
|
|
282
|
+
if not os.path.exists(cache_dir):
|
|
283
|
+
os.makedirs(cache_dir)
|
|
284
|
+
|
|
285
|
+
uuid_str = str(uuid.uuid1())
|
|
286
|
+
with open(f'{cache_dir}/{uuid_str}.tex', 'w') as f:
|
|
287
|
+
f.write(latex_template(latex_content))
|
|
288
|
+
|
|
289
|
+
cmd = [
|
|
290
|
+
'latexmlc', '--quiet', '--nocomments', f'--log={cache_dir}/{uuid_str}.log', f'{cache_dir}/{uuid_str}.tex',
|
|
291
|
+
f'--dest={cache_dir}/{uuid_str}.html'
|
|
292
|
+
]
|
|
293
|
+
try:
|
|
294
|
+
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
295
|
+
with open(f'{cache_dir}/{uuid_str}.html', 'r') as f:
|
|
296
|
+
html_content = f.read()
|
|
297
|
+
|
|
298
|
+
pattern = r'<table\b[^>]*>(.*)</table>'
|
|
299
|
+
tables = re.findall(pattern, html_content, re.DOTALL | re.IGNORECASE)
|
|
300
|
+
tables = [f'<table>{table}</table>' for table in tables]
|
|
301
|
+
html_content = '\n'.join(tables)
|
|
302
|
+
|
|
303
|
+
except Exception as e:
|
|
304
|
+
html_content = ''
|
|
305
|
+
|
|
306
|
+
shutil.rmtree(cache_dir)
|
|
307
|
+
return html_content
|
|
308
|
+
|
|
309
|
+
html_text = convert_latex_to_html(text)
|
|
310
|
+
normlized_tables = normalized_html_table(html_text)
|
|
311
|
+
return normlized_tables
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def normalized_table(text, format='html'):
|
|
315
|
+
if format not in ['html', 'latex']:
|
|
316
|
+
raise ValueError('Invalid format: {}'.format(format))
|
|
317
|
+
else:
|
|
318
|
+
return globals()['normalized_{}_table'.format(format)](text)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def textblock_with_norm_formula(text):
|
|
322
|
+
inline_matches = inline_reg.finditer(text)
|
|
323
|
+
removal_positions = []
|
|
324
|
+
for match in inline_matches:
|
|
325
|
+
position = [match.start(), match.end()]
|
|
326
|
+
content = match.group(1) if match.group(1) is not None else match.group(2)
|
|
327
|
+
# print('-------- content-------', content)
|
|
328
|
+
|
|
329
|
+
norm_content = normalized_formula(content)
|
|
330
|
+
removal_positions.append((position[0], position[1], norm_content))
|
|
331
|
+
|
|
332
|
+
# Remove inline formulas from original text
|
|
333
|
+
for start, end, norm_content in sorted(removal_positions, reverse=True):
|
|
334
|
+
text = text[:start] + norm_content.strip() + text[end:]
|
|
335
|
+
|
|
336
|
+
return text
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def inline_filter_unicode(text):
|
|
340
|
+
# Ensure text is string type
|
|
341
|
+
if not isinstance(text, str):
|
|
342
|
+
text = str(text)
|
|
343
|
+
|
|
344
|
+
# Replace inline formula boundary markers
|
|
345
|
+
#print('--------text-------',text)
|
|
346
|
+
placeholder = '__INLINE_FORMULA_BOUNDARY__'
|
|
347
|
+
text_copy = text.replace('$', placeholder).replace('\\(', placeholder).replace('\\)', placeholder)
|
|
348
|
+
#print('--------text_copy-------',text_copy)
|
|
349
|
+
# Convert LaTeX content to Unicode representation
|
|
350
|
+
text_copy = LatexNodes2Text().latex_to_text(text_copy)
|
|
351
|
+
#print('--------text_copy---unicode----',text_copy)
|
|
352
|
+
# Restore boundary markers
|
|
353
|
+
text_copy = text_copy.replace(placeholder, '$')
|
|
354
|
+
|
|
355
|
+
inline_array = []
|
|
356
|
+
inline_matches = inline_reg.finditer(text_copy)
|
|
357
|
+
# Record positions of inline formulas to be removed
|
|
358
|
+
removal_positions = []
|
|
359
|
+
|
|
360
|
+
for match in inline_matches:
|
|
361
|
+
position = [match.start(), match.end()]
|
|
362
|
+
content = match.group(1) if match.group(1) is not None else match.group(2)
|
|
363
|
+
print('-------- content-------', content)
|
|
364
|
+
# Remove escape characters \
|
|
365
|
+
clean_content = re.sub(r'\\([\\_&%^])', '', content)
|
|
366
|
+
|
|
367
|
+
if any(char in clean_content for char in r'\^_'):
|
|
368
|
+
# inline_array.append(match.group(0))
|
|
369
|
+
inline_array.append({
|
|
370
|
+
'category_type': 'equation_inline',
|
|
371
|
+
'position': position,
|
|
372
|
+
'content': content,
|
|
373
|
+
})
|
|
374
|
+
removal_positions.append((position[0], position[1]))
|
|
375
|
+
|
|
376
|
+
# Remove inline formulas from original text
|
|
377
|
+
for start, end in sorted(removal_positions, reverse=True):
|
|
378
|
+
text = text[:start] + text[end:]
|
|
379
|
+
|
|
380
|
+
return text, inline_array
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def inline_filter(text):
|
|
384
|
+
# Ensure text is string type
|
|
385
|
+
if not isinstance(text, str):
|
|
386
|
+
text = str(text)
|
|
387
|
+
|
|
388
|
+
inline_array = []
|
|
389
|
+
inline_matches = inline_reg.finditer(text)
|
|
390
|
+
|
|
391
|
+
for match in inline_matches:
|
|
392
|
+
position = [match.start(), match.end()]
|
|
393
|
+
content = match.group(1) if match.group(1) is not None else match.group(2)
|
|
394
|
+
# print('inline_content: ', content)
|
|
395
|
+
|
|
396
|
+
# Remove escape characters \
|
|
397
|
+
clean_content = re.sub(r'\\([\\_&%^])', '', content)
|
|
398
|
+
|
|
399
|
+
if any(char in clean_content for char in r'\^_'):
|
|
400
|
+
# inline_array.append(match.group(0))
|
|
401
|
+
inline_array.append({
|
|
402
|
+
'category_type': 'equation_inline',
|
|
403
|
+
'position': position,
|
|
404
|
+
'content': match.group(0),
|
|
405
|
+
})
|
|
406
|
+
text = text.replace(match.group(0), '')
|
|
407
|
+
# print('-----Found inline formula: ', match.group(0))
|
|
408
|
+
else:
|
|
409
|
+
text = text.replace(match.group(0), content)
|
|
410
|
+
|
|
411
|
+
return text, inline_array
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
# Text OCR quality check processing:
|
|
415
|
+
def clean_string(input_string):
|
|
416
|
+
# Use regex to keep Chinese characters, English letters and numbers
|
|
417
|
+
input_string = input_string.replace('\\t', '').replace('\\n',
|
|
418
|
+
'').replace('\t',
|
|
419
|
+
'').replace('\n',
|
|
420
|
+
'').replace('/t',
|
|
421
|
+
'').replace('/n', '')
|
|
422
|
+
cleaned_string = re.sub(r'[^\w\u4e00-\u9fff]', '', input_string)
|
|
423
|
+
return cleaned_string
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def extract_tabular(text):
|
|
427
|
+
begin_pattern = r'\\begin{tabular}'
|
|
428
|
+
end_pattern = r'\\end{tabular}'
|
|
429
|
+
|
|
430
|
+
tabulars = []
|
|
431
|
+
positions = []
|
|
432
|
+
current_pos = 0
|
|
433
|
+
stack = []
|
|
434
|
+
|
|
435
|
+
while current_pos < len(text):
|
|
436
|
+
begin_match = re.search(begin_pattern, text[current_pos:])
|
|
437
|
+
end_match = re.search(end_pattern, text[current_pos:])
|
|
438
|
+
|
|
439
|
+
if not begin_match and not end_match:
|
|
440
|
+
break
|
|
441
|
+
|
|
442
|
+
if begin_match and (not end_match or begin_match.start() < end_match.start()):
|
|
443
|
+
stack.append(current_pos + begin_match.start())
|
|
444
|
+
current_pos += begin_match.start() + len(end_pattern)
|
|
445
|
+
elif end_match:
|
|
446
|
+
if stack:
|
|
447
|
+
start_pos = stack.pop()
|
|
448
|
+
if not stack:
|
|
449
|
+
end_pos = current_pos + end_match.start() + len(end_pattern)
|
|
450
|
+
tabular_code = text[start_pos:end_pos]
|
|
451
|
+
tabulars.append(tabular_code)
|
|
452
|
+
positions.append((start_pos, end_pos))
|
|
453
|
+
current_pos += end_match.start() + len(end_pattern)
|
|
454
|
+
else:
|
|
455
|
+
current_pos += 1
|
|
456
|
+
|
|
457
|
+
if stack:
|
|
458
|
+
new_start = stack[0] + len(begin_pattern)
|
|
459
|
+
new_tabulars, new_positions = extract_tabular(text[new_start:])
|
|
460
|
+
new_positions = [(start + new_start, end + new_start) for start, end in new_positions]
|
|
461
|
+
tabulars.extend(new_tabulars)
|
|
462
|
+
positions.extend(new_positions)
|
|
463
|
+
|
|
464
|
+
return tabulars, positions
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
# math reg
|
|
468
|
+
# r'\\begin{equation\*?}(.*?)\\end{equation\*?}|'
|
|
469
|
+
# r'\\begin{align\*?}(.*?)\\end{align\*?}|'
|
|
470
|
+
# r'\\begin{gather\*?}(.*?)\\end{gather\*?}|'
|
|
471
|
+
display_reg = re.compile(r'\$\$(.*?)\$\$|'
|
|
472
|
+
r'\\\[(.*?)\\\]|'
|
|
473
|
+
r'\$(.*?)\$|'
|
|
474
|
+
r'\\\((.*?)\\\)', re.DOTALL)
|
|
475
|
+
|
|
476
|
+
# inline_reg = re.compile(
|
|
477
|
+
# r'(?<!\$)\$(?!\$)(.*?)(?<!\$)\$(?!\$)|'
|
|
478
|
+
# r'\\\((.*?)\\\)',
|
|
479
|
+
# )
|
|
480
|
+
inline_reg = re.compile(r'\$(.*?)\$|'
|
|
481
|
+
r'\\\((.*?)\\\)', )
|
|
482
|
+
|
|
483
|
+
# table
|
|
484
|
+
table_reg = re.compile(
|
|
485
|
+
r'\\begin{table\*?}(.*?)\\end{table\*?}|'
|
|
486
|
+
r'\\begin{tabular\*?}(.*?)\\end{tabular\*?}', re.DOTALL
|
|
487
|
+
)
|
|
488
|
+
md_table_reg = re.compile(r'\|\s*.*?\s*\|\n', re.DOTALL)
|
|
489
|
+
html_table_reg = re.compile(r'(<table.*?</table>)', re.DOTALL)
|
|
490
|
+
|
|
491
|
+
# title
|
|
492
|
+
title_reg = re.compile(r'^\s*#.*$', re.MULTILINE)
|
|
493
|
+
|
|
494
|
+
# img
|
|
495
|
+
img_pattern = r'!\[.*?\]\(.*?\)'
|
|
496
|
+
|
|
497
|
+
# code block
|
|
498
|
+
code_block_reg = re.compile(r'```(\w+)\n(.*?)```', re.DOTALL)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def md_tex_filter(content):
|
|
502
|
+
'''
|
|
503
|
+
Input: 1 page md or tex content - String
|
|
504
|
+
Output: text, display, inline, table, title, code - list
|
|
505
|
+
'''
|
|
506
|
+
content = re.sub(img_pattern, '', content) # remove image
|
|
507
|
+
content = remove_markdown_fences(content) # remove markdown fences
|
|
508
|
+
content = replace_repeated_chars(content) # replace all consecutive characters
|
|
509
|
+
|
|
510
|
+
pred_all = []
|
|
511
|
+
latex_table_array, table_positions = extract_tex_table(content)
|
|
512
|
+
for latex_table, position in zip(latex_table_array, table_positions):
|
|
513
|
+
position = [position[0], position[0] + len(latex_table)] # !!!
|
|
514
|
+
pred_all.append({'category_type': 'latex_table', 'position': position, 'content': latex_table})
|
|
515
|
+
content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[
|
|
516
|
+
position[1]:] # replace latex table with space
|
|
517
|
+
|
|
518
|
+
# extract html table
|
|
519
|
+
html_table_array, table_positions = extract_html_table(content)
|
|
520
|
+
for html_table, position in zip(html_table_array, table_positions):
|
|
521
|
+
position = [position[0], position[0] + len(html_table)]
|
|
522
|
+
pred_all.append({'category_type': 'html_table', 'position': position, 'content': html_table})
|
|
523
|
+
content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[position[1]:
|
|
524
|
+
] # replace html table with space
|
|
525
|
+
|
|
526
|
+
# extract interline formula
|
|
527
|
+
display_matches = display_reg.finditer(content)
|
|
528
|
+
for match in display_matches:
|
|
529
|
+
matched = match.group(0)
|
|
530
|
+
if matched:
|
|
531
|
+
single_line = ''.join(matched.split())
|
|
532
|
+
position = [match.start(), match.end()]
|
|
533
|
+
# replace $$ with \[\]
|
|
534
|
+
dollar_pattern = re.compile(r'\$\$(.*?)\$\$|\$(.*?)\$|\\\((.*?)\\\)', re.DOTALL)
|
|
535
|
+
sub_match = dollar_pattern.search(single_line)
|
|
536
|
+
if sub_match is None:
|
|
537
|
+
# pass
|
|
538
|
+
content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[position[1]:]
|
|
539
|
+
pred_all.append({'category_type': 'equation_isolated', 'position': position, 'content': single_line})
|
|
540
|
+
elif sub_match.group(1):
|
|
541
|
+
single_line = re.sub(dollar_pattern, r'\\[\1\\]', single_line)
|
|
542
|
+
content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[
|
|
543
|
+
position[1]:] # replace equation with space
|
|
544
|
+
pred_all.append({'category_type': 'equation_isolated', 'position': position, 'content': single_line})
|
|
545
|
+
else:
|
|
546
|
+
single_line = re.sub(dollar_pattern, r'\\[\2\3\\]', single_line)
|
|
547
|
+
pred_all.append({
|
|
548
|
+
'category_type': 'equation_isolated',
|
|
549
|
+
'position': position,
|
|
550
|
+
'content': single_line,
|
|
551
|
+
'fine_category_type': 'equation_inline'
|
|
552
|
+
})
|
|
553
|
+
|
|
554
|
+
# extract md table with ||
|
|
555
|
+
md_table_mathces = md_table_reg.findall(content + '\n')
|
|
556
|
+
if len(md_table_mathces) >= 2:
|
|
557
|
+
# print("md table found!")
|
|
558
|
+
# print("content:", content)
|
|
559
|
+
content = convert_markdown_to_html(content)
|
|
560
|
+
# print('----------content after converting md table to html:', content)
|
|
561
|
+
html_table_matches = html_table_reg.finditer(content)
|
|
562
|
+
if html_table_matches:
|
|
563
|
+
for match in html_table_matches:
|
|
564
|
+
matched = match.group(0)
|
|
565
|
+
position = [match.start(), match.end()]
|
|
566
|
+
# content = content.replace(match, '')
|
|
567
|
+
# print('content after removing the md table:', content)
|
|
568
|
+
content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[
|
|
569
|
+
position[1]:] # replace md table with space
|
|
570
|
+
pred_all.append({
|
|
571
|
+
'category_type': 'html_table',
|
|
572
|
+
'position': position,
|
|
573
|
+
'content': matched.strip(),
|
|
574
|
+
'fine_category_type': 'md2html_table'
|
|
575
|
+
})
|
|
576
|
+
# print('---------After md table: \n', content)
|
|
577
|
+
|
|
578
|
+
# extract code blocks
|
|
579
|
+
code_matches = code_block_reg.finditer(content)
|
|
580
|
+
if code_matches:
|
|
581
|
+
for match in code_matches:
|
|
582
|
+
position = [match.start(), match.end()]
|
|
583
|
+
language = match.group(1)
|
|
584
|
+
code = match.group(2).strip()
|
|
585
|
+
# content = content.replace(match.group(0), '')
|
|
586
|
+
content = content[:position[0]] + ' ' * (position[1] - position[0]) + content[
|
|
587
|
+
position[1]:] # replace code block with space
|
|
588
|
+
pred_all.append({
|
|
589
|
+
'category_type': 'text_all',
|
|
590
|
+
'position': position,
|
|
591
|
+
'content': code,
|
|
592
|
+
'language': language,
|
|
593
|
+
'fine_category_type': 'code'
|
|
594
|
+
})
|
|
595
|
+
|
|
596
|
+
# Remove latex style
|
|
597
|
+
content = re.sub(r'\\title\{(.*?)\}', r'\1', content)
|
|
598
|
+
content = re.sub(r'\\title\s*\{\s*(.*?)\s*\}', r'\1', content, flags=re.DOTALL)
|
|
599
|
+
content = re.sub(r'\\text\s*\{\s*(.*?)\s*\}', r'\1', content, flags=re.DOTALL)
|
|
600
|
+
content = re.sub(r'\\section\*?\{(.*?)\}', r'\1', content)
|
|
601
|
+
content = re.sub(r'\\section\*?\{\s*(.*?)\s*\}', r'\1', content, flags=re.DOTALL)
|
|
602
|
+
|
|
603
|
+
# extract texts
|
|
604
|
+
res = content.split('\n\n')
|
|
605
|
+
if len(res) == 1:
|
|
606
|
+
res = content.split('\n') # some models do not use double newlines, so use single newlines to split
|
|
607
|
+
|
|
608
|
+
content_position = 0
|
|
609
|
+
for text in res:
|
|
610
|
+
position = [content_position, content_position + len(text)]
|
|
611
|
+
content_position += len(text)
|
|
612
|
+
text = text.strip()
|
|
613
|
+
text = text.strip('\n')
|
|
614
|
+
# print('ori_text: ', text)
|
|
615
|
+
text = '\n'.join([_.strip()
|
|
616
|
+
for _ in text.split('\n')
|
|
617
|
+
if _.strip()]) # avoid some single newline content with many spaces
|
|
618
|
+
# print('after strip text: ', text)
|
|
619
|
+
|
|
620
|
+
if text: # Check if the stripped text is not empty
|
|
621
|
+
if text.startswith('<table') and text.endswith('</table>'):
|
|
622
|
+
pred_all.append({
|
|
623
|
+
'category_type': 'html_table',
|
|
624
|
+
'position': position,
|
|
625
|
+
'content': text,
|
|
626
|
+
})
|
|
627
|
+
|
|
628
|
+
elif text.startswith('$') and text.endswith('$'):
|
|
629
|
+
if text.replace('$', '').strip():
|
|
630
|
+
pred_all.append({
|
|
631
|
+
'category_type': 'equation_isolated',
|
|
632
|
+
'position': position,
|
|
633
|
+
'content': text.strip(),
|
|
634
|
+
})
|
|
635
|
+
else:
|
|
636
|
+
text = text.strip()
|
|
637
|
+
if text:
|
|
638
|
+
pred_all.append({
|
|
639
|
+
'category_type': 'text_all',
|
|
640
|
+
'position': position,
|
|
641
|
+
'content': text,
|
|
642
|
+
'fine_category_type': 'text_block'
|
|
643
|
+
})
|
|
644
|
+
|
|
645
|
+
pred_dataset = defaultdict(list)
|
|
646
|
+
pred_all = sorted(pred_all, key=lambda x: x['position'][0])
|
|
647
|
+
for item in pred_all:
|
|
648
|
+
pred_dataset[item['category_type']].append(item)
|
|
649
|
+
return pred_dataset
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def extract_tex_table(content):
|
|
653
|
+
tables = []
|
|
654
|
+
tables_positions = []
|
|
655
|
+
|
|
656
|
+
pattern = r'\\begin{table}(.*?)\\end{table}'
|
|
657
|
+
for match in re.finditer(pattern, content, re.DOTALL):
|
|
658
|
+
start_pos = match.start()
|
|
659
|
+
end_pos = match.end()
|
|
660
|
+
table_content = match.group(0)
|
|
661
|
+
tables.append(table_content)
|
|
662
|
+
tables_positions.append((start_pos, end_pos))
|
|
663
|
+
content = content[:start_pos] + ' ' * (end_pos - start_pos) + content[end_pos:]
|
|
664
|
+
|
|
665
|
+
tabulars, tabular_positions = extract_tabular(content)
|
|
666
|
+
all_tables = tables + tabulars
|
|
667
|
+
all_positions = tables_positions + tabular_positions
|
|
668
|
+
|
|
669
|
+
all_result = sorted([[pos, table] for pos, table in zip(all_positions, all_tables)], key=lambda x: x[0][0])
|
|
670
|
+
all_tables = [x[1] for x in all_result]
|
|
671
|
+
all_positions = [x[0] for x in all_result]
|
|
672
|
+
|
|
673
|
+
return all_tables, all_positions
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def extract_html_table(text):
|
|
677
|
+
begin_pattern = r'<table(?:[^>]*)>'
|
|
678
|
+
end_pattern = r'</table>'
|
|
679
|
+
|
|
680
|
+
tabulars = []
|
|
681
|
+
positions = []
|
|
682
|
+
current_pos = 0
|
|
683
|
+
stack = []
|
|
684
|
+
|
|
685
|
+
while current_pos < len(text):
|
|
686
|
+
begin_match = re.search(begin_pattern, text[current_pos:])
|
|
687
|
+
end_match = re.search(end_pattern, text[current_pos:])
|
|
688
|
+
|
|
689
|
+
if not begin_match and not end_match:
|
|
690
|
+
break
|
|
691
|
+
|
|
692
|
+
if begin_match and (not end_match or begin_match.start() < end_match.start()):
|
|
693
|
+
stack.append(current_pos + begin_match.start())
|
|
694
|
+
current_pos += begin_match.start() + len(end_pattern)
|
|
695
|
+
elif end_match:
|
|
696
|
+
if stack:
|
|
697
|
+
start_pos = stack.pop()
|
|
698
|
+
if not stack:
|
|
699
|
+
end_pos = current_pos + end_match.start() + len(end_pattern)
|
|
700
|
+
tabular_code = text[start_pos:end_pos]
|
|
701
|
+
tabulars.append(tabular_code)
|
|
702
|
+
positions.append((start_pos, end_pos))
|
|
703
|
+
current_pos += end_match.start() + len(end_pattern)
|
|
704
|
+
else:
|
|
705
|
+
current_pos += 1
|
|
706
|
+
|
|
707
|
+
if stack:
|
|
708
|
+
new_start = stack[0] + len(begin_pattern)
|
|
709
|
+
new_tabulars, new_positions = extract_html_table(text[new_start:])
|
|
710
|
+
new_positions = [(start + new_start, end + new_start) for start, end in new_positions]
|
|
711
|
+
tabulars.extend(new_tabulars)
|
|
712
|
+
positions.extend(new_positions)
|
|
713
|
+
|
|
714
|
+
return tabulars, positions
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
def extract_node_content(node):
|
|
718
|
+
""" Recursively extract content from LatexEnvironmentNode and rebuild LaTeX table representation """
|
|
719
|
+
if isinstance(node, LatexCharsNode):
|
|
720
|
+
return node.chars # Use chars attribute
|
|
721
|
+
elif isinstance(node, LatexGroupNode):
|
|
722
|
+
return '{' + ''.join(extract_node_content(n) for n in node.nodelist) + '}'
|
|
723
|
+
elif isinstance(node, LatexMacroNode):
|
|
724
|
+
# Extract macro command and its arguments
|
|
725
|
+
macro_content = '\\' + node.macroname
|
|
726
|
+
if node.nodeargs:
|
|
727
|
+
macro_content += ''.join([extract_node_content(arg) for arg in node.nodeargs])
|
|
728
|
+
return macro_content
|
|
729
|
+
elif isinstance(node, LatexEnvironmentNode):
|
|
730
|
+
# Extract environment, preserve environment name and arguments
|
|
731
|
+
content = '\\begin{' + node.environmentname + '}'
|
|
732
|
+
if node.nodeargd and node.nodeargd.argnlist:
|
|
733
|
+
# content += "".join("{" + extract_node_content(arg) + "}" for arg in node.nodeargd)
|
|
734
|
+
# content += "".join("{" + extract_node_content(node.nodeargd) + "}")
|
|
735
|
+
content += '{' + extract_node_content(node.nodeargd.argnlist[0]) + '}'
|
|
736
|
+
if node.nodelist:
|
|
737
|
+
content += ''.join(extract_node_content(n) for n in node.nodelist)
|
|
738
|
+
content += '\\end{' + node.environmentname + '}'
|
|
739
|
+
return content
|
|
740
|
+
elif isinstance(node, LatexSpecialsNode): # Changed to LatexSpecialsNode
|
|
741
|
+
return node.specials_chars
|
|
742
|
+
else:
|
|
743
|
+
return ''
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def get_node_end_pos(node):
|
|
747
|
+
"""Recursively determine the end position of a node"""
|
|
748
|
+
if hasattr(node, 'nodelist') and node.nodelist:
|
|
749
|
+
# If the node has child nodes, recursively find the end position of the last child node
|
|
750
|
+
return get_node_end_pos(node.nodelist[-1])
|
|
751
|
+
elif hasattr(node, 'pos_end'):
|
|
752
|
+
# If the node has pos_end attribute, return it directly
|
|
753
|
+
return node.pos_end
|
|
754
|
+
else:
|
|
755
|
+
# If there are no child nodes, assume the node ends at the last character of its content
|
|
756
|
+
return node.pos + len(str(node))
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def remove_tex_table(content):
|
|
760
|
+
tables, positions = extract_tex_table(content)
|
|
761
|
+
|
|
762
|
+
# Delete in reverse order by position to avoid affecting unprocessed start positions
|
|
763
|
+
for start, end in sorted(positions, reverse=True):
|
|
764
|
+
content = content[:start] + content[end:] # Remove table content
|
|
765
|
+
|
|
766
|
+
return content
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def get_pred_category_type(pred_idx, pred_items):
|
|
770
|
+
# if pred_idx:
|
|
771
|
+
if pred_items[pred_idx].get('fine_category_type'):
|
|
772
|
+
pred_pred_category_type = pred_items[pred_idx]['fine_category_type']
|
|
773
|
+
else:
|
|
774
|
+
pred_pred_category_type = pred_items[pred_idx]['category_type']
|
|
775
|
+
# else:
|
|
776
|
+
# pred_pred_category_type = ""
|
|
777
|
+
return pred_pred_category_type
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def compute_edit_distance_matrix_new(gt_lines, matched_lines):
|
|
781
|
+
try:
|
|
782
|
+
distance_matrix = np.zeros((len(gt_lines), len(matched_lines)))
|
|
783
|
+
for i, gt_line in enumerate(gt_lines):
|
|
784
|
+
for j, matched_line in enumerate(matched_lines):
|
|
785
|
+
if len(gt_line) == 0 and len(matched_line) == 0:
|
|
786
|
+
distance_matrix[i][j] = 0
|
|
787
|
+
else:
|
|
788
|
+
distance_matrix[i][j] = Levenshtein.distance(gt_line,
|
|
789
|
+
matched_line) / max(len(matched_line), len(gt_line))
|
|
790
|
+
return distance_matrix
|
|
791
|
+
except ZeroDivisionError:
|
|
792
|
+
#print("ZeroDivisionError occurred. Outputting norm_gt_lines and norm_pred_lines:")
|
|
793
|
+
# print("norm_gt_lines:", gt_lines)
|
|
794
|
+
# print("norm_pred_lines:", matched_lines)
|
|
795
|
+
raise
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
def get_gt_pred_lines(gt_items, pred_items, line_type):
|
|
799
|
+
norm_html_lines = []
|
|
800
|
+
gt_lines = []
|
|
801
|
+
gt_cat_list = []
|
|
802
|
+
for item in gt_items:
|
|
803
|
+
if item.get('fine_category_type'):
|
|
804
|
+
gt_cat_list.append(item['fine_category_type'])
|
|
805
|
+
else:
|
|
806
|
+
gt_cat_list.append(item['category_type'])
|
|
807
|
+
if item.get('content'):
|
|
808
|
+
gt_lines.append(str(item['content']))
|
|
809
|
+
norm_html_lines.append(str(item['content']))
|
|
810
|
+
elif line_type == 'text':
|
|
811
|
+
gt_lines.append(str(item['text']))
|
|
812
|
+
elif line_type == 'html_table':
|
|
813
|
+
gt_lines.append(str(item['html']))
|
|
814
|
+
elif line_type == 'formula':
|
|
815
|
+
gt_lines.append(str(item['latex']))
|
|
816
|
+
elif line_type == 'latex_table':
|
|
817
|
+
gt_lines.append(str(item['latex']))
|
|
818
|
+
norm_html_lines.append(str(item['html']))
|
|
819
|
+
|
|
820
|
+
pred_lines = [str(item['content']) for item in pred_items]
|
|
821
|
+
|
|
822
|
+
if line_type == 'formula':
|
|
823
|
+
norm_gt_lines = [normalized_formula(_) for _ in gt_lines]
|
|
824
|
+
norm_pred_lines = [normalized_formula(_) for _ in pred_lines]
|
|
825
|
+
elif line_type == 'text':
|
|
826
|
+
# norm_gt_lines = [textblock_with_norm_formula(_) for _ in gt_lines]
|
|
827
|
+
# norm_pred_lines = [textblock_with_norm_formula(_) for _ in pred_lines]
|
|
828
|
+
norm_gt_lines = [clean_string(textblock2unicode(_)) for _ in gt_lines]
|
|
829
|
+
norm_pred_lines = [clean_string(textblock2unicode(_)) for _ in pred_lines]
|
|
830
|
+
# norm_gt_lines = get_norm_text_lines(gt_lines)
|
|
831
|
+
# norm_pred_lines = get_norm_text_lines(pred_lines)
|
|
832
|
+
else:
|
|
833
|
+
norm_gt_lines = gt_lines
|
|
834
|
+
norm_pred_lines = pred_lines
|
|
835
|
+
|
|
836
|
+
if line_type == 'latex_table':
|
|
837
|
+
gt_lines = norm_html_lines
|
|
838
|
+
|
|
839
|
+
filtered_lists = [(a, b, c) for a, b, c in zip(gt_lines, norm_gt_lines, gt_cat_list) if a and b]
|
|
840
|
+
|
|
841
|
+
# decompress to three lists
|
|
842
|
+
if filtered_lists:
|
|
843
|
+
gt_lines_c, norm_gt_lines_c, gt_cat_list_c = zip(*filtered_lists)
|
|
844
|
+
|
|
845
|
+
# convert to lists
|
|
846
|
+
gt_lines_c = list(gt_lines_c)
|
|
847
|
+
norm_gt_lines_c = list(norm_gt_lines_c)
|
|
848
|
+
gt_cat_list_c = list(gt_cat_list_c)
|
|
849
|
+
else:
|
|
850
|
+
gt_lines_c = []
|
|
851
|
+
norm_gt_lines_c = []
|
|
852
|
+
gt_cat_list_c = []
|
|
853
|
+
|
|
854
|
+
# pred's empty values
|
|
855
|
+
filtered_lists = [(a, b) for a, b in zip(pred_lines, norm_pred_lines) if a and b]
|
|
856
|
+
|
|
857
|
+
# decompress to two lists
|
|
858
|
+
if filtered_lists:
|
|
859
|
+
pred_lines_c, norm_pred_lines_c = zip(*filtered_lists)
|
|
860
|
+
|
|
861
|
+
# convert to lists
|
|
862
|
+
pred_lines_c = list(pred_lines_c)
|
|
863
|
+
norm_pred_lines_c = list(norm_pred_lines_c)
|
|
864
|
+
else:
|
|
865
|
+
pred_lines_c = []
|
|
866
|
+
norm_pred_lines_c = []
|
|
867
|
+
|
|
868
|
+
return gt_lines_c, norm_gt_lines_c, gt_cat_list_c, pred_lines_c, norm_pred_lines_c
|
|
869
|
+
# return gt_lines, norm_gt_lines, gt_cat_list, pred_lines, norm_pred_lines
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
def match_gt2pred_simple(gt_items, pred_items, line_type, img_name):
|
|
873
|
+
|
|
874
|
+
gt_lines, norm_gt_lines, gt_cat_list, pred_lines, norm_pred_lines = get_gt_pred_lines(
|
|
875
|
+
gt_items, pred_items, line_type
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
match_list = []
|
|
879
|
+
if not norm_gt_lines: # not matched pred should be concatenated
|
|
880
|
+
# print("One of the lists is empty. Returning an empty gt result.")
|
|
881
|
+
# for pred_idx in range(len(norm_pred_lines)):
|
|
882
|
+
pred_idx_list = range(len(norm_pred_lines))
|
|
883
|
+
match_list.append({
|
|
884
|
+
'gt_idx': [''],
|
|
885
|
+
'gt': '',
|
|
886
|
+
'pred_idx': pred_idx_list,
|
|
887
|
+
'pred': ''.join(pred_lines[_] for _ in pred_idx_list),
|
|
888
|
+
'gt_position': [''],
|
|
889
|
+
'pred_position': pred_items[pred_idx_list[0]]['position'][0], # get the first pred's position
|
|
890
|
+
'norm_gt': '',
|
|
891
|
+
'norm_pred': ''.join(norm_pred_lines[_] for _ in pred_idx_list),
|
|
892
|
+
'gt_category_type': '',
|
|
893
|
+
'pred_category_type': get_pred_category_type(pred_idx_list[0], pred_items), # get the first pred's category
|
|
894
|
+
'gt_attribute': [{}],
|
|
895
|
+
'edit': 1,
|
|
896
|
+
'img_id': img_name
|
|
897
|
+
})
|
|
898
|
+
return match_list
|
|
899
|
+
elif not norm_pred_lines: # not matched gt should be separated
|
|
900
|
+
# print("One of the lists is empty. Returning an empty pred result.")
|
|
901
|
+
for gt_idx in range(len(norm_gt_lines)):
|
|
902
|
+
match_list.append({
|
|
903
|
+
'gt_idx': [gt_idx],
|
|
904
|
+
'gt':
|
|
905
|
+
gt_lines[gt_idx],
|
|
906
|
+
'pred_idx': [''],
|
|
907
|
+
'pred':
|
|
908
|
+
'',
|
|
909
|
+
'gt_position': [
|
|
910
|
+
gt_items[gt_idx].get('order')
|
|
911
|
+
if gt_items[gt_idx].get('order') else gt_items[gt_idx].get('position', [''])[0]
|
|
912
|
+
],
|
|
913
|
+
'pred_position':
|
|
914
|
+
'',
|
|
915
|
+
'norm_gt':
|
|
916
|
+
norm_gt_lines[gt_idx],
|
|
917
|
+
'norm_pred':
|
|
918
|
+
'',
|
|
919
|
+
'gt_category_type':
|
|
920
|
+
gt_cat_list[gt_idx],
|
|
921
|
+
'pred_category_type':
|
|
922
|
+
'',
|
|
923
|
+
'gt_attribute': [gt_items[gt_idx].get('attribute', {})],
|
|
924
|
+
'edit':
|
|
925
|
+
1,
|
|
926
|
+
'img_id':
|
|
927
|
+
img_name
|
|
928
|
+
})
|
|
929
|
+
return match_list
|
|
930
|
+
|
|
931
|
+
cost_matrix = compute_edit_distance_matrix_new(norm_gt_lines, norm_pred_lines)
|
|
932
|
+
|
|
933
|
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
|
934
|
+
|
|
935
|
+
for gt_idx in range(len(norm_gt_lines)):
|
|
936
|
+
if gt_idx in row_ind:
|
|
937
|
+
row_i = list(row_ind).index(gt_idx)
|
|
938
|
+
pred_idx = int(col_ind[row_i])
|
|
939
|
+
pred_line = pred_lines[pred_idx]
|
|
940
|
+
norm_pred_line = norm_pred_lines[pred_idx]
|
|
941
|
+
edit = cost_matrix[gt_idx][pred_idx]
|
|
942
|
+
# print('edit_dist', edit)
|
|
943
|
+
# if edit > 0.7:
|
|
944
|
+
# print('! Not match')
|
|
945
|
+
else:
|
|
946
|
+
# print('No match pred')
|
|
947
|
+
pred_idx = ''
|
|
948
|
+
pred_line = ''
|
|
949
|
+
norm_pred_line = ''
|
|
950
|
+
edit = 1
|
|
951
|
+
|
|
952
|
+
match_list.append({
|
|
953
|
+
'gt_idx': [gt_idx],
|
|
954
|
+
'gt':
|
|
955
|
+
gt_lines[gt_idx],
|
|
956
|
+
'norm_gt':
|
|
957
|
+
norm_gt_lines[gt_idx],
|
|
958
|
+
'gt_category_type':
|
|
959
|
+
gt_cat_list[gt_idx],
|
|
960
|
+
'gt_position': [
|
|
961
|
+
gt_items[gt_idx].get('order')
|
|
962
|
+
if gt_items[gt_idx].get('order') else gt_items[gt_idx].get('position', [''])[0]
|
|
963
|
+
],
|
|
964
|
+
'gt_attribute': [gt_items[gt_idx].get('attribute', {})],
|
|
965
|
+
'pred_idx': [pred_idx],
|
|
966
|
+
'pred':
|
|
967
|
+
pred_line,
|
|
968
|
+
'norm_pred':
|
|
969
|
+
norm_pred_line,
|
|
970
|
+
'pred_category_type':
|
|
971
|
+
get_pred_category_type(pred_idx, pred_items) if pred_idx else '',
|
|
972
|
+
'pred_position':
|
|
973
|
+
pred_items[pred_idx]['position'][0] if pred_idx else '',
|
|
974
|
+
'edit':
|
|
975
|
+
edit,
|
|
976
|
+
'img_id':
|
|
977
|
+
img_name
|
|
978
|
+
})
|
|
979
|
+
# print('-'*10)
|
|
980
|
+
# [([0,1], 0),(2, 1), (1,2)] --> [0,2,1]/[0,1,2]
|
|
981
|
+
|
|
982
|
+
pred_idx_list = [
|
|
983
|
+
pred_idx for pred_idx in range(len(norm_pred_lines)) if pred_idx not in col_ind
|
|
984
|
+
] # get not matched preds
|
|
985
|
+
if pred_idx_list: # if there are still remaining pred_idx, concatenate all preds
|
|
986
|
+
match_list.append({
|
|
987
|
+
'gt_idx': [''],
|
|
988
|
+
'gt': '',
|
|
989
|
+
'pred_idx': pred_idx_list,
|
|
990
|
+
'pred': ''.join(pred_lines[_] for _ in pred_idx_list),
|
|
991
|
+
'gt_position': [''],
|
|
992
|
+
'pred_position': pred_items[pred_idx_list[0]]['position'][0], # get the first pred's position
|
|
993
|
+
'norm_gt': '',
|
|
994
|
+
'norm_pred': ''.join(norm_pred_lines[_] for _ in pred_idx_list),
|
|
995
|
+
'gt_category_type': '',
|
|
996
|
+
'pred_category_type': get_pred_category_type(pred_idx_list[0], pred_items), # get the first pred's category
|
|
997
|
+
'gt_attribute': [{}],
|
|
998
|
+
'edit': 1,
|
|
999
|
+
'img_id': img_name
|
|
1000
|
+
})
|
|
1001
|
+
return match_list
|
|
1002
|
+
|
|
1003
|
+
|
|
1004
|
+
def match_gt2pred_no_split(gt_items, pred_items, line_type, img_name):
|
|
1005
|
+
# directly concatenate gt and pred by position
|
|
1006
|
+
gt_lines, norm_gt_lines, gt_cat_list, pred_lines, norm_pred_lines = get_gt_pred_lines(
|
|
1007
|
+
gt_items, pred_items, line_type
|
|
1008
|
+
)
|
|
1009
|
+
gt_line_with_position = []
|
|
1010
|
+
for gt_line, norm_gt_line, gt_item in zip(gt_lines, norm_gt_lines, gt_items):
|
|
1011
|
+
gt_position = gt_item['order'] if gt_item.get('order') else gt_item.get('position', [''])[0]
|
|
1012
|
+
if gt_position:
|
|
1013
|
+
gt_line_with_position.append((gt_position, gt_line, norm_gt_line))
|
|
1014
|
+
sorted_gt_lines = sorted(gt_line_with_position, key=lambda x: x[0])
|
|
1015
|
+
gt = '\n\n'.join([_[1] for _ in sorted_gt_lines])
|
|
1016
|
+
norm_gt = '\n\n'.join([_[2] for _ in sorted_gt_lines])
|
|
1017
|
+
pred_line_with_position = [(pred_item['position'], pred_line, pred_norm_line)
|
|
1018
|
+
for pred_line, pred_norm_line, pred_item in zip(pred_lines, norm_pred_lines, pred_items)]
|
|
1019
|
+
sorted_pred_lines = sorted(pred_line_with_position, key=lambda x: x[0])
|
|
1020
|
+
pred = '\n\n'.join([_[1] for _ in sorted_pred_lines])
|
|
1021
|
+
norm_pred = '\n\n'.join([_[2] for _ in sorted_pred_lines])
|
|
1022
|
+
# edit = Levenshtein.distance(norm_gt, norm_pred)/max(len(norm_gt), len(norm_pred))
|
|
1023
|
+
if norm_gt or norm_pred:
|
|
1024
|
+
return [{
|
|
1025
|
+
'gt_idx': [0],
|
|
1026
|
+
'gt': gt,
|
|
1027
|
+
'norm_gt': norm_gt,
|
|
1028
|
+
'gt_category_type': 'text_merge',
|
|
1029
|
+
'gt_position': [''],
|
|
1030
|
+
'gt_attribute': [{}],
|
|
1031
|
+
'pred_idx': [0],
|
|
1032
|
+
'pred': pred,
|
|
1033
|
+
'norm_pred': norm_pred,
|
|
1034
|
+
'pred_category_type': 'text_merge',
|
|
1035
|
+
'pred_position': '',
|
|
1036
|
+
# 'edit': edit,
|
|
1037
|
+
'img_id': img_name
|
|
1038
|
+
}]
|
|
1039
|
+
else:
|
|
1040
|
+
return []
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
import copy
|
|
1044
|
+
import evaluate
|
|
1045
|
+
|
|
1046
|
+
# from rapidfuzz.distance import Levenshtein
|
|
1047
|
+
import Levenshtein
|
|
1048
|
+
import numpy as np
|
|
1049
|
+
import pdb
|
|
1050
|
+
from collections import Counter, defaultdict
|
|
1051
|
+
from Levenshtein import distance as Levenshtein_distance
|
|
1052
|
+
from scipy.optimize import linear_sum_assignment
|
|
1053
|
+
|
|
1054
|
+
|
|
1055
|
+
def match_gt2pred_quick(gt_items, pred_items, line_type, img_name):
|
|
1056
|
+
|
|
1057
|
+
gt_lines, norm_gt_lines, gt_cat_list, pred_lines, norm_pred_lines = get_gt_pred_lines(
|
|
1058
|
+
gt_items, pred_items, line_type
|
|
1059
|
+
)
|
|
1060
|
+
all_gt_indices = set(range(len(norm_gt_lines)))
|
|
1061
|
+
all_pred_indices = set(range(len(norm_pred_lines)))
|
|
1062
|
+
|
|
1063
|
+
if not norm_gt_lines:
|
|
1064
|
+
match_list = []
|
|
1065
|
+
for pred_idx in range(len(norm_pred_lines)):
|
|
1066
|
+
match_list.append({
|
|
1067
|
+
'gt_idx': [''],
|
|
1068
|
+
'gt': '',
|
|
1069
|
+
'pred_idx': [pred_idx],
|
|
1070
|
+
'pred': pred_lines[pred_idx],
|
|
1071
|
+
'gt_position': '',
|
|
1072
|
+
'pred_position': pred_items[pred_idx]['position'][0],
|
|
1073
|
+
'norm_gt': '',
|
|
1074
|
+
'norm_pred': norm_pred_lines[pred_idx],
|
|
1075
|
+
'gt_category_type': '',
|
|
1076
|
+
'pred_category_type': get_pred_category_type(pred_idx, pred_items),
|
|
1077
|
+
'gt_attribute': [{}],
|
|
1078
|
+
'edit': 1,
|
|
1079
|
+
'img_id': img_name
|
|
1080
|
+
})
|
|
1081
|
+
return match_list
|
|
1082
|
+
elif not norm_pred_lines:
|
|
1083
|
+
match_list = []
|
|
1084
|
+
for gt_idx in range(len(norm_gt_lines)):
|
|
1085
|
+
match_list.append({
|
|
1086
|
+
'gt_idx': [gt_idx],
|
|
1087
|
+
'gt':
|
|
1088
|
+
gt_lines[gt_idx],
|
|
1089
|
+
'pred_idx': [''],
|
|
1090
|
+
'pred':
|
|
1091
|
+
'',
|
|
1092
|
+
'gt_position': [
|
|
1093
|
+
gt_items[gt_idx].get('order')
|
|
1094
|
+
if gt_items[gt_idx].get('order') else gt_items[gt_idx].get('position', [''])[0]
|
|
1095
|
+
],
|
|
1096
|
+
'pred_position':
|
|
1097
|
+
'',
|
|
1098
|
+
'norm_gt':
|
|
1099
|
+
norm_gt_lines[gt_idx],
|
|
1100
|
+
'norm_pred':
|
|
1101
|
+
'',
|
|
1102
|
+
'gt_category_type':
|
|
1103
|
+
gt_cat_list[gt_idx],
|
|
1104
|
+
'pred_category_type':
|
|
1105
|
+
'',
|
|
1106
|
+
'gt_attribute': [gt_items[gt_idx].get('attribute', {})],
|
|
1107
|
+
'edit':
|
|
1108
|
+
1,
|
|
1109
|
+
'img_id':
|
|
1110
|
+
img_name
|
|
1111
|
+
})
|
|
1112
|
+
return match_list
|
|
1113
|
+
elif len(norm_gt_lines) == 1 and len(norm_pred_lines) == 1:
|
|
1114
|
+
edit_distance = Levenshtein_distance(norm_gt_lines[0], norm_pred_lines[0])
|
|
1115
|
+
normalized_edit_distance = edit_distance / max(len(norm_gt_lines[0]), len(norm_pred_lines[0]))
|
|
1116
|
+
return [{
|
|
1117
|
+
'gt_idx': [0],
|
|
1118
|
+
'gt':
|
|
1119
|
+
gt_lines[0],
|
|
1120
|
+
'pred_idx': [0],
|
|
1121
|
+
'pred':
|
|
1122
|
+
pred_lines[0],
|
|
1123
|
+
'gt_position':
|
|
1124
|
+
[gt_items[0].get('order') if gt_items[0].get('order') else gt_items[0].get('position', [''])[0]],
|
|
1125
|
+
'pred_position':
|
|
1126
|
+
pred_items[0]['position'][0],
|
|
1127
|
+
'norm_gt':
|
|
1128
|
+
norm_gt_lines[0],
|
|
1129
|
+
'norm_pred':
|
|
1130
|
+
norm_pred_lines[0],
|
|
1131
|
+
'gt_category_type':
|
|
1132
|
+
gt_cat_list[0],
|
|
1133
|
+
'pred_category_type':
|
|
1134
|
+
get_pred_category_type(0, pred_items),
|
|
1135
|
+
'gt_attribute': [gt_items[0].get('attribute', {})],
|
|
1136
|
+
'edit':
|
|
1137
|
+
normalized_edit_distance,
|
|
1138
|
+
'img_id':
|
|
1139
|
+
img_name
|
|
1140
|
+
}]
|
|
1141
|
+
|
|
1142
|
+
cost_matrix = compute_edit_distance_matrix_new(norm_gt_lines, norm_pred_lines)
|
|
1143
|
+
|
|
1144
|
+
matched_col_idx, row_ind, cost_list = cal_final_match(cost_matrix, norm_gt_lines, norm_pred_lines)
|
|
1145
|
+
|
|
1146
|
+
gt_lens_dict, pred_lens_dict = initialize_indices(norm_gt_lines, norm_pred_lines)
|
|
1147
|
+
|
|
1148
|
+
matches, unmatched_gt_indices, unmatched_pred_indices = process_matches(
|
|
1149
|
+
matched_col_idx, row_ind, cost_list, norm_gt_lines, norm_pred_lines, pred_lines
|
|
1150
|
+
)
|
|
1151
|
+
|
|
1152
|
+
matching_dict = fuzzy_match_unmatched_items(unmatched_gt_indices, norm_gt_lines, norm_pred_lines)
|
|
1153
|
+
|
|
1154
|
+
final_matches = merge_matches(matches, matching_dict)
|
|
1155
|
+
|
|
1156
|
+
recalculate_edit_distances(final_matches, gt_lens_dict, norm_gt_lines, norm_pred_lines)
|
|
1157
|
+
|
|
1158
|
+
converted_results = convert_final_matches(final_matches, norm_gt_lines, norm_pred_lines)
|
|
1159
|
+
|
|
1160
|
+
merged_results = merge_duplicates_add_unmatched(
|
|
1161
|
+
converted_results, norm_gt_lines, norm_pred_lines, gt_lines, pred_lines, all_gt_indices, all_pred_indices
|
|
1162
|
+
)
|
|
1163
|
+
|
|
1164
|
+
for entry in merged_results:
|
|
1165
|
+
entry['gt_idx'] = [entry['gt_idx']] if not isinstance(entry['gt_idx'], list) else entry['gt_idx']
|
|
1166
|
+
entry['pred_idx'] = [entry['pred_idx']] if not isinstance(entry['pred_idx'], list) else entry['pred_idx']
|
|
1167
|
+
entry['gt_position'] = [
|
|
1168
|
+
gt_items[_].get('order') if gt_items[_].get('order') else gt_items[_].get('position', [''])[0]
|
|
1169
|
+
for _ in entry['gt_idx']
|
|
1170
|
+
] if entry['gt_idx'] != [''] else ['']
|
|
1171
|
+
entry['pred_position'] = pred_items[entry['pred_idx'][0]]['position'][0] if entry['pred_idx'] != [''] else ''
|
|
1172
|
+
entry['gt'] = ''.join([gt_lines[_] for _ in entry['gt_idx']]) if entry['gt_idx'] != [''] else ''
|
|
1173
|
+
entry['pred'] = ''.join([pred_lines[_] for _ in entry['pred_idx']]) if entry['pred_idx'] != [''] else ''
|
|
1174
|
+
entry['norm_gt'] = ''.join([norm_gt_lines[_] for _ in entry['gt_idx']]) if entry['gt_idx'] != [''] else ''
|
|
1175
|
+
entry['norm_pred'] = ''.join([norm_pred_lines[_]
|
|
1176
|
+
for _ in entry['pred_idx']]) if entry['pred_idx'] != [''] else ''
|
|
1177
|
+
|
|
1178
|
+
if entry['gt_idx'] != ['']:
|
|
1179
|
+
ignore_type = [
|
|
1180
|
+
'figure_caption', 'figure_footnote', 'table_caption', 'table_footnote', 'code_algorithm',
|
|
1181
|
+
'code_algorithm_caption', 'header', 'footer', 'page_footnote', 'page_number', 'equation_caption'
|
|
1182
|
+
]
|
|
1183
|
+
gt_cagegory_clean = [gt_cat_list[_] for _ in entry['gt_idx'] if gt_cat_list[_] not in ignore_type]
|
|
1184
|
+
if gt_cagegory_clean:
|
|
1185
|
+
entry['gt_category_type'] = Counter(gt_cagegory_clean).most_common(1)[0][0]
|
|
1186
|
+
else:
|
|
1187
|
+
entry['gt_category_type'] = Counter([gt_cat_list[_] for _ in entry['gt_idx']]).most_common(1)[0][0]
|
|
1188
|
+
else:
|
|
1189
|
+
entry['gt_category_type'] = ''
|
|
1190
|
+
entry['pred_category_type'] = get_pred_category_type(entry['pred_idx'][0],
|
|
1191
|
+
pred_items) if entry['pred_idx'] != [''] else ''
|
|
1192
|
+
entry['gt_attribute'] = [gt_items[_].get('attribute', {})
|
|
1193
|
+
for _ in entry['gt_idx']] if entry['gt_idx'] != [''] else [{}]
|
|
1194
|
+
entry['img_id'] = img_name
|
|
1195
|
+
|
|
1196
|
+
return merged_results
|
|
1197
|
+
|
|
1198
|
+
|
|
1199
|
+
def merge_duplicates_add_unmatched(
|
|
1200
|
+
converted_results, norm_gt_lines, norm_pred_lines, gt_lines, pred_lines, all_gt_indices, all_pred_indices
|
|
1201
|
+
):
|
|
1202
|
+
merged_results = []
|
|
1203
|
+
processed_pred = set()
|
|
1204
|
+
processed_gt = set()
|
|
1205
|
+
|
|
1206
|
+
for entry in converted_results:
|
|
1207
|
+
pred_idx = tuple(entry['pred_idx']) if isinstance(entry['pred_idx'], list) else (entry['pred_idx'], )
|
|
1208
|
+
if pred_idx not in processed_pred and pred_idx != ('', ):
|
|
1209
|
+
merged_entry = {
|
|
1210
|
+
'gt_idx': [entry['gt_idx']],
|
|
1211
|
+
'gt': entry['gt'],
|
|
1212
|
+
'pred_idx': entry['pred_idx'],
|
|
1213
|
+
'pred': entry['pred'],
|
|
1214
|
+
'edit': entry['edit']
|
|
1215
|
+
}
|
|
1216
|
+
for other_entry in converted_results:
|
|
1217
|
+
other_pred_idx = tuple(other_entry['pred_idx']) if isinstance(other_entry['pred_idx'],
|
|
1218
|
+
list) else (other_entry['pred_idx'], )
|
|
1219
|
+
if other_pred_idx == pred_idx and other_entry is not entry:
|
|
1220
|
+
merged_entry['gt_idx'].append(other_entry['gt_idx'])
|
|
1221
|
+
merged_entry['gt'] += other_entry['gt']
|
|
1222
|
+
processed_gt.add(other_entry['gt_idx'])
|
|
1223
|
+
merged_results.append(merged_entry)
|
|
1224
|
+
processed_pred.add(pred_idx)
|
|
1225
|
+
processed_gt.add(entry['gt_idx'])
|
|
1226
|
+
|
|
1227
|
+
for entry in converted_results:
|
|
1228
|
+
if entry['gt_idx'] not in processed_gt:
|
|
1229
|
+
merged_results.append(entry)
|
|
1230
|
+
|
|
1231
|
+
for gt_idx in range(len(norm_gt_lines)):
|
|
1232
|
+
if gt_idx not in processed_gt:
|
|
1233
|
+
merged_results.append({'gt_idx': [gt_idx], 'gt': gt_lines[gt_idx], 'pred_idx': [''], 'pred': '', 'edit': 1})
|
|
1234
|
+
return merged_results
|
|
1235
|
+
|
|
1236
|
+
|
|
1237
|
+
def formula_format(formula_matches, img_name):
|
|
1238
|
+
return [{
|
|
1239
|
+
'gt': item['gt'],
|
|
1240
|
+
'pred': item['pred'],
|
|
1241
|
+
'img_id': f'{img_name}_{i}'
|
|
1242
|
+
} for i, item in enumerate(formula_matches)]
|
|
1243
|
+
|
|
1244
|
+
|
|
1245
|
+
def merge_lists_with_sublists(main_list, sub_lists):
|
|
1246
|
+
main_list_final = list(copy.deepcopy(main_list))
|
|
1247
|
+
for sub_list in sub_lists:
|
|
1248
|
+
pop_idx = main_list_final.index(sub_list[0])
|
|
1249
|
+
for _ in sub_list:
|
|
1250
|
+
main_list_final.pop(pop_idx)
|
|
1251
|
+
main_list_final.insert(pop_idx, sub_list)
|
|
1252
|
+
return main_list_final
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
def sub_pred_fuzzy_matching(gt, pred):
|
|
1256
|
+
|
|
1257
|
+
min_d = float('inf')
|
|
1258
|
+
# pos = -1
|
|
1259
|
+
|
|
1260
|
+
gt_len = len(gt)
|
|
1261
|
+
pred_len = len(pred)
|
|
1262
|
+
|
|
1263
|
+
if gt_len >= pred_len and pred_len > 0:
|
|
1264
|
+
for i in range(gt_len - pred_len + 1):
|
|
1265
|
+
sub = gt[i:i + pred_len]
|
|
1266
|
+
dist = Levenshtein_distance(sub, pred) / pred_len
|
|
1267
|
+
if dist < min_d:
|
|
1268
|
+
min_d = dist
|
|
1269
|
+
pos = i
|
|
1270
|
+
|
|
1271
|
+
return min_d
|
|
1272
|
+
else:
|
|
1273
|
+
return False
|
|
1274
|
+
|
|
1275
|
+
|
|
1276
|
+
def sub_gt_fuzzy_matching(pred, gt):
|
|
1277
|
+
|
|
1278
|
+
min_d = float('inf')
|
|
1279
|
+
pos = ''
|
|
1280
|
+
matched_sub = ''
|
|
1281
|
+
gt_len = len(gt)
|
|
1282
|
+
pred_len = len(pred)
|
|
1283
|
+
|
|
1284
|
+
if pred_len >= gt_len and gt_len > 0:
|
|
1285
|
+
for i in range(pred_len - gt_len + 1):
|
|
1286
|
+
sub = pred[i:i + gt_len]
|
|
1287
|
+
dist = Levenshtein.distance(sub, gt) / gt_len
|
|
1288
|
+
if dist < min_d:
|
|
1289
|
+
min_d = dist
|
|
1290
|
+
pos = i
|
|
1291
|
+
matched_sub = sub
|
|
1292
|
+
return min_d, pos, gt_len, matched_sub
|
|
1293
|
+
else:
|
|
1294
|
+
return 1, '', gt_len, ''
|
|
1295
|
+
|
|
1296
|
+
|
|
1297
|
+
def get_final_subset(subset_certain, subset_certain_cost):
|
|
1298
|
+
if not subset_certain or not subset_certain_cost:
|
|
1299
|
+
return []
|
|
1300
|
+
|
|
1301
|
+
subset_turple = sorted([(a, b) for a, b in zip(subset_certain, subset_certain_cost)], key=lambda x: x[0][0])
|
|
1302
|
+
|
|
1303
|
+
group_list = defaultdict(list)
|
|
1304
|
+
group_idx = 0
|
|
1305
|
+
group_list[group_idx].append(subset_turple[0])
|
|
1306
|
+
|
|
1307
|
+
for item in subset_turple[1:]:
|
|
1308
|
+
overlap_flag = False
|
|
1309
|
+
for subset in group_list[group_idx]:
|
|
1310
|
+
for idx in item[0]:
|
|
1311
|
+
if idx in subset[0]:
|
|
1312
|
+
overlap_flag = True
|
|
1313
|
+
break
|
|
1314
|
+
if overlap_flag:
|
|
1315
|
+
break
|
|
1316
|
+
if overlap_flag:
|
|
1317
|
+
group_list[group_idx].append(item)
|
|
1318
|
+
else:
|
|
1319
|
+
group_idx += 1
|
|
1320
|
+
group_list[group_idx].append(item)
|
|
1321
|
+
|
|
1322
|
+
final_subset = []
|
|
1323
|
+
for _, group in group_list.items():
|
|
1324
|
+
if len(group) == 1:
|
|
1325
|
+
final_subset.append(group[0][0])
|
|
1326
|
+
else:
|
|
1327
|
+
path_dict = defaultdict(list)
|
|
1328
|
+
path_idx = 0
|
|
1329
|
+
path_dict[path_idx].append(group[0])
|
|
1330
|
+
|
|
1331
|
+
for subset in group[1:]:
|
|
1332
|
+
new_path = True
|
|
1333
|
+
for path_idx_s, path_items in path_dict.items():
|
|
1334
|
+
is_dup = False
|
|
1335
|
+
is_same = False
|
|
1336
|
+
for path_item in path_items:
|
|
1337
|
+
if path_item[0] == subset[0]:
|
|
1338
|
+
is_dup = True
|
|
1339
|
+
is_same = True
|
|
1340
|
+
if path_item[1] > subset[1]:
|
|
1341
|
+
path_dict[path_idx_s].pop(path_dict[path_idx_s].index(path_item))
|
|
1342
|
+
path_dict[path_idx_s].append(subset)
|
|
1343
|
+
else:
|
|
1344
|
+
for num_1 in path_item[0]:
|
|
1345
|
+
for num_2 in subset[0]:
|
|
1346
|
+
if num_1 == num_2:
|
|
1347
|
+
is_dup = True
|
|
1348
|
+
if not is_dup:
|
|
1349
|
+
path_dict[path_idx_s].append(subset)
|
|
1350
|
+
new_path = False
|
|
1351
|
+
if is_same:
|
|
1352
|
+
new_path = False
|
|
1353
|
+
if new_path:
|
|
1354
|
+
path_idx = len(path_dict.keys())
|
|
1355
|
+
path_dict[path_idx].append(subset)
|
|
1356
|
+
|
|
1357
|
+
saved_cost = float('inf')
|
|
1358
|
+
saved_subset = []
|
|
1359
|
+
for path_idx, path in path_dict.items():
|
|
1360
|
+
avg_cost = sum([i[1] for i in path]) / len(path)
|
|
1361
|
+
if avg_cost < saved_cost:
|
|
1362
|
+
saved_subset = [i[0] for i in path]
|
|
1363
|
+
saved_cost = avg_cost
|
|
1364
|
+
|
|
1365
|
+
final_subset.extend(saved_subset)
|
|
1366
|
+
|
|
1367
|
+
return final_subset
|
|
1368
|
+
|
|
1369
|
+
|
|
1370
|
+
def judge_pred_merge(gt_list, pred_list, threshold=0.6):
|
|
1371
|
+
if len(pred_list) == 1:
|
|
1372
|
+
return False, False
|
|
1373
|
+
|
|
1374
|
+
cur_pred = ' '.join(pred_list[:-1])
|
|
1375
|
+
merged_pred = ' '.join(pred_list)
|
|
1376
|
+
|
|
1377
|
+
cur_dist = Levenshtein.distance(gt_list[0], cur_pred) / max(len(gt_list[0]), len(cur_pred))
|
|
1378
|
+
merged_dist = Levenshtein.distance(gt_list[0], merged_pred) / max(len(gt_list[0]), len(merged_pred))
|
|
1379
|
+
|
|
1380
|
+
if merged_dist > cur_dist:
|
|
1381
|
+
return False, False
|
|
1382
|
+
|
|
1383
|
+
cur_fuzzy_dists = [sub_pred_fuzzy_matching(gt_list[0], cur_pred) for cur_pred in pred_list[:-1]]
|
|
1384
|
+
if any(dist is False or dist > threshold for dist in cur_fuzzy_dists):
|
|
1385
|
+
return False, False
|
|
1386
|
+
|
|
1387
|
+
add_fuzzy_dist = sub_pred_fuzzy_matching(gt_list[0], pred_list[-1])
|
|
1388
|
+
if add_fuzzy_dist is False:
|
|
1389
|
+
return False, False
|
|
1390
|
+
|
|
1391
|
+
merged_pred_flag = add_fuzzy_dist < threshold
|
|
1392
|
+
continue_flag = len(merged_pred) <= len(gt_list[0])
|
|
1393
|
+
|
|
1394
|
+
return merged_pred_flag, continue_flag
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
def deal_with_truncated(cost_matrix, norm_gt_lines, norm_pred_lines):
|
|
1398
|
+
matched_first = np.argwhere(cost_matrix < 0.25)
|
|
1399
|
+
masked_gt_idx = [i[0] for i in matched_first]
|
|
1400
|
+
unmasked_gt_idx = [i for i in range(cost_matrix.shape[0]) if i not in masked_gt_idx]
|
|
1401
|
+
masked_pred_idx = [i[1] for i in matched_first]
|
|
1402
|
+
unmasked_pred_idx = [i for i in range(cost_matrix.shape[1]) if i not in masked_pred_idx]
|
|
1403
|
+
|
|
1404
|
+
merges_gt_dict = {}
|
|
1405
|
+
merges_pred_dict = {}
|
|
1406
|
+
merged_gt_subsets = []
|
|
1407
|
+
|
|
1408
|
+
for gt_idx in unmasked_gt_idx:
|
|
1409
|
+
check_merge_subset = []
|
|
1410
|
+
merged_dist = []
|
|
1411
|
+
|
|
1412
|
+
for pred_idx in unmasked_pred_idx:
|
|
1413
|
+
step = 1
|
|
1414
|
+
merged_pred = [norm_pred_lines[pred_idx]]
|
|
1415
|
+
|
|
1416
|
+
while True:
|
|
1417
|
+
if pred_idx + step in masked_pred_idx or pred_idx + step >= len(norm_pred_lines):
|
|
1418
|
+
break
|
|
1419
|
+
else:
|
|
1420
|
+
merged_pred.append(norm_pred_lines[pred_idx + step])
|
|
1421
|
+
merged_pred_flag, continue_flag = judge_pred_merge([norm_gt_lines[gt_idx]], merged_pred)
|
|
1422
|
+
if not merged_pred_flag:
|
|
1423
|
+
break
|
|
1424
|
+
else:
|
|
1425
|
+
step += 1
|
|
1426
|
+
if not continue_flag:
|
|
1427
|
+
break
|
|
1428
|
+
|
|
1429
|
+
check_merge_subset.append(list(range(pred_idx, pred_idx + step)))
|
|
1430
|
+
matched_line = ' '.join([norm_pred_lines[i] for i in range(pred_idx, pred_idx + step)])
|
|
1431
|
+
dist = Levenshtein_distance(norm_gt_lines[gt_idx],
|
|
1432
|
+
matched_line) / max(len(matched_line), len(norm_gt_lines[gt_idx]))
|
|
1433
|
+
merged_dist.append(dist)
|
|
1434
|
+
|
|
1435
|
+
if not merged_dist:
|
|
1436
|
+
subset_certain = []
|
|
1437
|
+
min_cost_idx = ''
|
|
1438
|
+
min_cost = float('inf')
|
|
1439
|
+
else:
|
|
1440
|
+
min_cost = min(merged_dist)
|
|
1441
|
+
min_cost_idx = merged_dist.index(min_cost)
|
|
1442
|
+
subset_certain = check_merge_subset[min_cost_idx]
|
|
1443
|
+
|
|
1444
|
+
merges_gt_dict[gt_idx] = {
|
|
1445
|
+
'merge_subset': check_merge_subset,
|
|
1446
|
+
'merged_cost': merged_dist,
|
|
1447
|
+
'min_cost_idx': min_cost_idx,
|
|
1448
|
+
'subset_certain': subset_certain,
|
|
1449
|
+
'min_cost': min_cost
|
|
1450
|
+
}
|
|
1451
|
+
|
|
1452
|
+
subset_certain = [
|
|
1453
|
+
merges_gt_dict[gt_idx]['subset_certain']
|
|
1454
|
+
for gt_idx in unmasked_gt_idx
|
|
1455
|
+
if merges_gt_dict[gt_idx]['subset_certain']
|
|
1456
|
+
]
|
|
1457
|
+
subset_certain_cost = [
|
|
1458
|
+
merges_gt_dict[gt_idx]['min_cost'] for gt_idx in unmasked_gt_idx if merges_gt_dict[gt_idx]['subset_certain']
|
|
1459
|
+
]
|
|
1460
|
+
|
|
1461
|
+
subset_certain_final = get_final_subset(subset_certain, subset_certain_cost)
|
|
1462
|
+
|
|
1463
|
+
if not subset_certain_final:
|
|
1464
|
+
return cost_matrix, norm_pred_lines, range(len(norm_pred_lines))
|
|
1465
|
+
|
|
1466
|
+
final_pred_idx_list = merge_lists_with_sublists(range(len(norm_pred_lines)), subset_certain_final)
|
|
1467
|
+
final_norm_pred_lines = [
|
|
1468
|
+
' '.join(norm_pred_lines[idx_list[0]:idx_list[-1]
|
|
1469
|
+
+ 1]) if isinstance(idx_list, list) else norm_pred_lines[idx_list]
|
|
1470
|
+
for idx_list in final_pred_idx_list
|
|
1471
|
+
]
|
|
1472
|
+
|
|
1473
|
+
new_cost_matrix = compute_edit_distance_matrix_new(norm_gt_lines, final_norm_pred_lines)
|
|
1474
|
+
|
|
1475
|
+
return new_cost_matrix, final_norm_pred_lines, final_pred_idx_list
|
|
1476
|
+
|
|
1477
|
+
|
|
1478
|
+
def cal_move_dist(gt, pred):
|
|
1479
|
+
assert len(gt) == len(pred), 'Not right length'
|
|
1480
|
+
step = 0
|
|
1481
|
+
for i, gt_c in enumerate(gt):
|
|
1482
|
+
if gt_c != pred[i]:
|
|
1483
|
+
step += abs(i - pred.index(gt_c))
|
|
1484
|
+
pred[i], pred[pred.index(gt_c)] = pred[pred.index(gt_c)], pred[i]
|
|
1485
|
+
return step / len(gt)
|
|
1486
|
+
|
|
1487
|
+
|
|
1488
|
+
def cal_final_match(cost_matrix, norm_gt_lines, norm_pred_lines):
|
|
1489
|
+
min_indice = cost_matrix.argmax(axis=1)
|
|
1490
|
+
|
|
1491
|
+
new_cost_matrix, final_norm_pred_lines, final_pred_idx_list = deal_with_truncated(
|
|
1492
|
+
cost_matrix, norm_gt_lines, norm_pred_lines
|
|
1493
|
+
)
|
|
1494
|
+
|
|
1495
|
+
row_ind, col_ind = linear_sum_assignment(new_cost_matrix)
|
|
1496
|
+
|
|
1497
|
+
cost_list = [new_cost_matrix[r][c] for r, c in zip(row_ind, col_ind)]
|
|
1498
|
+
matched_col_idx = [final_pred_idx_list[i] for i in col_ind]
|
|
1499
|
+
|
|
1500
|
+
return matched_col_idx, row_ind, cost_list
|
|
1501
|
+
|
|
1502
|
+
|
|
1503
|
+
def initialize_indices(norm_gt_lines, norm_pred_lines):
|
|
1504
|
+
gt_lens_dict = {idx: len(gt_line) for idx, gt_line in enumerate(norm_gt_lines)}
|
|
1505
|
+
pred_lens_dict = {idx: len(pred_line) for idx, pred_line in enumerate(norm_pred_lines)}
|
|
1506
|
+
return gt_lens_dict, pred_lens_dict
|
|
1507
|
+
|
|
1508
|
+
|
|
1509
|
+
def process_matches(matched_col_idx, row_ind, cost_list, norm_gt_lines, norm_pred_lines, pred_lines):
|
|
1510
|
+
matches = {}
|
|
1511
|
+
unmatched_gt_indices = []
|
|
1512
|
+
unmatched_pred_indices = []
|
|
1513
|
+
|
|
1514
|
+
for i in range(len(norm_gt_lines)):
|
|
1515
|
+
if i in row_ind:
|
|
1516
|
+
idx = list(row_ind).index(i)
|
|
1517
|
+
pred_idx = matched_col_idx[idx]
|
|
1518
|
+
|
|
1519
|
+
if pred_idx is None or (isinstance(pred_idx, list) and None in pred_idx):
|
|
1520
|
+
unmatched_pred_indices.append(pred_idx)
|
|
1521
|
+
continue
|
|
1522
|
+
|
|
1523
|
+
if isinstance(pred_idx, list):
|
|
1524
|
+
pred_line = ' | '.join(norm_pred_lines[pred_idx[0]:pred_idx[-1] + 1])
|
|
1525
|
+
ori_pred_line = ' | '.join(pred_lines[pred_idx[0]:pred_idx[-1] + 1])
|
|
1526
|
+
matched_pred_indices_range = list(range(pred_idx[0], pred_idx[-1] + 1))
|
|
1527
|
+
else:
|
|
1528
|
+
pred_line = norm_pred_lines[pred_idx]
|
|
1529
|
+
ori_pred_line = pred_lines[pred_idx]
|
|
1530
|
+
matched_pred_indices_range = [pred_idx]
|
|
1531
|
+
|
|
1532
|
+
edit = cost_list[idx]
|
|
1533
|
+
|
|
1534
|
+
if edit > 0.7:
|
|
1535
|
+
unmatched_pred_indices.extend(matched_pred_indices_range)
|
|
1536
|
+
unmatched_gt_indices.append(i)
|
|
1537
|
+
else:
|
|
1538
|
+
matches[i] = {
|
|
1539
|
+
'pred_indices': matched_pred_indices_range,
|
|
1540
|
+
'edit_distance': edit,
|
|
1541
|
+
}
|
|
1542
|
+
for matched_pred_idx in matched_pred_indices_range:
|
|
1543
|
+
if matched_pred_idx in unmatched_pred_indices:
|
|
1544
|
+
unmatched_pred_indices.remove(matched_pred_idx)
|
|
1545
|
+
else:
|
|
1546
|
+
unmatched_gt_indices.append(i)
|
|
1547
|
+
|
|
1548
|
+
return matches, unmatched_gt_indices, unmatched_pred_indices
|
|
1549
|
+
|
|
1550
|
+
|
|
1551
|
+
def fuzzy_match_unmatched_items(unmatched_gt_indices, norm_gt_lines, norm_pred_lines):
|
|
1552
|
+
matching_dict = {}
|
|
1553
|
+
|
|
1554
|
+
for pred_idx, pred_content in enumerate(norm_pred_lines):
|
|
1555
|
+
if isinstance(pred_idx, list):
|
|
1556
|
+
continue
|
|
1557
|
+
|
|
1558
|
+
matching_indices = []
|
|
1559
|
+
|
|
1560
|
+
for unmatched_gt_idx in unmatched_gt_indices:
|
|
1561
|
+
gt_content = norm_gt_lines[unmatched_gt_idx]
|
|
1562
|
+
cur_fuzzy_dist_unmatch, cur_pos, gt_lens, matched_field = sub_gt_fuzzy_matching(pred_content, gt_content)
|
|
1563
|
+
if cur_fuzzy_dist_unmatch < 0.4:
|
|
1564
|
+
matching_indices.append(unmatched_gt_idx)
|
|
1565
|
+
|
|
1566
|
+
if matching_indices:
|
|
1567
|
+
matching_dict[pred_idx] = matching_indices
|
|
1568
|
+
|
|
1569
|
+
return matching_dict
|
|
1570
|
+
|
|
1571
|
+
|
|
1572
|
+
def merge_matches(matches, matching_dict):
|
|
1573
|
+
final_matches = {}
|
|
1574
|
+
processed_gt_indices = set()
|
|
1575
|
+
|
|
1576
|
+
for gt_idx, match_info in matches.items():
|
|
1577
|
+
pred_indices = match_info['pred_indices']
|
|
1578
|
+
edit_distance = match_info['edit_distance']
|
|
1579
|
+
|
|
1580
|
+
pred_key = tuple(sorted(pred_indices))
|
|
1581
|
+
|
|
1582
|
+
if pred_key in final_matches:
|
|
1583
|
+
if gt_idx not in processed_gt_indices:
|
|
1584
|
+
final_matches[pred_key]['gt_indices'].append(gt_idx)
|
|
1585
|
+
processed_gt_indices.add(gt_idx)
|
|
1586
|
+
else:
|
|
1587
|
+
final_matches[pred_key] = {'gt_indices': [gt_idx], 'edit_distance': edit_distance}
|
|
1588
|
+
processed_gt_indices.add(gt_idx)
|
|
1589
|
+
|
|
1590
|
+
for pred_idx, gt_indices in matching_dict.items():
|
|
1591
|
+
pred_key = (pred_idx, ) if not isinstance(pred_idx, (list, tuple)) else tuple(sorted(pred_idx))
|
|
1592
|
+
|
|
1593
|
+
if pred_key in final_matches:
|
|
1594
|
+
for gt_idx in gt_indices:
|
|
1595
|
+
if gt_idx not in processed_gt_indices:
|
|
1596
|
+
final_matches[pred_key]['gt_indices'].append(gt_idx)
|
|
1597
|
+
processed_gt_indices.add(gt_idx)
|
|
1598
|
+
else:
|
|
1599
|
+
final_matches[pred_key] = {
|
|
1600
|
+
'gt_indices': [gt_idx for gt_idx in gt_indices if gt_idx not in processed_gt_indices],
|
|
1601
|
+
'edit_distance': None
|
|
1602
|
+
}
|
|
1603
|
+
processed_gt_indices.update(final_matches[pred_key]['gt_indices'])
|
|
1604
|
+
|
|
1605
|
+
return final_matches
|
|
1606
|
+
|
|
1607
|
+
|
|
1608
|
+
def recalculate_edit_distances(final_matches, gt_lens_dict, norm_gt_lines, norm_pred_lines):
|
|
1609
|
+
for pred_key, info in final_matches.items():
|
|
1610
|
+
gt_indices = sorted(set(info['gt_indices']))
|
|
1611
|
+
|
|
1612
|
+
if not gt_indices:
|
|
1613
|
+
info['edit_distance'] = 1
|
|
1614
|
+
continue
|
|
1615
|
+
|
|
1616
|
+
if len(gt_indices) > 1:
|
|
1617
|
+
merged_gt_content = ''.join(norm_gt_lines[gt_idx] for gt_idx in gt_indices)
|
|
1618
|
+
pred_content = norm_pred_lines[pred_key[0]] if isinstance(pred_key[0], int) else ''
|
|
1619
|
+
|
|
1620
|
+
try:
|
|
1621
|
+
edit_distance = Levenshtein_distance(merged_gt_content, pred_content)
|
|
1622
|
+
normalized_edit_distance = edit_distance / max(len(merged_gt_content), len(pred_content))
|
|
1623
|
+
except ZeroDivisionError:
|
|
1624
|
+
normalized_edit_distance = 1
|
|
1625
|
+
|
|
1626
|
+
info['edit_distance'] = normalized_edit_distance
|
|
1627
|
+
else:
|
|
1628
|
+
gt_idx = gt_indices[0]
|
|
1629
|
+
pred_content = ' '.join(norm_pred_lines[pred_idx] for pred_idx in pred_key if isinstance(pred_idx, int))
|
|
1630
|
+
|
|
1631
|
+
try:
|
|
1632
|
+
edit_distance = Levenshtein_distance(norm_gt_lines[gt_idx], pred_content)
|
|
1633
|
+
normalized_edit_distance = edit_distance / max(len(norm_gt_lines[gt_idx]), len(pred_content))
|
|
1634
|
+
except ZeroDivisionError:
|
|
1635
|
+
normalized_edit_distance = 1
|
|
1636
|
+
|
|
1637
|
+
info['edit_distance'] = normalized_edit_distance
|
|
1638
|
+
info['pred_content'] = pred_content
|
|
1639
|
+
|
|
1640
|
+
|
|
1641
|
+
def convert_final_matches(final_matches, norm_gt_lines, norm_pred_lines):
|
|
1642
|
+
converted_results = []
|
|
1643
|
+
|
|
1644
|
+
all_gt_indices = set(range(len(norm_gt_lines)))
|
|
1645
|
+
all_pred_indices = set(range(len(norm_pred_lines)))
|
|
1646
|
+
|
|
1647
|
+
for pred_key, info in final_matches.items():
|
|
1648
|
+
pred_content = ' '.join(norm_pred_lines[pred_idx] for pred_idx in pred_key if isinstance(pred_idx, int))
|
|
1649
|
+
|
|
1650
|
+
for gt_idx in sorted(set(info['gt_indices'])):
|
|
1651
|
+
result_entry = {
|
|
1652
|
+
'gt_idx': int(gt_idx),
|
|
1653
|
+
'gt': norm_gt_lines[gt_idx],
|
|
1654
|
+
'pred_idx': list(pred_key),
|
|
1655
|
+
'pred': pred_content,
|
|
1656
|
+
'edit': info['edit_distance']
|
|
1657
|
+
}
|
|
1658
|
+
converted_results.append(result_entry)
|
|
1659
|
+
|
|
1660
|
+
matched_gt_indices = set().union(*[set(info['gt_indices']) for info in final_matches.values()])
|
|
1661
|
+
unmatched_gt_indices = all_gt_indices - matched_gt_indices
|
|
1662
|
+
matched_pred_indices = set(idx for pred_key in final_matches.keys() for idx in pred_key if isinstance(idx, int))
|
|
1663
|
+
unmatched_pred_indices = all_pred_indices - matched_pred_indices
|
|
1664
|
+
|
|
1665
|
+
if unmatched_pred_indices:
|
|
1666
|
+
if unmatched_gt_indices:
|
|
1667
|
+
distance_matrix = [[
|
|
1668
|
+
Levenshtein_distance(norm_gt_lines[gt_idx], norm_pred_lines[pred_idx])
|
|
1669
|
+
for pred_idx in unmatched_pred_indices
|
|
1670
|
+
]
|
|
1671
|
+
for gt_idx in unmatched_gt_indices]
|
|
1672
|
+
|
|
1673
|
+
row_ind, col_ind = linear_sum_assignment(distance_matrix)
|
|
1674
|
+
|
|
1675
|
+
for i, j in zip(row_ind, col_ind):
|
|
1676
|
+
gt_idx = list(unmatched_gt_indices)[i]
|
|
1677
|
+
pred_idx = list(unmatched_pred_indices)[j]
|
|
1678
|
+
result_entry = {
|
|
1679
|
+
'gt_idx': int(gt_idx),
|
|
1680
|
+
'gt': norm_gt_lines[gt_idx],
|
|
1681
|
+
'pred_idx': [pred_idx],
|
|
1682
|
+
'pred': norm_pred_lines[pred_idx],
|
|
1683
|
+
'edit': 1
|
|
1684
|
+
}
|
|
1685
|
+
converted_results.append(result_entry)
|
|
1686
|
+
|
|
1687
|
+
matched_gt_indices.update(list(unmatched_gt_indices)[i] for i in row_ind)
|
|
1688
|
+
else:
|
|
1689
|
+
result_entry = {
|
|
1690
|
+
'gt_idx': '',
|
|
1691
|
+
'gt': '',
|
|
1692
|
+
'pred_idx': list(unmatched_pred_indices),
|
|
1693
|
+
'pred': ' '.join(norm_pred_lines[pred_idx] for pred_idx in unmatched_pred_indices),
|
|
1694
|
+
'edit': 1
|
|
1695
|
+
}
|
|
1696
|
+
converted_results.append(result_entry)
|
|
1697
|
+
else:
|
|
1698
|
+
for gt_idx in unmatched_gt_indices:
|
|
1699
|
+
result_entry = {'gt_idx': int(gt_idx), 'gt': norm_gt_lines[gt_idx], 'pred_idx': '', 'pred': '', 'edit': 1}
|
|
1700
|
+
converted_results.append(result_entry)
|
|
1701
|
+
|
|
1702
|
+
return converted_results
|
|
1703
|
+
|
|
1704
|
+
|
|
1705
|
+
import json
|
|
1706
|
+
|
|
1707
|
+
|
|
1708
|
+
def read_md_file(filepath):
|
|
1709
|
+
with open(filepath, 'r', encoding='utf-8') as file:
|
|
1710
|
+
content = file.read()
|
|
1711
|
+
|
|
1712
|
+
return content
|
|
1713
|
+
|
|
1714
|
+
|
|
1715
|
+
def save_paired_result(preds, gts, save_path):
|
|
1716
|
+
save_result = []
|
|
1717
|
+
formula_id = 0
|
|
1718
|
+
for gt, pred in zip(gts, preds):
|
|
1719
|
+
save_result.append({'gt': gt, 'pred': pred, 'img_id': formula_id})
|
|
1720
|
+
formula_id += 1
|
|
1721
|
+
with open(save_path, 'w', encoding='utf-8') as f:
|
|
1722
|
+
json.dump(save_result, f, indent=4, ensure_ascii=False)
|
|
1723
|
+
|
|
1724
|
+
|
|
1725
|
+
import numpy as np
|
|
1726
|
+
import os
|
|
1727
|
+
import re
|
|
1728
|
+
|
|
1729
|
+
|
|
1730
|
+
def print_aligned_dict(data):
|
|
1731
|
+
# Find the maximum length of all keys
|
|
1732
|
+
max_key_length = max(len(key) for key in data['testcase1'])
|
|
1733
|
+
|
|
1734
|
+
# Print header
|
|
1735
|
+
print(f"{' ' * (max_key_length + 4)}", end='')
|
|
1736
|
+
for key in data:
|
|
1737
|
+
print(f'{key:>{max_key_length}}', end='')
|
|
1738
|
+
print()
|
|
1739
|
+
|
|
1740
|
+
# Print dictionary content
|
|
1741
|
+
for subkey in data['testcase1']:
|
|
1742
|
+
print(f'{subkey:<{max_key_length + 4}}', end='')
|
|
1743
|
+
for key in data:
|
|
1744
|
+
print(f'{data[key][subkey]:>{max_key_length}}', end='')
|
|
1745
|
+
print()
|
|
1746
|
+
|
|
1747
|
+
|
|
1748
|
+
def create_dict_from_folders(directory):
|
|
1749
|
+
body = {}
|
|
1750
|
+
for folder_name in os.listdir(directory):
|
|
1751
|
+
folder_path = os.path.join(directory, folder_name)
|
|
1752
|
+
if os.path.isdir(folder_path):
|
|
1753
|
+
body[folder_name] = {}
|
|
1754
|
+
return body
|
|
1755
|
+
|
|
1756
|
+
|
|
1757
|
+
# The function is from https://github.com/intsig-textin/markdown_tester
|
|
1758
|
+
def markdown_to_html(markdown_table):
|
|
1759
|
+
rows = [row.strip() for row in markdown_table.strip().split('\n')]
|
|
1760
|
+
num_columns = len(rows[0].split('|')) - 2
|
|
1761
|
+
|
|
1762
|
+
html_table = '<table>\n <thead>\n <tr>\n'
|
|
1763
|
+
|
|
1764
|
+
header_cells = [cell.strip() for cell in rows[0].split('|')[1:-1]]
|
|
1765
|
+
for cell in header_cells:
|
|
1766
|
+
html_table += f' <th>{cell}</th>\n'
|
|
1767
|
+
html_table += ' </tr>\n </thead>\n <tbody>\n'
|
|
1768
|
+
|
|
1769
|
+
for row in rows[2:]:
|
|
1770
|
+
cells = [cell.strip() for cell in row.split('|')[1:-1]]
|
|
1771
|
+
html_table += ' <tr>\n'
|
|
1772
|
+
for cell in cells:
|
|
1773
|
+
html_table += f' <td>{cell}</td>\n'
|
|
1774
|
+
html_table += ' </tr>\n'
|
|
1775
|
+
|
|
1776
|
+
html_table += ' </tbody>\n</table>\n'
|
|
1777
|
+
return html_table
|
|
1778
|
+
|
|
1779
|
+
|
|
1780
|
+
def convert_markdown_to_html(markdown_content, md_type):
|
|
1781
|
+
# Define a regex pattern to find Markdown tables with newlines
|
|
1782
|
+
markdown_content = markdown_content.replace('\r', '')
|
|
1783
|
+
pattern = re.compile(r'\|\s*.*?\s*\|\n', re.DOTALL)
|
|
1784
|
+
|
|
1785
|
+
# Find all matches in the Markdown content
|
|
1786
|
+
matches = pattern.findall(markdown_content)
|
|
1787
|
+
for match in matches:
|
|
1788
|
+
html_table = markdown_to_html(match)
|
|
1789
|
+
markdown_content = markdown_content.replace(match, html_table, 1) # Only replace the first occurrence
|
|
1790
|
+
res_html = convert_table(replace_table_with_placeholder(markdown_content))
|
|
1791
|
+
|
|
1792
|
+
return res_html
|
|
1793
|
+
|
|
1794
|
+
|
|
1795
|
+
def convert_table_str(s):
|
|
1796
|
+
s = re.sub(r'<table.*?>', '<table>', s)
|
|
1797
|
+
s = re.sub(r'<th', '<td', s)
|
|
1798
|
+
s = re.sub(r'</th>', '</td>', s)
|
|
1799
|
+
# s = re.sub(r'<td rowspan="(.)">',lambda x:f'<td colspan="1" rowspan="{x.group(1)}">',s)
|
|
1800
|
+
# s = re.sub(r'<td colspan="(.)">',lambda x:f'<td colspan="{x.group(1)}" rowspan="1">',s)
|
|
1801
|
+
res = ''
|
|
1802
|
+
res += '\n\n'
|
|
1803
|
+
temp_item = ''
|
|
1804
|
+
for c in s:
|
|
1805
|
+
temp_item += c
|
|
1806
|
+
if c == '>' and not re.search(r'<td.*?>\$', temp_item):
|
|
1807
|
+
res += temp_item + '\n'
|
|
1808
|
+
temp_item = ''
|
|
1809
|
+
return res + '\n'
|
|
1810
|
+
|
|
1811
|
+
|
|
1812
|
+
def merge_table(md):
|
|
1813
|
+
table_temp = ''
|
|
1814
|
+
for line in md:
|
|
1815
|
+
table_temp += line
|
|
1816
|
+
return convert_table_str(table_temp)
|
|
1817
|
+
|
|
1818
|
+
|
|
1819
|
+
def find_md_table_mode(line):
|
|
1820
|
+
if re.search(r'-*?:', line) or re.search(r'---', line) or re.search(r':-*?', line):
|
|
1821
|
+
return True
|
|
1822
|
+
return False
|
|
1823
|
+
|
|
1824
|
+
|
|
1825
|
+
def delete_table_and_body(input_list):
|
|
1826
|
+
res = []
|
|
1827
|
+
for line in input_list:
|
|
1828
|
+
if not re.search(r'</?t(able|head|body)>', line):
|
|
1829
|
+
res.append(line)
|
|
1830
|
+
return res
|
|
1831
|
+
|
|
1832
|
+
|
|
1833
|
+
def merge_tables(input_str):
|
|
1834
|
+
# Delete HTML comments
|
|
1835
|
+
input_str = re.sub(r'<!--[\s\S]*?-->', '', input_str)
|
|
1836
|
+
|
|
1837
|
+
# Use regex to find each <table> block
|
|
1838
|
+
table_blocks = re.findall(r'<table>[\s\S]*?</table>', input_str)
|
|
1839
|
+
|
|
1840
|
+
# Process each <table> block, replace <th> with <td>
|
|
1841
|
+
output_lines = []
|
|
1842
|
+
for block in table_blocks:
|
|
1843
|
+
block_lines = block.split('\n')
|
|
1844
|
+
for i, line in enumerate(block_lines):
|
|
1845
|
+
if '<th>' in line:
|
|
1846
|
+
block_lines[i] = line.replace('<th>', '<td>').replace('</th>', '</td>')
|
|
1847
|
+
final_tr = delete_table_and_body(block_lines)
|
|
1848
|
+
if len(final_tr) > 2:
|
|
1849
|
+
output_lines.extend(final_tr) # Ignore <table> and </table> tags, keep only table content
|
|
1850
|
+
|
|
1851
|
+
# Rejoin the processed strings
|
|
1852
|
+
merged_output = '<table>\n{}\n</table>'.format('\n'.join(output_lines))
|
|
1853
|
+
|
|
1854
|
+
return '\n\n' + merged_output + '\n\n'
|
|
1855
|
+
|
|
1856
|
+
|
|
1857
|
+
def replace_table_with_placeholder(input_string):
|
|
1858
|
+
lines = input_string.split('\n')
|
|
1859
|
+
output_lines = []
|
|
1860
|
+
|
|
1861
|
+
in_table_block = False
|
|
1862
|
+
temp_block = ''
|
|
1863
|
+
last_line = ''
|
|
1864
|
+
|
|
1865
|
+
org_table_list = []
|
|
1866
|
+
in_org_table = False
|
|
1867
|
+
|
|
1868
|
+
for idx, line in enumerate(lines):
|
|
1869
|
+
# if not in_org_table:
|
|
1870
|
+
# if "<table>" not in last_line and in_table_block == False and temp_block != "":
|
|
1871
|
+
# output_lines.append(merge_tables(temp_block))
|
|
1872
|
+
# temp_block = ""
|
|
1873
|
+
if '<table>' in line:
|
|
1874
|
+
# if "<table><tr" in line:
|
|
1875
|
+
# org_table_list.append(line)
|
|
1876
|
+
# in_org_table = True
|
|
1877
|
+
# output_lines.append(last_line)
|
|
1878
|
+
# continue
|
|
1879
|
+
# else:
|
|
1880
|
+
in_table_block = True
|
|
1881
|
+
temp_block += last_line
|
|
1882
|
+
elif in_table_block:
|
|
1883
|
+
if not find_md_table_mode(last_line) and '</thead>' not in last_line:
|
|
1884
|
+
temp_block += '\n' + last_line
|
|
1885
|
+
if '</table>' in last_line:
|
|
1886
|
+
if '<table>' not in line:
|
|
1887
|
+
in_table_block = False
|
|
1888
|
+
output_lines.append(merge_tables(temp_block))
|
|
1889
|
+
temp_block = ''
|
|
1890
|
+
else:
|
|
1891
|
+
output_lines.append(last_line)
|
|
1892
|
+
|
|
1893
|
+
last_line = line
|
|
1894
|
+
# else:
|
|
1895
|
+
# org_table_list.append(line)
|
|
1896
|
+
# if "</table" in line:
|
|
1897
|
+
# in_org_table = False
|
|
1898
|
+
# last_line = merge_table(org_table_list)
|
|
1899
|
+
# org_table_list = []
|
|
1900
|
+
|
|
1901
|
+
if last_line:
|
|
1902
|
+
if in_table_block or '</table>' in last_line:
|
|
1903
|
+
temp_block += '\n' + last_line
|
|
1904
|
+
output_lines.append(merge_tables(temp_block))
|
|
1905
|
+
else:
|
|
1906
|
+
output_lines.append(last_line)
|
|
1907
|
+
# if "</table>" in last_line:
|
|
1908
|
+
# output_lines.append(merge_tables(temp_block))
|
|
1909
|
+
|
|
1910
|
+
return '\n'.join(output_lines)
|
|
1911
|
+
|
|
1912
|
+
|
|
1913
|
+
def convert_table(input_str):
|
|
1914
|
+
# Replace <table>
|
|
1915
|
+
output_str = input_str.replace('<table>', "<table border=\"1\" >")
|
|
1916
|
+
|
|
1917
|
+
# Replace <td>
|
|
1918
|
+
output_str = output_str.replace('<td>', "<td colspan=\"1\" rowspan=\"1\">")
|
|
1919
|
+
|
|
1920
|
+
return output_str
|
|
1921
|
+
|
|
1922
|
+
|
|
1923
|
+
def convert_markdown_to_html(markdown_content):
|
|
1924
|
+
# Define a regex pattern to find Markdown tables with newlines
|
|
1925
|
+
markdown_content = markdown_content.replace('\r', '') + '\n'
|
|
1926
|
+
pattern = re.compile(r'\|\s*.*?\s*\|\n', re.DOTALL)
|
|
1927
|
+
|
|
1928
|
+
# Find all matches in the Markdown content
|
|
1929
|
+
matches = pattern.findall(markdown_content)
|
|
1930
|
+
|
|
1931
|
+
for match in matches:
|
|
1932
|
+
html_table = markdown_to_html(match)
|
|
1933
|
+
markdown_content = markdown_content.replace(match, html_table, 1) # Only replace the first occurrence
|
|
1934
|
+
|
|
1935
|
+
res_html = convert_table(replace_table_with_placeholder(markdown_content))
|
|
1936
|
+
|
|
1937
|
+
return res_html
|