evalscope 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalscope/__init__.py +3 -0
- evalscope/backend/__init__.py +3 -0
- evalscope/backend/base.py +27 -0
- evalscope/backend/opencompass/__init__.py +3 -0
- evalscope/backend/opencompass/api_meta_template.py +64 -0
- evalscope/backend/opencompass/backend_manager.py +247 -0
- evalscope/backend/opencompass/tasks/__init__.py +1 -0
- evalscope/backend/opencompass/tasks/eval_api.py +30 -0
- evalscope/backend/opencompass/tasks/eval_datasets.py +71 -0
- evalscope/backend/vlm_eval_kit/__init__.py +1 -0
- evalscope/backend/vlm_eval_kit/backend_manager.py +153 -0
- evalscope/benchmarks/__init__.py +4 -0
- evalscope/benchmarks/arc/__init__.py +5 -0
- evalscope/benchmarks/arc/ai2_arc.py +148 -0
- evalscope/benchmarks/arc/arc_adapter.py +231 -0
- evalscope/benchmarks/bbh/__init__.py +6 -0
- evalscope/benchmarks/bbh/bbh_adapter.py +308 -0
- evalscope/benchmarks/bbh/cot_prompts/boolean_expressions.txt +23 -0
- evalscope/benchmarks/bbh/cot_prompts/causal_judgement.txt +25 -0
- evalscope/benchmarks/bbh/cot_prompts/date_understanding.txt +33 -0
- evalscope/benchmarks/bbh/cot_prompts/disambiguation_qa.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/dyck_languages.txt +72 -0
- evalscope/benchmarks/bbh/cot_prompts/formal_fallacies.txt +44 -0
- evalscope/benchmarks/bbh/cot_prompts/geometric_shapes.txt +78 -0
- evalscope/benchmarks/bbh/cot_prompts/hyperbaton.txt +28 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_five_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_seven_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_three_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/movie_recommendation.txt +42 -0
- evalscope/benchmarks/bbh/cot_prompts/multistep_arithmetic_two.txt +25 -0
- evalscope/benchmarks/bbh/cot_prompts/navigate.txt +43 -0
- evalscope/benchmarks/bbh/cot_prompts/object_counting.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/penguins_in_a_table.txt +41 -0
- evalscope/benchmarks/bbh/cot_prompts/reasoning_about_colored_objects.txt +63 -0
- evalscope/benchmarks/bbh/cot_prompts/ruin_names.txt +44 -0
- evalscope/benchmarks/bbh/cot_prompts/salient_translation_error_detection.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/snarks.txt +30 -0
- evalscope/benchmarks/bbh/cot_prompts/sports_understanding.txt +10 -0
- evalscope/benchmarks/bbh/cot_prompts/temporal_sequences.txt +77 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_five_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_seven_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_three_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/web_of_lies.txt +28 -0
- evalscope/benchmarks/bbh/cot_prompts/word_sorting.txt +17 -0
- evalscope/benchmarks/benchmark.py +65 -0
- evalscope/benchmarks/ceval/__init__.py +5 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +340 -0
- evalscope/benchmarks/ceval/ceval_exam.py +159 -0
- evalscope/benchmarks/cmmlu/__init__.py +5 -0
- evalscope/benchmarks/cmmlu/cmmlu.py +166 -0
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +369 -0
- evalscope/benchmarks/competition_math/__init__.py +5 -0
- evalscope/benchmarks/competition_math/competition_math.py +88 -0
- evalscope/benchmarks/competition_math/competition_math_adapter.py +470 -0
- evalscope/benchmarks/data_adapter.py +263 -0
- evalscope/benchmarks/general_qa/__init__.py +5 -0
- evalscope/benchmarks/general_qa/general_qa_adapter.py +186 -0
- evalscope/benchmarks/gsm8k/__init__.py +5 -0
- evalscope/benchmarks/gsm8k/gsm8k.py +127 -0
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +236 -0
- evalscope/benchmarks/hellaswag/__init__.py +5 -0
- evalscope/benchmarks/hellaswag/hellaswag.py +116 -0
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +222 -0
- evalscope/benchmarks/humaneval/__init__.py +5 -0
- evalscope/benchmarks/humaneval/humaneval.py +82 -0
- evalscope/benchmarks/humaneval/humaneval_adapter.py +21 -0
- evalscope/benchmarks/mmlu/__init__.py +5 -0
- evalscope/benchmarks/mmlu/mmlu.py +174 -0
- evalscope/benchmarks/mmlu/mmlu_adapter.py +375 -0
- evalscope/benchmarks/race/__init__.py +5 -0
- evalscope/benchmarks/race/race.py +118 -0
- evalscope/benchmarks/race/race_adapter.py +229 -0
- evalscope/benchmarks/trivia_qa/__init__.py +5 -0
- evalscope/benchmarks/trivia_qa/trivia_qa.py +104 -0
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +207 -0
- evalscope/benchmarks/truthful_qa/__init__.py +5 -0
- evalscope/benchmarks/truthful_qa/truthful_qa.py +167 -0
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +351 -0
- evalscope/cache.py +98 -0
- evalscope/cli/__init__.py +1 -0
- evalscope/cli/base.py +20 -0
- evalscope/cli/cli.py +26 -0
- evalscope/cli/start_perf.py +37 -0
- evalscope/cli/start_server.py +138 -0
- evalscope/config.py +165 -0
- evalscope/constants.py +150 -0
- evalscope/evaluator/__init__.py +3 -0
- evalscope/evaluator/evaluator.py +689 -0
- evalscope/evaluator/rating_eval.py +178 -0
- evalscope/evaluator/reviewer/__init__.py +1 -0
- evalscope/evaluator/reviewer/auto_reviewer.py +411 -0
- evalscope/metrics/__init__.py +1 -0
- evalscope/metrics/bundled_rouge_score/__init__.py +14 -0
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +342 -0
- evalscope/metrics/code_metric.py +104 -0
- evalscope/metrics/math_accuracy.py +60 -0
- evalscope/metrics/metrics.py +405 -0
- evalscope/metrics/rouge_metric.py +129 -0
- evalscope/models/__init__.py +4 -0
- evalscope/models/custom/__init__.py +4 -0
- evalscope/models/custom/custom_model.py +53 -0
- evalscope/models/dummy_chat_model.py +50 -0
- evalscope/models/model.py +88 -0
- evalscope/models/model_adapter.py +586 -0
- evalscope/models/openai_model.py +103 -0
- evalscope/models/template.py +1446 -0
- evalscope/perf/__init__.py +0 -0
- evalscope/perf/_logging.py +32 -0
- evalscope/perf/api_plugin_base.py +60 -0
- evalscope/perf/custom_api.py +87 -0
- evalscope/perf/dashscope_api.py +84 -0
- evalscope/perf/dataset_plugin_base.py +64 -0
- evalscope/perf/datasets/__init__.py +0 -0
- evalscope/perf/datasets/line_by_line.py +18 -0
- evalscope/perf/datasets/longalpaca_12k.py +20 -0
- evalscope/perf/datasets/openqa.py +22 -0
- evalscope/perf/how_to_analysis_result.py +24 -0
- evalscope/perf/http_client.py +756 -0
- evalscope/perf/openai_api.py +130 -0
- evalscope/perf/plugin_registry.py +35 -0
- evalscope/perf/query_parameters.py +42 -0
- evalscope/perf/server_sent_event.py +43 -0
- evalscope/preprocess/__init__.py +1 -0
- evalscope/preprocess/tokenizers/__init__.py +0 -0
- evalscope/preprocess/tokenizers/gpt2_tokenizer.py +221 -0
- evalscope/registry/__init__.py +1 -0
- evalscope/registry/tasks/arc.yaml +29 -0
- evalscope/registry/tasks/bbh.yaml +27 -0
- evalscope/registry/tasks/bbh_mini.yaml +27 -0
- evalscope/registry/tasks/ceval.yaml +27 -0
- evalscope/registry/tasks/ceval_mini.yaml +27 -0
- evalscope/registry/tasks/cmmlu.yaml +27 -0
- evalscope/registry/tasks/eval_qwen-7b-chat_v100.yaml +28 -0
- evalscope/registry/tasks/general_qa.yaml +27 -0
- evalscope/registry/tasks/gsm8k.yaml +29 -0
- evalscope/registry/tasks/mmlu.yaml +29 -0
- evalscope/registry/tasks/mmlu_mini.yaml +27 -0
- evalscope/run.py +404 -0
- evalscope/run_arena.py +204 -0
- evalscope/run_ms.py +140 -0
- evalscope/summarizer.py +144 -0
- evalscope/third_party/__init__.py +1 -0
- evalscope/third_party/toolbench_static/__init__.py +3 -0
- evalscope/third_party/toolbench_static/eval.py +219 -0
- evalscope/third_party/toolbench_static/infer.py +278 -0
- evalscope/third_party/toolbench_static/llm/__init__.py +1 -0
- evalscope/third_party/toolbench_static/llm/swift_infer.py +45 -0
- evalscope/third_party/toolbench_static/toolbench_static.py +50 -0
- evalscope/tools/__init__.py +1 -0
- evalscope/tools/combine_reports.py +140 -0
- evalscope/tools/gen_mmlu_subject_mapping.py +90 -0
- evalscope/tools/rewrite_eval_results.py +95 -0
- evalscope/utils/__init__.py +4 -0
- evalscope/utils/arena_utils.py +247 -0
- evalscope/utils/completion_parsers.py +87 -0
- evalscope/utils/logger.py +64 -0
- evalscope/utils/task_cfg_parser.py +10 -0
- evalscope/utils/task_utils.py +19 -0
- evalscope/utils/utils.py +625 -0
- evalscope/version.py +4 -0
- evalscope-0.5.0.dist-info/METADATA +566 -0
- evalscope-0.5.0.dist-info/RECORD +165 -0
- evalscope-0.5.0.dist-info/WHEEL +5 -0
- evalscope-0.5.0.dist-info/entry_points.txt +3 -0
- evalscope-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
# Copyright (c) EleutherAI, Inc. and its affiliates.
|
|
3
|
+
import glob
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
from evalscope.benchmarks import DataAdapter
|
|
8
|
+
from evalscope.metrics.metrics import weighted_mean
|
|
9
|
+
from evalscope.utils import normalize_score
|
|
10
|
+
from evalscope.utils.logger import get_logger
|
|
11
|
+
# flake8: noqa
|
|
12
|
+
|
|
13
|
+
logger = get_logger()
|
|
14
|
+
|
|
15
|
+
DATASET_ID = 'modelscope/competition_math'
|
|
16
|
+
SUBSET_LIST = ['default']
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CompetitionMathAdapter(DataAdapter):
|
|
20
|
+
""" TODO: To be tested for all models. """
|
|
21
|
+
|
|
22
|
+
def __init__(self,
|
|
23
|
+
subset_list: list = None,
|
|
24
|
+
metric_list: list = None,
|
|
25
|
+
few_shot_num: int = None,
|
|
26
|
+
train_split: str = 'train',
|
|
27
|
+
eval_split: str = 'test',
|
|
28
|
+
**kwargs):
|
|
29
|
+
|
|
30
|
+
if subset_list is None:
|
|
31
|
+
subset_list = SUBSET_LIST
|
|
32
|
+
|
|
33
|
+
if metric_list is None:
|
|
34
|
+
metric_list = [{'name': 'WeightedAverageAccuracy', 'object': weighted_mean}]
|
|
35
|
+
|
|
36
|
+
if few_shot_num is None:
|
|
37
|
+
# Use 4-shot by default
|
|
38
|
+
logger.info(f'Set 4-shot examples by system for MATH.')
|
|
39
|
+
few_shot_num = 4
|
|
40
|
+
|
|
41
|
+
if few_shot_num != 4 and few_shot_num != 0:
|
|
42
|
+
logger.error(f'The MATH benchmark ONLY supports 4-shot by system or 0-shot settings, '
|
|
43
|
+
f'but got {self.few_shot_num}. Use 4-shot by default.')
|
|
44
|
+
few_shot_num = 4
|
|
45
|
+
|
|
46
|
+
super().__init__(subset_list=subset_list,
|
|
47
|
+
metric_list=metric_list,
|
|
48
|
+
few_shot_num=few_shot_num,
|
|
49
|
+
train_split=train_split,
|
|
50
|
+
eval_split=eval_split,
|
|
51
|
+
**kwargs)
|
|
52
|
+
|
|
53
|
+
def load_from_disk(self, dataset_name_or_path, subset_list, work_dir, **kwargs) -> dict:
|
|
54
|
+
data_dict: dict = {}
|
|
55
|
+
for subset_name in subset_list:
|
|
56
|
+
for split_name in [self.train_split, self.eval_split]:
|
|
57
|
+
if os.path.exists(dataset_name_or_path):
|
|
58
|
+
split_dir = os.path.join(dataset_name_or_path, split_name)
|
|
59
|
+
else:
|
|
60
|
+
split_dir = os.path.join(work_dir, dataset_name_or_path, split_name)
|
|
61
|
+
split_files = glob.glob(os.path.join(split_dir, '**', '*.json'))
|
|
62
|
+
split_data = []
|
|
63
|
+
for file_path in split_files:
|
|
64
|
+
if os.path.exists(file_path):
|
|
65
|
+
with open(file_path, 'r') as f:
|
|
66
|
+
split_data.append(json.load(f))
|
|
67
|
+
if subset_name in data_dict:
|
|
68
|
+
data_dict[subset_name].update({split_name: split_data})
|
|
69
|
+
else:
|
|
70
|
+
data_dict[subset_name] = {split_name: split_data}
|
|
71
|
+
|
|
72
|
+
return data_dict
|
|
73
|
+
|
|
74
|
+
def gen_prompt(self, input_d: dict, few_shot_list: list, **kwargs) -> dict:
|
|
75
|
+
"""
|
|
76
|
+
Generate the prompt for the model input.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
input_d: raw input dict.
|
|
80
|
+
{"problem": "How many vertical asymptotes does the graph of $y=\\frac{2}{x^2+x-6}$ have?", "level": "Level 3", "type": "Algebra", "solution": "The denominator of the rational function factors into $x^2+x-6=(x-2)(x+3)$. Since the numerator is always nonzero, there is a vertical asymptote whenever the denominator is $0$, which occurs for $x = 2$ and $x = -3$. Therefore, the graph has $\\boxed{2}$ vertical asymptotes."}
|
|
81
|
+
|
|
82
|
+
few_shot_list: few shot list. Each item is a raw input dict.
|
|
83
|
+
**kwargs:
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
{'data': [prompt]}
|
|
87
|
+
"""
|
|
88
|
+
use_fewshot = self.few_shot_num > 0
|
|
89
|
+
full_prompt = self._generate_prompt(input_d, use_fewshot=use_fewshot)
|
|
90
|
+
|
|
91
|
+
return {'data': [full_prompt]}
|
|
92
|
+
|
|
93
|
+
def get_gold_answer(self, input_d: dict) -> str:
|
|
94
|
+
# Extract the gold answer from the input dict.
|
|
95
|
+
return self._preprocess_input(input_d['solution'])
|
|
96
|
+
|
|
97
|
+
def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = 'checkpoint') -> str:
|
|
98
|
+
"""
|
|
99
|
+
Parse the model output to get the answer. Could be the best choice index.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
result: Predicted answer from the model. Usually a string for chat.
|
|
103
|
+
raw_input_d (dict): The raw input. Depending on the dataset.
|
|
104
|
+
eval_type: 'checkpoint' or 'service' or `custom`
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
The parsed answer. Depending on the dataset. Usually a string for chat.
|
|
108
|
+
"""
|
|
109
|
+
# TODO: check answer extraction
|
|
110
|
+
# Note: Use same extraction method for both of checkpoint/service/custom
|
|
111
|
+
return self._math_postprocess(result)
|
|
112
|
+
|
|
113
|
+
def match(self, gold: str, pred: str) -> float:
|
|
114
|
+
res = 0
|
|
115
|
+
if self._is_equiv(pred, gold):
|
|
116
|
+
res = 1
|
|
117
|
+
|
|
118
|
+
return res
|
|
119
|
+
|
|
120
|
+
def compute_metric(self, review_res_list: list) -> float:
|
|
121
|
+
"""
|
|
122
|
+
Compute evaluation result by specific metric.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
review_res_list: review score list, e.g. [0, 1, 1, 0, ...]
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
The metric score.
|
|
129
|
+
"""
|
|
130
|
+
items = [(score, 1.0) for score in review_res_list]
|
|
131
|
+
return weighted_mean(items)
|
|
132
|
+
|
|
133
|
+
def gen_report(self, subset_score_map: dict, report_name: str = None) -> dict:
|
|
134
|
+
"""
|
|
135
|
+
Generate the report for the model output.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
subset_score_map: The subset-score mapping. e.g. {subset_name: (score, num), ...}
|
|
139
|
+
report_name: The user-defined report name.
|
|
140
|
+
|
|
141
|
+
Returns: A dict of metric calculation results. The format is like:
|
|
142
|
+
{
|
|
143
|
+
"name":"CompetitionMath",
|
|
144
|
+
"metric":"WeightedAverageAccuracy",
|
|
145
|
+
"score":0.5632,
|
|
146
|
+
"category":[
|
|
147
|
+
{
|
|
148
|
+
"name":"DEFAULT",
|
|
149
|
+
"score":0.5632,
|
|
150
|
+
"subset":[
|
|
151
|
+
{
|
|
152
|
+
"name":"main",
|
|
153
|
+
"score":0.5632
|
|
154
|
+
},
|
|
155
|
+
]
|
|
156
|
+
}
|
|
157
|
+
],
|
|
158
|
+
"total_num":100
|
|
159
|
+
}
|
|
160
|
+
"""
|
|
161
|
+
total_num: int = sum([num for _, num in subset_score_map.values()])
|
|
162
|
+
weighted_avg_acc: float = sum([score * num for score, num in subset_score_map.values()]) / total_num
|
|
163
|
+
weighted_avg_acc = normalize_score(score=weighted_avg_acc)
|
|
164
|
+
cate_avg_list = [{'name': subset_name, 'score': normalize_score(score=score)} for subset_name, (score, _) in subset_score_map.items()]
|
|
165
|
+
|
|
166
|
+
category_d = dict(name='DEFAULT',
|
|
167
|
+
score=weighted_avg_acc,
|
|
168
|
+
subset=cate_avg_list)
|
|
169
|
+
|
|
170
|
+
res_map = dict(name=report_name or 'competition_math',
|
|
171
|
+
metric=self.metric_list[0]['name'],
|
|
172
|
+
score=weighted_avg_acc,
|
|
173
|
+
category=[category_d],
|
|
174
|
+
total_num=total_num)
|
|
175
|
+
|
|
176
|
+
return res_map
|
|
177
|
+
|
|
178
|
+
@classmethod
|
|
179
|
+
def _generate_prompt(cls, input_d: dict, use_fewshot: bool = True) -> str:
|
|
180
|
+
problem: str = input_d['problem']
|
|
181
|
+
|
|
182
|
+
if use_fewshot:
|
|
183
|
+
# Use 4-shot examples by system
|
|
184
|
+
context = (
|
|
185
|
+
'Problem:\nFind the domain of the expression $\\frac{{\sqrt{{x-2}}}}{{\sqrt{{5-x}}}}$.}}\nSolution:\nThe expressions inside each square root must be non-negative. Therefore, $x-2 \ge 0$, so $x\ge2$, and $5 - x \ge 0$, so $x \le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{{[2,5)}}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.\n'
|
|
186
|
+
'Problem:\nIf $\det \mathbf{{A}} = 2$ and $\det \mathbf{{B}} = 12,$ then find $\det (\mathbf{{A}} \mathbf{{B}}).$\nSolution:\nWe have that $\det (\mathbf{{A}} \mathbf{{B}}) = (\det \mathbf{{A}})(\det \mathbf{{B}}) = (2)(12) = \\boxed{{24}}.$\nFinal Answer: The final answer is $24$. I hope it is correct.\n'
|
|
187
|
+
'Problem:\nTerrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?\nSolution:\nIf Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\cdot 12\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\cdot15\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{{align*}} 30n&=480\\\\ \Rightarrow\qquad n&=480/30=\\boxed{{16}} \end{{align*}}\nFinal Answer: The final answer is $16$. I hope it is correct.\n'
|
|
188
|
+
'Problem:\nIf the system of equations: \\begin{{align*}} 6x-4y&=a,\\\\ 6y-9x &=b. \end{{align*}}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{{a}}{{b}},$ assuming $b$ is nonzero.\nSolution:\nIf we multiply the first equation by $-\\frac{{3}}{{2}}$, we obtain $$6y-9x=-\\frac{{3}}{{2}}a.$$Since we also know that $6y-9x=b$, we have $$-\\frac{{3}}{{2}}a=b\Rightarrow\\frac{{a}}{{b}}=\\boxed{{-\\frac{{2}}{{3}}}}.$$\nFinal Answer: The final answer is $-\\frac{{2}}{{3}}$. I hope it is correct.\n'
|
|
189
|
+
f'Problem:\n{problem}\nSolution:\n'
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
context = 'Problem:\n' + problem + '\nSolution:\n'
|
|
193
|
+
return context
|
|
194
|
+
|
|
195
|
+
@classmethod
|
|
196
|
+
def _preprocess_input(cls, input: str) -> str:
|
|
197
|
+
"""
|
|
198
|
+
Preprocess the input data, remove the boxed solution.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
input_d: The raw input. A single data format of the Competition Math.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
The preprocessed input.
|
|
205
|
+
"""
|
|
206
|
+
return cls._remove_boxed(cls._last_boxed_only_string(input))
|
|
207
|
+
|
|
208
|
+
@classmethod
|
|
209
|
+
def _remove_boxed(cls, s):
|
|
210
|
+
if s is None:
|
|
211
|
+
return s
|
|
212
|
+
|
|
213
|
+
if '\\boxed ' in s:
|
|
214
|
+
left = '\\boxed '
|
|
215
|
+
assert s[: len(left)] == left
|
|
216
|
+
return s[len(left):]
|
|
217
|
+
|
|
218
|
+
left = '\\boxed{'
|
|
219
|
+
|
|
220
|
+
assert s[: len(left)] == left
|
|
221
|
+
assert s[-1] == '}'
|
|
222
|
+
|
|
223
|
+
return s[len(left): -1]
|
|
224
|
+
|
|
225
|
+
@classmethod
|
|
226
|
+
def _last_boxed_only_string(cls, string):
|
|
227
|
+
|
|
228
|
+
idx = string.rfind('\\boxed')
|
|
229
|
+
if '\\boxed ' in string:
|
|
230
|
+
return '\\boxed ' + string.split('\\boxed ')[-1].split('$')[0]
|
|
231
|
+
if idx < 0:
|
|
232
|
+
idx = string.rfind('\\fbox')
|
|
233
|
+
if idx < 0:
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
i = idx
|
|
237
|
+
right_brace_idx = None
|
|
238
|
+
num_left_braces_open = 0
|
|
239
|
+
while i < len(string):
|
|
240
|
+
if string[i] == '{':
|
|
241
|
+
num_left_braces_open += 1
|
|
242
|
+
if string[i] == '}':
|
|
243
|
+
num_left_braces_open -= 1
|
|
244
|
+
if num_left_braces_open == 0:
|
|
245
|
+
right_brace_idx = i
|
|
246
|
+
break
|
|
247
|
+
i += 1
|
|
248
|
+
|
|
249
|
+
if right_brace_idx is None:
|
|
250
|
+
retval = None
|
|
251
|
+
else:
|
|
252
|
+
retval = string[idx: right_brace_idx + 1]
|
|
253
|
+
|
|
254
|
+
return retval
|
|
255
|
+
|
|
256
|
+
@classmethod
|
|
257
|
+
def _is_equiv(cls, str1, str2, verbose=False):
|
|
258
|
+
if str1 is None and str2 is None:
|
|
259
|
+
logger.warning('WARNING: Both None')
|
|
260
|
+
return True
|
|
261
|
+
if str1 is None or str2 is None:
|
|
262
|
+
return False
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
ss1 = cls.strip_string(str1)
|
|
266
|
+
ss2 = cls.strip_string(str2)
|
|
267
|
+
if verbose:
|
|
268
|
+
logger.info(f'ss1: {ss1}, ss2: {ss2}')
|
|
269
|
+
return ss1 == ss2
|
|
270
|
+
except Exception:
|
|
271
|
+
return str1 == str2
|
|
272
|
+
|
|
273
|
+
@classmethod
|
|
274
|
+
def strip_string(cls, string):
|
|
275
|
+
# linebreaks
|
|
276
|
+
string = string.replace('\n', '')
|
|
277
|
+
|
|
278
|
+
# remove inverse spaces
|
|
279
|
+
string = string.replace('\\!', '')
|
|
280
|
+
|
|
281
|
+
# replace \\ with \
|
|
282
|
+
string = string.replace('\\\\', '\\')
|
|
283
|
+
|
|
284
|
+
# replace tfrac and dfrac with frac
|
|
285
|
+
string = string.replace('tfrac', 'frac')
|
|
286
|
+
string = string.replace('dfrac', 'frac')
|
|
287
|
+
|
|
288
|
+
# remove \left and \right
|
|
289
|
+
string = string.replace('\\left', '')
|
|
290
|
+
string = string.replace('\\right', '')
|
|
291
|
+
|
|
292
|
+
# Remove circ (degrees)
|
|
293
|
+
string = string.replace('^{\\circ}', '')
|
|
294
|
+
string = string.replace('^\\circ', '')
|
|
295
|
+
|
|
296
|
+
# remove dollar signs
|
|
297
|
+
string = string.replace('\\$', '')
|
|
298
|
+
|
|
299
|
+
# remove units (on the right)
|
|
300
|
+
string = cls.remove_right_units(string)
|
|
301
|
+
|
|
302
|
+
# remove percentage
|
|
303
|
+
string = string.replace('\\%', '')
|
|
304
|
+
string = string.replace('\%', '') # noqa: W605
|
|
305
|
+
|
|
306
|
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
|
307
|
+
string = string.replace(' .', ' 0.')
|
|
308
|
+
string = string.replace('{.', '{0.')
|
|
309
|
+
# if empty, return empty string
|
|
310
|
+
if len(string) == 0:
|
|
311
|
+
return string
|
|
312
|
+
if string[0] == '.':
|
|
313
|
+
string = '0' + string
|
|
314
|
+
|
|
315
|
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
|
316
|
+
if len(string.split('=')) == 2:
|
|
317
|
+
if len(string.split('=')[0]) <= 2:
|
|
318
|
+
string = string.split('=')[1]
|
|
319
|
+
|
|
320
|
+
# fix sqrt3 --> sqrt{3}
|
|
321
|
+
string = cls.fix_sqrt(string)
|
|
322
|
+
|
|
323
|
+
# remove spaces
|
|
324
|
+
string = string.replace(' ', '')
|
|
325
|
+
|
|
326
|
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
|
327
|
+
string = cls.fix_fracs(string)
|
|
328
|
+
|
|
329
|
+
# manually change 0.5 --> \frac{1}{2}
|
|
330
|
+
if string == '0.5':
|
|
331
|
+
string = '\\frac{1}{2}'
|
|
332
|
+
|
|
333
|
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
|
334
|
+
string = cls.fix_a_slash_b(string)
|
|
335
|
+
|
|
336
|
+
return string
|
|
337
|
+
|
|
338
|
+
@classmethod
|
|
339
|
+
def remove_right_units(cls, string):
|
|
340
|
+
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
|
341
|
+
if '\\text{ ' in string:
|
|
342
|
+
splits = string.split('\\text{ ')
|
|
343
|
+
assert len(splits) == 2
|
|
344
|
+
return splits[0]
|
|
345
|
+
else:
|
|
346
|
+
return string
|
|
347
|
+
|
|
348
|
+
@classmethod
|
|
349
|
+
def fix_fracs(cls, string):
|
|
350
|
+
substrs = string.split('\\frac')
|
|
351
|
+
new_str = substrs[0]
|
|
352
|
+
if len(substrs) > 1:
|
|
353
|
+
substrs = substrs[1:]
|
|
354
|
+
for substr in substrs:
|
|
355
|
+
new_str += '\\frac'
|
|
356
|
+
if substr[0] == '{':
|
|
357
|
+
new_str += substr
|
|
358
|
+
else:
|
|
359
|
+
try:
|
|
360
|
+
assert len(substr) >= 2
|
|
361
|
+
except AssertionError:
|
|
362
|
+
return string
|
|
363
|
+
a = substr[0]
|
|
364
|
+
b = substr[1]
|
|
365
|
+
if b != '{':
|
|
366
|
+
if len(substr) > 2:
|
|
367
|
+
post_substr = substr[2:]
|
|
368
|
+
new_str += '{' + a + '}{' + b + '}' + post_substr
|
|
369
|
+
else:
|
|
370
|
+
new_str += '{' + a + '}{' + b + '}'
|
|
371
|
+
else:
|
|
372
|
+
if len(substr) > 2:
|
|
373
|
+
post_substr = substr[2:]
|
|
374
|
+
new_str += '{' + a + '}' + b + post_substr
|
|
375
|
+
else:
|
|
376
|
+
new_str += '{' + a + '}' + b
|
|
377
|
+
string = new_str
|
|
378
|
+
return string
|
|
379
|
+
|
|
380
|
+
@classmethod
|
|
381
|
+
def fix_sqrt(cls, string):
|
|
382
|
+
if '\\sqrt' not in string:
|
|
383
|
+
return string
|
|
384
|
+
splits = string.split('\\sqrt')
|
|
385
|
+
new_string = splits[0]
|
|
386
|
+
for split in splits[1:]:
|
|
387
|
+
if split[0] != '{':
|
|
388
|
+
a = split[0]
|
|
389
|
+
new_substr = '\\sqrt{' + a + '}' + split[1:]
|
|
390
|
+
else:
|
|
391
|
+
new_substr = '\\sqrt' + split
|
|
392
|
+
new_string += new_substr
|
|
393
|
+
return new_string
|
|
394
|
+
|
|
395
|
+
@classmethod
|
|
396
|
+
def fix_a_slash_b(cls, string):
|
|
397
|
+
if len(string.split('/')) != 2:
|
|
398
|
+
return string
|
|
399
|
+
a = string.split('/')[0]
|
|
400
|
+
b = string.split('/')[1]
|
|
401
|
+
try:
|
|
402
|
+
a = int(a)
|
|
403
|
+
b = int(b)
|
|
404
|
+
assert string == '{}/{}'.format(a, b)
|
|
405
|
+
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
|
|
406
|
+
return new_string
|
|
407
|
+
except AssertionError:
|
|
408
|
+
return string
|
|
409
|
+
|
|
410
|
+
@classmethod
|
|
411
|
+
def _math_postprocess(cls, text: str) -> str:
|
|
412
|
+
SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''),
|
|
413
|
+
(r'\ ', ''), (' ', ''), ('mbox', 'text'),
|
|
414
|
+
(',\\text{and}', ','), ('\\text{and}', ','),
|
|
415
|
+
('\\text{m}', '\\text{}'), ('\\le', '<')]
|
|
416
|
+
REMOVED_EXPRESSIONS = [
|
|
417
|
+
'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft',
|
|
418
|
+
'hours', 'km', 'units', '\\ldots', 'sue', 'points', 'feet', 'minutes',
|
|
419
|
+
'digits', 'cents', 'degrees', 'cm', 'gm', 'pounds', 'meters', 'meals',
|
|
420
|
+
'edges', 'students', 'childrentickets', 'multiples', '\\text{s}',
|
|
421
|
+
'\\text{.}', '\\text{\ns}', '\\text{}^2', '\\text{}^3', '\\text{\n}',
|
|
422
|
+
'\\text{}', r'\mathrm{th}', r'^\circ', r'^{\circ}', r'\;', r',\!',
|
|
423
|
+
'{,}', '"', '\\dots', '\n', '\r', '\f'
|
|
424
|
+
]
|
|
425
|
+
import re
|
|
426
|
+
|
|
427
|
+
def normalize_final_answer(final_answer: str) -> str:
|
|
428
|
+
"""Normalize a final answer to a quantitative reasoning question."""
|
|
429
|
+
# final_answer = final_answer.split('=')[-1]
|
|
430
|
+
for before, after in SUBSTITUTIONS:
|
|
431
|
+
final_answer = final_answer.replace(before, after)
|
|
432
|
+
for expr in REMOVED_EXPRESSIONS:
|
|
433
|
+
final_answer = final_answer.replace(expr, '')
|
|
434
|
+
|
|
435
|
+
# Extract answer that is in LaTeX math, is bold,
|
|
436
|
+
# is surrounded by a box, etc.
|
|
437
|
+
final_answer = re.sub(r'(\\text\{)(.*?)(\})', '\\2', final_answer)
|
|
438
|
+
final_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', final_answer)
|
|
439
|
+
final_answer = re.sub(r'(\\overline\{)(.*?)(\})', '\\2', final_answer)
|
|
440
|
+
final_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', final_answer)
|
|
441
|
+
assert '\n' not in final_answer
|
|
442
|
+
assert '\r' not in final_answer
|
|
443
|
+
assert '\f' not in final_answer
|
|
444
|
+
if len(re.findall(r'finalansweris(.*)', final_answer)) > 0:
|
|
445
|
+
final_answer = re.findall(r'finalansweris(.*)', final_answer)[-1]
|
|
446
|
+
|
|
447
|
+
if len(re.findall(r'oxed\{(.*?)\}', final_answer)) > 0:
|
|
448
|
+
final_answer = re.findall(r'oxed\{(.*?)\}', final_answer)[-1]
|
|
449
|
+
|
|
450
|
+
if len(re.findall(r'\$(.*?)\$', final_answer)) > 0:
|
|
451
|
+
final_answer = re.findall(r'\$(.*?)\$', final_answer)[-1]
|
|
452
|
+
final_answer = final_answer.strip()
|
|
453
|
+
if 'rac' in final_answer and '\\frac' not in final_answer:
|
|
454
|
+
final_answer = final_answer.replace('rac', '\\frac')
|
|
455
|
+
|
|
456
|
+
final_answer = re.sub(r'(frac)([^{])(.)', 'frac{\\2}{\\3}',
|
|
457
|
+
final_answer)
|
|
458
|
+
final_answer = re.sub(r'(sqrt)([^{])', 'sqrt{\\2}', final_answer)
|
|
459
|
+
final_answer = final_answer.replace('$', '')
|
|
460
|
+
|
|
461
|
+
# Normalize 100,000 -> 100000
|
|
462
|
+
if final_answer.replace(',', '').isdigit():
|
|
463
|
+
final_answer = final_answer.replace(',', '')
|
|
464
|
+
|
|
465
|
+
return final_answer
|
|
466
|
+
|
|
467
|
+
for maybe_ans in text.split('.'):
|
|
468
|
+
if 'final answer' in maybe_ans.lower():
|
|
469
|
+
return normalize_final_answer(maybe_ans)
|
|
470
|
+
return normalize_final_answer(text.split('.')[0])
|