evalscope 1.0.2__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/api/benchmark/__init__.py +8 -1
- evalscope/api/benchmark/adapters/__init__.py +1 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +12 -0
- evalscope/api/benchmark/adapters/ner_adapter.py +212 -0
- evalscope/api/benchmark/benchmark.py +14 -0
- evalscope/api/dataset/dataset.py +21 -0
- evalscope/api/dataset/loader.py +6 -2
- evalscope/api/mixin/sandbox_mixin.py +32 -54
- evalscope/api/model/generate_config.py +6 -0
- evalscope/app/ui/multi_model.py +6 -1
- evalscope/app/ui/single_model.py +8 -2
- evalscope/app/utils/data_utils.py +3 -2
- evalscope/app/utils/visualization.py +2 -2
- evalscope/benchmarks/aa_lcr/aa_lcr_adapter.py +205 -0
- evalscope/benchmarks/ai2d/ai2d_adapter.py +3 -2
- evalscope/benchmarks/bfcl/bfcl_adapter.py +11 -46
- evalscope/benchmarks/blink/__init__.py +0 -0
- evalscope/benchmarks/blink/blink_adapter.py +61 -0
- evalscope/benchmarks/chartqa/__init__.py +0 -0
- evalscope/benchmarks/chartqa/chartqa_adapter.py +80 -0
- evalscope/benchmarks/chartqa/utils.py +38 -0
- evalscope/benchmarks/data_collection/data_collection_adapter.py +2 -1
- evalscope/benchmarks/docvqa/__init__.py +0 -0
- evalscope/benchmarks/docvqa/docvqa_adapter.py +67 -0
- evalscope/benchmarks/general_arena/general_arena_adapter.py +1 -1
- evalscope/benchmarks/general_arena/utils.py +2 -1
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +1 -1
- evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +23 -4
- evalscope/benchmarks/hallusion_bench/__init__.py +0 -0
- evalscope/benchmarks/hallusion_bench/hallusion_bench_adapter.py +158 -0
- evalscope/benchmarks/hle/hle_adapter.py +3 -2
- evalscope/benchmarks/humaneval/humaneval_adapter.py +2 -1
- evalscope/benchmarks/infovqa/__init__.py +0 -0
- evalscope/benchmarks/infovqa/infovqa_adapter.py +66 -0
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +3 -1
- evalscope/benchmarks/math_verse/__init__.py +0 -0
- evalscope/benchmarks/math_verse/math_verse_adapter.py +100 -0
- evalscope/benchmarks/math_vision/__init__.py +0 -0
- evalscope/benchmarks/math_vision/math_vision_adapter.py +111 -0
- evalscope/benchmarks/math_vista/math_vista_adapter.py +6 -26
- evalscope/benchmarks/mm_bench/mm_bench_adapter.py +2 -2
- evalscope/benchmarks/mmmu/mmmu_adapter.py +1 -1
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +1 -1
- evalscope/benchmarks/ner/__init__.py +0 -0
- evalscope/benchmarks/ner/broad_twitter_corpus_adapter.py +52 -0
- evalscope/benchmarks/ner/conll2003_adapter.py +48 -0
- evalscope/benchmarks/ner/copious_adapter.py +85 -0
- evalscope/benchmarks/ner/cross_ner_adapter.py +120 -0
- evalscope/benchmarks/ner/cross_ner_entities/__init__.py +0 -0
- evalscope/benchmarks/ner/cross_ner_entities/ai.py +54 -0
- evalscope/benchmarks/ner/cross_ner_entities/literature.py +36 -0
- evalscope/benchmarks/ner/cross_ner_entities/music.py +39 -0
- evalscope/benchmarks/ner/cross_ner_entities/politics.py +37 -0
- evalscope/benchmarks/ner/cross_ner_entities/science.py +58 -0
- evalscope/benchmarks/ner/genia_ner_adapter.py +66 -0
- evalscope/benchmarks/ner/harvey_ner_adapter.py +58 -0
- evalscope/benchmarks/ner/mit_movie_trivia_adapter.py +74 -0
- evalscope/benchmarks/ner/mit_restaurant_adapter.py +66 -0
- evalscope/benchmarks/ner/ontonotes5_adapter.py +87 -0
- evalscope/benchmarks/ner/wnut2017_adapter.py +61 -0
- evalscope/benchmarks/ocr_bench/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench/ocr_bench_adapter.py +101 -0
- evalscope/benchmarks/ocr_bench_v2/IoUscore_metric.py +87 -0
- evalscope/benchmarks/ocr_bench_v2/TEDS_metric.py +963 -0
- evalscope/benchmarks/ocr_bench_v2/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench_v2/ocr_bench_v2_adapter.py +161 -0
- evalscope/benchmarks/ocr_bench_v2/page_ocr_metric.py +50 -0
- evalscope/benchmarks/ocr_bench_v2/parallel.py +46 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/__init__.py +0 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/readme.txt +26 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/rrc_evaluation_funcs_1_1.py +537 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_eval/script.py +481 -0
- evalscope/benchmarks/ocr_bench_v2/spotting_metric.py +179 -0
- evalscope/benchmarks/ocr_bench_v2/utils.py +433 -0
- evalscope/benchmarks/ocr_bench_v2/vqa_metric.py +254 -0
- evalscope/benchmarks/omnidoc_bench/__init__.py +0 -0
- evalscope/benchmarks/omnidoc_bench/end2end_eval.py +349 -0
- evalscope/benchmarks/omnidoc_bench/metrics.py +547 -0
- evalscope/benchmarks/omnidoc_bench/omnidoc_bench_adapter.py +135 -0
- evalscope/benchmarks/omnidoc_bench/utils.py +1937 -0
- evalscope/benchmarks/poly_math/__init__.py +0 -0
- evalscope/benchmarks/poly_math/poly_math_adapter.py +127 -0
- evalscope/benchmarks/poly_math/utils/instruction.py +105 -0
- evalscope/benchmarks/pope/__init__.py +0 -0
- evalscope/benchmarks/pope/pope_adapter.py +111 -0
- evalscope/benchmarks/seed_bench_2_plus/__init__.py +0 -0
- evalscope/benchmarks/seed_bench_2_plus/seed_bench_2_plus_adapter.py +72 -0
- evalscope/benchmarks/simple_vqa/__init__.py +0 -0
- evalscope/benchmarks/simple_vqa/simple_vqa_adapter.py +169 -0
- evalscope/benchmarks/tau_bench/tau_bench_adapter.py +1 -1
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +1 -1
- evalscope/benchmarks/visu_logic/__init__.py +0 -0
- evalscope/benchmarks/visu_logic/visu_logic_adapter.py +75 -0
- evalscope/benchmarks/zerobench/__init__.py +0 -0
- evalscope/benchmarks/zerobench/zerobench_adapter.py +64 -0
- evalscope/constants.py +4 -0
- evalscope/evaluator/evaluator.py +72 -79
- evalscope/metrics/math_parser.py +14 -0
- evalscope/metrics/metric.py +52 -1
- evalscope/metrics/metrics.py +16 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +2 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +2 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +2 -6
- evalscope/models/utils/openai.py +4 -0
- evalscope/perf/arguments.py +24 -4
- evalscope/perf/benchmark.py +74 -89
- evalscope/perf/http_client.py +31 -16
- evalscope/perf/main.py +15 -2
- evalscope/perf/plugin/api/base.py +9 -7
- evalscope/perf/plugin/api/custom_api.py +13 -58
- evalscope/perf/plugin/api/default_api.py +179 -79
- evalscope/perf/plugin/api/openai_api.py +4 -3
- evalscope/perf/plugin/datasets/base.py +21 -0
- evalscope/perf/plugin/datasets/custom.py +2 -3
- evalscope/perf/plugin/datasets/line_by_line.py +2 -3
- evalscope/perf/plugin/datasets/longalpaca.py +2 -3
- evalscope/perf/plugin/datasets/openqa.py +2 -4
- evalscope/perf/plugin/datasets/random_dataset.py +1 -3
- evalscope/perf/utils/benchmark_util.py +36 -22
- evalscope/perf/utils/db_util.py +14 -19
- evalscope/perf/utils/local_server.py +0 -44
- evalscope/perf/utils/log_utils.py +21 -6
- evalscope/report/__init__.py +11 -2
- evalscope/report/combinator.py +52 -2
- evalscope/run.py +4 -0
- evalscope/utils/function_utils.py +195 -12
- evalscope/utils/io_utils.py +74 -0
- evalscope/utils/json_schema.py +8 -6
- evalscope/utils/logger.py +49 -17
- evalscope/utils/multi_choices.py +16 -1
- evalscope/utils/ner.py +377 -0
- evalscope/version.py +2 -2
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/METADATA +239 -393
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/RECORD +140 -98
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/WHEEL +1 -1
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/top_level.txt +0 -1
- tests/__init__.py +0 -1
- tests/benchmark/__init__.py +0 -1
- tests/benchmark/test_eval.py +0 -429
- tests/benchmark/test_image_edit.py +0 -65
- tests/benchmark/test_sandbox.py +0 -81
- tests/benchmark/test_t2i.py +0 -142
- tests/benchmark/test_vlm.py +0 -137
- tests/cli/__init__.py +0 -1
- tests/cli/test_all.py +0 -269
- tests/cli/test_collection.py +0 -99
- tests/cli/test_custom.py +0 -268
- tests/cli/test_reasoning.py +0 -81
- tests/common.py +0 -73
- tests/perf/__init__.py +0 -1
- tests/perf/test_perf.py +0 -206
- tests/rag/test_clip_benchmark.py +0 -87
- tests/rag/test_mteb.py +0 -213
- tests/rag/test_ragas.py +0 -128
- tests/swift/__init__.py +0 -1
- tests/swift/test_run_swift_eval.py +0 -146
- tests/swift/test_run_swift_vlm_eval.py +0 -128
- tests/swift/test_run_swift_vlm_jugde_eval.py +0 -157
- tests/test_run_all.py +0 -12
- tests/utils.py +0 -13
- tests/vlm/__init__.py +0 -1
- tests/vlm/test_vlmeval.py +0 -102
- {tests/rag → evalscope/benchmarks/aa_lcr}/__init__.py +0 -0
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info}/entry_points.txt +0 -0
- {evalscope-1.0.2.dist-info → evalscope-1.1.1.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
import ast
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
|
|
6
|
+
from .IoUscore_metric import calculate_iou, extract_coordinates, vqa_with_position_evaluation
|
|
7
|
+
from .page_ocr_metric import cal_per_metrics
|
|
8
|
+
from .spotting_metric import extract_bounding_boxes_robust, spotting_evaluation
|
|
9
|
+
from .TEDS_metric import (
|
|
10
|
+
TEDS,
|
|
11
|
+
compute_f1_score,
|
|
12
|
+
convert_markdown_table_to_html,
|
|
13
|
+
convert_str_to_dict,
|
|
14
|
+
convert_str_to_multi_dict,
|
|
15
|
+
dict_to_html,
|
|
16
|
+
doc_parsing_evaluation,
|
|
17
|
+
generate_combinations,
|
|
18
|
+
wrap_html_table,
|
|
19
|
+
)
|
|
20
|
+
from .vqa_metric import (
|
|
21
|
+
cn_math_expression_evaluation,
|
|
22
|
+
cn_vqa_evaluation,
|
|
23
|
+
counting_evaluation,
|
|
24
|
+
math_expression_evaluation,
|
|
25
|
+
vqa_evaluation,
|
|
26
|
+
vqa_evaluation_case_sensitive,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
teds = TEDS(n_jobs=os.cpu_count() or 1)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def is_nan_value(value):
|
|
33
|
+
if value is None:
|
|
34
|
+
return True
|
|
35
|
+
if isinstance(value, str) and value.lower() == 'nan':
|
|
36
|
+
return True
|
|
37
|
+
try:
|
|
38
|
+
import pandas as pd
|
|
39
|
+
|
|
40
|
+
if pd.isna(value):
|
|
41
|
+
return True
|
|
42
|
+
except:
|
|
43
|
+
pass
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_value_or_zero(value):
|
|
48
|
+
return 0.0 if value is None else value
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def ocrbench_v2_process_results(doc, pred):
|
|
52
|
+
question = doc['question']
|
|
53
|
+
gt_ans = doc['answers']
|
|
54
|
+
data_type = doc['type']
|
|
55
|
+
|
|
56
|
+
score = 0
|
|
57
|
+
|
|
58
|
+
if (
|
|
59
|
+
data_type == 'APP agent en' or data_type == 'ASCII art classification en' or data_type == 'math QA en'
|
|
60
|
+
or data_type == 'reasoning VQA en' or data_type == 'science QA en' or data_type == 'text recognition en'
|
|
61
|
+
or data_type == 'document classification en' or data_type == 'cognition VQA en' or data_type == 'diagram QA en'
|
|
62
|
+
):
|
|
63
|
+
if doc['eval'] == 'multiple choice':
|
|
64
|
+
if not isinstance(gt_ans, list):
|
|
65
|
+
gt_ans = [gt_ans]
|
|
66
|
+
assert len(gt_ans) == 1
|
|
67
|
+
|
|
68
|
+
if not isinstance(pred, str):
|
|
69
|
+
score = 0
|
|
70
|
+
else:
|
|
71
|
+
predict = ''.join(c for c in pred if c.isalpha())
|
|
72
|
+
|
|
73
|
+
if predict == gt_ans[0]:
|
|
74
|
+
score = 1
|
|
75
|
+
else:
|
|
76
|
+
score = 0
|
|
77
|
+
elif doc['eval'] == 'case sensitive':
|
|
78
|
+
score = vqa_evaluation_case_sensitive(pred, gt_ans)
|
|
79
|
+
|
|
80
|
+
else:
|
|
81
|
+
score = vqa_evaluation(pred, gt_ans)
|
|
82
|
+
|
|
83
|
+
elif data_type == 'cognition VQA cn' or data_type == 'reasoning VQA cn':
|
|
84
|
+
if doc['eval'] == 'multiple choice':
|
|
85
|
+
assert len(gt_ans) == 1
|
|
86
|
+
predict = ''.join(c for c in pred if c.isalpha())
|
|
87
|
+
|
|
88
|
+
if predict == gt_ans[0]:
|
|
89
|
+
score = 1
|
|
90
|
+
else:
|
|
91
|
+
score = 0
|
|
92
|
+
elif doc['eval'] == 'case sensitive':
|
|
93
|
+
score = vqa_evaluation_case_sensitive(pred, gt_ans)
|
|
94
|
+
|
|
95
|
+
else:
|
|
96
|
+
score = cn_vqa_evaluation(pred, gt_ans)
|
|
97
|
+
|
|
98
|
+
elif data_type == 'handwritten answer extraction cn':
|
|
99
|
+
if '简答' in question:
|
|
100
|
+
ocr_metric = cal_per_metrics(pred, gt_ans[0])
|
|
101
|
+
score = (
|
|
102
|
+
get_value_or_zero(ocr_metric['bleu']) + get_value_or_zero(ocr_metric['meteor'])
|
|
103
|
+
+ get_value_or_zero(ocr_metric['f_measure']) + (1 - get_value_or_zero(ocr_metric['edit_dist']))
|
|
104
|
+
) / 4
|
|
105
|
+
else:
|
|
106
|
+
assert len(gt_ans) == 1
|
|
107
|
+
answer = gt_ans[0]
|
|
108
|
+
chars = list(answer)
|
|
109
|
+
if len(answer) > 1:
|
|
110
|
+
answer_list = [
|
|
111
|
+
''.join(chars), '.'.join(chars), '. '.join(chars), ','.join(chars), ', '.join(chars),
|
|
112
|
+
'、'.join(chars), ';'.join(chars), '; '.join(chars), ' '.join(chars), '和'.join(chars)
|
|
113
|
+
]
|
|
114
|
+
max_score = 0
|
|
115
|
+
for answer in answer_list:
|
|
116
|
+
if answer in pred:
|
|
117
|
+
temp_score = 1
|
|
118
|
+
else:
|
|
119
|
+
temp_score = 0
|
|
120
|
+
if temp_score > max_score:
|
|
121
|
+
max_score = temp_score
|
|
122
|
+
score = max_score
|
|
123
|
+
|
|
124
|
+
else:
|
|
125
|
+
if gt_ans[0] in pred:
|
|
126
|
+
score = 1
|
|
127
|
+
else:
|
|
128
|
+
score = 0
|
|
129
|
+
|
|
130
|
+
elif data_type == 'formula recognition cn':
|
|
131
|
+
if is_nan_value(pred):
|
|
132
|
+
score = 0
|
|
133
|
+
else:
|
|
134
|
+
score = cn_math_expression_evaluation(pred, gt_ans)
|
|
135
|
+
|
|
136
|
+
elif data_type == 'text counting en':
|
|
137
|
+
score = counting_evaluation(pred, gt_ans, doc['eval'])
|
|
138
|
+
|
|
139
|
+
elif data_type == 'formula recognition en':
|
|
140
|
+
score = math_expression_evaluation(pred, gt_ans)
|
|
141
|
+
|
|
142
|
+
elif data_type == 'table parsing en':
|
|
143
|
+
if type(gt_ans) == list and len(gt_ans) == 1:
|
|
144
|
+
if not isinstance(pred, str):
|
|
145
|
+
score = 0
|
|
146
|
+
|
|
147
|
+
elif 'html' in question.lower():
|
|
148
|
+
no_find = False
|
|
149
|
+
predict_table = pred.replace('\n', '')
|
|
150
|
+
if '<body' in predict_table:
|
|
151
|
+
predict_table = re.findall('<body.*', predict_table)[0]
|
|
152
|
+
elif '<table' in predict_table:
|
|
153
|
+
predict_table = re.findall('<table.*', predict_table)[0]
|
|
154
|
+
else:
|
|
155
|
+
no_find = True
|
|
156
|
+
|
|
157
|
+
if no_find:
|
|
158
|
+
score = 0
|
|
159
|
+
else:
|
|
160
|
+
pred_table_html = wrap_html_table(predict_table)
|
|
161
|
+
gold_table_html = wrap_html_table(gt_ans[0])
|
|
162
|
+
try:
|
|
163
|
+
score = teds.evaluate(pred_table_html, gold_table_html)
|
|
164
|
+
except:
|
|
165
|
+
score = 0
|
|
166
|
+
|
|
167
|
+
elif 'markdown' in question.lower():
|
|
168
|
+
if not isinstance(pred, str):
|
|
169
|
+
prediction = str(pred)
|
|
170
|
+
pred_table_html = convert_markdown_table_to_html(prediction)
|
|
171
|
+
gt_table_html = convert_markdown_table_to_html(gt_ans[0])
|
|
172
|
+
score = teds.evaluate(pred_table_html, gt_table_html)
|
|
173
|
+
|
|
174
|
+
else:
|
|
175
|
+
pred_table_html = convert_markdown_table_to_html(pred)
|
|
176
|
+
gt_table_html = convert_markdown_table_to_html(gt_ans[0])
|
|
177
|
+
score = teds.evaluate(pred_table_html, gt_table_html)
|
|
178
|
+
else:
|
|
179
|
+
raise ValueError
|
|
180
|
+
|
|
181
|
+
elif data_type == 'table parsing cn':
|
|
182
|
+
if not isinstance(pred, str):
|
|
183
|
+
score = 0
|
|
184
|
+
else:
|
|
185
|
+
no_find = False
|
|
186
|
+
predict_table = pred.replace('\n', '')
|
|
187
|
+
if '<body' in predict_table:
|
|
188
|
+
predict_table = re.findall('<body.*', predict_table)[0]
|
|
189
|
+
elif '<table' in predict_table:
|
|
190
|
+
predict_table = re.findall('<table.*', predict_table)[0]
|
|
191
|
+
else:
|
|
192
|
+
no_find = True
|
|
193
|
+
|
|
194
|
+
if no_find:
|
|
195
|
+
score = 0
|
|
196
|
+
else:
|
|
197
|
+
pred_table_html = wrap_html_table(predict_table)
|
|
198
|
+
gold_table_html = wrap_html_table(gt_ans[0])
|
|
199
|
+
try:
|
|
200
|
+
score = teds.evaluate(pred_table_html, gold_table_html)
|
|
201
|
+
except:
|
|
202
|
+
score = 0
|
|
203
|
+
print('error')
|
|
204
|
+
|
|
205
|
+
elif data_type == 'chart parsing en':
|
|
206
|
+
answer = gt_ans[0]
|
|
207
|
+
if pred:
|
|
208
|
+
pred_chart_dict = convert_str_to_multi_dict(pred)
|
|
209
|
+
if len(pred_chart_dict) == 0:
|
|
210
|
+
score = 0
|
|
211
|
+
else:
|
|
212
|
+
pred_chart_html = dict_to_html(pred_chart_dict)
|
|
213
|
+
if isinstance(answer, str):
|
|
214
|
+
answer = convert_str_to_multi_dict(pred)
|
|
215
|
+
gt_chart_html = dict_to_html(answer)
|
|
216
|
+
score = teds.evaluate(pred_chart_html, gt_chart_html)
|
|
217
|
+
else:
|
|
218
|
+
score = 0
|
|
219
|
+
|
|
220
|
+
elif data_type == 'document parsing en':
|
|
221
|
+
assert type(gt_ans) == list and len(gt_ans) == 1
|
|
222
|
+
score = doc_parsing_evaluation(pred, gt_ans[0])
|
|
223
|
+
|
|
224
|
+
elif data_type == 'document parsing cn':
|
|
225
|
+
assert type(gt_ans) == list and len(gt_ans) == 1
|
|
226
|
+
score = doc_parsing_evaluation(pred, gt_ans[0])
|
|
227
|
+
|
|
228
|
+
elif data_type == 'key information extraction en' or data_type == 'key information mapping en':
|
|
229
|
+
assert len(gt_ans) == 1
|
|
230
|
+
answers = generate_combinations(gt_ans[0])
|
|
231
|
+
|
|
232
|
+
if type(answers) == list and len(answers) == 1:
|
|
233
|
+
if not isinstance(pred, str):
|
|
234
|
+
score = 0
|
|
235
|
+
else:
|
|
236
|
+
pred_kie_dict = convert_str_to_dict(pred)
|
|
237
|
+
score = compute_f1_score(pred_kie_dict, answers[0])
|
|
238
|
+
else:
|
|
239
|
+
max_score = 0
|
|
240
|
+
for answer in answers:
|
|
241
|
+
pred_kie_dict = convert_str_to_dict(pred)
|
|
242
|
+
score = compute_f1_score(pred_kie_dict, answer)
|
|
243
|
+
|
|
244
|
+
if score > max_score:
|
|
245
|
+
max_score = score
|
|
246
|
+
score = max_score
|
|
247
|
+
|
|
248
|
+
elif data_type == 'key information extraction cn':
|
|
249
|
+
assert len(gt_ans) == 1
|
|
250
|
+
answers = ast.literal_eval(gt_ans[0])
|
|
251
|
+
answers = {k: v if isinstance(v, list) else [v] for k, v in answers.items()}
|
|
252
|
+
answers = generate_combinations(answers)
|
|
253
|
+
if type(answers) == list and len(answers) == 1:
|
|
254
|
+
if not isinstance(pred, str):
|
|
255
|
+
score = 0
|
|
256
|
+
else:
|
|
257
|
+
pred_kie_dict = convert_str_to_dict(pred)
|
|
258
|
+
score = compute_f1_score(pred_kie_dict, answers[0])
|
|
259
|
+
else:
|
|
260
|
+
max_score = 0
|
|
261
|
+
for answer in answers:
|
|
262
|
+
pred_kie_dict = convert_str_to_dict(pred)
|
|
263
|
+
score = compute_f1_score(pred_kie_dict, answer)
|
|
264
|
+
|
|
265
|
+
if score > max_score:
|
|
266
|
+
max_score = score
|
|
267
|
+
score = max_score
|
|
268
|
+
|
|
269
|
+
elif data_type == 'VQA with position en':
|
|
270
|
+
if not isinstance(pred, str):
|
|
271
|
+
score = 0
|
|
272
|
+
else:
|
|
273
|
+
pred_dict = convert_str_to_dict(pred)
|
|
274
|
+
score = vqa_with_position_evaluation(pred_dict, doc)
|
|
275
|
+
|
|
276
|
+
elif data_type == 'text translation cn':
|
|
277
|
+
if len(pred) == 0:
|
|
278
|
+
score = 0
|
|
279
|
+
else:
|
|
280
|
+
ocr_metric = cal_per_metrics(pred, gt_ans[0])
|
|
281
|
+
score = (
|
|
282
|
+
ocr_metric['bleu'] + ocr_metric['meteor'] + ocr_metric['f_measure'] + (1 - ocr_metric['edit_dist'])
|
|
283
|
+
) / 4
|
|
284
|
+
|
|
285
|
+
elif data_type == 'fine-grained text recognition en':
|
|
286
|
+
if not isinstance(pred, str):
|
|
287
|
+
score = 0
|
|
288
|
+
elif len(pred) == 0:
|
|
289
|
+
score = 0
|
|
290
|
+
else:
|
|
291
|
+
ocr_metric = cal_per_metrics(pred, gt_ans[0])
|
|
292
|
+
score = (
|
|
293
|
+
get_value_or_zero(ocr_metric['bleu']) + get_value_or_zero(ocr_metric['meteor'])
|
|
294
|
+
+ get_value_or_zero(ocr_metric['f_measure']) + (1 - get_value_or_zero(ocr_metric['edit_dist']))
|
|
295
|
+
) / 4
|
|
296
|
+
elif data_type == 'full-page OCR en':
|
|
297
|
+
if not pred:
|
|
298
|
+
score = 0
|
|
299
|
+
else:
|
|
300
|
+
ocr_metric = cal_per_metrics(pred, gt_ans[0])
|
|
301
|
+
score = (
|
|
302
|
+
get_value_or_zero(ocr_metric['bleu']) + get_value_or_zero(ocr_metric['meteor'])
|
|
303
|
+
+ get_value_or_zero(ocr_metric['f_measure']) + (1 - get_value_or_zero(ocr_metric['edit_dist']))
|
|
304
|
+
) / 4
|
|
305
|
+
|
|
306
|
+
elif data_type == 'full-page OCR cn':
|
|
307
|
+
if not isinstance(pred, str):
|
|
308
|
+
score = 0
|
|
309
|
+
else:
|
|
310
|
+
if len(pred) == 0:
|
|
311
|
+
score = 0
|
|
312
|
+
else:
|
|
313
|
+
ocr_metric = cal_per_metrics(pred, gt_ans[0])
|
|
314
|
+
score = (
|
|
315
|
+
ocr_metric['bleu'] + ocr_metric['meteor'] + ocr_metric['f_measure'] + (1 - ocr_metric['edit_dist'])
|
|
316
|
+
) / 4
|
|
317
|
+
|
|
318
|
+
elif data_type == 'text grounding en':
|
|
319
|
+
if not isinstance(pred, str):
|
|
320
|
+
score = 0
|
|
321
|
+
else:
|
|
322
|
+
predict_bbox = extract_coordinates(pred)
|
|
323
|
+
if not predict_bbox:
|
|
324
|
+
score = 0
|
|
325
|
+
else:
|
|
326
|
+
score = calculate_iou(predict_bbox, gt_ans)
|
|
327
|
+
|
|
328
|
+
elif data_type == 'text spotting en':
|
|
329
|
+
if not isinstance(pred, str):
|
|
330
|
+
score = 0
|
|
331
|
+
else:
|
|
332
|
+
predict_bbox = extract_bounding_boxes_robust(pred)
|
|
333
|
+
if not predict_bbox:
|
|
334
|
+
score = 0
|
|
335
|
+
else:
|
|
336
|
+
score = spotting_evaluation(predict_bbox, doc)
|
|
337
|
+
|
|
338
|
+
return score
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def calculate_average_score(categories, OCRBench_v2_score):
|
|
342
|
+
return sum(
|
|
343
|
+
sum(OCRBench_v2_score[cat]) / len(OCRBench_v2_score[cat]) if len(OCRBench_v2_score[cat]) > 0 else 0
|
|
344
|
+
for cat in categories
|
|
345
|
+
) / len(categories)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def ocrbench_v2_aggregate_accuracy(results):
|
|
349
|
+
question_type_scores = {}
|
|
350
|
+
OCRBench_v2_score = {
|
|
351
|
+
'text_recognition_en': [],
|
|
352
|
+
'text_detection_en': [],
|
|
353
|
+
'text_spotting_en': [],
|
|
354
|
+
'relationship_extraction_en': [],
|
|
355
|
+
'element_parsing_en': [],
|
|
356
|
+
'mathematical_calculation_en': [],
|
|
357
|
+
'visual_text_understanding_en': [],
|
|
358
|
+
'knowledge_reasoning_en': [],
|
|
359
|
+
'text_recognition_cn': [],
|
|
360
|
+
'relationship_extraction_cn': [],
|
|
361
|
+
'element_parsing_cn': [],
|
|
362
|
+
'visual_text_understanding_cn': [],
|
|
363
|
+
'knowledge_reasoning_cn': [],
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
for result in results:
|
|
367
|
+
|
|
368
|
+
question_type = result['question_type']
|
|
369
|
+
score = result['score']
|
|
370
|
+
|
|
371
|
+
if question_type not in question_type_scores:
|
|
372
|
+
question_type_scores[question_type] = []
|
|
373
|
+
question_type_scores[question_type].append(score)
|
|
374
|
+
|
|
375
|
+
if question_type in ['text recognition en', 'fine-grained text recognition en', 'full-page OCR en']:
|
|
376
|
+
OCRBench_v2_score['text_recognition_en'].append(score)
|
|
377
|
+
|
|
378
|
+
elif question_type in ['text grounding en', 'VQA with position en']:
|
|
379
|
+
OCRBench_v2_score['text_detection_en'].append(score)
|
|
380
|
+
|
|
381
|
+
elif question_type == 'text spotting en':
|
|
382
|
+
OCRBench_v2_score['text_spotting_en'].append(score)
|
|
383
|
+
|
|
384
|
+
elif question_type in ['key information extraction en', 'key information mapping en']:
|
|
385
|
+
OCRBench_v2_score['relationship_extraction_en'].append(score)
|
|
386
|
+
|
|
387
|
+
elif question_type in ['document parsing en', 'chart parsing en', 'table parsing en', 'formula recognition en']:
|
|
388
|
+
OCRBench_v2_score['element_parsing_en'].append(score)
|
|
389
|
+
|
|
390
|
+
elif question_type in ['math QA en', 'text counting en']:
|
|
391
|
+
OCRBench_v2_score['mathematical_calculation_en'].append(score)
|
|
392
|
+
|
|
393
|
+
elif question_type in ['document classification en', 'cognition VQA en', 'diagram QA en']:
|
|
394
|
+
OCRBench_v2_score['visual_text_understanding_en'].append(score)
|
|
395
|
+
|
|
396
|
+
elif question_type in ['reasoning VQA en', 'science QA en', 'APP agent en', 'ASCII art classification en']:
|
|
397
|
+
OCRBench_v2_score['knowledge_reasoning_en'].append(score)
|
|
398
|
+
|
|
399
|
+
elif question_type == 'full-page OCR cn':
|
|
400
|
+
OCRBench_v2_score['text_recognition_cn'].append(score)
|
|
401
|
+
|
|
402
|
+
elif question_type in ['key information extraction cn', 'handwritten answer extraction cn']:
|
|
403
|
+
OCRBench_v2_score['relationship_extraction_cn'].append(score)
|
|
404
|
+
|
|
405
|
+
elif question_type in ['document parsing cn', 'table parsing cn', 'formula recognition cn']:
|
|
406
|
+
OCRBench_v2_score['element_parsing_cn'].append(score)
|
|
407
|
+
|
|
408
|
+
elif question_type == 'cognition VQA cn':
|
|
409
|
+
OCRBench_v2_score['visual_text_understanding_cn'].append(score)
|
|
410
|
+
|
|
411
|
+
elif question_type in ['reasoning VQA cn', 'text translation cn']:
|
|
412
|
+
OCRBench_v2_score['knowledge_reasoning_cn'].append(score)
|
|
413
|
+
|
|
414
|
+
else:
|
|
415
|
+
print('No such task!')
|
|
416
|
+
raise TypeError
|
|
417
|
+
|
|
418
|
+
english_tasks = [
|
|
419
|
+
'text_recognition_en', 'text_detection_en', 'text_spotting_en', 'relationship_extraction_en',
|
|
420
|
+
'element_parsing_en', 'mathematical_calculation_en', 'visual_text_understanding_en', 'knowledge_reasoning_en'
|
|
421
|
+
]
|
|
422
|
+
|
|
423
|
+
chinese_tasks = [
|
|
424
|
+
'text_recognition_cn', 'relationship_extraction_cn', 'element_parsing_cn', 'visual_text_understanding_cn',
|
|
425
|
+
'knowledge_reasoning_cn'
|
|
426
|
+
]
|
|
427
|
+
|
|
428
|
+
OCRBench_v2_English_subset_score = calculate_average_score(english_tasks, OCRBench_v2_score)
|
|
429
|
+
OCRBench_v2_Chinese_subset_score = calculate_average_score(chinese_tasks, OCRBench_v2_score)
|
|
430
|
+
|
|
431
|
+
Final_score = (OCRBench_v2_English_subset_score + OCRBench_v2_Chinese_subset_score) / 2
|
|
432
|
+
|
|
433
|
+
return Final_score # return the final score as accuracy
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
import math
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def levenshtein_distance(s1, s2):
|
|
7
|
+
if len(s1) > len(s2):
|
|
8
|
+
s1, s2 = s2, s1
|
|
9
|
+
|
|
10
|
+
distances = range(len(s1) + 1)
|
|
11
|
+
for i2, c2 in enumerate(s2):
|
|
12
|
+
distances_ = [i2 + 1]
|
|
13
|
+
for i1, c1 in enumerate(s1):
|
|
14
|
+
if c1 == c2:
|
|
15
|
+
distances_.append(distances[i1])
|
|
16
|
+
else:
|
|
17
|
+
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
|
|
18
|
+
distances = distances_
|
|
19
|
+
return distances[-1]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def vqa_evaluation(predict, answers):
|
|
23
|
+
score = 0
|
|
24
|
+
if isinstance(answers, list):
|
|
25
|
+
predict_str = str(predict).lower().strip().replace('\n', ' ')
|
|
26
|
+
for ans in answers:
|
|
27
|
+
answer = str(ans).lower().strip().replace('\n', ' ')
|
|
28
|
+
if len(answer.split()) < 5:
|
|
29
|
+
if answer in predict_str:
|
|
30
|
+
score = 1
|
|
31
|
+
else:
|
|
32
|
+
dist = levenshtein_distance(predict_str, answer)
|
|
33
|
+
length = max(len(predict_str), len(answer))
|
|
34
|
+
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
|
|
35
|
+
ANLS_value = 1 - ANLS_value
|
|
36
|
+
|
|
37
|
+
if ANLS_value >= 0.5 and ANLS_value > score:
|
|
38
|
+
score = ANLS_value
|
|
39
|
+
|
|
40
|
+
else:
|
|
41
|
+
answer = str(answers).lower().strip().replace('\n', ' ')
|
|
42
|
+
predict_str = str(predict).lower().strip().replace('\n', ' ')
|
|
43
|
+
if len(answer.split()) < 5:
|
|
44
|
+
if answer in predict_str:
|
|
45
|
+
score = 1
|
|
46
|
+
else:
|
|
47
|
+
dist = levenshtein_distance(predict_str, answer)
|
|
48
|
+
length = max(len(predict_str), len(answer))
|
|
49
|
+
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
|
|
50
|
+
ANLS_value = 1 - ANLS_value
|
|
51
|
+
|
|
52
|
+
if ANLS_value >= 0.5 and ANLS_value > score:
|
|
53
|
+
score = ANLS_value
|
|
54
|
+
|
|
55
|
+
return score
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def cn_vqa_evaluation(predict, answers):
|
|
59
|
+
score = 0
|
|
60
|
+
if isinstance(answers, list):
|
|
61
|
+
predict_str = str(predict).lower().strip().replace('\n', ' ').replace(' ', '')
|
|
62
|
+
for ans in answers:
|
|
63
|
+
answer = str(ans).lower().strip().replace('\n', ' ').replace(' ', '')
|
|
64
|
+
if len(answer.split(',')) < 4:
|
|
65
|
+
if answer in predict_str:
|
|
66
|
+
score = 1
|
|
67
|
+
else:
|
|
68
|
+
dist = levenshtein_distance(predict_str, answer)
|
|
69
|
+
length = max(len(predict_str), len(answer))
|
|
70
|
+
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
|
|
71
|
+
ANLS_value = 1 - ANLS_value
|
|
72
|
+
|
|
73
|
+
if ANLS_value >= 0.5 and ANLS_value > score:
|
|
74
|
+
score = ANLS_value
|
|
75
|
+
|
|
76
|
+
else:
|
|
77
|
+
answer = str(answers).lower().strip().replace('\n', ' ').replace(' ', '')
|
|
78
|
+
predict_str = str(predict).lower().strip().replace('\n', ' ').replace(' ', '')
|
|
79
|
+
if len(answer.split(',')) < 4:
|
|
80
|
+
if answer in predict_str:
|
|
81
|
+
score = 1
|
|
82
|
+
else:
|
|
83
|
+
dist = levenshtein_distance(predict_str, answer)
|
|
84
|
+
length = max(len(predict_str), len(answer))
|
|
85
|
+
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
|
|
86
|
+
ANLS_value = 1 - ANLS_value
|
|
87
|
+
|
|
88
|
+
if ANLS_value >= 0.5 and ANLS_value > score:
|
|
89
|
+
score = ANLS_value
|
|
90
|
+
|
|
91
|
+
return score
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def vqa_evaluation_case_sensitive(predict, answers):
|
|
95
|
+
score = 0
|
|
96
|
+
if isinstance(answers, list):
|
|
97
|
+
predict_str = str(predict).strip().replace('\n', ' ')
|
|
98
|
+
for ans in answers:
|
|
99
|
+
answer = str(ans).strip().replace('\n', ' ')
|
|
100
|
+
if len(answer.split()) < 5:
|
|
101
|
+
if answer in predict_str:
|
|
102
|
+
score = 1
|
|
103
|
+
else:
|
|
104
|
+
dist = levenshtein_distance(predict_str, answer)
|
|
105
|
+
length = max(len(predict_str), len(answer))
|
|
106
|
+
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
|
|
107
|
+
ANLS_value = 1 - ANLS_value
|
|
108
|
+
|
|
109
|
+
if ANLS_value >= 0.5 and ANLS_value > score:
|
|
110
|
+
score = ANLS_value
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
answer = str(answers).strip().replace('\n', ' ')
|
|
114
|
+
predict_str = str(predict).strip().replace('\n', ' ')
|
|
115
|
+
if len(answer.split()) < 5:
|
|
116
|
+
if answer in predict_str:
|
|
117
|
+
score = 1
|
|
118
|
+
else:
|
|
119
|
+
dist = levenshtein_distance(predict_str, answer)
|
|
120
|
+
length = max(len(predict_str), len(answer))
|
|
121
|
+
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
|
|
122
|
+
ANLS_value = 1 - ANLS_value
|
|
123
|
+
|
|
124
|
+
if ANLS_value >= 0.5 and ANLS_value > score:
|
|
125
|
+
score = ANLS_value
|
|
126
|
+
|
|
127
|
+
return score
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def extract_first_number(string):
|
|
131
|
+
match = re.search(r'\d+', string)
|
|
132
|
+
if match:
|
|
133
|
+
return int(match.group())
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def counting_evaluation(predict, answers, eval_method):
|
|
138
|
+
score = 0
|
|
139
|
+
|
|
140
|
+
# normalize predict to string for both matching and number extraction
|
|
141
|
+
if isinstance(predict, str):
|
|
142
|
+
predict_str = predict.lower().strip().replace('\n', ' ')
|
|
143
|
+
elif isinstance(predict, (int, float)):
|
|
144
|
+
if isinstance(predict, float) and math.isnan(predict):
|
|
145
|
+
return 0
|
|
146
|
+
predict_str = str(predict).lower().strip().replace('\n', ' ')
|
|
147
|
+
else:
|
|
148
|
+
predict_str = str(predict).lower().strip().replace('\n', ' ')
|
|
149
|
+
|
|
150
|
+
if isinstance(answers, list):
|
|
151
|
+
temp_score = 0
|
|
152
|
+
for ans in answers:
|
|
153
|
+
answer = str(ans).lower().strip().replace('\n', ' ')
|
|
154
|
+
if eval_method == 'exact match':
|
|
155
|
+
score = 1 if answer in predict_str else 0
|
|
156
|
+
elif eval_method == 'regression':
|
|
157
|
+
predict_number = extract_first_number(predict_str)
|
|
158
|
+
if predict_number is not None:
|
|
159
|
+
try:
|
|
160
|
+
answer_int = int(answer)
|
|
161
|
+
except ValueError:
|
|
162
|
+
score = 0
|
|
163
|
+
else:
|
|
164
|
+
if predict_number <= 0 or predict_number >= 2 * answer_int:
|
|
165
|
+
score = 0
|
|
166
|
+
else:
|
|
167
|
+
iou = 1 - abs(predict_number - answer_int) / answer_int
|
|
168
|
+
score = iou if iou > 0.5 else 0
|
|
169
|
+
else:
|
|
170
|
+
score = 0
|
|
171
|
+
if score > temp_score:
|
|
172
|
+
temp_score = score
|
|
173
|
+
score = temp_score
|
|
174
|
+
|
|
175
|
+
else:
|
|
176
|
+
answer = str(answers).lower().strip().replace('\n', ' ')
|
|
177
|
+
if eval_method == 'exact match':
|
|
178
|
+
score = 1 if answer in predict_str else 0
|
|
179
|
+
elif eval_method == 'regression':
|
|
180
|
+
predict_number = extract_first_number(predict_str)
|
|
181
|
+
if predict_number is not None:
|
|
182
|
+
try:
|
|
183
|
+
answer_int = int(answer)
|
|
184
|
+
except ValueError:
|
|
185
|
+
score = 0
|
|
186
|
+
else:
|
|
187
|
+
if predict_number <= 0 or predict_number >= 2 * answer_int:
|
|
188
|
+
score = 0
|
|
189
|
+
else:
|
|
190
|
+
iou = 1 - abs(predict_number - answer_int) / answer_int
|
|
191
|
+
score = iou if iou > 0.5 else 0
|
|
192
|
+
else:
|
|
193
|
+
score = 0
|
|
194
|
+
return score
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def math_expression_evaluation(predict, answers):
|
|
198
|
+
score = 0
|
|
199
|
+
if type(answers) == list:
|
|
200
|
+
for j in range(len(answers)):
|
|
201
|
+
answer = answers[j].strip().replace('\n', ' ').replace(' ', '')
|
|
202
|
+
predict = predict.strip().replace('\n', ' ').replace(' ', '')
|
|
203
|
+
if answer in predict:
|
|
204
|
+
score = 1
|
|
205
|
+
else:
|
|
206
|
+
answers = answers.strip().replace('\n', ' ').replace(' ', '')
|
|
207
|
+
predict = predict.strip().replace('\n', ' ').replace(' ', '')
|
|
208
|
+
if answers in predict:
|
|
209
|
+
score = 1
|
|
210
|
+
return score
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def remove_text_tags(latex_str):
|
|
214
|
+
"""
|
|
215
|
+
Removes LaTeX \text{...} tags while keeping their content.
|
|
216
|
+
|
|
217
|
+
:param latex_str: A string containing LaTeX expressions
|
|
218
|
+
:return: The processed string with \text{...} tags removed
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
pattern = r'\\text\{([^{}]*)\}'
|
|
222
|
+
|
|
223
|
+
processed_str = re.sub(pattern, r'\1', latex_str)
|
|
224
|
+
|
|
225
|
+
return processed_str
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def cn_math_expression_evaluation(predict, answers):
|
|
229
|
+
score = 0
|
|
230
|
+
|
|
231
|
+
assert len(answers) == 1
|
|
232
|
+
answers = [remove_text_tags(answers[0])]
|
|
233
|
+
predict = remove_text_tags(predict)
|
|
234
|
+
|
|
235
|
+
if type(answers) == list:
|
|
236
|
+
for j in range(len(answers)):
|
|
237
|
+
answer = answers[j].strip().replace('\n', ' ').replace(' ', '')
|
|
238
|
+
predict = predict.strip().replace('\n', ' ').replace(' ', '')
|
|
239
|
+
if answer in predict:
|
|
240
|
+
score = 1
|
|
241
|
+
else:
|
|
242
|
+
answers = answers.strip().replace('\n', ' ').replace(' ', '')
|
|
243
|
+
predict = predict.strip().replace('\n', ' ').replace(' ', '')
|
|
244
|
+
if answers in predict:
|
|
245
|
+
score = 1
|
|
246
|
+
return score
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
if __name__ == '__main__':
|
|
250
|
+
test_predict = 'apple pie and banana'
|
|
251
|
+
test_answers = ['apple', 'banana pie', 'apple pie and orange']
|
|
252
|
+
|
|
253
|
+
vqa_score = vqa_evaluation(test_predict, test_answers)
|
|
254
|
+
print(f"VQA evaluation score for predict '{test_predict}' and answers {test_answers}: {vqa_score}")
|
|
File without changes
|