evalscope 1.0.2__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/api/benchmark/__init__.py +8 -1
- evalscope/api/benchmark/adapters/__init__.py +1 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +12 -0
- evalscope/api/benchmark/adapters/ner_adapter.py +212 -0
- evalscope/api/benchmark/benchmark.py +14 -0
- evalscope/api/dataset/dataset.py +21 -0
- evalscope/api/dataset/loader.py +6 -2
- evalscope/api/mixin/sandbox_mixin.py +32 -54
- evalscope/api/model/generate_config.py +6 -0
- evalscope/app/ui/multi_model.py +6 -1
- evalscope/app/ui/single_model.py +8 -2
- evalscope/app/utils/data_utils.py +3 -2
- evalscope/app/utils/visualization.py +2 -2
- evalscope/benchmarks/aa_lcr/aa_lcr_adapter.py +205 -0
- evalscope/benchmarks/ai2d/ai2d_adapter.py +3 -2
- evalscope/benchmarks/bfcl/bfcl_adapter.py +11 -46
- evalscope/benchmarks/blink/__init__.py +0 -0
- evalscope/benchmarks/blink/blink_adapter.py +61 -0
- evalscope/benchmarks/chartqa/__init__.py +0 -0
- evalscope/benchmarks/chartqa/chartqa_adapter.py +80 -0
- evalscope/benchmarks/chartqa/utils.py +38 -0
- evalscope/benchmarks/data_collection/data_collection_adapter.py +2 -1
- evalscope/benchmarks/docvqa/__init__.py +0 -0
- evalscope/benchmarks/docvqa/docvqa_adapter.py +67 -0
- evalscope/benchmarks/general_arena/general_arena_adapter.py +1 -1
- evalscope/benchmarks/general_arena/utils.py +2 -1
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +1 -1
- evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +23 -4
- evalscope/benchmarks/hallusion_bench/__init__.py +0 -0
- evalscope/benchmarks/hallusion_bench/hallusion_bench_adapter.py +158 -0
- evalscope/benchmarks/hle/hle_adapter.py +3 -2
- evalscope/benchmarks/humaneval/humaneval_adapter.py +2 -1
- evalscope/benchmarks/infovqa/__init__.py +0 -0
- evalscope/benchmarks/infovqa/infovqa_adapter.py +66 -0
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +3 -1
- evalscope/benchmarks/math_verse/__init__.py +0 -0
- evalscope/benchmarks/math_verse/math_verse_adapter.py +100 -0
- evalscope/benchmarks/math_vision/__init__.py +0 -0
- evalscope/benchmarks/math_vision/math_vision_adapter.py +111 -0
- evalscope/benchmarks/math_vista/math_vista_adapter.py +6 -26
- evalscope/benchmarks/mm_bench/mm_bench_adapter.py +2 -2
- evalscope/benchmarks/mmmu/mmmu_adapter.py +1 -1
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +1 -1
- evalscope/benchmarks/ner/__init__.py +0 -0
- evalscope/benchmarks/ner/broad_twitter_corpus_adapter.py +52 -0
- evalscope/benchmarks/ner/conll2003_adapter.py +48 -0
- evalscope/benchmarks/ner/copious_adapter.py +85 -0
- evalscope/benchmarks/ner/cross_ner_adapter.py +120 -0
- evalscope/benchmarks/ner/cross_ner_entities/__init__.py +0 -0
- evalscope/benchmarks/ner/cross_ner_entities/ai.py +54 -0
- evalscope/benchmarks/ner/cross_ner_entities/literature.py +36 -0
- evalscope/benchmarks/ner/cross_ner_entities/music.py +39 -0
- evalscope/benchmarks/ner/cross_ner_entities/politics.py +37 -0
- evalscope/benchmarks/ner/cross_ner_entities/science.py +58 -0
- evalscope/benchmarks/ner/genia_ner_adapter.py +66 -0
- evalscope/benchmarks/ner/harvey_ner_adapter.py +58 -0
- evalscope/benchmarks/ner/mit_movie_trivia_adapter.py +74 -0
- evalscope/benchmarks/ner/mit_restaurant_adapter.py +66 -0
- evalscope/benchmarks/ner/ontonotes5_adapter.py +87 -0
- evalscope/benchmarks/ner/wnut2017_adapter.py +61 -0
- evalscope/benchmarks/ocr_bench/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_adapter.py +101 -0
- evalscope/benchmarks/ocr_bench_v2/IoUscore_metric.py +87 -0
- evalscope/benchmarks/ocr_bench_v2/TEDS_metric.py +963 -0
- evalscope/benchmarks/ocr_bench_v2/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench_v2/ocr_bench_v2_adapter.py +161 -0
- evalscope/benchmarks/ocr_bench_v2/page_ocr_metric.py +50 -0
- evalscope/benchmarks/ocr_bench_v2/parallel.py +46 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/readme.txt +26 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/rrc_evaluation_funcs_1_1.py +537 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/script.py +481 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_metric.py +179 -0
- evalscope/benchmarks/ocr_bench_v2/utils.py +433 -0
- evalscope/benchmarks/ocr_bench_v2/vqa_metric.py +254 -0
- evalscope/benchmarks/omnidoc_bench/__init__.py +0 -0
- evalscope/benchmarks/omnidoc_bench/end2end_eval.py +349 -0
- evalscope/benchmarks/omnidoc_bench/metrics.py +547 -0
- evalscope/benchmarks/omnidoc_bench/omnidoc_bench_adapter.py +135 -0
- evalscope/benchmarks/omnidoc_bench/utils.py +1937 -0
- evalscope/benchmarks/poly_math/__init__.py +0 -0
- evalscope/benchmarks/poly_math/poly_math_adapter.py +127 -0
- evalscope/benchmarks/poly_math/utils/instruction.py +105 -0
- evalscope/benchmarks/pope/__init__.py +0 -0
- evalscope/benchmarks/pope/pope_adapter.py +111 -0
- evalscope/benchmarks/seed_bench_2_plus/__init__.py +0 -0
- evalscope/benchmarks/seed_bench_2_plus/seed_bench_2_plus_adapter.py +72 -0
- evalscope/benchmarks/simple_vqa/__init__.py +0 -0
- evalscope/benchmarks/simple_vqa/simple_vqa_adapter.py +169 -0
- evalscope/benchmarks/tau_bench/tau_bench_adapter.py +1 -1
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +1 -1
- evalscope/benchmarks/visu_logic/__init__.py +0 -0
- evalscope/benchmarks/visu_logic/visu_logic_adapter.py +75 -0
- evalscope/benchmarks/zerobench/__init__.py +0 -0
- evalscope/benchmarks/zerobench/zerobench_adapter.py +64 -0
- evalscope/constants.py +4 -0
- evalscope/evaluator/evaluator.py +72 -79
- evalscope/metrics/math_parser.py +14 -0
- evalscope/metrics/metric.py +52 -1
- evalscope/metrics/metrics.py +16 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +2 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +2 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +2 -6
- evalscope/models/utils/openai.py +4 -0
- evalscope/perf/arguments.py +24 -4
- evalscope/perf/benchmark.py +74 -89
- evalscope/perf/http_client.py +31 -16
- evalscope/perf/main.py +15 -2
- evalscope/perf/plugin/api/base.py +9 -7
- evalscope/perf/plugin/api/custom_api.py +13 -58
- evalscope/perf/plugin/api/default_api.py +179 -79
- evalscope/perf/plugin/api/openai_api.py +4 -3
- evalscope/perf/plugin/datasets/base.py +21 -0
- evalscope/perf/plugin/datasets/custom.py +2 -3
- evalscope/perf/plugin/datasets/line_by_line.py +2 -3
- evalscope/perf/plugin/datasets/longalpaca.py +2 -3
- evalscope/perf/plugin/datasets/openqa.py +2 -4
- evalscope/perf/plugin/datasets/random_dataset.py +1 -3
- evalscope/perf/utils/benchmark_util.py +36 -22
- evalscope/perf/utils/db_util.py +14 -19
- evalscope/perf/utils/local_server.py +0 -44
- evalscope/perf/utils/log_utils.py +21 -6
- evalscope/report/__init__.py +11 -2
- evalscope/report/combinator.py +52 -2
- evalscope/run.py +4 -0
- evalscope/utils/function_utils.py +195 -12
- evalscope/utils/io_utils.py +74 -0
- evalscope/utils/json_schema.py +8 -6
- evalscope/utils/logger.py +49 -17
- evalscope/utils/multi_choices.py +16 -1
- evalscope/utils/ner.py +377 -0
- evalscope/version.py +2 -2
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/METADATA +239 -393
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/RECORD +140 -98
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/WHEEL +1 -1
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/top_level.txt +0 -1
- tests/__init__.py +0 -1
- tests/benchmark/__init__.py +0 -1
- tests/benchmark/test_eval.py +0 -429
- tests/benchmark/test_image_edit.py +0 -65
- tests/benchmark/test_sandbox.py +0 -81
- tests/benchmark/test_t2i.py +0 -142
- tests/benchmark/test_vlm.py +0 -137
- tests/cli/__init__.py +0 -1
- tests/cli/test_all.py +0 -269
- tests/cli/test_collection.py +0 -99
- tests/cli/test_custom.py +0 -268
- tests/cli/test_reasoning.py +0 -81
- tests/common.py +0 -73
- tests/perf/__init__.py +0 -1
- tests/perf/test_perf.py +0 -206
- tests/rag/test_clip_benchmark.py +0 -87
- tests/rag/test_mteb.py +0 -213
- tests/rag/test_ragas.py +0 -128
- tests/swift/__init__.py +0 -1
- tests/swift/test_run_swift_eval.py +0 -146
- tests/swift/test_run_swift_vlm_eval.py +0 -128
- tests/swift/test_run_swift_vlm_jugde_eval.py +0 -157
- tests/test_run_all.py +0 -12
- tests/utils.py +0 -13
- tests/vlm/__init__.py +0 -1
- tests/vlm/test_vlmeval.py +0 -102
- {tests/rag → evalscope/benchmarks/aa_lcr}/__init__.py +0 -0
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/entry_points.txt +0 -0
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,547 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
import copy
|
|
3
|
+
import jieba
|
|
4
|
+
import Levenshtein
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import re
|
|
7
|
+
from apted import APTED, Config
|
|
8
|
+
from apted.helpers import Tree
|
|
9
|
+
from collections import defaultdict, deque
|
|
10
|
+
from lxml import etree, html
|
|
11
|
+
from nltk.translate.bleu_score import corpus_bleu
|
|
12
|
+
from nltk.translate.meteor_score import meteor_score
|
|
13
|
+
from tabulate import tabulate
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
|
|
16
|
+
from evalscope.utils import get_logger
|
|
17
|
+
from .utils import normalized_table
|
|
18
|
+
|
|
19
|
+
logger = get_logger()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def show_result(results):
|
|
23
|
+
for metric_name in results.keys():
|
|
24
|
+
score_table = [[k, v] for k, v in results[metric_name].items()]
|
|
25
|
+
logger.info(f'\n{metric_name}:\n' + tabulate(score_table) + '\n' + '=' * 100)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def sort_nested_dict(d):
|
|
29
|
+
# If it's a dictionary, recursively sort it
|
|
30
|
+
if isinstance(d, dict):
|
|
31
|
+
# Sort the current dictionary
|
|
32
|
+
sorted_dict = {k: sort_nested_dict(v) for k, v in sorted(d.items())}
|
|
33
|
+
return sorted_dict
|
|
34
|
+
# If not a dictionary, return directly
|
|
35
|
+
return d
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_full_labels_results(samples: dict):
|
|
39
|
+
if not samples:
|
|
40
|
+
return {}
|
|
41
|
+
label_group_dict = defaultdict(lambda: defaultdict(list))
|
|
42
|
+
for sample in samples:
|
|
43
|
+
label_list = []
|
|
44
|
+
if not sample.get('gt_attribute'):
|
|
45
|
+
continue
|
|
46
|
+
for anno in sample['gt_attribute']:
|
|
47
|
+
for k, v in anno.items():
|
|
48
|
+
label_list.append(k + ': ' + str(v))
|
|
49
|
+
for label_name in list(
|
|
50
|
+
set(label_list)
|
|
51
|
+
): # Currently if there are merged cases, calculate based on the set of all labels involved after merging
|
|
52
|
+
for metric, score in sample['metric'].items():
|
|
53
|
+
label_group_dict[label_name][metric].append(score)
|
|
54
|
+
|
|
55
|
+
logger.info('----Anno Attribute---------------')
|
|
56
|
+
result = {}
|
|
57
|
+
result['sample_count'] = {}
|
|
58
|
+
for attribute in label_group_dict.keys():
|
|
59
|
+
for metric, scores in label_group_dict[attribute].items():
|
|
60
|
+
mean_score = sum(scores) / len(scores)
|
|
61
|
+
if not result.get(metric):
|
|
62
|
+
result[metric] = {}
|
|
63
|
+
result[metric][attribute] = mean_score
|
|
64
|
+
result['sample_count'][attribute] = len(scores)
|
|
65
|
+
result = sort_nested_dict(result)
|
|
66
|
+
show_result(result)
|
|
67
|
+
return result
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_page_split(samples, page_info): # Page level metric
|
|
71
|
+
if not page_info:
|
|
72
|
+
return {}
|
|
73
|
+
result_list = defaultdict(list)
|
|
74
|
+
|
|
75
|
+
for sample in samples:
|
|
76
|
+
img_name = sample['img_id'] if sample['img_id'].endswith('.jpg') else '_'.join(sample['img_id'].split('_')[:-1])
|
|
77
|
+
page_info_s = page_info[img_name]
|
|
78
|
+
if not sample.get('metric'):
|
|
79
|
+
continue
|
|
80
|
+
for metric, score in sample['metric'].items():
|
|
81
|
+
gt = sample['norm_gt'] if sample.get('norm_gt') else sample['gt']
|
|
82
|
+
pred = sample['norm_pred'] if sample.get('norm_pred') else sample['pred']
|
|
83
|
+
result_list[metric].append({
|
|
84
|
+
'image_name': img_name,
|
|
85
|
+
'metric': metric,
|
|
86
|
+
'attribute': 'ALL',
|
|
87
|
+
'score': score,
|
|
88
|
+
'upper_len': max(len(gt), len(pred))
|
|
89
|
+
})
|
|
90
|
+
for k, v in page_info_s.items():
|
|
91
|
+
if isinstance(v, list): # special issue
|
|
92
|
+
for special_issue in v:
|
|
93
|
+
if 'table' not in special_issue: # Table-related special fields have duplicates
|
|
94
|
+
result_list[metric].append({
|
|
95
|
+
'image_name': img_name,
|
|
96
|
+
'metric': metric,
|
|
97
|
+
'attribute': special_issue,
|
|
98
|
+
'score': score,
|
|
99
|
+
'upper_len': max(len(gt), len(pred))
|
|
100
|
+
})
|
|
101
|
+
else:
|
|
102
|
+
result_list[metric].append({
|
|
103
|
+
'image_name': img_name,
|
|
104
|
+
'metric': metric,
|
|
105
|
+
'attribute': k + ': ' + str(v),
|
|
106
|
+
'score': score,
|
|
107
|
+
'upper_len': max(len(gt), len(pred))
|
|
108
|
+
})
|
|
109
|
+
|
|
110
|
+
# Page level logic, accumulation is only done within pages, and mean operation is performed between pages
|
|
111
|
+
result = {}
|
|
112
|
+
if result_list.get('Edit_dist'):
|
|
113
|
+
df = pd.DataFrame(result_list['Edit_dist'])
|
|
114
|
+
up_total_avg = df.groupby(
|
|
115
|
+
['image_name', 'attribute']
|
|
116
|
+
).apply(lambda x: (x['score'] * x['upper_len']).sum() / x['upper_len'].sum(),
|
|
117
|
+
include_groups=False).groupby('attribute').mean(
|
|
118
|
+
) # At page level, accumulate edits, denominator is sum of max(gt, pred) from each sample
|
|
119
|
+
result['Edit_dist'] = up_total_avg.to_dict()
|
|
120
|
+
for metric in result_list.keys():
|
|
121
|
+
if metric == 'Edit_dist':
|
|
122
|
+
continue
|
|
123
|
+
df = pd.DataFrame(result_list[metric])
|
|
124
|
+
page_avg = df.groupby(['image_name', 'attribute']).apply(lambda x: x['score'].mean(),
|
|
125
|
+
include_groups=False).groupby('attribute').mean()
|
|
126
|
+
result[metric] = page_avg.to_dict()
|
|
127
|
+
|
|
128
|
+
result = sort_nested_dict(result)
|
|
129
|
+
# print('----Page Attribute---------------')
|
|
130
|
+
show_result(result)
|
|
131
|
+
return result
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_groups(samples, group_info):
|
|
135
|
+
group_samples = defaultdict(list)
|
|
136
|
+
for sample in samples:
|
|
137
|
+
group_samples['all'].append(sample)
|
|
138
|
+
for group in group_info:
|
|
139
|
+
select_flag = True
|
|
140
|
+
for k, v in group.items():
|
|
141
|
+
for gt_attribute in sample['gt_attribute'
|
|
142
|
+
]: # gt_attribute is a list containing all merged gt attributes
|
|
143
|
+
if not gt_attribute: # if no GT attributes, don't include in calculation
|
|
144
|
+
select_flag = False
|
|
145
|
+
elif gt_attribute[k] != v: # if any gt attribute doesn't meet criteria, don't select
|
|
146
|
+
select_flag = False
|
|
147
|
+
if select_flag:
|
|
148
|
+
group_samples[str(group)].append(sample)
|
|
149
|
+
return group_samples
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class Registry:
|
|
153
|
+
|
|
154
|
+
def __init__(self):
|
|
155
|
+
self._registry = {}
|
|
156
|
+
|
|
157
|
+
def register(self, name):
|
|
158
|
+
|
|
159
|
+
def decorator(item):
|
|
160
|
+
if name in self._registry:
|
|
161
|
+
raise ValueError(f'Item {name} already registered.')
|
|
162
|
+
self._registry[name] = item
|
|
163
|
+
return item
|
|
164
|
+
|
|
165
|
+
return decorator
|
|
166
|
+
|
|
167
|
+
def get(self, name):
|
|
168
|
+
if name not in self._registry:
|
|
169
|
+
raise ValueError(f'Item {name} not found in registry.')
|
|
170
|
+
return self._registry[name]
|
|
171
|
+
|
|
172
|
+
def list_items(self):
|
|
173
|
+
return list(self._registry.keys())
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
METRIC_REGISTRY = Registry()
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@METRIC_REGISTRY.register('TEDS')
|
|
180
|
+
class call_TEDS():
|
|
181
|
+
|
|
182
|
+
def __init__(self, samples):
|
|
183
|
+
self.samples = samples
|
|
184
|
+
|
|
185
|
+
def evaluate(self, group_info=[], save_name='default'):
|
|
186
|
+
teds = TEDS(structure_only=False)
|
|
187
|
+
teds_structure_only = TEDS(structure_only=True)
|
|
188
|
+
|
|
189
|
+
group_scores = defaultdict(list)
|
|
190
|
+
group_scores_structure_only = defaultdict(list)
|
|
191
|
+
|
|
192
|
+
samples = self.samples
|
|
193
|
+
for sample in samples:
|
|
194
|
+
gt = sample['norm_gt'] if sample.get('norm_gt') else sample['gt']
|
|
195
|
+
pred = sample['norm_pred'] if sample.get('norm_pred') else sample['pred']
|
|
196
|
+
|
|
197
|
+
score = teds.evaluate(pred, gt)
|
|
198
|
+
score_structure_only = teds_structure_only.evaluate(pred, gt)
|
|
199
|
+
# print('TEDS score:', score)
|
|
200
|
+
group_scores['all'].append(score)
|
|
201
|
+
group_scores_structure_only['all'].append(score_structure_only)
|
|
202
|
+
|
|
203
|
+
if not sample.get('metric'):
|
|
204
|
+
sample['metric'] = {}
|
|
205
|
+
sample['metric']['TEDS'] = score
|
|
206
|
+
sample['metric']['TEDS_structure_only'] = score_structure_only
|
|
207
|
+
|
|
208
|
+
for group in group_info:
|
|
209
|
+
select_flag = True
|
|
210
|
+
for k, v in group.items():
|
|
211
|
+
for gt_attribute in sample['gt_attribute'
|
|
212
|
+
]: # gt_attribute is a list containing all merged gt attributes
|
|
213
|
+
if not gt_attribute: # if no GT attributes, don't include in calculation
|
|
214
|
+
select_flag = False
|
|
215
|
+
elif gt_attribute[k] != v: # if any gt attribute doesn't meet criteria, don't select
|
|
216
|
+
select_flag = False
|
|
217
|
+
if select_flag:
|
|
218
|
+
group_scores[str(group)].append(score)
|
|
219
|
+
|
|
220
|
+
result = {}
|
|
221
|
+
for group_name, scores in group_scores.items():
|
|
222
|
+
if len(scores) > 0:
|
|
223
|
+
result[group_name] = sum(scores) / len(scores) # average of normalized scores at sample level
|
|
224
|
+
else:
|
|
225
|
+
result[group_name] = 'NaN'
|
|
226
|
+
logger.warning(f'Empty matched samples for {group_name}.')
|
|
227
|
+
|
|
228
|
+
structure_only_result = {}
|
|
229
|
+
for group_name, scores in group_scores_structure_only.items():
|
|
230
|
+
if len(scores) > 0:
|
|
231
|
+
structure_only_result[group_name] = sum(scores) / len(
|
|
232
|
+
scores
|
|
233
|
+
) # average of normalized scores at sample level
|
|
234
|
+
else:
|
|
235
|
+
structure_only_result[group_name] = 'NaN'
|
|
236
|
+
logger.warning(f'Empty matched samples for {group_name}.')
|
|
237
|
+
|
|
238
|
+
return samples, {'TEDS': result, 'TEDS_structure_only': structure_only_result}
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def tokenize(text) -> list[str]:
|
|
242
|
+
"""Tokenizes text, handling Chinese and non-Chinese strings appropriately."""
|
|
243
|
+
|
|
244
|
+
def contain_chinese_string(text):
|
|
245
|
+
chinese_pattern = re.compile(r'[\u4e00-\u9fa5]')
|
|
246
|
+
return bool(chinese_pattern.search(text))
|
|
247
|
+
|
|
248
|
+
if contain_chinese_string(text):
|
|
249
|
+
res = jieba.lcut(text)
|
|
250
|
+
else:
|
|
251
|
+
res = text.split()
|
|
252
|
+
return res
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@METRIC_REGISTRY.register('BLEU')
|
|
256
|
+
class call_BLEU():
|
|
257
|
+
|
|
258
|
+
def __init__(self, samples):
|
|
259
|
+
self.samples = samples
|
|
260
|
+
|
|
261
|
+
def evaluate(self, group_info=[], save_name='default'):
|
|
262
|
+
group_samples = get_groups(self.samples, group_info)
|
|
263
|
+
result = {}
|
|
264
|
+
|
|
265
|
+
for group_name, samples in group_samples.items():
|
|
266
|
+
predictions, references = [], []
|
|
267
|
+
for sample in samples:
|
|
268
|
+
gt = sample['norm_gt'] if sample.get('norm_gt') else sample['gt']
|
|
269
|
+
pred = sample['norm_pred'] if sample.get('norm_pred') else sample['pred']
|
|
270
|
+
predictions.append(tokenize(pred))
|
|
271
|
+
references.append([tokenize(gt)])
|
|
272
|
+
|
|
273
|
+
if not predictions or not any(predictions) or not references or not any(references):
|
|
274
|
+
bleu_score = 0
|
|
275
|
+
else:
|
|
276
|
+
try:
|
|
277
|
+
bleu_score = corpus_bleu(references, predictions)
|
|
278
|
+
except ZeroDivisionError:
|
|
279
|
+
bleu_score = 0
|
|
280
|
+
|
|
281
|
+
result[group_name] = bleu_score
|
|
282
|
+
|
|
283
|
+
return self.samples, {'BLEU': result}
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@METRIC_REGISTRY.register('METEOR')
|
|
287
|
+
class call_METEOR():
|
|
288
|
+
|
|
289
|
+
def __init__(self, samples):
|
|
290
|
+
self.samples = samples
|
|
291
|
+
|
|
292
|
+
def evaluate(self, group_info=[], save_name='default'):
|
|
293
|
+
group_samples = get_groups(self.samples, group_info)
|
|
294
|
+
result = {}
|
|
295
|
+
for group_name, samples in group_samples.items():
|
|
296
|
+
predictions, references = [], []
|
|
297
|
+
for sample in samples:
|
|
298
|
+
gt = sample['norm_gt'] if sample.get('norm_gt') else sample['gt']
|
|
299
|
+
pred = sample['norm_pred'] if sample.get('norm_pred') else sample['pred']
|
|
300
|
+
predictions.append(tokenize(pred))
|
|
301
|
+
references.append(tokenize(gt))
|
|
302
|
+
# Calculate METEOR score
|
|
303
|
+
if not predictions or not references:
|
|
304
|
+
meteor_results = 0
|
|
305
|
+
else:
|
|
306
|
+
try:
|
|
307
|
+
total_score = 0
|
|
308
|
+
for ref, pred in zip(references, predictions):
|
|
309
|
+
score = meteor_score([ref], pred)
|
|
310
|
+
total_score += score
|
|
311
|
+
meteor_results = total_score / len(references)
|
|
312
|
+
except Exception as e:
|
|
313
|
+
logger.error(f'METEOR calculation error: {e}')
|
|
314
|
+
meteor_results = 0
|
|
315
|
+
result[group_name] = meteor_results
|
|
316
|
+
|
|
317
|
+
return self.samples, {'METEOR': result}
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
@METRIC_REGISTRY.register('Edit_dist')
|
|
321
|
+
class call_Edit_dist():
|
|
322
|
+
|
|
323
|
+
def __init__(self, samples):
|
|
324
|
+
self.samples = samples
|
|
325
|
+
|
|
326
|
+
def evaluate(self, group_info=[], save_name='default'):
|
|
327
|
+
samples = self.samples
|
|
328
|
+
for sample in samples:
|
|
329
|
+
img_name = sample['img_id'] if sample['img_id'].endswith('.jpg') else '_'.join(
|
|
330
|
+
sample['img_id'].split('_')[:-1]
|
|
331
|
+
)
|
|
332
|
+
sample['image_name'] = img_name
|
|
333
|
+
gt = sample['norm_gt'] if sample.get('norm_gt') else sample['gt']
|
|
334
|
+
pred = sample['norm_pred'] if sample.get('norm_pred') else sample['pred']
|
|
335
|
+
upper_len = max(len(pred), len(gt))
|
|
336
|
+
sample['upper_len'] = upper_len
|
|
337
|
+
if len(pred) > 0 or len(gt) > 0:
|
|
338
|
+
edit_dist = Levenshtein.distance(pred, gt)
|
|
339
|
+
if not sample.get('metric'):
|
|
340
|
+
sample['metric'] = {}
|
|
341
|
+
sample['metric']['Edit_dist'] = edit_dist / upper_len
|
|
342
|
+
sample['Edit_num'] = edit_dist
|
|
343
|
+
|
|
344
|
+
if isinstance(samples, list):
|
|
345
|
+
saved_samples = samples
|
|
346
|
+
else:
|
|
347
|
+
saved_samples = samples.samples
|
|
348
|
+
|
|
349
|
+
if not saved_samples:
|
|
350
|
+
return self.samples, {'Edit_dist': {'ALL_page_avg': 'NaN'}}
|
|
351
|
+
|
|
352
|
+
df = pd.DataFrame(saved_samples)
|
|
353
|
+
up_total_avg = df.groupby('image_name').apply(
|
|
354
|
+
lambda x: x['Edit_num'].sum() / x['upper_len'].sum(), include_groups=False
|
|
355
|
+
) # page level, sum of edits divided by sum of max(gt,pred) lengths for each sample
|
|
356
|
+
|
|
357
|
+
return samples, {'Edit_dist': {'ALL_page_avg': up_total_avg.mean()}}
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
@METRIC_REGISTRY.register('CDM')
|
|
361
|
+
class call_CDM():
|
|
362
|
+
|
|
363
|
+
def __init__(self, samples):
|
|
364
|
+
self.samples = samples
|
|
365
|
+
|
|
366
|
+
def evaluate(self, group_info=[], save_name='default'):
|
|
367
|
+
if isinstance(self.samples, list):
|
|
368
|
+
cdm_samples = copy.deepcopy(self.samples)
|
|
369
|
+
else:
|
|
370
|
+
cdm_samples = copy.deepcopy(self.samples.samples)
|
|
371
|
+
for idx, sample in enumerate(cdm_samples):
|
|
372
|
+
sample['img_name'] = sample['img_id']
|
|
373
|
+
sample['img_id'] = str(idx)
|
|
374
|
+
sample['gt'] = sample['gt'].lstrip('$$').rstrip('$$').strip()
|
|
375
|
+
sample['pred'] = sample['pred'].split('```latex')[-1].split('```')[0]
|
|
376
|
+
sample['pred'] = sample['pred'].lstrip('$$').rstrip('$$').strip()
|
|
377
|
+
|
|
378
|
+
return self.samples, False
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
class TEDS(object):
|
|
382
|
+
''' Tree Edit Distance basead Similarity
|
|
383
|
+
'''
|
|
384
|
+
|
|
385
|
+
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
|
386
|
+
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
|
387
|
+
self.structure_only = structure_only
|
|
388
|
+
self.n_jobs = n_jobs
|
|
389
|
+
self.ignore_nodes = ignore_nodes
|
|
390
|
+
self.__tokens__ = []
|
|
391
|
+
|
|
392
|
+
def tokenize(self, node):
|
|
393
|
+
''' Tokenizes table cells
|
|
394
|
+
'''
|
|
395
|
+
self.__tokens__.append('<%s>' % node.tag)
|
|
396
|
+
if node.text is not None:
|
|
397
|
+
self.__tokens__ += list(node.text)
|
|
398
|
+
for n in node.getchildren():
|
|
399
|
+
self.tokenize(n)
|
|
400
|
+
if node.tag != 'unk':
|
|
401
|
+
self.__tokens__.append('</%s>' % node.tag)
|
|
402
|
+
if node.tag != 'td' and node.tail is not None:
|
|
403
|
+
self.__tokens__ += list(node.tail)
|
|
404
|
+
|
|
405
|
+
def load_html_tree(self, node, parent=None):
|
|
406
|
+
''' Converts HTML tree to the format required by apted
|
|
407
|
+
'''
|
|
408
|
+
global __tokens__
|
|
409
|
+
if node.tag == 'td':
|
|
410
|
+
if self.structure_only:
|
|
411
|
+
cell = []
|
|
412
|
+
else:
|
|
413
|
+
self.__tokens__ = []
|
|
414
|
+
self.tokenize(node)
|
|
415
|
+
cell = self.__tokens__[1:-1].copy()
|
|
416
|
+
new_node = TableTree(
|
|
417
|
+
node.tag, int(node.attrib.get('colspan', '1')), int(node.attrib.get('rowspan', '1')), cell, *deque()
|
|
418
|
+
)
|
|
419
|
+
else:
|
|
420
|
+
new_node = TableTree(node.tag, None, None, None, *deque())
|
|
421
|
+
if parent is not None:
|
|
422
|
+
parent.children.append(new_node)
|
|
423
|
+
if node.tag != 'td':
|
|
424
|
+
for n in node.getchildren():
|
|
425
|
+
self.load_html_tree(n, new_node)
|
|
426
|
+
if parent is None:
|
|
427
|
+
return new_node
|
|
428
|
+
|
|
429
|
+
def evaluate(self, pred, true):
|
|
430
|
+
''' Computes TEDS score between the prediction and the ground truth of a
|
|
431
|
+
given sample
|
|
432
|
+
'''
|
|
433
|
+
if (not pred) or (not true):
|
|
434
|
+
return 0.0
|
|
435
|
+
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
|
436
|
+
pred = html.fromstring(pred, parser=parser)
|
|
437
|
+
true = html.fromstring(true, parser=parser)
|
|
438
|
+
if pred.xpath('body/table') and true.xpath('body/table'):
|
|
439
|
+
pred = pred.xpath('body/table')[0]
|
|
440
|
+
true = true.xpath('body/table')[0]
|
|
441
|
+
if self.ignore_nodes:
|
|
442
|
+
etree.strip_tags(pred, *self.ignore_nodes)
|
|
443
|
+
etree.strip_tags(true, *self.ignore_nodes)
|
|
444
|
+
n_nodes_pred = len(pred.xpath('.//*'))
|
|
445
|
+
n_nodes_true = len(true.xpath('.//*'))
|
|
446
|
+
n_nodes = max(n_nodes_pred, n_nodes_true)
|
|
447
|
+
tree_pred = self.load_html_tree(pred)
|
|
448
|
+
tree_true = self.load_html_tree(true)
|
|
449
|
+
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
|
|
450
|
+
return 1.0 - (float(distance) / n_nodes)
|
|
451
|
+
else:
|
|
452
|
+
return 0.0
|
|
453
|
+
|
|
454
|
+
def batch_evaluate(self, pred_json, true_json):
|
|
455
|
+
''' Computes TEDS score between the prediction and the ground truth of
|
|
456
|
+
a batch of samples
|
|
457
|
+
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
|
458
|
+
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
|
459
|
+
@output: {'FILENAME': 'TEDS SCORE', ...}
|
|
460
|
+
'''
|
|
461
|
+
samples = true_json.keys()
|
|
462
|
+
# if self.n_jobs == 1:
|
|
463
|
+
scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
|
|
464
|
+
# else:
|
|
465
|
+
# inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
|
466
|
+
# scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
|
467
|
+
scores = dict(zip(samples, scores))
|
|
468
|
+
return scores
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
class CustomConfig(Config):
|
|
472
|
+
|
|
473
|
+
@staticmethod
|
|
474
|
+
def maximum(*sequences):
|
|
475
|
+
"""Get maximum possible value
|
|
476
|
+
"""
|
|
477
|
+
return max(map(len, sequences))
|
|
478
|
+
|
|
479
|
+
def normalized_distance(self, *sequences):
|
|
480
|
+
"""Get distance from 0 to 1
|
|
481
|
+
"""
|
|
482
|
+
return float(Levenshtein.distance(*sequences)) / self.maximum(*sequences)
|
|
483
|
+
|
|
484
|
+
def rename(self, node1, node2):
|
|
485
|
+
"""Compares attributes of trees"""
|
|
486
|
+
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
|
487
|
+
return 1.
|
|
488
|
+
if node1.tag == 'td':
|
|
489
|
+
if node1.content or node2.content:
|
|
490
|
+
return self.normalized_distance(node1.content, node2.content)
|
|
491
|
+
return 0.
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
class TableTree(Tree):
|
|
495
|
+
|
|
496
|
+
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
|
497
|
+
self.tag = tag
|
|
498
|
+
self.colspan = colspan
|
|
499
|
+
self.rowspan = rowspan
|
|
500
|
+
self.content = content
|
|
501
|
+
self.children = list(children)
|
|
502
|
+
|
|
503
|
+
def bracket(self):
|
|
504
|
+
"""Show tree using brackets notation"""
|
|
505
|
+
if self.tag == 'td':
|
|
506
|
+
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
|
507
|
+
(self.tag, self.colspan, self.rowspan, self.content)
|
|
508
|
+
else:
|
|
509
|
+
result = '"tag": %s' % self.tag
|
|
510
|
+
for child in self.children:
|
|
511
|
+
result += child.bracket()
|
|
512
|
+
return '{{{}}}'.format(result)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class recogition_end2end_base_dataset():
|
|
516
|
+
|
|
517
|
+
def __init__(self, samples):
|
|
518
|
+
img_id = 0
|
|
519
|
+
for sample in samples:
|
|
520
|
+
if not sample.get('img_id'):
|
|
521
|
+
sample['img_id'] = img_id
|
|
522
|
+
img_id += 1
|
|
523
|
+
self.samples = samples
|
|
524
|
+
|
|
525
|
+
def __getitem__(self, idx):
|
|
526
|
+
return self.samples[idx]
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
class recogition_end2end_table_dataset(recogition_end2end_base_dataset):
|
|
530
|
+
|
|
531
|
+
def __init__(self, samples, table_format):
|
|
532
|
+
self.pred_table_format = table_format
|
|
533
|
+
self.samples = self.normalize_data(samples)
|
|
534
|
+
|
|
535
|
+
def normalize_data(self, samples):
|
|
536
|
+
img_id = 0
|
|
537
|
+
for sample in samples:
|
|
538
|
+
p = sample['pred']
|
|
539
|
+
r = sample['gt']
|
|
540
|
+
p = normalized_table(p, self.pred_table_format)
|
|
541
|
+
r = normalized_table(r)
|
|
542
|
+
sample['norm_gt'] = r
|
|
543
|
+
sample['norm_pred'] = p
|
|
544
|
+
sample['img_id'] = sample['img_id'] if sample.get('img_id') else img_id
|
|
545
|
+
img_id += 1
|
|
546
|
+
|
|
547
|
+
return samples
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Any, Dict, List
|
|
4
|
+
|
|
5
|
+
from evalscope.api.benchmark import BenchmarkMeta, VisionLanguageAdapter
|
|
6
|
+
from evalscope.api.dataset import Sample
|
|
7
|
+
from evalscope.api.messages import ChatMessageUser, Content, ContentImage, ContentText
|
|
8
|
+
from evalscope.api.metric.scorer import AggScore, SampleScore, Score
|
|
9
|
+
from evalscope.api.registry import register_benchmark
|
|
10
|
+
from evalscope.constants import Tags
|
|
11
|
+
from evalscope.utils.import_utils import check_import
|
|
12
|
+
from evalscope.utils.logger import get_logger
|
|
13
|
+
|
|
14
|
+
logger = get_logger()
|
|
15
|
+
|
|
16
|
+
PROMPT_TEMPLATE = r""" You are an AI assistant specialized in converting PDF images to Markdown format. Please follow these instructions for the conversion:
|
|
17
|
+
|
|
18
|
+
1. Text Processing:
|
|
19
|
+
- Accurately recognize all text content in the PDF image without guessing or inferring.
|
|
20
|
+
- Convert the recognized text into Markdown format.
|
|
21
|
+
- Maintain the original document structure, including headings, paragraphs, lists, etc.
|
|
22
|
+
|
|
23
|
+
2. Mathematical Formula Processing:
|
|
24
|
+
- Convert all mathematical formulas to LaTeX format.
|
|
25
|
+
- Enclose inline formulas with \( \). For example: This is an inline formula \( E = mc^2 \)
|
|
26
|
+
- Enclose block formulas with \\[ \\]. For example: \[ \frac{-b \pm \sqrt{b^2 - 4ac}}{2a} \]
|
|
27
|
+
|
|
28
|
+
3. Table Processing:
|
|
29
|
+
- Convert tables to HTML format.
|
|
30
|
+
- Wrap the entire table with <table> and </table>.
|
|
31
|
+
|
|
32
|
+
4. Figure Handling:
|
|
33
|
+
- Ignore figures content in the PDF image. Do not attempt to describe or convert images.
|
|
34
|
+
|
|
35
|
+
5. Output Format:
|
|
36
|
+
- Ensure the output Markdown document has a clear structure with appropriate line breaks between elements.
|
|
37
|
+
- For complex layouts, try to maintain the original document's structure and format as closely as possible.
|
|
38
|
+
|
|
39
|
+
Please strictly follow these guidelines to ensure accuracy and consistency in the conversion. Your task is to accurately convert the content of the PDF image into Markdown format without adding any extra explanations or comments.
|
|
40
|
+
""" # noqa: E501
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@register_benchmark(
|
|
44
|
+
BenchmarkMeta(
|
|
45
|
+
name='omni_doc_bench',
|
|
46
|
+
pretty_name='OmniDocBench',
|
|
47
|
+
tags=[Tags.MULTI_MODAL, Tags.KNOWLEDGE, Tags.QA],
|
|
48
|
+
description=
|
|
49
|
+
"""OmniDocBench is an evaluation dataset for diverse document parsing in real-world scenarios, with the following characteristics:
|
|
50
|
+
- Diverse Document Types: The evaluation set contains 1355 PDF pages, covering 9 document types, 4 layout types and 3 language types. It has broad coverage including academic papers, financial reports, newspapers, textbooks, handwritten notes, etc.
|
|
51
|
+
- Rich Annotations: Contains location information for 15 block-level (text paragraphs, titles, tables, etc., over 20k in total) and 4 span-level (text lines, inline formulas, superscripts/subscripts, etc., over 80k in total) document elements, as well as recognition results for each element region (text annotations, LaTeX formula annotations, tables with both LaTeX and HTML annotations). OmniDocBench also provides reading order annotations for document components. Additionally, it includes various attribute labels at page and block levels, with 5 page attribute labels, 3 text attribute labels and 6 table attribute labels.
|
|
52
|
+
**The evaluation in EvalScope implements the `end2end` and `quick_match` methods from the official [OmniDocBench-v1.5 repository](https://github.com/opendatalab/OmniDocBench).**
|
|
53
|
+
""", # noqa: E501
|
|
54
|
+
dataset_id='evalscope/OmniDocBench_tsv',
|
|
55
|
+
metric_list={
|
|
56
|
+
'text_block': {
|
|
57
|
+
'metric': ['Edit_dist', 'BLEU', 'METEOR']
|
|
58
|
+
},
|
|
59
|
+
'display_formula': {
|
|
60
|
+
'metric': ['Edit_dist']
|
|
61
|
+
},
|
|
62
|
+
'table': {
|
|
63
|
+
'metric': ['TEDS', 'Edit_dist']
|
|
64
|
+
},
|
|
65
|
+
'reading_order': {
|
|
66
|
+
'metric': ['Edit_dist']
|
|
67
|
+
}
|
|
68
|
+
},
|
|
69
|
+
eval_split='train',
|
|
70
|
+
prompt_template=PROMPT_TEMPLATE,
|
|
71
|
+
extra_params={
|
|
72
|
+
'match_method': 'quick_match',
|
|
73
|
+
}
|
|
74
|
+
)
|
|
75
|
+
)
|
|
76
|
+
class OmniDocBenchAdapter(VisionLanguageAdapter):
|
|
77
|
+
|
|
78
|
+
def __init__(self, **kwargs):
|
|
79
|
+
super().__init__(**kwargs)
|
|
80
|
+
self.add_aggregation_name = False
|
|
81
|
+
self.match_method = self.extra_params.get('match_method', 'quick_match')
|
|
82
|
+
|
|
83
|
+
check_import(
|
|
84
|
+
module_name=['apted', 'distance', 'editdistance', 'Levenshtein', 'lxml', 'pylatexenc', 'bs4'],
|
|
85
|
+
package=['apted', 'distance', 'editdistance', 'Levenshtein', 'lxml', 'pylatexenc', 'BeautifulSoup4'],
|
|
86
|
+
raise_error=True,
|
|
87
|
+
feature_name='OmniDocBench'
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def record_to_sample(self, record) -> Sample:
|
|
91
|
+
content_list: List[Content] = [ContentText(text=self.prompt_template)]
|
|
92
|
+
content_list.append(ContentImage(image=f'data:image/png;base64,{record["image"]}'))
|
|
93
|
+
|
|
94
|
+
return Sample(
|
|
95
|
+
input=[ChatMessageUser(content=content_list)], target='', metadata=ast.literal_eval(record['answer'])
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def match_score(self, original_prediction, filtered_prediction, reference, task_state) -> Score:
|
|
99
|
+
# Dummy implementation to comply with the interface
|
|
100
|
+
|
|
101
|
+
score = Score(
|
|
102
|
+
prediction=original_prediction,
|
|
103
|
+
extracted_prediction=filtered_prediction,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return score
|
|
107
|
+
|
|
108
|
+
def aggregate_scores(self, sample_scores: List[SampleScore]) -> List[AggScore]:
|
|
109
|
+
from .end2end_eval import End2EndEvaluator
|
|
110
|
+
|
|
111
|
+
if not sample_scores:
|
|
112
|
+
return []
|
|
113
|
+
|
|
114
|
+
predictions = [s.score.prediction for s in sample_scores]
|
|
115
|
+
references = [s.sample_metadata for s in sample_scores]
|
|
116
|
+
|
|
117
|
+
evaluator = End2EndEvaluator(
|
|
118
|
+
prediction=predictions,
|
|
119
|
+
reference=references,
|
|
120
|
+
metrics=self.metric_list,
|
|
121
|
+
match_method=self.match_method,
|
|
122
|
+
)
|
|
123
|
+
agg_results = evaluator.score()
|
|
124
|
+
|
|
125
|
+
agg_scores = []
|
|
126
|
+
for metric_name, agg_result in agg_results.items():
|
|
127
|
+
if agg_result is not np.nan:
|
|
128
|
+
agg_score = AggScore(
|
|
129
|
+
score=agg_result,
|
|
130
|
+
metric_name=metric_name,
|
|
131
|
+
num=len(sample_scores),
|
|
132
|
+
)
|
|
133
|
+
agg_scores.append(agg_score)
|
|
134
|
+
|
|
135
|
+
return agg_scores
|