evalscope 1.0.2__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/api/benchmark/__init__.py +8 -1
- evalscope/api/benchmark/adapters/__init__.py +1 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +12 -0
- evalscope/api/benchmark/adapters/ner_adapter.py +212 -0
- evalscope/api/benchmark/benchmark.py +14 -0
- evalscope/api/dataset/dataset.py +21 -0
- evalscope/api/dataset/loader.py +6 -2
- evalscope/api/mixin/sandbox_mixin.py +32 -54
- evalscope/api/model/generate_config.py +6 -0
- evalscope/app/ui/multi_model.py +6 -1
- evalscope/app/ui/single_model.py +8 -2
- evalscope/app/utils/data_utils.py +3 -2
- evalscope/app/utils/visualization.py +2 -2
- evalscope/benchmarks/aa_lcr/aa_lcr_adapter.py +205 -0
- evalscope/benchmarks/ai2d/ai2d_adapter.py +3 -2
- evalscope/benchmarks/bfcl/bfcl_adapter.py +11 -46
- evalscope/benchmarks/blink/__init__.py +0 -0
- evalscope/benchmarks/blink/blink_adapter.py +61 -0
- evalscope/benchmarks/chartqa/__init__.py +0 -0
- evalscope/benchmarks/chartqa/chartqa_adapter.py +80 -0
- evalscope/benchmarks/chartqa/utils.py +38 -0
- evalscope/benchmarks/data_collection/data_collection_adapter.py +2 -1
- evalscope/benchmarks/docvqa/__init__.py +0 -0
- evalscope/benchmarks/docvqa/docvqa_adapter.py +67 -0
- evalscope/benchmarks/general_arena/general_arena_adapter.py +1 -1
- evalscope/benchmarks/general_arena/utils.py +2 -1
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +1 -1
- evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +23 -4
- evalscope/benchmarks/hallusion_bench/__init__.py +0 -0
- evalscope/benchmarks/hallusion_bench/hallusion_bench_adapter.py +158 -0
- evalscope/benchmarks/hle/hle_adapter.py +3 -2
- evalscope/benchmarks/humaneval/humaneval_adapter.py +2 -1
- evalscope/benchmarks/infovqa/__init__.py +0 -0
- evalscope/benchmarks/infovqa/infovqa_adapter.py +66 -0
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +3 -1
- evalscope/benchmarks/math_verse/__init__.py +0 -0
- evalscope/benchmarks/math_verse/math_verse_adapter.py +100 -0
- evalscope/benchmarks/math_vision/__init__.py +0 -0
- evalscope/benchmarks/math_vision/math_vision_adapter.py +111 -0
- evalscope/benchmarks/math_vista/math_vista_adapter.py +6 -26
- evalscope/benchmarks/mm_bench/mm_bench_adapter.py +2 -2
- evalscope/benchmarks/mmmu/mmmu_adapter.py +1 -1
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +1 -1
- evalscope/benchmarks/ner/__init__.py +0 -0
- evalscope/benchmarks/ner/broad_twitter_corpus_adapter.py +52 -0
- evalscope/benchmarks/ner/conll2003_adapter.py +48 -0
- evalscope/benchmarks/ner/copious_adapter.py +85 -0
- evalscope/benchmarks/ner/cross_ner_adapter.py +120 -0
- evalscope/benchmarks/ner/cross_ner_entities/__init__.py +0 -0
- evalscope/benchmarks/ner/cross_ner_entities/ai.py +54 -0
- evalscope/benchmarks/ner/cross_ner_entities/literature.py +36 -0
- evalscope/benchmarks/ner/cross_ner_entities/music.py +39 -0
- evalscope/benchmarks/ner/cross_ner_entities/politics.py +37 -0
- evalscope/benchmarks/ner/cross_ner_entities/science.py +58 -0
- evalscope/benchmarks/ner/genia_ner_adapter.py +66 -0
- evalscope/benchmarks/ner/harvey_ner_adapter.py +58 -0
- evalscope/benchmarks/ner/mit_movie_trivia_adapter.py +74 -0
- evalscope/benchmarks/ner/mit_restaurant_adapter.py +66 -0
- evalscope/benchmarks/ner/ontonotes5_adapter.py +87 -0
- evalscope/benchmarks/ner/wnut2017_adapter.py +61 -0
- evalscope/benchmarks/ocr_bench/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_adapter.py +101 -0
- evalscope/benchmarks/ocr_bench_v2/IoUscore_metric.py +87 -0
- evalscope/benchmarks/ocr_bench_v2/TEDS_metric.py +963 -0
- evalscope/benchmarks/ocr_bench_v2/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench_v2/ocr_bench_v2_adapter.py +161 -0
- evalscope/benchmarks/ocr_bench_v2/page_ocr_metric.py +50 -0
- evalscope/benchmarks/ocr_bench_v2/parallel.py +46 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/readme.txt +26 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/rrc_evaluation_funcs_1_1.py +537 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/script.py +481 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_metric.py +179 -0
- evalscope/benchmarks/ocr_bench_v2/utils.py +433 -0
- evalscope/benchmarks/ocr_bench_v2/vqa_metric.py +254 -0
- evalscope/benchmarks/omnidoc_bench/__init__.py +0 -0
- evalscope/benchmarks/omnidoc_bench/end2end_eval.py +349 -0
- evalscope/benchmarks/omnidoc_bench/metrics.py +547 -0
- evalscope/benchmarks/omnidoc_bench/omnidoc_bench_adapter.py +135 -0
- evalscope/benchmarks/omnidoc_bench/utils.py +1937 -0
- evalscope/benchmarks/poly_math/__init__.py +0 -0
- evalscope/benchmarks/poly_math/poly_math_adapter.py +127 -0
- evalscope/benchmarks/poly_math/utils/instruction.py +105 -0
- evalscope/benchmarks/pope/__init__.py +0 -0
- evalscope/benchmarks/pope/pope_adapter.py +111 -0
- evalscope/benchmarks/seed_bench_2_plus/__init__.py +0 -0
- evalscope/benchmarks/seed_bench_2_plus/seed_bench_2_plus_adapter.py +72 -0
- evalscope/benchmarks/simple_vqa/__init__.py +0 -0
- evalscope/benchmarks/simple_vqa/simple_vqa_adapter.py +169 -0
- evalscope/benchmarks/tau_bench/tau_bench_adapter.py +1 -1
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +1 -1
- evalscope/benchmarks/visu_logic/__init__.py +0 -0
- evalscope/benchmarks/visu_logic/visu_logic_adapter.py +75 -0
- evalscope/benchmarks/zerobench/__init__.py +0 -0
- evalscope/benchmarks/zerobench/zerobench_adapter.py +64 -0
- evalscope/constants.py +4 -0
- evalscope/evaluator/evaluator.py +72 -79
- evalscope/metrics/math_parser.py +14 -0
- evalscope/metrics/metric.py +52 -1
- evalscope/metrics/metrics.py +16 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +2 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +2 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +2 -6
- evalscope/models/utils/openai.py +4 -0
- evalscope/perf/arguments.py +24 -4
- evalscope/perf/benchmark.py +74 -89
- evalscope/perf/http_client.py +31 -16
- evalscope/perf/main.py +15 -2
- evalscope/perf/plugin/api/base.py +9 -7
- evalscope/perf/plugin/api/custom_api.py +13 -58
- evalscope/perf/plugin/api/default_api.py +179 -79
- evalscope/perf/plugin/api/openai_api.py +4 -3
- evalscope/perf/plugin/datasets/base.py +21 -0
- evalscope/perf/plugin/datasets/custom.py +2 -3
- evalscope/perf/plugin/datasets/line_by_line.py +2 -3
- evalscope/perf/plugin/datasets/longalpaca.py +2 -3
- evalscope/perf/plugin/datasets/openqa.py +2 -4
- evalscope/perf/plugin/datasets/random_dataset.py +1 -3
- evalscope/perf/utils/benchmark_util.py +36 -22
- evalscope/perf/utils/db_util.py +14 -19
- evalscope/perf/utils/local_server.py +0 -44
- evalscope/perf/utils/log_utils.py +21 -6
- evalscope/report/__init__.py +11 -2
- evalscope/report/combinator.py +52 -2
- evalscope/run.py +4 -0
- evalscope/utils/function_utils.py +195 -12
- evalscope/utils/io_utils.py +74 -0
- evalscope/utils/json_schema.py +8 -6
- evalscope/utils/logger.py +49 -17
- evalscope/utils/multi_choices.py +16 -1
- evalscope/utils/ner.py +377 -0
- evalscope/version.py +2 -2
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/METADATA +239 -393
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/RECORD +140 -98
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/WHEEL +1 -1
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/top_level.txt +0 -1
- tests/__init__.py +0 -1
- tests/benchmark/__init__.py +0 -1
- tests/benchmark/test_eval.py +0 -429
- tests/benchmark/test_image_edit.py +0 -65
- tests/benchmark/test_sandbox.py +0 -81
- tests/benchmark/test_t2i.py +0 -142
- tests/benchmark/test_vlm.py +0 -137
- tests/cli/__init__.py +0 -1
- tests/cli/test_all.py +0 -269
- tests/cli/test_collection.py +0 -99
- tests/cli/test_custom.py +0 -268
- tests/cli/test_reasoning.py +0 -81
- tests/common.py +0 -73
- tests/perf/__init__.py +0 -1
- tests/perf/test_perf.py +0 -206
- tests/rag/test_clip_benchmark.py +0 -87
- tests/rag/test_mteb.py +0 -213
- tests/rag/test_ragas.py +0 -128
- tests/swift/__init__.py +0 -1
- tests/swift/test_run_swift_eval.py +0 -146
- tests/swift/test_run_swift_vlm_eval.py +0 -128
- tests/swift/test_run_swift_vlm_jugde_eval.py +0 -157
- tests/test_run_all.py +0 -12
- tests/utils.py +0 -13
- tests/vlm/__init__.py +0 -1
- tests/vlm/test_vlmeval.py +0 -102
- {tests/rag → evalscope/benchmarks/aa_lcr}/__init__.py +0 -0
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/entry_points.txt +0 -0
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,963 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
# Copyright 2020 IBM
|
|
3
|
+
# Author: peter.zhong@au1.ibm.com
|
|
4
|
+
#
|
|
5
|
+
# This is free software; you can redistribute it and/or modify
|
|
6
|
+
# it under the terms of the Apache 2.0 License.
|
|
7
|
+
#
|
|
8
|
+
# This software is distributed in the hope that it will be useful,
|
|
9
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
10
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
11
|
+
# Apache 2.0 License for more details.
|
|
12
|
+
|
|
13
|
+
import ast
|
|
14
|
+
import distance
|
|
15
|
+
import editdistance
|
|
16
|
+
import json
|
|
17
|
+
import Levenshtein
|
|
18
|
+
import numpy as np
|
|
19
|
+
import re
|
|
20
|
+
import string
|
|
21
|
+
from apted import APTED, Config
|
|
22
|
+
from apted.helpers import Tree
|
|
23
|
+
from collections import deque
|
|
24
|
+
from itertools import product
|
|
25
|
+
from lxml import etree, html
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
from typing import Any, Callable, Optional, Sequence
|
|
28
|
+
from zss import Node, simple_distance
|
|
29
|
+
|
|
30
|
+
from .parallel import parallel_process
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TableTree(Tree):
|
|
34
|
+
|
|
35
|
+
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
|
36
|
+
self.tag = tag
|
|
37
|
+
self.colspan = colspan
|
|
38
|
+
self.rowspan = rowspan
|
|
39
|
+
self.content = content
|
|
40
|
+
self.children = list(children)
|
|
41
|
+
|
|
42
|
+
def bracket(self):
|
|
43
|
+
"""Show tree using brackets notation"""
|
|
44
|
+
if self.tag == 'td':
|
|
45
|
+
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
|
|
46
|
+
self.tag, self.colspan, self.rowspan, self.content
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
result = '"tag": %s' % self.tag
|
|
50
|
+
for child in self.children:
|
|
51
|
+
result += child.bracket()
|
|
52
|
+
return '{{{}}}'.format(result)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class CustomConfig(Config):
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def maximum(*sequences):
|
|
59
|
+
"""Get maximum possible value"""
|
|
60
|
+
return max(map(len, sequences))
|
|
61
|
+
|
|
62
|
+
def normalized_distance(self, *sequences):
|
|
63
|
+
"""Get distance from 0 to 1"""
|
|
64
|
+
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
|
65
|
+
|
|
66
|
+
def rename(self, node1, node2):
|
|
67
|
+
"""Compares attributes of trees"""
|
|
68
|
+
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
|
69
|
+
return 1.0
|
|
70
|
+
if node1.tag == 'td':
|
|
71
|
+
if node1.content or node2.content:
|
|
72
|
+
return self.normalized_distance(node1.content, node2.content)
|
|
73
|
+
return 0.0
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class TEDS(object):
|
|
77
|
+
"""Tree Edit Distance basead Similarity"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
|
80
|
+
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
|
81
|
+
self.structure_only = structure_only
|
|
82
|
+
self.n_jobs = n_jobs
|
|
83
|
+
self.ignore_nodes = ignore_nodes
|
|
84
|
+
self.__tokens__ = []
|
|
85
|
+
|
|
86
|
+
def tokenize(self, node):
|
|
87
|
+
"""Tokenizes table cells"""
|
|
88
|
+
self.__tokens__.append('<%s>' % node.tag)
|
|
89
|
+
if node.text is not None:
|
|
90
|
+
self.__tokens__ += list(node.text)
|
|
91
|
+
for n in node.getchildren():
|
|
92
|
+
self.tokenize(n)
|
|
93
|
+
if node.tag != 'unk':
|
|
94
|
+
self.__tokens__.append('</%s>' % node.tag)
|
|
95
|
+
if node.tag != 'td' and node.tail is not None:
|
|
96
|
+
self.__tokens__ += list(node.tail)
|
|
97
|
+
|
|
98
|
+
def load_html_tree(self, node, parent=None):
|
|
99
|
+
"""Converts HTML tree to the format required by apted"""
|
|
100
|
+
global __tokens__
|
|
101
|
+
if node.tag == 'td':
|
|
102
|
+
if self.structure_only:
|
|
103
|
+
cell = []
|
|
104
|
+
else:
|
|
105
|
+
self.__tokens__ = []
|
|
106
|
+
self.tokenize(node)
|
|
107
|
+
cell = self.__tokens__[1:-1].copy()
|
|
108
|
+
new_node = TableTree(
|
|
109
|
+
node.tag, int(node.attrib.get('colspan', '1')), int(node.attrib.get('rowspan', '1')), cell, *deque()
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
new_node = TableTree(node.tag, None, None, None, *deque())
|
|
113
|
+
if parent is not None:
|
|
114
|
+
parent.children.append(new_node)
|
|
115
|
+
if node.tag != 'td':
|
|
116
|
+
for n in node.getchildren():
|
|
117
|
+
self.load_html_tree(n, new_node)
|
|
118
|
+
if parent is None:
|
|
119
|
+
return new_node
|
|
120
|
+
|
|
121
|
+
def evaluate(self, pred, true):
|
|
122
|
+
"""Computes TEDS score between the prediction and the ground truth of a
|
|
123
|
+
given sample
|
|
124
|
+
"""
|
|
125
|
+
if (not pred) or (not true):
|
|
126
|
+
return 0.0
|
|
127
|
+
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
|
128
|
+
pred = html.fromstring(pred, parser=parser)
|
|
129
|
+
true = html.fromstring(true, parser=parser)
|
|
130
|
+
# print("pred:",pred)
|
|
131
|
+
# print("true:",true)
|
|
132
|
+
if pred.xpath('body/table') and true.xpath('body/table'):
|
|
133
|
+
pred = pred.xpath('body/table')[0]
|
|
134
|
+
true = true.xpath('body/table')[0]
|
|
135
|
+
if self.ignore_nodes:
|
|
136
|
+
etree.strip_tags(pred, *self.ignore_nodes)
|
|
137
|
+
etree.strip_tags(true, *self.ignore_nodes)
|
|
138
|
+
n_nodes_pred = len(pred.xpath('.//*'))
|
|
139
|
+
n_nodes_true = len(true.xpath('.//*'))
|
|
140
|
+
n_nodes = max(n_nodes_pred, n_nodes_true)
|
|
141
|
+
tree_pred = self.load_html_tree(pred)
|
|
142
|
+
tree_true = self.load_html_tree(true)
|
|
143
|
+
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
|
|
144
|
+
return 1.0 - (float(distance) / n_nodes)
|
|
145
|
+
else:
|
|
146
|
+
return 0.0
|
|
147
|
+
|
|
148
|
+
def batch_evaluate(self, pred_json, true_json):
|
|
149
|
+
"""Computes TEDS score between the prediction and the ground truth of
|
|
150
|
+
a batch of samples
|
|
151
|
+
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
|
152
|
+
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
|
153
|
+
@output: {'FILENAME': 'TEDS SCORE', ...}
|
|
154
|
+
"""
|
|
155
|
+
samples = true_json.keys()
|
|
156
|
+
if self.n_jobs == 1:
|
|
157
|
+
scores = [
|
|
158
|
+
self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)
|
|
159
|
+
]
|
|
160
|
+
else:
|
|
161
|
+
# inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
|
162
|
+
inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]} for filename in samples]
|
|
163
|
+
scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
|
164
|
+
scores = dict(zip(samples, scores))
|
|
165
|
+
return scores
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def convert_table_to_html_str(table_row_list=[]):
|
|
169
|
+
"""
|
|
170
|
+
Given a list of table rows, build the corresponding html string, which is used to compute the TEDS score.
|
|
171
|
+
We use the official code of PubTabNet to compute TEDS score, it does not consider '<th>' label.
|
|
172
|
+
We also remove unneccessary spaces within a table cell and extra '\n' as they will influence the TEDS score.
|
|
173
|
+
"""
|
|
174
|
+
html_table_str = '<html><body><table>' + '\n'
|
|
175
|
+
for data_row in table_row_list:
|
|
176
|
+
html_table_str += '<tr>'
|
|
177
|
+
for cell_str in data_row:
|
|
178
|
+
html_table_str += f'<td>{cell_str}</td>'
|
|
179
|
+
html_table_str += '</tr>'
|
|
180
|
+
html_table_str += '\n'
|
|
181
|
+
html_table_str += '</table></body></html>'
|
|
182
|
+
html_table_str = html_table_str.replace('\n', '')
|
|
183
|
+
return html_table_str
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def convert_markdown_table_to_html(markdown_table):
|
|
187
|
+
"""
|
|
188
|
+
Converts a markdown table to the corresponding html string for TEDS computation.
|
|
189
|
+
"""
|
|
190
|
+
# remove extra code block tokens like '```markdown' and '```
|
|
191
|
+
markdown_table = markdown_table.strip('```markdown').strip('```').strip()
|
|
192
|
+
row_str_list = markdown_table.split('\n')
|
|
193
|
+
# extra the first header row and other data rows
|
|
194
|
+
valid_row_str_list = [row_str_list[0]] + row_str_list[2:]
|
|
195
|
+
table_rows = []
|
|
196
|
+
for row_str in valid_row_str_list:
|
|
197
|
+
one_row = []
|
|
198
|
+
for cell in row_str.strip().split('|')[1:-1]:
|
|
199
|
+
if set(cell) != set(' '):
|
|
200
|
+
one_row.append(cell.strip())
|
|
201
|
+
else:
|
|
202
|
+
one_row.append(' ')
|
|
203
|
+
table_rows.append(one_row)
|
|
204
|
+
# build html string based on table rows
|
|
205
|
+
html_str = convert_table_to_html_str(table_rows)
|
|
206
|
+
return html_str
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def dict_to_html(data):
|
|
210
|
+
html = '<html><body><table>\n'
|
|
211
|
+
for key, value in data.items():
|
|
212
|
+
if not isinstance(value, str):
|
|
213
|
+
value = str(value)
|
|
214
|
+
value_str = ' '.join(value)
|
|
215
|
+
|
|
216
|
+
html += f' <tr><td>{key}</td><td>{value_str}</td></tr>\n'
|
|
217
|
+
html += '</table></body></html>'
|
|
218
|
+
return html
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def convert_str_to_dict(predict_str: str):
|
|
222
|
+
"""
|
|
223
|
+
Parses the 'predict' string and returns a dictionary.
|
|
224
|
+
Missing or unparseable content is handled gracefully.
|
|
225
|
+
|
|
226
|
+
Parameters:
|
|
227
|
+
- predict_str (str): The prediction string containing the output dict.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
- dict: A dictionary extracted from the predict string.
|
|
231
|
+
"""
|
|
232
|
+
# Remove code fences like ```python\n...\n```
|
|
233
|
+
code_fence_pattern = r'```(?:python|json)?\n(.*?)\n```'
|
|
234
|
+
match = re.search(code_fence_pattern, predict_str, re.DOTALL | re.IGNORECASE)
|
|
235
|
+
if match:
|
|
236
|
+
content = match.group(1)
|
|
237
|
+
else:
|
|
238
|
+
content = predict_str.strip()
|
|
239
|
+
|
|
240
|
+
data = {}
|
|
241
|
+
success = False
|
|
242
|
+
|
|
243
|
+
# try parsing with JSON
|
|
244
|
+
try:
|
|
245
|
+
data = json.loads(content)
|
|
246
|
+
success = True
|
|
247
|
+
except json.JSONDecodeError:
|
|
248
|
+
pass
|
|
249
|
+
|
|
250
|
+
# try parsing with ast.literal_eval
|
|
251
|
+
if not success:
|
|
252
|
+
try:
|
|
253
|
+
data = ast.literal_eval(content)
|
|
254
|
+
if isinstance(data, dict):
|
|
255
|
+
success = True
|
|
256
|
+
except (ValueError, SyntaxError):
|
|
257
|
+
pass
|
|
258
|
+
|
|
259
|
+
# try parsing with regex
|
|
260
|
+
if not success:
|
|
261
|
+
key_value_pattern = r'["\']?([\w\s]+)["\']?\s*[:=]\s*["\']?([^\n,"\'{}]+)["\']?'
|
|
262
|
+
matches = re.findall(key_value_pattern, content)
|
|
263
|
+
try:
|
|
264
|
+
for key, value in matches:
|
|
265
|
+
data[key.strip()] = value.strip()
|
|
266
|
+
except:
|
|
267
|
+
return {}
|
|
268
|
+
|
|
269
|
+
if not data:
|
|
270
|
+
return {}
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
result = {k.strip(): str(v).strip() for k, v in data.items()}
|
|
274
|
+
except:
|
|
275
|
+
return {}
|
|
276
|
+
return result
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def convert_str_to_multi_dict(predict_str: str):
|
|
280
|
+
"""
|
|
281
|
+
Parses the 'predict' string and returns a dictionary.
|
|
282
|
+
Handles nested dictionaries and missing or unparseable content gracefully.
|
|
283
|
+
|
|
284
|
+
Parameters:
|
|
285
|
+
- predict_str (str): The prediction string containing the output dict.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
- dict: A dictionary extracted from the predict string.
|
|
289
|
+
"""
|
|
290
|
+
# Remove code fences like ```python\n...\n```
|
|
291
|
+
code_fence_pattern = r'```(?:python|json)?\n(.*?)\n```'
|
|
292
|
+
matches = re.findall(code_fence_pattern, predict_str, re.DOTALL | re.IGNORECASE)
|
|
293
|
+
if matches:
|
|
294
|
+
content = max(matches, key=len)
|
|
295
|
+
else:
|
|
296
|
+
content = predict_str.strip()
|
|
297
|
+
|
|
298
|
+
def strip_variable_assignment(s):
|
|
299
|
+
variable_assignment_pattern = r'^\s*\w+\s*=\s*'
|
|
300
|
+
return re.sub(variable_assignment_pattern, '', s.strip(), count=1)
|
|
301
|
+
|
|
302
|
+
content = strip_variable_assignment(content)
|
|
303
|
+
|
|
304
|
+
def remove_comments(s):
|
|
305
|
+
return re.sub(r'#.*', '', s)
|
|
306
|
+
|
|
307
|
+
content = remove_comments(content)
|
|
308
|
+
|
|
309
|
+
last_brace_pos = content.rfind('}')
|
|
310
|
+
if last_brace_pos != -1:
|
|
311
|
+
content = content[:last_brace_pos + 1]
|
|
312
|
+
|
|
313
|
+
data = {}
|
|
314
|
+
success = False
|
|
315
|
+
|
|
316
|
+
# try parsing with ast.literal_eval
|
|
317
|
+
try:
|
|
318
|
+
data = ast.literal_eval(content)
|
|
319
|
+
if isinstance(data, dict):
|
|
320
|
+
success = True
|
|
321
|
+
except (ValueError, SyntaxError, TypeError):
|
|
322
|
+
pass
|
|
323
|
+
|
|
324
|
+
if not success:
|
|
325
|
+
return {}
|
|
326
|
+
|
|
327
|
+
def process_data(obj):
|
|
328
|
+
if isinstance(obj, dict):
|
|
329
|
+
return {k: process_data(v) for k, v in obj.items()}
|
|
330
|
+
elif isinstance(obj, list):
|
|
331
|
+
return [process_data(elem) for elem in obj]
|
|
332
|
+
else:
|
|
333
|
+
return obj
|
|
334
|
+
|
|
335
|
+
data = process_data(data)
|
|
336
|
+
|
|
337
|
+
return data
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def generate_combinations(input_dict):
|
|
341
|
+
"""
|
|
342
|
+
Function to generate all possible combinations of values from a dictionary.
|
|
343
|
+
"""
|
|
344
|
+
kie_answer = input_dict
|
|
345
|
+
if not isinstance(kie_answer, dict):
|
|
346
|
+
kie_answer = kie_answer.strip('"')
|
|
347
|
+
try:
|
|
348
|
+
kie_answer = json.loads(kie_answer)
|
|
349
|
+
except json.JSONDecodeError:
|
|
350
|
+
try:
|
|
351
|
+
kie_answer = ast.literal_eval(kie_answer)
|
|
352
|
+
if not isinstance(kie_answer, dict):
|
|
353
|
+
kie_answer = ast.literal_eval(kie_answer)
|
|
354
|
+
except (ValueError, SyntaxError):
|
|
355
|
+
print(f"Unable to parse 'answers' field: {kie_answer}")
|
|
356
|
+
return {}
|
|
357
|
+
|
|
358
|
+
# Ensure the parsed result is a dictionary.
|
|
359
|
+
if not isinstance(kie_answer, dict):
|
|
360
|
+
print("Parsed 'answers' is still not a dictionary.")
|
|
361
|
+
raise ValueError('Input could not be parsed into a dictionary.')
|
|
362
|
+
|
|
363
|
+
keys = list(kie_answer.keys())
|
|
364
|
+
|
|
365
|
+
value_lists = []
|
|
366
|
+
for single_key in keys:
|
|
367
|
+
sinlge_value = kie_answer[single_key]
|
|
368
|
+
if not isinstance(sinlge_value, list):
|
|
369
|
+
sinlge_value = [sinlge_value]
|
|
370
|
+
value_lists.append(sinlge_value)
|
|
371
|
+
|
|
372
|
+
# Compute the Cartesian product of the value lists.
|
|
373
|
+
combinations = list(product(*value_lists))
|
|
374
|
+
|
|
375
|
+
# Create a dictionary for each combination of values.
|
|
376
|
+
result = [dict(zip(keys, values)) for values in combinations]
|
|
377
|
+
|
|
378
|
+
return result
|
|
379
|
+
|
|
380
|
+
else:
|
|
381
|
+
keys = list(input_dict.keys())
|
|
382
|
+
value_lists = [input_dict[key] for key in keys]
|
|
383
|
+
|
|
384
|
+
# Compute the Cartesian product of the value lists.
|
|
385
|
+
combinations = list(product(*value_lists))
|
|
386
|
+
|
|
387
|
+
# Create a dictionary for each combination of values.
|
|
388
|
+
result = [dict(zip(keys, values)) for values in combinations]
|
|
389
|
+
|
|
390
|
+
return result
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def compute_f1_score(preds, gts, ignores=[]):
|
|
394
|
+
"""Compute the F1-score for KIE task between predicted and ground truth dictionaries.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
preds (dict): The predicted key-value pairs.
|
|
398
|
+
gts (dict): The ground truth key-value pairs.
|
|
399
|
+
ignores (list): The list of keys to ignore during evaluation.
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
dict: A dictionary where keys are field names and values are their corresponding F1-scores.
|
|
403
|
+
"""
|
|
404
|
+
# Optionally remove ignored keys from predictions and ground truths
|
|
405
|
+
keys = set(preds.keys()).union(set(gts.keys())) - set(ignores)
|
|
406
|
+
f1_scores = {}
|
|
407
|
+
|
|
408
|
+
for key in keys:
|
|
409
|
+
pred_value = preds.get(key, None)
|
|
410
|
+
gt_value = gts.get(key, None)
|
|
411
|
+
|
|
412
|
+
if pred_value:
|
|
413
|
+
pred_value = pred_value.lower().strip().replace('\n', ' ').replace(' ', '')
|
|
414
|
+
if gt_value:
|
|
415
|
+
gt_value = gt_value.lower().strip().replace('\n', ' ').replace(' ', '')
|
|
416
|
+
|
|
417
|
+
if pred_value is None and gt_value is None:
|
|
418
|
+
continue
|
|
419
|
+
elif pred_value is None:
|
|
420
|
+
precision = 0.0
|
|
421
|
+
recall = 0.0
|
|
422
|
+
elif gt_value is None:
|
|
423
|
+
# false positive
|
|
424
|
+
precision = 0.0
|
|
425
|
+
recall = 0.0
|
|
426
|
+
else:
|
|
427
|
+
if pred_value == gt_value:
|
|
428
|
+
# True positive
|
|
429
|
+
precision = 1.0
|
|
430
|
+
recall = 1.0
|
|
431
|
+
else:
|
|
432
|
+
precision = 0.0
|
|
433
|
+
recall = 0.0
|
|
434
|
+
|
|
435
|
+
# Compute F1-score
|
|
436
|
+
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
|
437
|
+
f1_scores[key] = f1_score
|
|
438
|
+
|
|
439
|
+
if len(f1_scores) == 0:
|
|
440
|
+
return 0
|
|
441
|
+
average_f1 = sum(f1_scores.values()) / len(f1_scores)
|
|
442
|
+
|
|
443
|
+
return average_f1
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def pre_clean(text):
|
|
447
|
+
text = re.sub(r'<bos>|<eos>|<pad>|<unk>', '', text)
|
|
448
|
+
text = re.sub(r'\s##(\S)', r'\1', text)
|
|
449
|
+
text = re.sub(r'\\\s', r'\\', text)
|
|
450
|
+
text = re.sub(r'\s\*\s\*\s', r'**', text)
|
|
451
|
+
text = re.sub(r'{\s', r'{', text)
|
|
452
|
+
text = re.sub(r'\s}', r'}', text)
|
|
453
|
+
text = re.sub(r'\s}', r'}', text)
|
|
454
|
+
text = re.sub(r'\\begin\s', r'\\begin', text)
|
|
455
|
+
text = re.sub(r'\\end\s', r'\\end', text)
|
|
456
|
+
text = re.sub(r'\\end{table}', r'\\end{table} \n\n', text)
|
|
457
|
+
text = text.replace('\n', ' ')
|
|
458
|
+
text = text.replace('*', ' ')
|
|
459
|
+
text = text.replace('_', ' ')
|
|
460
|
+
return text
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def get_tree(input_str):
|
|
464
|
+
tree = Node('ROOT').addkid(Node('TITLE'))
|
|
465
|
+
|
|
466
|
+
lines = input_str.split('\n')
|
|
467
|
+
lines = [pre_clean(line) for line in lines]
|
|
468
|
+
last_title = ''
|
|
469
|
+
for line in lines:
|
|
470
|
+
if line.startswith('#'):
|
|
471
|
+
child = tree.get('ROOT')
|
|
472
|
+
line = line.replace('#', '')
|
|
473
|
+
child.addkid(Node(line))
|
|
474
|
+
last_title = line
|
|
475
|
+
else:
|
|
476
|
+
if last_title == '':
|
|
477
|
+
child = tree.get('TITLE')
|
|
478
|
+
child.addkid(Node(line))
|
|
479
|
+
else:
|
|
480
|
+
child = tree.get(last_title)
|
|
481
|
+
child.addkid(Node(line))
|
|
482
|
+
return tree
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def STEDS(pred_tree, ref_tree):
|
|
486
|
+
|
|
487
|
+
def my_distance(pred, ref):
|
|
488
|
+
if len(pred.split()) == 0 or len(ref.split()) == 0:
|
|
489
|
+
return 1
|
|
490
|
+
else:
|
|
491
|
+
return 0
|
|
492
|
+
|
|
493
|
+
total_distance = simple_distance(pred_tree, ref_tree, label_dist=my_distance)
|
|
494
|
+
num_of_nodes = max(len(list(pred_tree.iter())), len(list(ref_tree.iter())))
|
|
495
|
+
return 1 - total_distance / num_of_nodes
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def doc_parsing_evaluation(pred, gt):
|
|
499
|
+
score = 0
|
|
500
|
+
if not isinstance(pred, str):
|
|
501
|
+
return 0
|
|
502
|
+
pred_tree = get_tree(pred)
|
|
503
|
+
gt_tree = get_tree(gt)
|
|
504
|
+
score = STEDS(pred_tree, gt_tree)
|
|
505
|
+
|
|
506
|
+
return score
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def wrap_html_table(html_table):
|
|
510
|
+
"""
|
|
511
|
+
The TEDS computation from PubTabNet code requires that the input html table should have <html>, <body>, and <table> tags.
|
|
512
|
+
Add them if they are missing.
|
|
513
|
+
"""
|
|
514
|
+
html_table = html_table.replace('\n', '')
|
|
515
|
+
# add missing <table> tag if missing
|
|
516
|
+
if '<table' in html_table and '</table>' not in html_table:
|
|
517
|
+
html_table = html_table + '</table>'
|
|
518
|
+
elif '<table' not in html_table and '</table>' in html_table:
|
|
519
|
+
html_table = '<table>' + html_table
|
|
520
|
+
elif '<table' not in html_table and '</table>' not in html_table:
|
|
521
|
+
html_table = '<table>' + html_table + '</table>'
|
|
522
|
+
else:
|
|
523
|
+
pass
|
|
524
|
+
# add <body> and <html> tags if missing
|
|
525
|
+
if '<body>' not in html_table:
|
|
526
|
+
html_table = '<body>' + html_table + '</body>'
|
|
527
|
+
if '<html>' not in html_table:
|
|
528
|
+
html_table = '<html>' + html_table + '</html>'
|
|
529
|
+
return html_table
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def get_anls(s1, s2):
|
|
533
|
+
try:
|
|
534
|
+
s1 = s1.lower()
|
|
535
|
+
s2 = s2.lower()
|
|
536
|
+
except:
|
|
537
|
+
pass
|
|
538
|
+
if s1 == s2:
|
|
539
|
+
return 1.0
|
|
540
|
+
iou = 1 - editdistance.eval(s1, s2) / max(len(s1), len(s2))
|
|
541
|
+
anls = iou
|
|
542
|
+
return anls
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def ocr_eval(references, predictions):
|
|
546
|
+
socre_ = 0.0
|
|
547
|
+
None_num = 0
|
|
548
|
+
for idx, ref_value in enumerate(references):
|
|
549
|
+
pred_value = predictions[idx]
|
|
550
|
+
pred_values, ref_values = [], []
|
|
551
|
+
if isinstance(pred_value, str):
|
|
552
|
+
pred_values.append(pred_value)
|
|
553
|
+
else:
|
|
554
|
+
pred_values = pred_value
|
|
555
|
+
if isinstance(ref_value, str):
|
|
556
|
+
ref_values.append(ref_value)
|
|
557
|
+
else:
|
|
558
|
+
ref_values = ref_value
|
|
559
|
+
|
|
560
|
+
temp_score = 0.0
|
|
561
|
+
temp_num = len(ref_values)
|
|
562
|
+
|
|
563
|
+
for tmpidx, tmpref in enumerate(ref_values):
|
|
564
|
+
tmppred = pred_values[tmpidx] if tmpidx < len(pred_values) else pred_values[0]
|
|
565
|
+
if len(pred_values) == 1 and tmppred != 'None' and 'None' not in ref_values: # pred 1, and not None
|
|
566
|
+
temp_score = max(temp_score, get_anls(tmppred, tmpref))
|
|
567
|
+
temp_num = len(ref_values)
|
|
568
|
+
else:
|
|
569
|
+
if tmppred == 'None' and tmpref != 'None':
|
|
570
|
+
temp_score += 0.0
|
|
571
|
+
elif tmpref == 'None':
|
|
572
|
+
temp_num -= 1
|
|
573
|
+
else:
|
|
574
|
+
temp_score += get_anls(tmppred, tmpref)
|
|
575
|
+
if temp_num == 0:
|
|
576
|
+
ocr_score = 0.0
|
|
577
|
+
None_num += 1
|
|
578
|
+
else:
|
|
579
|
+
ocr_score = temp_score / (temp_num)
|
|
580
|
+
socre_ += ocr_score
|
|
581
|
+
if None_num == len(references):
|
|
582
|
+
return 9999
|
|
583
|
+
else:
|
|
584
|
+
return round(socre_ / (len(references) - None_num), 5)
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def csv_eval(predictions, references, easy, pred_type='json'):
|
|
588
|
+
predictions = predictions
|
|
589
|
+
labels = references
|
|
590
|
+
|
|
591
|
+
def is_int(val):
|
|
592
|
+
try:
|
|
593
|
+
int(val)
|
|
594
|
+
return True
|
|
595
|
+
except ValueError:
|
|
596
|
+
return False
|
|
597
|
+
|
|
598
|
+
def is_float(val):
|
|
599
|
+
try:
|
|
600
|
+
float(val)
|
|
601
|
+
return True
|
|
602
|
+
except ValueError:
|
|
603
|
+
return False
|
|
604
|
+
|
|
605
|
+
def convert_dict_to_list(data):
|
|
606
|
+
"""
|
|
607
|
+
Convert a dictionary to a list of tuples, handling both simple and nested dictionaries.
|
|
608
|
+
|
|
609
|
+
Args:
|
|
610
|
+
data (dict): The input dictionary, which might be nested or simple.
|
|
611
|
+
|
|
612
|
+
Returns:
|
|
613
|
+
list: A list of tuples generated from the input dictionary.
|
|
614
|
+
"""
|
|
615
|
+
# print(data)
|
|
616
|
+
converted_list = []
|
|
617
|
+
for key, value in data.items():
|
|
618
|
+
# Check if the value is a dictionary (indicating a nested structure)
|
|
619
|
+
if isinstance(value, dict):
|
|
620
|
+
# Handle nested dictionary
|
|
621
|
+
for subkey, subvalue in value.items():
|
|
622
|
+
# converted_list.append((key, subkey, subvalue))
|
|
623
|
+
converted_list.append((key, subkey, re.sub(r'[^\d.-]', '', str(subvalue))))
|
|
624
|
+
|
|
625
|
+
else:
|
|
626
|
+
# Handle simple key-value pair
|
|
627
|
+
# converted_list.append((key, "value", value))
|
|
628
|
+
converted_list.append((key, 'value', re.sub(r'[^\d.-]', '', str(value))))
|
|
629
|
+
return converted_list
|
|
630
|
+
|
|
631
|
+
def csv2triples(csv, separator='\\t', delimiter='\\n'):
|
|
632
|
+
lines = csv.strip().split(delimiter)
|
|
633
|
+
header = lines[0].split(separator)
|
|
634
|
+
triples = []
|
|
635
|
+
for line in lines[1:]:
|
|
636
|
+
if not line:
|
|
637
|
+
continue
|
|
638
|
+
values = line.split(separator)
|
|
639
|
+
entity = values[0]
|
|
640
|
+
for i in range(1, len(values)):
|
|
641
|
+
if i >= len(header):
|
|
642
|
+
break
|
|
643
|
+
# ---------------------------------------------------------
|
|
644
|
+
temp = [entity.strip(), header[i].strip()]
|
|
645
|
+
temp = [x if len(x) == 0 or x[-1] != ':' else x[:-1] for x in temp]
|
|
646
|
+
value = values[i].strip()
|
|
647
|
+
value = re.sub(r'[^\d.-]', '', str(value))
|
|
648
|
+
# value = value.replace("%","")
|
|
649
|
+
# value = value.replace("$","")
|
|
650
|
+
triples.append((temp[0], temp[1], value))
|
|
651
|
+
# ---------------------------------------------------------
|
|
652
|
+
return triples
|
|
653
|
+
|
|
654
|
+
def csv2triples_noheader(csv, separator='\\t', delimiter='\\n'):
|
|
655
|
+
lines = csv.strip().split(delimiter)
|
|
656
|
+
maybe_header = [x.strip() for x in lines[0].split(separator)]
|
|
657
|
+
not_header = False
|
|
658
|
+
if len(maybe_header) > 2:
|
|
659
|
+
for c in maybe_header[1:]:
|
|
660
|
+
try:
|
|
661
|
+
num = float(c)
|
|
662
|
+
not_header = True
|
|
663
|
+
except:
|
|
664
|
+
continue
|
|
665
|
+
if not_header:
|
|
666
|
+
break
|
|
667
|
+
header = None if not_header else maybe_header
|
|
668
|
+
data_start = 0 if not_header and separator in lines[0] else 1
|
|
669
|
+
triples = []
|
|
670
|
+
for line in lines[data_start:]:
|
|
671
|
+
if not line:
|
|
672
|
+
continue
|
|
673
|
+
values = [x.strip() for x in line.split(separator)]
|
|
674
|
+
entity = values[0]
|
|
675
|
+
for i in range(1, len(values)):
|
|
676
|
+
try:
|
|
677
|
+
temp = [entity if entity[-1] != ':' else entity[:-1], '']
|
|
678
|
+
except:
|
|
679
|
+
temp = [entity, '']
|
|
680
|
+
if header is not None:
|
|
681
|
+
try:
|
|
682
|
+
this_header = header[i]
|
|
683
|
+
temp = [entity, this_header]
|
|
684
|
+
temp = [x if x[-1] != ':' else x[:-1] for x in temp]
|
|
685
|
+
except:
|
|
686
|
+
this_header = entity.strip()
|
|
687
|
+
value = values[i].strip()
|
|
688
|
+
value = re.sub(r'[^\d.-]', '', str(value))
|
|
689
|
+
# value = value.replace("%","")
|
|
690
|
+
# value = value.replace("$","")
|
|
691
|
+
triples.append((temp[0], temp[1], value))
|
|
692
|
+
# ---------------------------------------------------------
|
|
693
|
+
return triples
|
|
694
|
+
|
|
695
|
+
def process_triplets(triplets):
|
|
696
|
+
new_triplets = []
|
|
697
|
+
for triplet in triplets:
|
|
698
|
+
new_triplet = []
|
|
699
|
+
triplet_temp = []
|
|
700
|
+
if len(triplet) > 2:
|
|
701
|
+
if is_int(triplet[2]) or is_float(triplet[2]):
|
|
702
|
+
triplet_temp = (triplet[0].lower(), triplet[1].lower(), float(triplet[2]))
|
|
703
|
+
else:
|
|
704
|
+
triplet_temp = (triplet[0].lower(), triplet[1].lower(), triplet[2].lower())
|
|
705
|
+
else:
|
|
706
|
+
triplet_temp = (triplet[0].lower(), triplet[1].lower(), 'no meaning')
|
|
707
|
+
new_triplets.append(triplet_temp)
|
|
708
|
+
return new_triplets
|
|
709
|
+
|
|
710
|
+
def intersection_with_tolerance(a, b, tol_word, tol_num):
|
|
711
|
+
a = set(a)
|
|
712
|
+
b = set(b)
|
|
713
|
+
c = set()
|
|
714
|
+
for elem1 in a:
|
|
715
|
+
for elem2 in b:
|
|
716
|
+
if is_float(elem1[-1]) and is_float(elem2[-1]):
|
|
717
|
+
if (((Levenshtein.distance(''.join(elem1[:-1]), ''.join(elem2[:-1])) <= tol_word) and
|
|
718
|
+
(abs(elem1[-1] - elem2[-1]) / (abs(elem2[-1]) + 0.000001) <= tol_num))
|
|
719
|
+
or ((''.join(elem1[:-1]) in ''.join(elem2[:-1])) and
|
|
720
|
+
(abs(elem1[-1] - elem2[-1]) / (abs(elem2[-1]) + 0.000001) <= tol_num))
|
|
721
|
+
or ((''.join(elem2[:-1]) in ''.join(elem1[:-1])) and
|
|
722
|
+
(abs(elem1[-1] - elem2[-1]) / (abs(elem2[-1]) + 0.000001) <= tol_num))):
|
|
723
|
+
c.add(elem1)
|
|
724
|
+
else:
|
|
725
|
+
if Levenshtein.distance(
|
|
726
|
+
''.join([str(i) for i in elem1]), ''.join([str(j) for j in elem2])
|
|
727
|
+
) <= tol_word:
|
|
728
|
+
c.add(elem1)
|
|
729
|
+
return list(c)
|
|
730
|
+
|
|
731
|
+
def union_with_tolerance(a, b, tol_word, tol_num):
|
|
732
|
+
c = set(a) | set(b)
|
|
733
|
+
d = set(a) & set(b)
|
|
734
|
+
e = intersection_with_tolerance(a, b, tol_word, tol_num)
|
|
735
|
+
f = set(e)
|
|
736
|
+
g = c - (f - d)
|
|
737
|
+
return list(g)
|
|
738
|
+
|
|
739
|
+
def get_eval_list(
|
|
740
|
+
pred_csv, label_csv, separator='\\t', delimiter='\\n', tol_word=3, tol_num=0.05, pred_type='json'
|
|
741
|
+
):
|
|
742
|
+
if pred_type == 'json':
|
|
743
|
+
pred_triple_list = []
|
|
744
|
+
for it in pred_csv:
|
|
745
|
+
pred_triple_temp = convert_dict_to_list(it)
|
|
746
|
+
pred_triple_pre = process_triplets(pred_triple_temp)
|
|
747
|
+
pred_triple_list.append(pred_triple_pre)
|
|
748
|
+
else:
|
|
749
|
+
pred_triple_list = []
|
|
750
|
+
for it in pred_csv:
|
|
751
|
+
pred_triple_temp = csv2triples(it, separator=separator, delimiter=delimiter)
|
|
752
|
+
# pred_triple_temp = csv2triples_noheader(it, separator=separator, delimiter=delimiter)
|
|
753
|
+
pred_triple_pre = process_triplets(pred_triple_temp)
|
|
754
|
+
pred_triple_list.append(pred_triple_pre)
|
|
755
|
+
|
|
756
|
+
label_triple_list = []
|
|
757
|
+
for it in label_csv:
|
|
758
|
+
label_triple_temp = convert_dict_to_list(it)
|
|
759
|
+
label_triple_pre = process_triplets(label_triple_temp)
|
|
760
|
+
label_triple_list.append(label_triple_pre)
|
|
761
|
+
|
|
762
|
+
intersection_list = []
|
|
763
|
+
union_list = []
|
|
764
|
+
sim_list = []
|
|
765
|
+
# for each chart image
|
|
766
|
+
for pred, label in zip(pred_triple_list, label_triple_list):
|
|
767
|
+
for idx in range(len(pred)):
|
|
768
|
+
try:
|
|
769
|
+
if label[idx][1] == 'value' and 'value' not in pred[idx][:2]:
|
|
770
|
+
pred[idx] = (pred[idx][0], 'value', pred[idx][2])
|
|
771
|
+
temp_pred_head = sorted(pred[idx][:2])
|
|
772
|
+
temp_gt_head = sorted(label[idx][:2])
|
|
773
|
+
pred[idx] = (temp_pred_head[0], temp_pred_head[1], pred[idx][2])
|
|
774
|
+
label[idx] = (temp_gt_head[0], temp_gt_head[1], label[idx][2])
|
|
775
|
+
except:
|
|
776
|
+
continue
|
|
777
|
+
intersection = intersection_with_tolerance(pred, label, tol_word=tol_word, tol_num=tol_num)
|
|
778
|
+
union = union_with_tolerance(pred, label, tol_word=tol_word, tol_num=tol_num)
|
|
779
|
+
sim = len(intersection) / len(union)
|
|
780
|
+
intersection_list.append(intersection)
|
|
781
|
+
union_list.append(union)
|
|
782
|
+
sim_list.append(sim)
|
|
783
|
+
return intersection_list, union_list, sim_list
|
|
784
|
+
|
|
785
|
+
def get_ap(predictions, labels, sim_threhold, tolerance, separator='\\t', delimiter='\\n', easy=1):
|
|
786
|
+
if tolerance == 'strict':
|
|
787
|
+
tol_word = 0
|
|
788
|
+
if easy == 1:
|
|
789
|
+
tol_num = 0
|
|
790
|
+
else:
|
|
791
|
+
tol_num = 0.1
|
|
792
|
+
|
|
793
|
+
elif tolerance == 'slight':
|
|
794
|
+
tol_word = 2
|
|
795
|
+
if easy == 1:
|
|
796
|
+
tol_num = 0.05
|
|
797
|
+
else:
|
|
798
|
+
tol_num = 0.3
|
|
799
|
+
|
|
800
|
+
elif tolerance == 'high':
|
|
801
|
+
tol_word = 5
|
|
802
|
+
if easy == 1:
|
|
803
|
+
tol_num = 0.1
|
|
804
|
+
else:
|
|
805
|
+
tol_num = 0.5
|
|
806
|
+
intersection_list, union_list, sim_list = get_eval_list(
|
|
807
|
+
predictions,
|
|
808
|
+
labels,
|
|
809
|
+
separator=separator,
|
|
810
|
+
delimiter=delimiter,
|
|
811
|
+
tol_word=tol_word,
|
|
812
|
+
tol_num=tol_num,
|
|
813
|
+
pred_type=pred_type
|
|
814
|
+
)
|
|
815
|
+
ap = len([num for num in sim_list if num >= sim_threhold]) / (len(sim_list) + 1e-16)
|
|
816
|
+
return ap
|
|
817
|
+
|
|
818
|
+
map_strict = 0
|
|
819
|
+
map_slight = 0
|
|
820
|
+
map_high = 0
|
|
821
|
+
s = '\\t'
|
|
822
|
+
d = '\\n'
|
|
823
|
+
|
|
824
|
+
for sim_threhold in np.arange(0.5, 1, 0.05):
|
|
825
|
+
map_temp_strict = get_ap(
|
|
826
|
+
predictions, labels, sim_threhold=sim_threhold, tolerance='strict', separator=s, delimiter=d, easy=easy
|
|
827
|
+
)
|
|
828
|
+
map_temp_slight = get_ap(
|
|
829
|
+
predictions, labels, sim_threhold=sim_threhold, tolerance='slight', separator=s, delimiter=d, easy=easy
|
|
830
|
+
)
|
|
831
|
+
map_temp_high = get_ap(
|
|
832
|
+
predictions, labels, sim_threhold=sim_threhold, tolerance='high', separator=s, delimiter=d, easy=easy
|
|
833
|
+
)
|
|
834
|
+
map_strict += map_temp_strict / 10
|
|
835
|
+
map_slight += map_temp_slight / 10
|
|
836
|
+
map_high += map_temp_high / 10
|
|
837
|
+
|
|
838
|
+
em = get_ap(predictions, labels, sim_threhold=1, tolerance='strict', separator=s, delimiter=d, easy=easy)
|
|
839
|
+
ap_50_strict = get_ap(
|
|
840
|
+
predictions, labels, sim_threhold=0.5, tolerance='strict', separator=s, delimiter=d, easy=easy
|
|
841
|
+
)
|
|
842
|
+
ap_75_strict = get_ap(
|
|
843
|
+
predictions, labels, sim_threhold=0.75, tolerance='strict', separator=s, delimiter=d, easy=easy
|
|
844
|
+
)
|
|
845
|
+
ap_90_strict = get_ap(
|
|
846
|
+
predictions, labels, sim_threhold=0.90, tolerance='strict', separator=s, delimiter=d, easy=easy
|
|
847
|
+
)
|
|
848
|
+
ap_50_slight = get_ap(
|
|
849
|
+
predictions, labels, sim_threhold=0.5, tolerance='slight', separator=s, delimiter=d, easy=easy
|
|
850
|
+
)
|
|
851
|
+
ap_75_slight = get_ap(
|
|
852
|
+
predictions, labels, sim_threhold=0.75, tolerance='slight', separator=s, delimiter=d, easy=easy
|
|
853
|
+
)
|
|
854
|
+
ap_90_slight = get_ap(
|
|
855
|
+
predictions, labels, sim_threhold=0.90, tolerance='slight', separator=s, delimiter=d, easy=easy
|
|
856
|
+
)
|
|
857
|
+
ap_50_high = get_ap(predictions, labels, sim_threhold=0.5, tolerance='high', separator=s, delimiter=d, easy=easy)
|
|
858
|
+
ap_75_high = get_ap(predictions, labels, sim_threhold=0.75, tolerance='high', separator=s, delimiter=d, easy=easy)
|
|
859
|
+
ap_90_high = get_ap(predictions, labels, sim_threhold=0.90, tolerance='high', separator=s, delimiter=d, easy=easy)
|
|
860
|
+
|
|
861
|
+
return em, map_strict, map_slight, map_high, ap_50_strict, ap_75_strict, ap_90_strict, ap_50_slight, ap_75_slight, ap_90_slight, ap_50_high, ap_75_high, ap_90_high
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
def draw_SCRM_table(
|
|
865
|
+
em,
|
|
866
|
+
map_strict,
|
|
867
|
+
map_slight,
|
|
868
|
+
map_high,
|
|
869
|
+
ap_50_strict,
|
|
870
|
+
ap_75_strict,
|
|
871
|
+
ap_90_strict,
|
|
872
|
+
ap_50_slight,
|
|
873
|
+
ap_75_slight,
|
|
874
|
+
ap_90_slight,
|
|
875
|
+
ap_50_high,
|
|
876
|
+
ap_75_high,
|
|
877
|
+
ap_90_high,
|
|
878
|
+
title_ocr_socre,
|
|
879
|
+
source_ocr_socre,
|
|
880
|
+
x_title_ocr_socre,
|
|
881
|
+
y_title_ocr_socre,
|
|
882
|
+
structure_accuracy,
|
|
883
|
+
):
|
|
884
|
+
result = f"""
|
|
885
|
+
-----------------------------------------------------------\n
|
|
886
|
+
| Metrics | Sim_threshold | Tolerance | Value |\n
|
|
887
|
+
-----------------------------------------------------------\n
|
|
888
|
+
| | | strict | {'%.4f' % map_strict} | \n
|
|
889
|
+
| | ----------------------------\n
|
|
890
|
+
| mPrecison | 0.5:0.05:0.95 | slight | {'%.4f' % map_slight} |\n
|
|
891
|
+
| | ---------------------------\n
|
|
892
|
+
| | | high | {'%.4f' % map_high} |\n
|
|
893
|
+
-----------------------------------------------------------\n
|
|
894
|
+
| | | strict | {'%.4f' % ap_50_strict} |\n
|
|
895
|
+
| | ---------------------------\n
|
|
896
|
+
| Precison | 0.5 | slight | {'%.4f' % ap_50_slight } |\n
|
|
897
|
+
| | ---------------------------\n
|
|
898
|
+
| | | high | {'%.4f' % ap_50_high } |\n
|
|
899
|
+
-----------------------------------------------------------\n
|
|
900
|
+
| | | strict | {'%.4f' % ap_75_strict} |\n
|
|
901
|
+
| | ---------------------------\n
|
|
902
|
+
| Precison | 0.75 | slight | {'%.4f' % ap_75_slight} |\n
|
|
903
|
+
| | ---------------------------\n
|
|
904
|
+
| | | high | {'%.4f' % ap_75_high} |\n
|
|
905
|
+
-----------------------------------------------------------\n
|
|
906
|
+
| | | strict | {'%.4f' % ap_90_strict} |\n
|
|
907
|
+
| | ---------------------------\n
|
|
908
|
+
| Precison | 0.9 | slight | {'%.4f' % ap_90_slight } |\n
|
|
909
|
+
| | ---------------------------\n
|
|
910
|
+
| | | high | {'%.4f' % ap_90_high} |\n
|
|
911
|
+
-----------------------------------------------------------\n
|
|
912
|
+
|Precison(EM) | {'%.4f' % em} |\n
|
|
913
|
+
-----------------------------------------------------------\n
|
|
914
|
+
|Title(EM) | {'%.4f' % title_ocr_socre} |\n
|
|
915
|
+
-----------------------------------------------------------\n
|
|
916
|
+
|Source(EM) | {'%.4f' % source_ocr_socre} |\n
|
|
917
|
+
-----------------------------------------------------------\n
|
|
918
|
+
|X_title(EM) | {'%.4f' % x_title_ocr_socre} |\n
|
|
919
|
+
-----------------------------------------------------------\n
|
|
920
|
+
|Y_title(EM) | {'%.4f' % y_title_ocr_socre} |\n
|
|
921
|
+
-----------------------------------------------------------\n
|
|
922
|
+
|structure_acc| {'%.4f' % structure_accuracy} |\n
|
|
923
|
+
-----------------------------------------------------------\n
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
"""
|
|
927
|
+
return result
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
if __name__ == '__main__':
|
|
931
|
+
import json
|
|
932
|
+
import pprint
|
|
933
|
+
|
|
934
|
+
# markdown structure for Table Parsing task
|
|
935
|
+
pred_markdown = '| 1 | august 5 , 1972 | detroit lions | l 23 - 31 | 0 - 1 |\n| 2 | august 12 , 1972 | green bay packers | l 13 - 14 | 0 - 2 |\n| 3 | august 19 , 1972 | cincinnati bengals | w 35 - 17 | 1 - 2 |\n| 4 | august 25 , 1972 | atlanta falcons | w 24 - 10 | 2 - 2 |\n| 5 | august 31 , 1972 | washington redskins | l 24 - 27 | 2 - 3 |\n| 6 | september 10 , 1972 | minnesota vikings | w 21 - 19 | 3 - 3 |'
|
|
936
|
+
true_markdown = '| week | date | opponent | result | record |\n| --- | --- | --- | --- | --- |\n| 1 | august 5 , 1972 | detroit lions | l 23 - 31 | 0 - 1 |\n| 2 | august 12 , 1972 | green bay packers | l 13 - 14 | 0 - 2 |\n| 3 | august 19 , 1972 | cincinnati bengals | w 35 - 17 | 1 - 2 |\n| 4 | august 25 , 1972 | atlanta falcons | w 24 - 10 | 2 - 2 |\n| 5 | august 31 , 1972 | washington redskins | l 24 - 27 | 2 - 3 |\n| 6 | september 10 , 1972 | minnesota vikings | w 21 - 19 | 3 - 3 |'
|
|
937
|
+
teds = TEDS(n_jobs=4)
|
|
938
|
+
pred_table_html = convert_markdown_table_to_html(pred_markdown)
|
|
939
|
+
true_table_html = convert_markdown_table_to_html(true_markdown)
|
|
940
|
+
|
|
941
|
+
scores = teds.evaluate(pred_table_html, true_table_html)
|
|
942
|
+
|
|
943
|
+
pp = pprint.PrettyPrinter()
|
|
944
|
+
pp.pprint(scores)
|
|
945
|
+
|
|
946
|
+
# dict structure for Key Information Extraction task
|
|
947
|
+
pred_dict = {'company': ['OLD TOWN '], 'date': ['2024'], 'address': ['SRI RAMPAI'], 'total': ['30']}
|
|
948
|
+
true_dict = {
|
|
949
|
+
'company': ['OLD TOWN KOPITAM SND BHD'],
|
|
950
|
+
'date': ['2024/9/27'],
|
|
951
|
+
'address': ['SRI RAMPAI'],
|
|
952
|
+
'total': ['30']
|
|
953
|
+
}
|
|
954
|
+
teds = TEDS(n_jobs=4)
|
|
955
|
+
pred_dict_html = dict_to_html(pred_dict)
|
|
956
|
+
true_dict_html = dict_to_html(true_dict)
|
|
957
|
+
print(pred_dict_html)
|
|
958
|
+
print(true_dict_html)
|
|
959
|
+
|
|
960
|
+
scores = teds.evaluate(pred_dict_html, true_dict_html)
|
|
961
|
+
|
|
962
|
+
pp = pprint.PrettyPrinter()
|
|
963
|
+
pp.pprint(scores)
|