thinkbooster 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. llm_tts/datasets/__init__.py +46 -0
  2. llm_tts/datasets/gsm8k.py +168 -0
  3. llm_tts/datasets/human_eval_plus.py +266 -0
  4. llm_tts/datasets/kernelbench.py +238 -0
  5. llm_tts/datasets/mbpp_plus.py +283 -0
  6. llm_tts/early_stopping.py +295 -0
  7. llm_tts/evaluation/__init__.py +13 -0
  8. llm_tts/evaluation/alignscore.py +86 -0
  9. llm_tts/evaluation/exact_match.py +258 -0
  10. llm_tts/evaluation/grader.py +399 -0
  11. llm_tts/evaluation/human_eval_plus_evaluator.py +277 -0
  12. llm_tts/evaluation/latex2sympy/__init__.py +8 -0
  13. llm_tts/evaluation/latex2sympy/asciimath_printer.py +50 -0
  14. llm_tts/evaluation/latex2sympy/gen/PSLexer.py +1692 -0
  15. llm_tts/evaluation/latex2sympy/gen/PSListener.py +579 -0
  16. llm_tts/evaluation/latex2sympy/gen/PSParser.py +7502 -0
  17. llm_tts/evaluation/latex2sympy/gen/PSVisitor.py +328 -0
  18. llm_tts/evaluation/latex2sympy/gen/__init__.py +0 -0
  19. llm_tts/evaluation/latex2sympy/latex2sympy2.py +1157 -0
  20. llm_tts/evaluation/latex2sympy/sandbox/linalg_equations.py +10 -0
  21. llm_tts/evaluation/latex2sympy/sandbox/linalg_span.py +19 -0
  22. llm_tts/evaluation/latex2sympy/sandbox/matrix.py +46 -0
  23. llm_tts/evaluation/latex2sympy/sandbox/matrix_placeholders.py +65 -0
  24. llm_tts/evaluation/latex2sympy/sandbox/sandbox.py +23 -0
  25. llm_tts/evaluation/latex2sympy/sandbox/sandbox_equality.py +75 -0
  26. llm_tts/evaluation/latex2sympy/sandbox/sectan.py +51 -0
  27. llm_tts/evaluation/latex2sympy/sandbox/vector.py +75 -0
  28. llm_tts/evaluation/latex2sympy/setup.py +45 -0
  29. llm_tts/evaluation/latex2sympy/tests/__init__.py +0 -0
  30. llm_tts/evaluation/latex2sympy/tests/abs_test.py +19 -0
  31. llm_tts/evaluation/latex2sympy/tests/all_bad_test.py +70 -0
  32. llm_tts/evaluation/latex2sympy/tests/all_good_test.py +284 -0
  33. llm_tts/evaluation/latex2sympy/tests/atom_expr_test.py +58 -0
  34. llm_tts/evaluation/latex2sympy/tests/binomial_test.py +36 -0
  35. llm_tts/evaluation/latex2sympy/tests/ceil_test.py +29 -0
  36. llm_tts/evaluation/latex2sympy/tests/complex_test.py +21 -0
  37. llm_tts/evaluation/latex2sympy/tests/context.py +84 -0
  38. llm_tts/evaluation/latex2sympy/tests/exp_test.py +57 -0
  39. llm_tts/evaluation/latex2sympy/tests/floor_test.py +29 -0
  40. llm_tts/evaluation/latex2sympy/tests/gcd_test.py +161 -0
  41. llm_tts/evaluation/latex2sympy/tests/greek_test.py +19 -0
  42. llm_tts/evaluation/latex2sympy/tests/grouping_test.py +52 -0
  43. llm_tts/evaluation/latex2sympy/tests/lcm_test.py +161 -0
  44. llm_tts/evaluation/latex2sympy/tests/left_right_cdot_test.py +9 -0
  45. llm_tts/evaluation/latex2sympy/tests/linalg_test.py +15 -0
  46. llm_tts/evaluation/latex2sympy/tests/max_test.py +79 -0
  47. llm_tts/evaluation/latex2sympy/tests/min_test.py +79 -0
  48. llm_tts/evaluation/latex2sympy/tests/mod_test.py +70 -0
  49. llm_tts/evaluation/latex2sympy/tests/overline_test.py +9 -0
  50. llm_tts/evaluation/latex2sympy/tests/pi_test.py +15 -0
  51. llm_tts/evaluation/latex2sympy/tests/trig_test.py +21 -0
  52. llm_tts/evaluation/latex2sympy/tests/variable_test.py +92 -0
  53. llm_tts/evaluation/llm_as_a_judge.py +309 -0
  54. llm_tts/evaluation/math_normalize.py +417 -0
  55. llm_tts/evaluation/mbpp_plus_evaluator.py +277 -0
  56. llm_tts/evaluation/parser.py +770 -0
  57. llm_tts/generators/__init__.py +66 -0
  58. llm_tts/generators/api.py +1249 -0
  59. llm_tts/generators/base.py +430 -0
  60. llm_tts/generators/huggingface.py +728 -0
  61. llm_tts/generators/vllm.py +1394 -0
  62. llm_tts/integrations/__init__.py +17 -0
  63. llm_tts/integrations/langchain_chat_model.py +168 -0
  64. llm_tts/models/__init__.py +8 -0
  65. llm_tts/models/base.py +62 -0
  66. llm_tts/models/blackboxmodel_with_streaming.py +392 -0
  67. llm_tts/scale_discriminator.py +127 -0
  68. llm_tts/scorers/__init__.py +14 -0
  69. llm_tts/scorers/estimator_uncertainty_pd.py +48 -0
  70. llm_tts/scorers/majority_voting.py +236 -0
  71. llm_tts/scorers/multi_scorer.py +236 -0
  72. llm_tts/scorers/step_scorer_base.py +153 -0
  73. llm_tts/scorers/step_scorer_confidence.py +47 -0
  74. llm_tts/scorers/step_scorer_llm_critic.py +947 -0
  75. llm_tts/scorers/step_scorer_prm.py +1002 -0
  76. llm_tts/scorers/step_scorer_reward_base.py +47 -0
  77. llm_tts/scorers/step_scorer_uncertainty.py +48 -0
  78. llm_tts/step_boundary_detectors/__init__.py +65 -0
  79. llm_tts/step_boundary_detectors/base.py +23 -0
  80. llm_tts/step_boundary_detectors/non_thinking/__init__.py +12 -0
  81. llm_tts/step_boundary_detectors/non_thinking/structured.py +169 -0
  82. llm_tts/step_boundary_detectors/thinking/__init__.py +39 -0
  83. llm_tts/step_boundary_detectors/thinking/huggingface/__init__.py +18 -0
  84. llm_tts/step_boundary_detectors/thinking/marker.py +662 -0
  85. llm_tts/step_boundary_detectors/thinking/offline/__init__.py +18 -0
  86. llm_tts/step_boundary_detectors/thinking/offline/hybrid.py +308 -0
  87. llm_tts/step_boundary_detectors/thinking/offline/llm.py +384 -0
  88. llm_tts/step_boundary_detectors/thinking/offline/sentence.py +138 -0
  89. llm_tts/step_boundary_detectors/thinking/vllm/__init__.py +31 -0
  90. llm_tts/step_boundary_detectors/thinking/vllm/stop_tokens.py +480 -0
  91. llm_tts/strategies/__init__.py +35 -0
  92. llm_tts/strategies/adaptive_scaling_best_of_n.py +679 -0
  93. llm_tts/strategies/deepconf/__init__.py +9 -0
  94. llm_tts/strategies/deepconf/strategy.py +1364 -0
  95. llm_tts/strategies/deepconf/utils.py +312 -0
  96. llm_tts/strategies/metadata_builder.py +222 -0
  97. llm_tts/strategies/phi.py +228 -0
  98. llm_tts/strategies/strategy_base.py +183 -0
  99. llm_tts/strategies/strategy_baseline.py +399 -0
  100. llm_tts/strategies/strategy_beam_search.py +1168 -0
  101. llm_tts/strategies/strategy_chain_of_thought.py +119 -0
  102. llm_tts/strategies/strategy_extended_thinking.py +386 -0
  103. llm_tts/strategies/strategy_offline_best_of_n.py +969 -0
  104. llm_tts/strategies/strategy_online_best_of_n.py +1101 -0
  105. llm_tts/strategies/strategy_self_consistency.py +512 -0
  106. llm_tts/strategies/strategy_uncertainty_cot.py +343 -0
  107. llm_tts/utils/__init__.py +15 -0
  108. llm_tts/utils/answer_extraction.py +141 -0
  109. llm_tts/utils/flops.py +295 -0
  110. llm_tts/utils/parallel.py +82 -0
  111. llm_tts/utils/telegram.py +154 -0
  112. llm_tts/utils/telegram_bot.py +83 -0
  113. llm_tts/utils/torch_dtype.py +25 -0
  114. service_app/__init__.py +0 -0
  115. service_app/api/__init__.py +0 -0
  116. service_app/api/models/__init__.py +0 -0
  117. service_app/api/models/openai_compat.py +238 -0
  118. service_app/api/routes/__init__.py +1 -0
  119. service_app/api/routes/chat.py +514 -0
  120. service_app/api/routes/debugger.py +103 -0
  121. service_app/api/routes/models.py +71 -0
  122. service_app/core/__init__.py +0 -0
  123. service_app/core/config.py +95 -0
  124. service_app/core/debugger_events.py +1035 -0
  125. service_app/core/logging_config.py +83 -0
  126. service_app/core/prm_scorer_factory.py +86 -0
  127. service_app/core/strategy_manager.py +687 -0
  128. service_app/core/visual_debugger_demo.py +689 -0
  129. service_app/main.py +314 -0
  130. thinkbooster-0.1.0.dist-info/METADATA +288 -0
  131. thinkbooster-0.1.0.dist-info/RECORD +134 -0
  132. thinkbooster-0.1.0.dist-info/WHEEL +5 -0
  133. thinkbooster-0.1.0.dist-info/licenses/LICENSE +22 -0
  134. thinkbooster-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,46 @@
1
+ """Dataset loaders for various benchmarks."""
2
+
3
+ from .gsm8k import (
4
+ evaluate_gsm8k_answer,
5
+ extract_answer_from_gsm8k,
6
+ format_gsm8k_for_deepconf,
7
+ load_gsm8k,
8
+ )
9
+ from .human_eval_plus import create_evalplus_samples as create_human_eval_plus_samples
10
+ from .human_eval_plus import (
11
+ extract_code_from_response as extract_code_from_response_human_eval,
12
+ )
13
+ from .human_eval_plus import (
14
+ format_human_eval_prompt,
15
+ )
16
+ from .human_eval_plus import load_evalplus_samples as load_human_eval_plus_samples
17
+ from .human_eval_plus import (
18
+ load_human_eval_plus,
19
+ )
20
+ from .mbpp_plus import (
21
+ create_evalplus_samples,
22
+ extract_code_from_response,
23
+ format_mbpp_prompt,
24
+ load_evalplus_samples,
25
+ load_mbpp_plus,
26
+ )
27
+
28
+ __all__ = [
29
+ # GSM8K
30
+ "load_gsm8k",
31
+ "evaluate_gsm8k_answer",
32
+ "extract_answer_from_gsm8k",
33
+ "format_gsm8k_for_deepconf",
34
+ # MBPP+
35
+ "load_mbpp_plus",
36
+ "extract_code_from_response",
37
+ "format_mbpp_prompt",
38
+ "create_evalplus_samples",
39
+ "load_evalplus_samples",
40
+ # HumanEval+
41
+ "load_human_eval_plus",
42
+ "extract_code_from_response_human_eval",
43
+ "format_human_eval_prompt",
44
+ "create_human_eval_plus_samples",
45
+ "load_human_eval_plus_samples",
46
+ ]
@@ -0,0 +1,168 @@
1
+ """
2
+ GSM8K dataset loader and preprocessing for DeepConf evaluation.
3
+
4
+ GSM8K (Grade School Math 8K) is a dataset of 8.5K grade school math word problems.
5
+ Each problem requires multi-step reasoning to solve.
6
+ """
7
+
8
+ import logging
9
+ from typing import Dict, List, Optional
10
+
11
+ from datasets import load_dataset
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ def extract_answer_from_gsm8k(solution: str) -> str:
17
+ """
18
+ Extract the final numerical answer from GSM8K solution format.
19
+
20
+ GSM8K solutions end with "#### {answer}" format.
21
+
22
+ Args:
23
+ solution: The solution string from GSM8K
24
+
25
+ Returns:
26
+ The numerical answer as a string
27
+ """
28
+ if "####" in solution:
29
+ answer = solution.split("####")[-1].strip()
30
+ # Remove commas from numbers (e.g., "1,000" -> "1000")
31
+ answer = answer.replace(",", "")
32
+ return answer
33
+ return ""
34
+
35
+
36
+ def format_gsm8k_for_deepconf(question: str, answer: str) -> Dict[str, str]:
37
+ """
38
+ Format GSM8K data for DeepConf evaluation.
39
+
40
+ Converts GSM8K format to the format expected by DeepConf:
41
+ - Question stays as is
42
+ - Answer is extracted from "#### X" format
43
+ - Expected output format is \\boxed{X}
44
+
45
+ Args:
46
+ question: The question text
47
+ answer: The solution text (includes #### {answer})
48
+
49
+ Returns:
50
+ Dict with 'question' and 'answer' keys
51
+ """
52
+ extracted_answer = extract_answer_from_gsm8k(answer)
53
+
54
+ return {
55
+ "question": question.strip(),
56
+ "answer": extracted_answer,
57
+ "original_solution": answer, # Keep for reference
58
+ }
59
+
60
+
61
+ def load_gsm8k(
62
+ split: str = "test",
63
+ subset_size: Optional[int] = None,
64
+ cache_dir: Optional[str] = None,
65
+ ) -> List[Dict[str, str]]:
66
+ """
67
+ Load GSM8K dataset and format for DeepConf.
68
+
69
+ Args:
70
+ split: Dataset split ('train' or 'test')
71
+ subset_size: If provided, only load first N examples
72
+ cache_dir: Cache directory for HuggingFace datasets
73
+
74
+ Returns:
75
+ List of dicts with 'question' and 'answer' keys
76
+ """
77
+ log.info(f"Loading GSM8K dataset (split={split})...")
78
+
79
+ # Load from HuggingFace
80
+ dataset = load_dataset("openai/gsm8k", "main", split=split, cache_dir=cache_dir)
81
+
82
+ # Take subset if requested
83
+ if subset_size is not None:
84
+ dataset = dataset.select(range(min(subset_size, len(dataset))))
85
+ log.info(f"Using subset of {len(dataset)} examples")
86
+
87
+ # Format for DeepConf
88
+ formatted_data = []
89
+ for item in dataset:
90
+ formatted = format_gsm8k_for_deepconf(
91
+ question=item["question"], answer=item["answer"]
92
+ )
93
+ formatted_data.append(formatted)
94
+
95
+ log.info(f"Loaded {len(formatted_data)} GSM8K examples")
96
+
97
+ return formatted_data
98
+
99
+
100
+ def evaluate_gsm8k_answer(predicted: str, ground_truth: str) -> bool:
101
+ """
102
+ Evaluate if predicted answer matches ground truth for GSM8K.
103
+
104
+ Handles:
105
+ - Numeric comparison (with tolerance for floats)
106
+ - String normalization (strip whitespace, lowercase)
107
+ - Comma removal from numbers
108
+
109
+ Args:
110
+ predicted: Predicted answer (extracted from \\boxed{})
111
+ ground_truth: Ground truth answer
112
+
113
+ Returns:
114
+ True if answers match, False otherwise
115
+ """
116
+ # Normalize both answers
117
+ pred_clean = predicted.strip().replace(",", "").lower()
118
+ gt_clean = ground_truth.strip().replace(",", "").lower()
119
+
120
+ # Direct string match
121
+ if pred_clean == gt_clean:
122
+ return True
123
+
124
+ # Try numeric comparison
125
+ try:
126
+ pred_num = float(pred_clean)
127
+ gt_num = float(gt_clean)
128
+
129
+ # Use relative tolerance for floating point comparison
130
+ return abs(pred_num - gt_num) < 1e-6 * max(abs(pred_num), abs(gt_num), 1)
131
+ except (ValueError, TypeError):
132
+ pass
133
+
134
+ return False
135
+
136
+
137
+ if __name__ == "__main__":
138
+ # Test loading
139
+ logging.basicConfig(level=logging.INFO)
140
+
141
+ print("\n=== Testing GSM8K loader ===\n")
142
+
143
+ # Load small subset
144
+ data = load_gsm8k(split="test", subset_size=5)
145
+
146
+ print(f"Loaded {len(data)} examples\n")
147
+
148
+ for i, item in enumerate(data[:3]):
149
+ print(f"Example {i+1}:")
150
+ print(f" Question: {item['question'][:100]}...")
151
+ print(f" Answer: {item['answer']}")
152
+ print(f" Original: {item['original_solution'][:80]}...")
153
+ print()
154
+
155
+ # Test answer evaluation
156
+ print("\n=== Testing answer evaluation ===\n")
157
+ test_cases = [
158
+ ("70", "70", True),
159
+ ("70", "70.0", True),
160
+ ("1000", "1,000", True),
161
+ ("70", "71", False),
162
+ ("abc", "ABC", True),
163
+ ]
164
+
165
+ for pred, gt, expected in test_cases:
166
+ result = evaluate_gsm8k_answer(pred, gt)
167
+ status = "✓" if result == expected else "✗"
168
+ print(f"{status} '{pred}' vs '{gt}': {result} (expected {expected})")
@@ -0,0 +1,266 @@
1
+ """
2
+ HumanEval+ dataset loader and utilities.
3
+
4
+ HumanEval+ is an enhanced version of HumanEval with 80x more test cases
5
+ for rigorous evaluation of code generation.
6
+
7
+ Dataset: https://huggingface.co/datasets/evalplus/humanevalplus
8
+ EvalPlus: https://github.com/evalplus/evalplus
9
+ """
10
+
11
+ import json
12
+ import logging
13
+ import re
14
+ from typing import Any, Dict, List, Optional
15
+
16
+ from evalplus.data import get_human_eval_plus, write_jsonl
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ def load_human_eval_plus(
22
+ subset_size: Optional[int] = None,
23
+ ) -> List[Dict[str, Any]]:
24
+ """
25
+ Load HumanEval+ dataset using evalplus API.
26
+
27
+ Args:
28
+ subset_size: If provided, only load first N examples
29
+
30
+ Returns:
31
+ List of dicts with formatted data for the evaluation pipeline
32
+ """
33
+ return _load_from_evalplus(subset_size)
34
+
35
+
36
+ def _load_from_evalplus(subset_size: Optional[int] = None) -> List[Dict[str, Any]]:
37
+ """Load HumanEval+ using evalplus API.
38
+
39
+ Formats prompts to match EvalPlus official methodology:
40
+ - instruction_prefix + code block with docstring
41
+ """
42
+ log.info("Loading HumanEval+ using evalplus API...")
43
+
44
+ # EvalPlus instruction prefix for chat/instruction models
45
+ INSTRUCTION_PREFIX = (
46
+ "Please provide a self-contained Python script that solves the "
47
+ "following problem in a markdown code block:"
48
+ )
49
+
50
+ problems = get_human_eval_plus()
51
+ formatted_data = []
52
+
53
+ for task_id, problem in problems.items():
54
+ # Format prompt exactly like EvalPlus does for chat models:
55
+ # instruction_prefix + "\n```python\n" + prompt + "\n```"
56
+ raw_prompt = problem["prompt"].strip()
57
+ formatted_prompt = f"{INSTRUCTION_PREFIX}\n```python\n{raw_prompt}\n```"
58
+
59
+ formatted = {
60
+ # Standard fields for the evaluation pipeline
61
+ "question": formatted_prompt,
62
+ "answer": problem["canonical_solution"],
63
+ # HumanEval+ specific fields
64
+ "task_id": task_id,
65
+ "entry_point": problem.get(
66
+ "entry_point", _extract_function_name(raw_prompt)
67
+ ),
68
+ "prompt": raw_prompt, # Original prompt (function signature + docstring)
69
+ "base_input": problem.get("base_input", []),
70
+ "plus_input": problem.get("plus_input", []),
71
+ "atol": problem.get("atol", 0),
72
+ "contract": problem.get("contract", ""),
73
+ }
74
+ formatted_data.append(formatted)
75
+
76
+ if subset_size and len(formatted_data) >= subset_size:
77
+ break
78
+
79
+ log.info(f"Loaded {len(formatted_data)} HumanEval+ problems via evalplus API")
80
+ return formatted_data
81
+
82
+
83
+ def _extract_function_name(prompt: str) -> str:
84
+ """Extract function name from HumanEval prompt."""
85
+ # HumanEval prompts typically start with function signature
86
+ # Pattern: "def function_name("
87
+ match = re.search(r"def (\w+)\s*\(", prompt)
88
+ if match:
89
+ return match.group(1)
90
+
91
+ # Default fallback
92
+ return "solution"
93
+
94
+
95
+ def extract_code_from_response(response: str) -> str:
96
+ """
97
+ Extract Python code from model response.
98
+
99
+ Handles various formats:
100
+ - Code blocks with ```python or ``` markers
101
+ - Raw code
102
+ - Code with explanation
103
+
104
+ Args:
105
+ response: Model's response text
106
+
107
+ Returns:
108
+ Extracted code string
109
+ """
110
+ # Try to extract from code blocks first
111
+ # Match ```python ... ``` or ``` ... ``` blocks
112
+ code_block_pattern = r"```(?:python)?\s*\n(.*?)```"
113
+ code_blocks = re.findall(code_block_pattern, response, re.DOTALL)
114
+
115
+ if code_blocks:
116
+ # Return the last code block (usually the final solution)
117
+ code = code_blocks[-1].strip()
118
+ # Handle malformed code blocks where "python" appears on its own line
119
+ # Some models output "```\npython\n..." instead of "```python\n..."
120
+ if code.startswith("python\n"):
121
+ code = code[7:] # Remove "python\n"
122
+ elif code.startswith("python3\n"):
123
+ code = code[8:] # Remove "python3\n"
124
+ elif code.startswith("Python\n"):
125
+ code = code[7:] # Remove "Python\n"
126
+ return code.strip()
127
+
128
+ # Try to find function definition
129
+ # Look for def ... up to the next blank line or end
130
+ func_pattern = r"(def \w+\s*\([^)]*\):.*?)(?:\n\n|\Z)"
131
+ func_matches = re.findall(func_pattern, response, re.DOTALL)
132
+
133
+ if func_matches:
134
+ return func_matches[-1].strip()
135
+
136
+ # Return the response as-is (might be raw code)
137
+ return response.strip()
138
+
139
+
140
+ def format_human_eval_prompt(
141
+ problem: Dict[str, Any],
142
+ prompt_template: Optional[str] = None,
143
+ ) -> str:
144
+ """
145
+ Format a HumanEval+ problem into a prompt for the model.
146
+
147
+ Args:
148
+ problem: A formatted HumanEval+ problem dict
149
+ prompt_template: Optional template with {prompt}, {entry_point} placeholders
150
+
151
+ Returns:
152
+ Formatted prompt string
153
+ """
154
+ if prompt_template:
155
+ return prompt_template.format(
156
+ prompt=problem["question"],
157
+ entry_point=problem.get("entry_point", "solution"),
158
+ task_id=problem.get("task_id", ""),
159
+ )
160
+
161
+ # Default formatting - just return the prompt
162
+ return problem["question"]
163
+
164
+
165
+ def create_evalplus_samples(
166
+ results: List[Dict[str, Any]],
167
+ output_path: str,
168
+ ) -> None:
169
+ """
170
+ Create a samples file in EvalPlus format for evaluation.
171
+
172
+ The format expected by evalplus is JSONL with:
173
+ - task_id: str
174
+ - solution: str (the complete solution code)
175
+
176
+ Args:
177
+ results: List of result dicts with 'task_id' and generated code
178
+ output_path: Path to save the samples file
179
+ """
180
+ log.info(f"Saving {len(results)} samples to {output_path}")
181
+
182
+ samples = []
183
+ for result in results:
184
+ task_id = result.get("task_id", "")
185
+ # Get the generated code from various possible fields
186
+ solution = (
187
+ result.get("generated_code")
188
+ or result.get("extracted_answer")
189
+ or result.get("generated_answer")
190
+ or result.get("generated_trajectory", "")
191
+ )
192
+
193
+ # Extract code if it contains markdown
194
+ if "```" in solution:
195
+ solution = extract_code_from_response(solution)
196
+
197
+ samples.append(
198
+ {
199
+ "task_id": task_id,
200
+ "solution": solution,
201
+ }
202
+ )
203
+
204
+ write_jsonl(output_path, samples)
205
+
206
+ log.info(f"Samples saved to {output_path}")
207
+
208
+
209
+ def load_evalplus_samples(path: str) -> List[Dict[str, Any]]:
210
+ """
211
+ Load samples from an EvalPlus format file.
212
+
213
+ Args:
214
+ path: Path to the samples JSONL file
215
+
216
+ Returns:
217
+ List of sample dictionaries
218
+ """
219
+ samples = []
220
+ with open(path, "r", encoding="utf-8") as f:
221
+ for line in f:
222
+ line = line.strip()
223
+ if line:
224
+ samples.append(json.loads(line))
225
+ return samples
226
+
227
+
228
+ if __name__ == "__main__":
229
+ # Test loading
230
+ logging.basicConfig(level=logging.INFO)
231
+
232
+ print("\n=== Testing HumanEval+ loader ===\n")
233
+
234
+ # Load small subset
235
+ data = load_human_eval_plus(subset_size=5)
236
+
237
+ print(f"Loaded {len(data)} problems\n")
238
+
239
+ for i, item in enumerate(data[:3]):
240
+ print(f"Problem {i + 1}:")
241
+ print(f" Task ID: {item['task_id']}")
242
+ print(f" Entry point: {item['entry_point']}")
243
+ print(f" Prompt: {item['question'][:100]}...")
244
+ print(f" Solution preview: {item['answer'][:80]}...")
245
+ print()
246
+
247
+ # Test code extraction
248
+ print("\n=== Testing code extraction ===\n")
249
+
250
+ test_response = """
251
+ Here's the solution:
252
+
253
+ ```python
254
+ def has_close_elements(numbers: List[float], threshold: float) -> bool:
255
+ for i in range(len(numbers)):
256
+ for j in range(i + 1, len(numbers)):
257
+ if abs(numbers[i] - numbers[j]) < threshold:
258
+ return True
259
+ return False
260
+ ```
261
+
262
+ This function checks if any two elements are closer than the threshold.
263
+ """
264
+
265
+ extracted = extract_code_from_response(test_response)
266
+ print(f"Extracted code:\n{extracted}")