evalscope 0.8.2__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/__init__.py +2 -0
- evalscope/arguments.py +10 -3
- evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_classification.py +0 -1
- evalscope/backend/rag_eval/utils/llm.py +1 -1
- evalscope/benchmarks/__init__.py +20 -1
- evalscope/benchmarks/arc/__init__.py +0 -5
- evalscope/benchmarks/arc/arc_adapter.py +23 -99
- evalscope/benchmarks/bbh/__init__.py +0 -4
- evalscope/benchmarks/bbh/bbh_adapter.py +19 -89
- evalscope/benchmarks/benchmark.py +70 -59
- evalscope/benchmarks/ceval/__init__.py +0 -5
- evalscope/benchmarks/ceval/ceval_adapter.py +22 -46
- evalscope/benchmarks/cmmlu/__init__.py +0 -5
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +20 -41
- evalscope/benchmarks/competition_math/__init__.py +0 -5
- evalscope/benchmarks/competition_math/competition_math_adapter.py +29 -371
- evalscope/benchmarks/data_adapter.py +114 -85
- evalscope/benchmarks/general_qa/__init__.py +0 -5
- evalscope/benchmarks/general_qa/general_qa_adapter.py +16 -19
- evalscope/benchmarks/gsm8k/__init__.py +0 -4
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +19 -98
- evalscope/benchmarks/hellaswag/__init__.py +0 -5
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +23 -96
- evalscope/benchmarks/humaneval/__init__.py +0 -4
- evalscope/benchmarks/humaneval/humaneval_adapter.py +16 -117
- evalscope/benchmarks/mmlu/__init__.py +0 -5
- evalscope/benchmarks/mmlu/mmlu_adapter.py +26 -48
- evalscope/benchmarks/mmlu_pro/__init__.py +0 -0
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +110 -0
- evalscope/benchmarks/race/__init__.py +0 -5
- evalscope/benchmarks/race/race_adapter.py +25 -53
- evalscope/benchmarks/trivia_qa/__init__.py +0 -5
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +24 -97
- evalscope/benchmarks/truthful_qa/__init__.py +0 -5
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +23 -33
- evalscope/collections/__init__.py +3 -0
- evalscope/collections/evaluator.py +178 -0
- evalscope/collections/sampler.py +132 -0
- evalscope/collections/schema.py +122 -0
- evalscope/config.py +7 -5
- evalscope/constants.py +7 -28
- evalscope/evaluator/evaluator.py +66 -109
- evalscope/evaluator/reviewer/auto_reviewer.py +12 -4
- evalscope/metrics/__init__.py +6 -0
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +1 -1
- evalscope/metrics/math_accuracy.py +193 -50
- evalscope/metrics/metrics.py +7 -4
- evalscope/metrics/rouge_metric.py +13 -8
- evalscope/models/__init__.py +14 -1
- evalscope/models/base_adapter.py +52 -0
- evalscope/models/chat_adapter.py +138 -0
- evalscope/models/choice_adapter.py +211 -0
- evalscope/models/custom_adapter.py +67 -0
- evalscope/models/local_model.py +74 -0
- evalscope/models/model.py +141 -0
- evalscope/models/server_adapter.py +104 -0
- evalscope/run.py +37 -66
- evalscope/run_arena.py +1 -1
- evalscope/utils/__init__.py +1 -1
- evalscope/utils/chat_service.py +4 -3
- evalscope/utils/io_utils.py +8 -0
- evalscope/utils/logger.py +4 -0
- evalscope/utils/model_utils.py +10 -0
- evalscope/utils/utils.py +3 -25
- evalscope/version.py +2 -2
- {evalscope-0.8.2.dist-info → evalscope-0.9.0.dist-info}/METADATA +32 -15
- {evalscope-0.8.2.dist-info → evalscope-0.9.0.dist-info}/RECORD +75 -66
- tests/cli/test_collection.py +53 -0
- tests/cli/test_run.py +43 -1
- tests/rag/test_mteb.py +3 -2
- evalscope/models/api/__init__.py +0 -3
- evalscope/models/dummy_chat_model.py +0 -49
- evalscope/models/model_adapter.py +0 -525
- evalscope/models/openai_model.py +0 -103
- /evalscope/{models/api → third_party/longbench_write/tools}/openai_api.py +0 -0
- {evalscope-0.8.2.dist-info → evalscope-0.9.0.dist-info}/LICENSE +0 -0
- {evalscope-0.8.2.dist-info → evalscope-0.9.0.dist-info}/WHEEL +0 -0
- {evalscope-0.8.2.dist-info → evalscope-0.9.0.dist-info}/entry_points.txt +0 -0
- {evalscope-0.8.2.dist-info → evalscope-0.9.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
-
|
|
3
|
-
from evalscope.benchmarks.cmmlu.cmmlu_adapter import DATASET_ID, SUBJECT_MAPPING, SUBSET_LIST
|
|
4
|
-
from evalscope.benchmarks.cmmlu.cmmlu_adapter import CMMLUAdapter
|
|
5
|
-
from evalscope.benchmarks.cmmlu.cmmlu_adapter import CMMLUAdapter as DataAdapterClass
|
|
6
|
-
from evalscope.models.model_adapter import MultiChoiceModelAdapter as ModelAdapterClass # noqa
|
|
@@ -3,8 +3,10 @@
|
|
|
3
3
|
import csv
|
|
4
4
|
import os
|
|
5
5
|
|
|
6
|
-
from evalscope.benchmarks
|
|
7
|
-
from evalscope.
|
|
6
|
+
from evalscope.benchmarks import Benchmark, DataAdapter
|
|
7
|
+
from evalscope.constants import EvalType
|
|
8
|
+
from evalscope.metrics import WeightedAverageAccuracy, exact_match
|
|
9
|
+
from evalscope.models import MultiChoiceModelAdapter
|
|
8
10
|
from evalscope.utils import ResponseParser, normalize_score
|
|
9
11
|
from evalscope.utils.logger import get_logger
|
|
10
12
|
|
|
@@ -12,8 +14,6 @@ from evalscope.utils.logger import get_logger
|
|
|
12
14
|
|
|
13
15
|
logger = get_logger()
|
|
14
16
|
|
|
15
|
-
DATASET_ID = 'modelscope/cmmlu'
|
|
16
|
-
|
|
17
17
|
SUBSET_LIST = [
|
|
18
18
|
'agronomy', 'anatomy', 'ancient_chinese', 'arts', 'astronomy', 'business_ethics', 'chinese_civil_service_exam',
|
|
19
19
|
'chinese_driving_rule', 'chinese_food_culture', 'chinese_foreign_policy', 'chinese_history', 'chinese_literature',
|
|
@@ -101,31 +101,23 @@ SUBJECT_MAPPING = {
|
|
|
101
101
|
}
|
|
102
102
|
|
|
103
103
|
|
|
104
|
+
@Benchmark.register(
|
|
105
|
+
name='cmmlu',
|
|
106
|
+
dataset_id='modelscope/cmmlu',
|
|
107
|
+
model_adapter=MultiChoiceModelAdapter,
|
|
108
|
+
subset_list=SUBSET_LIST,
|
|
109
|
+
metric_list=[WeightedAverageAccuracy],
|
|
110
|
+
few_shot_num=5,
|
|
111
|
+
train_split='dev',
|
|
112
|
+
eval_split='test',
|
|
113
|
+
)
|
|
104
114
|
class CMMLUAdapter(DataAdapter):
|
|
105
115
|
|
|
106
116
|
choices = ['A', 'B', 'C', 'D']
|
|
107
117
|
|
|
108
|
-
def __init__(self,
|
|
109
|
-
subset_list: list = None,
|
|
110
|
-
metric_list: list = None,
|
|
111
|
-
few_shot_num: int = 5,
|
|
112
|
-
train_split: str = 'dev',
|
|
113
|
-
eval_split: str = 'test',
|
|
114
|
-
**kwargs):
|
|
115
|
-
|
|
116
|
-
if subset_list is None:
|
|
117
|
-
subset_list = SUBSET_LIST
|
|
118
|
-
|
|
119
|
-
if metric_list is None:
|
|
120
|
-
metric_list = [{'name': 'WeightedAverageAccuracy', 'object': weighted_mean}]
|
|
118
|
+
def __init__(self, **kwargs):
|
|
121
119
|
|
|
122
|
-
super().__init__(
|
|
123
|
-
subset_list=subset_list,
|
|
124
|
-
metric_list=metric_list,
|
|
125
|
-
few_shot_num=few_shot_num,
|
|
126
|
-
train_split=train_split,
|
|
127
|
-
eval_split=eval_split,
|
|
128
|
-
**kwargs)
|
|
120
|
+
super().__init__(**kwargs)
|
|
129
121
|
|
|
130
122
|
def load_from_disk(self, dataset_name_or_path, subset_list, work_dir, **kwargs) -> dict:
|
|
131
123
|
data_dict = {}
|
|
@@ -187,7 +179,7 @@ class CMMLUAdapter(DataAdapter):
|
|
|
187
179
|
# Get the gold choice
|
|
188
180
|
return input_d.get('Answer', '')
|
|
189
181
|
|
|
190
|
-
def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str =
|
|
182
|
+
def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = EvalType.CHECKPOINT) -> str:
|
|
191
183
|
"""
|
|
192
184
|
Parse the model output to get the answer. Could be the best choice index.
|
|
193
185
|
|
|
@@ -199,11 +191,11 @@ class CMMLUAdapter(DataAdapter):
|
|
|
199
191
|
Returns:
|
|
200
192
|
The parsed answer. Depending on the dataset. Usually a string for chat.
|
|
201
193
|
"""
|
|
202
|
-
if eval_type ==
|
|
194
|
+
if eval_type == EvalType.CHECKPOINT:
|
|
203
195
|
return result
|
|
204
|
-
elif eval_type ==
|
|
196
|
+
elif eval_type == EvalType.SERVICE:
|
|
205
197
|
return ResponseParser.parse_first_option_with_choices(result, self.choices) # TODO: to be checked !
|
|
206
|
-
elif eval_type ==
|
|
198
|
+
elif eval_type == EvalType.CUSTOM:
|
|
207
199
|
return ResponseParser.parse_first_option_with_choices(result, self.choices) # TODO: to be checked !
|
|
208
200
|
else:
|
|
209
201
|
raise ValueError(f'Invalid eval_type: {eval_type}')
|
|
@@ -211,19 +203,6 @@ class CMMLUAdapter(DataAdapter):
|
|
|
211
203
|
def match(self, gold: str, pred: str) -> float:
|
|
212
204
|
return exact_match(gold=gold, pred=pred)
|
|
213
205
|
|
|
214
|
-
def compute_metric(self, review_res_list: list) -> float:
|
|
215
|
-
"""
|
|
216
|
-
Compute evaluation result by specific metric.
|
|
217
|
-
|
|
218
|
-
Args:
|
|
219
|
-
review_res_list: review score list, e.g. [0, 1, 1, 0, ...]
|
|
220
|
-
|
|
221
|
-
Returns:
|
|
222
|
-
The metric score.
|
|
223
|
-
"""
|
|
224
|
-
items = [(score, 1.0) for score in review_res_list]
|
|
225
|
-
return weighted_mean(items)
|
|
226
|
-
|
|
227
206
|
def gen_report(self, subset_score_map: dict, report_name: str = None) -> dict:
|
|
228
207
|
"""
|
|
229
208
|
Generate report for the evaluation.
|
|
@@ -1,6 +1 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
-
|
|
3
|
-
from evalscope.benchmarks.competition_math.competition_math_adapter import DATASET_ID, SUBSET_LIST
|
|
4
|
-
from evalscope.benchmarks.competition_math.competition_math_adapter import CompetitionMathAdapter
|
|
5
|
-
from evalscope.benchmarks.competition_math.competition_math_adapter import CompetitionMathAdapter as DataAdapterClass
|
|
6
|
-
from evalscope.models.model_adapter import ChatGenerationModelAdapter as ModelAdapterClass # noqa
|
|
@@ -4,53 +4,40 @@ import glob
|
|
|
4
4
|
import json
|
|
5
5
|
import os
|
|
6
6
|
|
|
7
|
-
from evalscope.benchmarks import DataAdapter
|
|
8
|
-
from evalscope.metrics
|
|
9
|
-
from evalscope.
|
|
7
|
+
from evalscope.benchmarks import Benchmark, DataAdapter
|
|
8
|
+
from evalscope.metrics import WeightedAverageAccuracy
|
|
9
|
+
from evalscope.metrics.math_accuracy import is_equiv, last_boxed_only_string, remove_boxed
|
|
10
|
+
from evalscope.models import ChatGenerationModelAdapter
|
|
10
11
|
from evalscope.utils.logger import get_logger
|
|
11
12
|
|
|
12
13
|
# flake8: noqa
|
|
13
14
|
|
|
14
15
|
logger = get_logger()
|
|
15
16
|
|
|
16
|
-
DATASET_ID = 'modelscope/competition_math'
|
|
17
|
-
SUBSET_LIST = ['default']
|
|
18
|
-
|
|
19
17
|
|
|
18
|
+
@Benchmark.register(
|
|
19
|
+
name='competition_math',
|
|
20
|
+
dataset_id='modelscope/competition_math',
|
|
21
|
+
model_adapter=ChatGenerationModelAdapter,
|
|
22
|
+
subset_list=['default'],
|
|
23
|
+
metric_list=[WeightedAverageAccuracy],
|
|
24
|
+
few_shot_num=4,
|
|
25
|
+
train_split='train',
|
|
26
|
+
eval_split='test',
|
|
27
|
+
prompt_template='',
|
|
28
|
+
)
|
|
20
29
|
class CompetitionMathAdapter(DataAdapter):
|
|
21
|
-
"""
|
|
22
|
-
|
|
23
|
-
def __init__(self,
|
|
24
|
-
subset_list: list = None,
|
|
25
|
-
metric_list: list = None,
|
|
26
|
-
few_shot_num: int = None,
|
|
27
|
-
train_split: str = 'train',
|
|
28
|
-
eval_split: str = 'test',
|
|
29
|
-
**kwargs):
|
|
30
|
-
|
|
31
|
-
if subset_list is None:
|
|
32
|
-
subset_list = SUBSET_LIST
|
|
30
|
+
""" To be tested for all models. """
|
|
33
31
|
|
|
34
|
-
|
|
35
|
-
metric_list = [{'name': 'WeightedAverageAccuracy', 'object': weighted_mean}]
|
|
36
|
-
|
|
37
|
-
if few_shot_num is None:
|
|
38
|
-
# Use 4-shot by default
|
|
39
|
-
logger.info(f'Set 4-shot examples by system for MATH.')
|
|
40
|
-
few_shot_num = 4
|
|
32
|
+
def __init__(self, **kwargs):
|
|
41
33
|
|
|
34
|
+
few_shot_num = kwargs.get('few_shot_num', 4)
|
|
42
35
|
if few_shot_num != 4 and few_shot_num != 0:
|
|
43
36
|
logger.error(f'The MATH benchmark ONLY supports 4-shot by system or 0-shot settings, '
|
|
44
|
-
f'but got {
|
|
45
|
-
few_shot_num = 4
|
|
37
|
+
f'but got {few_shot_num}. Use 4-shot by default.')
|
|
38
|
+
kwargs['few_shot_num'] = 4
|
|
46
39
|
|
|
47
|
-
super().__init__(
|
|
48
|
-
subset_list=subset_list,
|
|
49
|
-
metric_list=metric_list,
|
|
50
|
-
few_shot_num=few_shot_num,
|
|
51
|
-
train_split=train_split,
|
|
52
|
-
eval_split=eval_split,
|
|
53
|
-
**kwargs)
|
|
40
|
+
super().__init__(**kwargs)
|
|
54
41
|
|
|
55
42
|
def load_from_disk(self, dataset_name_or_path, subset_list, work_dir, **kwargs) -> dict:
|
|
56
43
|
data_dict: dict = {}
|
|
@@ -90,11 +77,11 @@ class CompetitionMathAdapter(DataAdapter):
|
|
|
90
77
|
use_fewshot = self.few_shot_num > 0
|
|
91
78
|
full_prompt = self._generate_prompt(input_d, use_fewshot=use_fewshot)
|
|
92
79
|
|
|
93
|
-
return {'data': [full_prompt]}
|
|
80
|
+
return {'data': [full_prompt], 'system_prompt': 'Put the final answer in \\boxed{}.'}
|
|
94
81
|
|
|
95
82
|
def get_gold_answer(self, input_d: dict) -> str:
|
|
96
83
|
# Extract the gold answer from the input dict.
|
|
97
|
-
return
|
|
84
|
+
return remove_boxed(last_boxed_only_string(input_d['solution']))
|
|
98
85
|
|
|
99
86
|
def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = 'checkpoint') -> str:
|
|
100
87
|
"""
|
|
@@ -108,77 +95,20 @@ class CompetitionMathAdapter(DataAdapter):
|
|
|
108
95
|
Returns:
|
|
109
96
|
The parsed answer. Depending on the dataset. Usually a string for chat.
|
|
110
97
|
"""
|
|
111
|
-
# TODO: check answer extraction
|
|
112
98
|
# Note: Use same extraction method for both of checkpoint/service/custom
|
|
113
|
-
|
|
99
|
+
try:
|
|
100
|
+
result = remove_boxed(last_boxed_only_string(result))
|
|
101
|
+
except Exception:
|
|
102
|
+
return None
|
|
103
|
+
return result
|
|
114
104
|
|
|
115
105
|
def match(self, gold: str, pred: str) -> float:
|
|
116
106
|
res = 0
|
|
117
|
-
if
|
|
107
|
+
if is_equiv(pred, gold):
|
|
118
108
|
res = 1
|
|
119
109
|
|
|
120
110
|
return res
|
|
121
111
|
|
|
122
|
-
def compute_metric(self, review_res_list: list) -> float:
|
|
123
|
-
"""
|
|
124
|
-
Compute evaluation result by specific metric.
|
|
125
|
-
|
|
126
|
-
Args:
|
|
127
|
-
review_res_list: review score list, e.g. [0, 1, 1, 0, ...]
|
|
128
|
-
|
|
129
|
-
Returns:
|
|
130
|
-
The metric score.
|
|
131
|
-
"""
|
|
132
|
-
items = [(score, 1.0) for score in review_res_list]
|
|
133
|
-
return weighted_mean(items)
|
|
134
|
-
|
|
135
|
-
def gen_report(self, subset_score_map: dict, report_name: str = None) -> dict:
|
|
136
|
-
"""
|
|
137
|
-
Generate the report for the model output.
|
|
138
|
-
|
|
139
|
-
Args:
|
|
140
|
-
subset_score_map: The subset-score mapping. e.g. {subset_name: (score, num), ...}
|
|
141
|
-
report_name: The user-defined report name.
|
|
142
|
-
|
|
143
|
-
Returns: A dict of metric calculation results. The format is like:
|
|
144
|
-
{
|
|
145
|
-
"name":"CompetitionMath",
|
|
146
|
-
"metric":"WeightedAverageAccuracy",
|
|
147
|
-
"score":0.5632,
|
|
148
|
-
"category":[
|
|
149
|
-
{
|
|
150
|
-
"name":"DEFAULT",
|
|
151
|
-
"score":0.5632,
|
|
152
|
-
"subset":[
|
|
153
|
-
{
|
|
154
|
-
"name":"main",
|
|
155
|
-
"score":0.5632
|
|
156
|
-
},
|
|
157
|
-
]
|
|
158
|
-
}
|
|
159
|
-
],
|
|
160
|
-
"total_num":100
|
|
161
|
-
}
|
|
162
|
-
"""
|
|
163
|
-
total_num: int = sum([num for _, num in subset_score_map.values()])
|
|
164
|
-
weighted_avg_acc: float = sum([score * num for score, num in subset_score_map.values()]) / total_num
|
|
165
|
-
weighted_avg_acc = normalize_score(score=weighted_avg_acc)
|
|
166
|
-
cate_avg_list = [{
|
|
167
|
-
'name': subset_name,
|
|
168
|
-
'score': normalize_score(score=score)
|
|
169
|
-
} for subset_name, (score, _) in subset_score_map.items()]
|
|
170
|
-
|
|
171
|
-
category_d = dict(name='DEFAULT', score=weighted_avg_acc, subset=cate_avg_list)
|
|
172
|
-
|
|
173
|
-
res_map = dict(
|
|
174
|
-
name=report_name or 'competition_math',
|
|
175
|
-
metric=self.metric_list[0]['name'],
|
|
176
|
-
score=weighted_avg_acc,
|
|
177
|
-
category=[category_d],
|
|
178
|
-
total_num=total_num)
|
|
179
|
-
|
|
180
|
-
return res_map
|
|
181
|
-
|
|
182
112
|
@classmethod
|
|
183
113
|
def _generate_prompt(cls, input_d: dict, use_fewshot: bool = True) -> str:
|
|
184
114
|
problem: str = input_d['problem']
|
|
@@ -194,275 +124,3 @@ class CompetitionMathAdapter(DataAdapter):
|
|
|
194
124
|
else:
|
|
195
125
|
context = 'Problem:\n' + problem + '\nSolution:\n'
|
|
196
126
|
return context
|
|
197
|
-
|
|
198
|
-
@classmethod
|
|
199
|
-
def _preprocess_input(cls, input: str) -> str:
|
|
200
|
-
"""
|
|
201
|
-
Preprocess the input data, remove the boxed solution.
|
|
202
|
-
|
|
203
|
-
Args:
|
|
204
|
-
input_d: The raw input. A single data format of the Competition Math.
|
|
205
|
-
|
|
206
|
-
Returns:
|
|
207
|
-
The preprocessed input.
|
|
208
|
-
"""
|
|
209
|
-
return cls._remove_boxed(cls._last_boxed_only_string(input))
|
|
210
|
-
|
|
211
|
-
@classmethod
|
|
212
|
-
def _remove_boxed(cls, s):
|
|
213
|
-
if s is None:
|
|
214
|
-
return s
|
|
215
|
-
|
|
216
|
-
if '\\boxed ' in s:
|
|
217
|
-
left = '\\boxed '
|
|
218
|
-
assert s[:len(left)] == left
|
|
219
|
-
return s[len(left):]
|
|
220
|
-
|
|
221
|
-
left = '\\boxed{'
|
|
222
|
-
|
|
223
|
-
assert s[:len(left)] == left
|
|
224
|
-
assert s[-1] == '}'
|
|
225
|
-
|
|
226
|
-
return s[len(left):-1]
|
|
227
|
-
|
|
228
|
-
@classmethod
|
|
229
|
-
def _last_boxed_only_string(cls, string):
|
|
230
|
-
|
|
231
|
-
idx = string.rfind('\\boxed')
|
|
232
|
-
if '\\boxed ' in string:
|
|
233
|
-
return '\\boxed ' + string.split('\\boxed ')[-1].split('$')[0]
|
|
234
|
-
if idx < 0:
|
|
235
|
-
idx = string.rfind('\\fbox')
|
|
236
|
-
if idx < 0:
|
|
237
|
-
return None
|
|
238
|
-
|
|
239
|
-
i = idx
|
|
240
|
-
right_brace_idx = None
|
|
241
|
-
num_left_braces_open = 0
|
|
242
|
-
while i < len(string):
|
|
243
|
-
if string[i] == '{':
|
|
244
|
-
num_left_braces_open += 1
|
|
245
|
-
if string[i] == '}':
|
|
246
|
-
num_left_braces_open -= 1
|
|
247
|
-
if num_left_braces_open == 0:
|
|
248
|
-
right_brace_idx = i
|
|
249
|
-
break
|
|
250
|
-
i += 1
|
|
251
|
-
|
|
252
|
-
if right_brace_idx is None:
|
|
253
|
-
retval = None
|
|
254
|
-
else:
|
|
255
|
-
retval = string[idx:right_brace_idx + 1]
|
|
256
|
-
|
|
257
|
-
return retval
|
|
258
|
-
|
|
259
|
-
@classmethod
|
|
260
|
-
def _is_equiv(cls, str1, str2, verbose=False):
|
|
261
|
-
if str1 is None and str2 is None:
|
|
262
|
-
logger.warning('WARNING: Both None')
|
|
263
|
-
return True
|
|
264
|
-
if str1 is None or str2 is None:
|
|
265
|
-
return False
|
|
266
|
-
|
|
267
|
-
try:
|
|
268
|
-
ss1 = cls.strip_string(str1)
|
|
269
|
-
ss2 = cls.strip_string(str2)
|
|
270
|
-
if verbose:
|
|
271
|
-
logger.info(f'ss1: {ss1}, ss2: {ss2}')
|
|
272
|
-
return ss1 == ss2
|
|
273
|
-
except Exception:
|
|
274
|
-
return str1 == str2
|
|
275
|
-
|
|
276
|
-
@classmethod
|
|
277
|
-
def strip_string(cls, string):
|
|
278
|
-
# linebreaks
|
|
279
|
-
string = string.replace('\n', '')
|
|
280
|
-
|
|
281
|
-
# remove inverse spaces
|
|
282
|
-
string = string.replace('\\!', '')
|
|
283
|
-
|
|
284
|
-
# replace \\ with \
|
|
285
|
-
string = string.replace('\\\\', '\\')
|
|
286
|
-
|
|
287
|
-
# replace tfrac and dfrac with frac
|
|
288
|
-
string = string.replace('tfrac', 'frac')
|
|
289
|
-
string = string.replace('dfrac', 'frac')
|
|
290
|
-
|
|
291
|
-
# remove \left and \right
|
|
292
|
-
string = string.replace('\\left', '')
|
|
293
|
-
string = string.replace('\\right', '')
|
|
294
|
-
|
|
295
|
-
# Remove circ (degrees)
|
|
296
|
-
string = string.replace('^{\\circ}', '')
|
|
297
|
-
string = string.replace('^\\circ', '')
|
|
298
|
-
|
|
299
|
-
# remove dollar signs
|
|
300
|
-
string = string.replace('\\$', '')
|
|
301
|
-
|
|
302
|
-
# remove units (on the right)
|
|
303
|
-
string = cls.remove_right_units(string)
|
|
304
|
-
|
|
305
|
-
# remove percentage
|
|
306
|
-
string = string.replace('\\%', '')
|
|
307
|
-
string = string.replace('\%', '') # noqa: W605
|
|
308
|
-
|
|
309
|
-
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
|
310
|
-
string = string.replace(' .', ' 0.')
|
|
311
|
-
string = string.replace('{.', '{0.')
|
|
312
|
-
# if empty, return empty string
|
|
313
|
-
if len(string) == 0:
|
|
314
|
-
return string
|
|
315
|
-
if string[0] == '.':
|
|
316
|
-
string = '0' + string
|
|
317
|
-
|
|
318
|
-
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
|
319
|
-
if len(string.split('=')) == 2:
|
|
320
|
-
if len(string.split('=')[0]) <= 2:
|
|
321
|
-
string = string.split('=')[1]
|
|
322
|
-
|
|
323
|
-
# fix sqrt3 --> sqrt{3}
|
|
324
|
-
string = cls.fix_sqrt(string)
|
|
325
|
-
|
|
326
|
-
# remove spaces
|
|
327
|
-
string = string.replace(' ', '')
|
|
328
|
-
|
|
329
|
-
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
|
330
|
-
string = cls.fix_fracs(string)
|
|
331
|
-
|
|
332
|
-
# manually change 0.5 --> \frac{1}{2}
|
|
333
|
-
if string == '0.5':
|
|
334
|
-
string = '\\frac{1}{2}'
|
|
335
|
-
|
|
336
|
-
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
|
337
|
-
string = cls.fix_a_slash_b(string)
|
|
338
|
-
|
|
339
|
-
return string
|
|
340
|
-
|
|
341
|
-
@classmethod
|
|
342
|
-
def remove_right_units(cls, string):
|
|
343
|
-
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
|
344
|
-
if '\\text{ ' in string:
|
|
345
|
-
splits = string.split('\\text{ ')
|
|
346
|
-
assert len(splits) == 2
|
|
347
|
-
return splits[0]
|
|
348
|
-
else:
|
|
349
|
-
return string
|
|
350
|
-
|
|
351
|
-
@classmethod
|
|
352
|
-
def fix_fracs(cls, string):
|
|
353
|
-
substrs = string.split('\\frac')
|
|
354
|
-
new_str = substrs[0]
|
|
355
|
-
if len(substrs) > 1:
|
|
356
|
-
substrs = substrs[1:]
|
|
357
|
-
for substr in substrs:
|
|
358
|
-
new_str += '\\frac'
|
|
359
|
-
if substr[0] == '{':
|
|
360
|
-
new_str += substr
|
|
361
|
-
else:
|
|
362
|
-
try:
|
|
363
|
-
assert len(substr) >= 2
|
|
364
|
-
except AssertionError:
|
|
365
|
-
return string
|
|
366
|
-
a = substr[0]
|
|
367
|
-
b = substr[1]
|
|
368
|
-
if b != '{':
|
|
369
|
-
if len(substr) > 2:
|
|
370
|
-
post_substr = substr[2:]
|
|
371
|
-
new_str += '{' + a + '}{' + b + '}' + post_substr
|
|
372
|
-
else:
|
|
373
|
-
new_str += '{' + a + '}{' + b + '}'
|
|
374
|
-
else:
|
|
375
|
-
if len(substr) > 2:
|
|
376
|
-
post_substr = substr[2:]
|
|
377
|
-
new_str += '{' + a + '}' + b + post_substr
|
|
378
|
-
else:
|
|
379
|
-
new_str += '{' + a + '}' + b
|
|
380
|
-
string = new_str
|
|
381
|
-
return string
|
|
382
|
-
|
|
383
|
-
@classmethod
|
|
384
|
-
def fix_sqrt(cls, string):
|
|
385
|
-
if '\\sqrt' not in string:
|
|
386
|
-
return string
|
|
387
|
-
splits = string.split('\\sqrt')
|
|
388
|
-
new_string = splits[0]
|
|
389
|
-
for split in splits[1:]:
|
|
390
|
-
if split[0] != '{':
|
|
391
|
-
a = split[0]
|
|
392
|
-
new_substr = '\\sqrt{' + a + '}' + split[1:]
|
|
393
|
-
else:
|
|
394
|
-
new_substr = '\\sqrt' + split
|
|
395
|
-
new_string += new_substr
|
|
396
|
-
return new_string
|
|
397
|
-
|
|
398
|
-
@classmethod
|
|
399
|
-
def fix_a_slash_b(cls, string):
|
|
400
|
-
if len(string.split('/')) != 2:
|
|
401
|
-
return string
|
|
402
|
-
a = string.split('/')[0]
|
|
403
|
-
b = string.split('/')[1]
|
|
404
|
-
try:
|
|
405
|
-
a = int(a)
|
|
406
|
-
b = int(b)
|
|
407
|
-
assert string == '{}/{}'.format(a, b)
|
|
408
|
-
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
|
|
409
|
-
return new_string
|
|
410
|
-
except AssertionError:
|
|
411
|
-
return string
|
|
412
|
-
|
|
413
|
-
@classmethod
|
|
414
|
-
def _math_postprocess(cls, text: str) -> str:
|
|
415
|
-
SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), (r'\ ', ''), (' ', ''), ('mbox', 'text'),
|
|
416
|
-
(',\\text{and}', ','), ('\\text{and}', ','), ('\\text{m}', '\\text{}'), ('\\le', '<')]
|
|
417
|
-
REMOVED_EXPRESSIONS = [
|
|
418
|
-
'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft', 'hours', 'km', 'units', '\\ldots', 'sue',
|
|
419
|
-
'points', 'feet', 'minutes', 'digits', 'cents', 'degrees', 'cm', 'gm', 'pounds', 'meters', 'meals', 'edges',
|
|
420
|
-
'students', 'childrentickets', 'multiples', '\\text{s}', '\\text{.}', '\\text{\ns}', '\\text{}^2',
|
|
421
|
-
'\\text{}^3', '\\text{\n}', '\\text{}', r'\mathrm{th}', r'^\circ', r'^{\circ}', r'\;', r',\!', '{,}', '"',
|
|
422
|
-
'\\dots', '\n', '\r', '\f'
|
|
423
|
-
]
|
|
424
|
-
import re
|
|
425
|
-
|
|
426
|
-
def normalize_final_answer(final_answer: str) -> str:
|
|
427
|
-
"""Normalize a final answer to a quantitative reasoning question."""
|
|
428
|
-
# final_answer = final_answer.split('=')[-1]
|
|
429
|
-
for before, after in SUBSTITUTIONS:
|
|
430
|
-
final_answer = final_answer.replace(before, after)
|
|
431
|
-
for expr in REMOVED_EXPRESSIONS:
|
|
432
|
-
final_answer = final_answer.replace(expr, '')
|
|
433
|
-
|
|
434
|
-
# Extract answer that is in LaTeX math, is bold,
|
|
435
|
-
# is surrounded by a box, etc.
|
|
436
|
-
final_answer = re.sub(r'(\\text\{)(.*?)(\})', '\\2', final_answer)
|
|
437
|
-
final_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', final_answer)
|
|
438
|
-
final_answer = re.sub(r'(\\overline\{)(.*?)(\})', '\\2', final_answer)
|
|
439
|
-
final_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', final_answer)
|
|
440
|
-
assert '\n' not in final_answer
|
|
441
|
-
assert '\r' not in final_answer
|
|
442
|
-
assert '\f' not in final_answer
|
|
443
|
-
if len(re.findall(r'finalansweris(.*)', final_answer)) > 0:
|
|
444
|
-
final_answer = re.findall(r'finalansweris(.*)', final_answer)[-1]
|
|
445
|
-
|
|
446
|
-
if len(re.findall(r'oxed\{(.*?)\}', final_answer)) > 0:
|
|
447
|
-
final_answer = re.findall(r'oxed\{(.*?)\}', final_answer)[-1]
|
|
448
|
-
|
|
449
|
-
if len(re.findall(r'\$(.*?)\$', final_answer)) > 0:
|
|
450
|
-
final_answer = re.findall(r'\$(.*?)\$', final_answer)[-1]
|
|
451
|
-
final_answer = final_answer.strip()
|
|
452
|
-
if 'rac' in final_answer and '\\frac' not in final_answer:
|
|
453
|
-
final_answer = final_answer.replace('rac', '\\frac')
|
|
454
|
-
|
|
455
|
-
final_answer = re.sub(r'(frac)([^{])(.)', 'frac{\\2}{\\3}', final_answer)
|
|
456
|
-
final_answer = re.sub(r'(sqrt)([^{])', 'sqrt{\\2}', final_answer)
|
|
457
|
-
final_answer = final_answer.replace('$', '')
|
|
458
|
-
|
|
459
|
-
# Normalize 100,000 -> 100000
|
|
460
|
-
if final_answer.replace(',', '').isdigit():
|
|
461
|
-
final_answer = final_answer.replace(',', '')
|
|
462
|
-
|
|
463
|
-
return final_answer
|
|
464
|
-
|
|
465
|
-
for maybe_ans in text.split('.'):
|
|
466
|
-
if 'final answer' in maybe_ans.lower():
|
|
467
|
-
return normalize_final_answer(maybe_ans)
|
|
468
|
-
return normalize_final_answer(text.split('.')[0])
|