thinkbooster 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_tts/datasets/__init__.py +46 -0
- llm_tts/datasets/gsm8k.py +168 -0
- llm_tts/datasets/human_eval_plus.py +266 -0
- llm_tts/datasets/kernelbench.py +238 -0
- llm_tts/datasets/mbpp_plus.py +283 -0
- llm_tts/early_stopping.py +295 -0
- llm_tts/evaluation/__init__.py +13 -0
- llm_tts/evaluation/alignscore.py +86 -0
- llm_tts/evaluation/exact_match.py +258 -0
- llm_tts/evaluation/grader.py +399 -0
- llm_tts/evaluation/human_eval_plus_evaluator.py +277 -0
- llm_tts/evaluation/latex2sympy/__init__.py +8 -0
- llm_tts/evaluation/latex2sympy/asciimath_printer.py +50 -0
- llm_tts/evaluation/latex2sympy/gen/PSLexer.py +1692 -0
- llm_tts/evaluation/latex2sympy/gen/PSListener.py +579 -0
- llm_tts/evaluation/latex2sympy/gen/PSParser.py +7502 -0
- llm_tts/evaluation/latex2sympy/gen/PSVisitor.py +328 -0
- llm_tts/evaluation/latex2sympy/gen/__init__.py +0 -0
- llm_tts/evaluation/latex2sympy/latex2sympy2.py +1157 -0
- llm_tts/evaluation/latex2sympy/sandbox/linalg_equations.py +10 -0
- llm_tts/evaluation/latex2sympy/sandbox/linalg_span.py +19 -0
- llm_tts/evaluation/latex2sympy/sandbox/matrix.py +46 -0
- llm_tts/evaluation/latex2sympy/sandbox/matrix_placeholders.py +65 -0
- llm_tts/evaluation/latex2sympy/sandbox/sandbox.py +23 -0
- llm_tts/evaluation/latex2sympy/sandbox/sandbox_equality.py +75 -0
- llm_tts/evaluation/latex2sympy/sandbox/sectan.py +51 -0
- llm_tts/evaluation/latex2sympy/sandbox/vector.py +75 -0
- llm_tts/evaluation/latex2sympy/setup.py +45 -0
- llm_tts/evaluation/latex2sympy/tests/__init__.py +0 -0
- llm_tts/evaluation/latex2sympy/tests/abs_test.py +19 -0
- llm_tts/evaluation/latex2sympy/tests/all_bad_test.py +70 -0
- llm_tts/evaluation/latex2sympy/tests/all_good_test.py +284 -0
- llm_tts/evaluation/latex2sympy/tests/atom_expr_test.py +58 -0
- llm_tts/evaluation/latex2sympy/tests/binomial_test.py +36 -0
- llm_tts/evaluation/latex2sympy/tests/ceil_test.py +29 -0
- llm_tts/evaluation/latex2sympy/tests/complex_test.py +21 -0
- llm_tts/evaluation/latex2sympy/tests/context.py +84 -0
- llm_tts/evaluation/latex2sympy/tests/exp_test.py +57 -0
- llm_tts/evaluation/latex2sympy/tests/floor_test.py +29 -0
- llm_tts/evaluation/latex2sympy/tests/gcd_test.py +161 -0
- llm_tts/evaluation/latex2sympy/tests/greek_test.py +19 -0
- llm_tts/evaluation/latex2sympy/tests/grouping_test.py +52 -0
- llm_tts/evaluation/latex2sympy/tests/lcm_test.py +161 -0
- llm_tts/evaluation/latex2sympy/tests/left_right_cdot_test.py +9 -0
- llm_tts/evaluation/latex2sympy/tests/linalg_test.py +15 -0
- llm_tts/evaluation/latex2sympy/tests/max_test.py +79 -0
- llm_tts/evaluation/latex2sympy/tests/min_test.py +79 -0
- llm_tts/evaluation/latex2sympy/tests/mod_test.py +70 -0
- llm_tts/evaluation/latex2sympy/tests/overline_test.py +9 -0
- llm_tts/evaluation/latex2sympy/tests/pi_test.py +15 -0
- llm_tts/evaluation/latex2sympy/tests/trig_test.py +21 -0
- llm_tts/evaluation/latex2sympy/tests/variable_test.py +92 -0
- llm_tts/evaluation/llm_as_a_judge.py +309 -0
- llm_tts/evaluation/math_normalize.py +417 -0
- llm_tts/evaluation/mbpp_plus_evaluator.py +277 -0
- llm_tts/evaluation/parser.py +770 -0
- llm_tts/generators/__init__.py +66 -0
- llm_tts/generators/api.py +1249 -0
- llm_tts/generators/base.py +430 -0
- llm_tts/generators/huggingface.py +728 -0
- llm_tts/generators/vllm.py +1394 -0
- llm_tts/integrations/__init__.py +17 -0
- llm_tts/integrations/langchain_chat_model.py +168 -0
- llm_tts/models/__init__.py +8 -0
- llm_tts/models/base.py +62 -0
- llm_tts/models/blackboxmodel_with_streaming.py +392 -0
- llm_tts/scale_discriminator.py +127 -0
- llm_tts/scorers/__init__.py +14 -0
- llm_tts/scorers/estimator_uncertainty_pd.py +48 -0
- llm_tts/scorers/majority_voting.py +236 -0
- llm_tts/scorers/multi_scorer.py +236 -0
- llm_tts/scorers/step_scorer_base.py +153 -0
- llm_tts/scorers/step_scorer_confidence.py +47 -0
- llm_tts/scorers/step_scorer_llm_critic.py +947 -0
- llm_tts/scorers/step_scorer_prm.py +1002 -0
- llm_tts/scorers/step_scorer_reward_base.py +47 -0
- llm_tts/scorers/step_scorer_uncertainty.py +48 -0
- llm_tts/step_boundary_detectors/__init__.py +65 -0
- llm_tts/step_boundary_detectors/base.py +23 -0
- llm_tts/step_boundary_detectors/non_thinking/__init__.py +12 -0
- llm_tts/step_boundary_detectors/non_thinking/structured.py +169 -0
- llm_tts/step_boundary_detectors/thinking/__init__.py +39 -0
- llm_tts/step_boundary_detectors/thinking/huggingface/__init__.py +18 -0
- llm_tts/step_boundary_detectors/thinking/marker.py +662 -0
- llm_tts/step_boundary_detectors/thinking/offline/__init__.py +18 -0
- llm_tts/step_boundary_detectors/thinking/offline/hybrid.py +308 -0
- llm_tts/step_boundary_detectors/thinking/offline/llm.py +384 -0
- llm_tts/step_boundary_detectors/thinking/offline/sentence.py +138 -0
- llm_tts/step_boundary_detectors/thinking/vllm/__init__.py +31 -0
- llm_tts/step_boundary_detectors/thinking/vllm/stop_tokens.py +480 -0
- llm_tts/strategies/__init__.py +35 -0
- llm_tts/strategies/adaptive_scaling_best_of_n.py +679 -0
- llm_tts/strategies/deepconf/__init__.py +9 -0
- llm_tts/strategies/deepconf/strategy.py +1364 -0
- llm_tts/strategies/deepconf/utils.py +312 -0
- llm_tts/strategies/metadata_builder.py +222 -0
- llm_tts/strategies/phi.py +228 -0
- llm_tts/strategies/strategy_base.py +183 -0
- llm_tts/strategies/strategy_baseline.py +399 -0
- llm_tts/strategies/strategy_beam_search.py +1168 -0
- llm_tts/strategies/strategy_chain_of_thought.py +119 -0
- llm_tts/strategies/strategy_extended_thinking.py +386 -0
- llm_tts/strategies/strategy_offline_best_of_n.py +969 -0
- llm_tts/strategies/strategy_online_best_of_n.py +1101 -0
- llm_tts/strategies/strategy_self_consistency.py +512 -0
- llm_tts/strategies/strategy_uncertainty_cot.py +343 -0
- llm_tts/utils/__init__.py +15 -0
- llm_tts/utils/answer_extraction.py +141 -0
- llm_tts/utils/flops.py +295 -0
- llm_tts/utils/parallel.py +82 -0
- llm_tts/utils/telegram.py +154 -0
- llm_tts/utils/telegram_bot.py +83 -0
- llm_tts/utils/torch_dtype.py +25 -0
- service_app/__init__.py +0 -0
- service_app/api/__init__.py +0 -0
- service_app/api/models/__init__.py +0 -0
- service_app/api/models/openai_compat.py +238 -0
- service_app/api/routes/__init__.py +1 -0
- service_app/api/routes/chat.py +514 -0
- service_app/api/routes/debugger.py +103 -0
- service_app/api/routes/models.py +71 -0
- service_app/core/__init__.py +0 -0
- service_app/core/config.py +95 -0
- service_app/core/debugger_events.py +1035 -0
- service_app/core/logging_config.py +83 -0
- service_app/core/prm_scorer_factory.py +86 -0
- service_app/core/strategy_manager.py +687 -0
- service_app/core/visual_debugger_demo.py +689 -0
- service_app/main.py +314 -0
- thinkbooster-0.1.0.dist-info/METADATA +288 -0
- thinkbooster-0.1.0.dist-info/RECORD +134 -0
- thinkbooster-0.1.0.dist-info/WHEEL +5 -0
- thinkbooster-0.1.0.dist-info/licenses/LICENSE +22 -0
- thinkbooster-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Dataset loaders for various benchmarks."""
|
|
2
|
+
|
|
3
|
+
from .gsm8k import (
|
|
4
|
+
evaluate_gsm8k_answer,
|
|
5
|
+
extract_answer_from_gsm8k,
|
|
6
|
+
format_gsm8k_for_deepconf,
|
|
7
|
+
load_gsm8k,
|
|
8
|
+
)
|
|
9
|
+
from .human_eval_plus import create_evalplus_samples as create_human_eval_plus_samples
|
|
10
|
+
from .human_eval_plus import (
|
|
11
|
+
extract_code_from_response as extract_code_from_response_human_eval,
|
|
12
|
+
)
|
|
13
|
+
from .human_eval_plus import (
|
|
14
|
+
format_human_eval_prompt,
|
|
15
|
+
)
|
|
16
|
+
from .human_eval_plus import load_evalplus_samples as load_human_eval_plus_samples
|
|
17
|
+
from .human_eval_plus import (
|
|
18
|
+
load_human_eval_plus,
|
|
19
|
+
)
|
|
20
|
+
from .mbpp_plus import (
|
|
21
|
+
create_evalplus_samples,
|
|
22
|
+
extract_code_from_response,
|
|
23
|
+
format_mbpp_prompt,
|
|
24
|
+
load_evalplus_samples,
|
|
25
|
+
load_mbpp_plus,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
# GSM8K
|
|
30
|
+
"load_gsm8k",
|
|
31
|
+
"evaluate_gsm8k_answer",
|
|
32
|
+
"extract_answer_from_gsm8k",
|
|
33
|
+
"format_gsm8k_for_deepconf",
|
|
34
|
+
# MBPP+
|
|
35
|
+
"load_mbpp_plus",
|
|
36
|
+
"extract_code_from_response",
|
|
37
|
+
"format_mbpp_prompt",
|
|
38
|
+
"create_evalplus_samples",
|
|
39
|
+
"load_evalplus_samples",
|
|
40
|
+
# HumanEval+
|
|
41
|
+
"load_human_eval_plus",
|
|
42
|
+
"extract_code_from_response_human_eval",
|
|
43
|
+
"format_human_eval_prompt",
|
|
44
|
+
"create_human_eval_plus_samples",
|
|
45
|
+
"load_human_eval_plus_samples",
|
|
46
|
+
]
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GSM8K dataset loader and preprocessing for DeepConf evaluation.
|
|
3
|
+
|
|
4
|
+
GSM8K (Grade School Math 8K) is a dataset of 8.5K grade school math word problems.
|
|
5
|
+
Each problem requires multi-step reasoning to solve.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Dict, List, Optional
|
|
10
|
+
|
|
11
|
+
from datasets import load_dataset
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def extract_answer_from_gsm8k(solution: str) -> str:
|
|
17
|
+
"""
|
|
18
|
+
Extract the final numerical answer from GSM8K solution format.
|
|
19
|
+
|
|
20
|
+
GSM8K solutions end with "#### {answer}" format.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
solution: The solution string from GSM8K
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
The numerical answer as a string
|
|
27
|
+
"""
|
|
28
|
+
if "####" in solution:
|
|
29
|
+
answer = solution.split("####")[-1].strip()
|
|
30
|
+
# Remove commas from numbers (e.g., "1,000" -> "1000")
|
|
31
|
+
answer = answer.replace(",", "")
|
|
32
|
+
return answer
|
|
33
|
+
return ""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def format_gsm8k_for_deepconf(question: str, answer: str) -> Dict[str, str]:
|
|
37
|
+
"""
|
|
38
|
+
Format GSM8K data for DeepConf evaluation.
|
|
39
|
+
|
|
40
|
+
Converts GSM8K format to the format expected by DeepConf:
|
|
41
|
+
- Question stays as is
|
|
42
|
+
- Answer is extracted from "#### X" format
|
|
43
|
+
- Expected output format is \\boxed{X}
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
question: The question text
|
|
47
|
+
answer: The solution text (includes #### {answer})
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Dict with 'question' and 'answer' keys
|
|
51
|
+
"""
|
|
52
|
+
extracted_answer = extract_answer_from_gsm8k(answer)
|
|
53
|
+
|
|
54
|
+
return {
|
|
55
|
+
"question": question.strip(),
|
|
56
|
+
"answer": extracted_answer,
|
|
57
|
+
"original_solution": answer, # Keep for reference
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def load_gsm8k(
|
|
62
|
+
split: str = "test",
|
|
63
|
+
subset_size: Optional[int] = None,
|
|
64
|
+
cache_dir: Optional[str] = None,
|
|
65
|
+
) -> List[Dict[str, str]]:
|
|
66
|
+
"""
|
|
67
|
+
Load GSM8K dataset and format for DeepConf.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
split: Dataset split ('train' or 'test')
|
|
71
|
+
subset_size: If provided, only load first N examples
|
|
72
|
+
cache_dir: Cache directory for HuggingFace datasets
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
List of dicts with 'question' and 'answer' keys
|
|
76
|
+
"""
|
|
77
|
+
log.info(f"Loading GSM8K dataset (split={split})...")
|
|
78
|
+
|
|
79
|
+
# Load from HuggingFace
|
|
80
|
+
dataset = load_dataset("openai/gsm8k", "main", split=split, cache_dir=cache_dir)
|
|
81
|
+
|
|
82
|
+
# Take subset if requested
|
|
83
|
+
if subset_size is not None:
|
|
84
|
+
dataset = dataset.select(range(min(subset_size, len(dataset))))
|
|
85
|
+
log.info(f"Using subset of {len(dataset)} examples")
|
|
86
|
+
|
|
87
|
+
# Format for DeepConf
|
|
88
|
+
formatted_data = []
|
|
89
|
+
for item in dataset:
|
|
90
|
+
formatted = format_gsm8k_for_deepconf(
|
|
91
|
+
question=item["question"], answer=item["answer"]
|
|
92
|
+
)
|
|
93
|
+
formatted_data.append(formatted)
|
|
94
|
+
|
|
95
|
+
log.info(f"Loaded {len(formatted_data)} GSM8K examples")
|
|
96
|
+
|
|
97
|
+
return formatted_data
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def evaluate_gsm8k_answer(predicted: str, ground_truth: str) -> bool:
|
|
101
|
+
"""
|
|
102
|
+
Evaluate if predicted answer matches ground truth for GSM8K.
|
|
103
|
+
|
|
104
|
+
Handles:
|
|
105
|
+
- Numeric comparison (with tolerance for floats)
|
|
106
|
+
- String normalization (strip whitespace, lowercase)
|
|
107
|
+
- Comma removal from numbers
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
predicted: Predicted answer (extracted from \\boxed{})
|
|
111
|
+
ground_truth: Ground truth answer
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
True if answers match, False otherwise
|
|
115
|
+
"""
|
|
116
|
+
# Normalize both answers
|
|
117
|
+
pred_clean = predicted.strip().replace(",", "").lower()
|
|
118
|
+
gt_clean = ground_truth.strip().replace(",", "").lower()
|
|
119
|
+
|
|
120
|
+
# Direct string match
|
|
121
|
+
if pred_clean == gt_clean:
|
|
122
|
+
return True
|
|
123
|
+
|
|
124
|
+
# Try numeric comparison
|
|
125
|
+
try:
|
|
126
|
+
pred_num = float(pred_clean)
|
|
127
|
+
gt_num = float(gt_clean)
|
|
128
|
+
|
|
129
|
+
# Use relative tolerance for floating point comparison
|
|
130
|
+
return abs(pred_num - gt_num) < 1e-6 * max(abs(pred_num), abs(gt_num), 1)
|
|
131
|
+
except (ValueError, TypeError):
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
if __name__ == "__main__":
|
|
138
|
+
# Test loading
|
|
139
|
+
logging.basicConfig(level=logging.INFO)
|
|
140
|
+
|
|
141
|
+
print("\n=== Testing GSM8K loader ===\n")
|
|
142
|
+
|
|
143
|
+
# Load small subset
|
|
144
|
+
data = load_gsm8k(split="test", subset_size=5)
|
|
145
|
+
|
|
146
|
+
print(f"Loaded {len(data)} examples\n")
|
|
147
|
+
|
|
148
|
+
for i, item in enumerate(data[:3]):
|
|
149
|
+
print(f"Example {i+1}:")
|
|
150
|
+
print(f" Question: {item['question'][:100]}...")
|
|
151
|
+
print(f" Answer: {item['answer']}")
|
|
152
|
+
print(f" Original: {item['original_solution'][:80]}...")
|
|
153
|
+
print()
|
|
154
|
+
|
|
155
|
+
# Test answer evaluation
|
|
156
|
+
print("\n=== Testing answer evaluation ===\n")
|
|
157
|
+
test_cases = [
|
|
158
|
+
("70", "70", True),
|
|
159
|
+
("70", "70.0", True),
|
|
160
|
+
("1000", "1,000", True),
|
|
161
|
+
("70", "71", False),
|
|
162
|
+
("abc", "ABC", True),
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
for pred, gt, expected in test_cases:
|
|
166
|
+
result = evaluate_gsm8k_answer(pred, gt)
|
|
167
|
+
status = "✓" if result == expected else "✗"
|
|
168
|
+
print(f"{status} '{pred}' vs '{gt}': {result} (expected {expected})")
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HumanEval+ dataset loader and utilities.
|
|
3
|
+
|
|
4
|
+
HumanEval+ is an enhanced version of HumanEval with 80x more test cases
|
|
5
|
+
for rigorous evaluation of code generation.
|
|
6
|
+
|
|
7
|
+
Dataset: https://huggingface.co/datasets/evalplus/humanevalplus
|
|
8
|
+
EvalPlus: https://github.com/evalplus/evalplus
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import re
|
|
14
|
+
from typing import Any, Dict, List, Optional
|
|
15
|
+
|
|
16
|
+
from evalplus.data import get_human_eval_plus, write_jsonl
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def load_human_eval_plus(
|
|
22
|
+
subset_size: Optional[int] = None,
|
|
23
|
+
) -> List[Dict[str, Any]]:
|
|
24
|
+
"""
|
|
25
|
+
Load HumanEval+ dataset using evalplus API.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
subset_size: If provided, only load first N examples
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
List of dicts with formatted data for the evaluation pipeline
|
|
32
|
+
"""
|
|
33
|
+
return _load_from_evalplus(subset_size)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _load_from_evalplus(subset_size: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
37
|
+
"""Load HumanEval+ using evalplus API.
|
|
38
|
+
|
|
39
|
+
Formats prompts to match EvalPlus official methodology:
|
|
40
|
+
- instruction_prefix + code block with docstring
|
|
41
|
+
"""
|
|
42
|
+
log.info("Loading HumanEval+ using evalplus API...")
|
|
43
|
+
|
|
44
|
+
# EvalPlus instruction prefix for chat/instruction models
|
|
45
|
+
INSTRUCTION_PREFIX = (
|
|
46
|
+
"Please provide a self-contained Python script that solves the "
|
|
47
|
+
"following problem in a markdown code block:"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
problems = get_human_eval_plus()
|
|
51
|
+
formatted_data = []
|
|
52
|
+
|
|
53
|
+
for task_id, problem in problems.items():
|
|
54
|
+
# Format prompt exactly like EvalPlus does for chat models:
|
|
55
|
+
# instruction_prefix + "\n```python\n" + prompt + "\n```"
|
|
56
|
+
raw_prompt = problem["prompt"].strip()
|
|
57
|
+
formatted_prompt = f"{INSTRUCTION_PREFIX}\n```python\n{raw_prompt}\n```"
|
|
58
|
+
|
|
59
|
+
formatted = {
|
|
60
|
+
# Standard fields for the evaluation pipeline
|
|
61
|
+
"question": formatted_prompt,
|
|
62
|
+
"answer": problem["canonical_solution"],
|
|
63
|
+
# HumanEval+ specific fields
|
|
64
|
+
"task_id": task_id,
|
|
65
|
+
"entry_point": problem.get(
|
|
66
|
+
"entry_point", _extract_function_name(raw_prompt)
|
|
67
|
+
),
|
|
68
|
+
"prompt": raw_prompt, # Original prompt (function signature + docstring)
|
|
69
|
+
"base_input": problem.get("base_input", []),
|
|
70
|
+
"plus_input": problem.get("plus_input", []),
|
|
71
|
+
"atol": problem.get("atol", 0),
|
|
72
|
+
"contract": problem.get("contract", ""),
|
|
73
|
+
}
|
|
74
|
+
formatted_data.append(formatted)
|
|
75
|
+
|
|
76
|
+
if subset_size and len(formatted_data) >= subset_size:
|
|
77
|
+
break
|
|
78
|
+
|
|
79
|
+
log.info(f"Loaded {len(formatted_data)} HumanEval+ problems via evalplus API")
|
|
80
|
+
return formatted_data
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _extract_function_name(prompt: str) -> str:
|
|
84
|
+
"""Extract function name from HumanEval prompt."""
|
|
85
|
+
# HumanEval prompts typically start with function signature
|
|
86
|
+
# Pattern: "def function_name("
|
|
87
|
+
match = re.search(r"def (\w+)\s*\(", prompt)
|
|
88
|
+
if match:
|
|
89
|
+
return match.group(1)
|
|
90
|
+
|
|
91
|
+
# Default fallback
|
|
92
|
+
return "solution"
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def extract_code_from_response(response: str) -> str:
|
|
96
|
+
"""
|
|
97
|
+
Extract Python code from model response.
|
|
98
|
+
|
|
99
|
+
Handles various formats:
|
|
100
|
+
- Code blocks with ```python or ``` markers
|
|
101
|
+
- Raw code
|
|
102
|
+
- Code with explanation
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
response: Model's response text
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Extracted code string
|
|
109
|
+
"""
|
|
110
|
+
# Try to extract from code blocks first
|
|
111
|
+
# Match ```python ... ``` or ``` ... ``` blocks
|
|
112
|
+
code_block_pattern = r"```(?:python)?\s*\n(.*?)```"
|
|
113
|
+
code_blocks = re.findall(code_block_pattern, response, re.DOTALL)
|
|
114
|
+
|
|
115
|
+
if code_blocks:
|
|
116
|
+
# Return the last code block (usually the final solution)
|
|
117
|
+
code = code_blocks[-1].strip()
|
|
118
|
+
# Handle malformed code blocks where "python" appears on its own line
|
|
119
|
+
# Some models output "```\npython\n..." instead of "```python\n..."
|
|
120
|
+
if code.startswith("python\n"):
|
|
121
|
+
code = code[7:] # Remove "python\n"
|
|
122
|
+
elif code.startswith("python3\n"):
|
|
123
|
+
code = code[8:] # Remove "python3\n"
|
|
124
|
+
elif code.startswith("Python\n"):
|
|
125
|
+
code = code[7:] # Remove "Python\n"
|
|
126
|
+
return code.strip()
|
|
127
|
+
|
|
128
|
+
# Try to find function definition
|
|
129
|
+
# Look for def ... up to the next blank line or end
|
|
130
|
+
func_pattern = r"(def \w+\s*\([^)]*\):.*?)(?:\n\n|\Z)"
|
|
131
|
+
func_matches = re.findall(func_pattern, response, re.DOTALL)
|
|
132
|
+
|
|
133
|
+
if func_matches:
|
|
134
|
+
return func_matches[-1].strip()
|
|
135
|
+
|
|
136
|
+
# Return the response as-is (might be raw code)
|
|
137
|
+
return response.strip()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def format_human_eval_prompt(
|
|
141
|
+
problem: Dict[str, Any],
|
|
142
|
+
prompt_template: Optional[str] = None,
|
|
143
|
+
) -> str:
|
|
144
|
+
"""
|
|
145
|
+
Format a HumanEval+ problem into a prompt for the model.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
problem: A formatted HumanEval+ problem dict
|
|
149
|
+
prompt_template: Optional template with {prompt}, {entry_point} placeholders
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Formatted prompt string
|
|
153
|
+
"""
|
|
154
|
+
if prompt_template:
|
|
155
|
+
return prompt_template.format(
|
|
156
|
+
prompt=problem["question"],
|
|
157
|
+
entry_point=problem.get("entry_point", "solution"),
|
|
158
|
+
task_id=problem.get("task_id", ""),
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Default formatting - just return the prompt
|
|
162
|
+
return problem["question"]
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def create_evalplus_samples(
|
|
166
|
+
results: List[Dict[str, Any]],
|
|
167
|
+
output_path: str,
|
|
168
|
+
) -> None:
|
|
169
|
+
"""
|
|
170
|
+
Create a samples file in EvalPlus format for evaluation.
|
|
171
|
+
|
|
172
|
+
The format expected by evalplus is JSONL with:
|
|
173
|
+
- task_id: str
|
|
174
|
+
- solution: str (the complete solution code)
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
results: List of result dicts with 'task_id' and generated code
|
|
178
|
+
output_path: Path to save the samples file
|
|
179
|
+
"""
|
|
180
|
+
log.info(f"Saving {len(results)} samples to {output_path}")
|
|
181
|
+
|
|
182
|
+
samples = []
|
|
183
|
+
for result in results:
|
|
184
|
+
task_id = result.get("task_id", "")
|
|
185
|
+
# Get the generated code from various possible fields
|
|
186
|
+
solution = (
|
|
187
|
+
result.get("generated_code")
|
|
188
|
+
or result.get("extracted_answer")
|
|
189
|
+
or result.get("generated_answer")
|
|
190
|
+
or result.get("generated_trajectory", "")
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Extract code if it contains markdown
|
|
194
|
+
if "```" in solution:
|
|
195
|
+
solution = extract_code_from_response(solution)
|
|
196
|
+
|
|
197
|
+
samples.append(
|
|
198
|
+
{
|
|
199
|
+
"task_id": task_id,
|
|
200
|
+
"solution": solution,
|
|
201
|
+
}
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
write_jsonl(output_path, samples)
|
|
205
|
+
|
|
206
|
+
log.info(f"Samples saved to {output_path}")
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def load_evalplus_samples(path: str) -> List[Dict[str, Any]]:
|
|
210
|
+
"""
|
|
211
|
+
Load samples from an EvalPlus format file.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
path: Path to the samples JSONL file
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
List of sample dictionaries
|
|
218
|
+
"""
|
|
219
|
+
samples = []
|
|
220
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
221
|
+
for line in f:
|
|
222
|
+
line = line.strip()
|
|
223
|
+
if line:
|
|
224
|
+
samples.append(json.loads(line))
|
|
225
|
+
return samples
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
if __name__ == "__main__":
|
|
229
|
+
# Test loading
|
|
230
|
+
logging.basicConfig(level=logging.INFO)
|
|
231
|
+
|
|
232
|
+
print("\n=== Testing HumanEval+ loader ===\n")
|
|
233
|
+
|
|
234
|
+
# Load small subset
|
|
235
|
+
data = load_human_eval_plus(subset_size=5)
|
|
236
|
+
|
|
237
|
+
print(f"Loaded {len(data)} problems\n")
|
|
238
|
+
|
|
239
|
+
for i, item in enumerate(data[:3]):
|
|
240
|
+
print(f"Problem {i + 1}:")
|
|
241
|
+
print(f" Task ID: {item['task_id']}")
|
|
242
|
+
print(f" Entry point: {item['entry_point']}")
|
|
243
|
+
print(f" Prompt: {item['question'][:100]}...")
|
|
244
|
+
print(f" Solution preview: {item['answer'][:80]}...")
|
|
245
|
+
print()
|
|
246
|
+
|
|
247
|
+
# Test code extraction
|
|
248
|
+
print("\n=== Testing code extraction ===\n")
|
|
249
|
+
|
|
250
|
+
test_response = """
|
|
251
|
+
Here's the solution:
|
|
252
|
+
|
|
253
|
+
```python
|
|
254
|
+
def has_close_elements(numbers: List[float], threshold: float) -> bool:
|
|
255
|
+
for i in range(len(numbers)):
|
|
256
|
+
for j in range(i + 1, len(numbers)):
|
|
257
|
+
if abs(numbers[i] - numbers[j]) < threshold:
|
|
258
|
+
return True
|
|
259
|
+
return False
|
|
260
|
+
```
|
|
261
|
+
|
|
262
|
+
This function checks if any two elements are closer than the threshold.
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
extracted = extract_code_from_response(test_response)
|
|
266
|
+
print(f"Extracted code:\n{extracted}")
|