textpolicy 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +3 -0
- textpolicy/algorithms/__init__.py +29 -4
- textpolicy/algorithms/grpo.py +771 -361
- textpolicy/algorithms/length_shaping.py +151 -0
- textpolicy/analysis/__init__.py +23 -0
- textpolicy/analysis/emergence_logger.py +248 -0
- textpolicy/analysis/planning_patterns.py +105 -0
- textpolicy/analysis/serialization.py +65 -0
- textpolicy/generation/mlx_generation.py +36 -21
- textpolicy/tasks/__init__.py +7 -0
- textpolicy/tasks/countdown/__init__.py +21 -0
- textpolicy/tasks/countdown/dataset.py +163 -0
- textpolicy/tasks/countdown/evaluator.py +197 -0
- textpolicy/tasks/countdown/prompt.py +89 -0
- textpolicy/tasks/countdown/reward.py +56 -0
- textpolicy/training/trainer.py +41 -21
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/METADATA +1 -1
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/RECORD +22 -11
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/WHEEL +0 -0
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/entry_points.txt +0 -0
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/top_level.txt +0 -0
|
@@ -27,8 +27,8 @@ try:
|
|
|
27
27
|
from mlx_lm.sample_utils import make_sampler, make_logits_processors
|
|
28
28
|
# sampling utilities fallback when sample_utils is unavailable
|
|
29
29
|
except ImportError:
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
make_sampler = None
|
|
31
|
+
make_logits_processors = None
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def _get_eos_configs_for_model(
|
|
@@ -194,7 +194,9 @@ def generate_tokens(
|
|
|
194
194
|
return _simple_generate(model, prompt_tokens, max_tokens, temperature)
|
|
195
195
|
|
|
196
196
|
prompt_list = prompt_tokens.tolist()
|
|
197
|
-
|
|
197
|
+
response_token_list: list = []
|
|
198
|
+
response_logprob_list: list = []
|
|
199
|
+
|
|
198
200
|
# Use stream_generate instead of generate to get proper EOS token handling
|
|
199
201
|
# This is the core fix - stream_generate respects EOS tokens, generate() does not
|
|
200
202
|
try:
|
|
@@ -216,19 +218,20 @@ def generate_tokens(
|
|
|
216
218
|
logits_processors=logits_processors,
|
|
217
219
|
))
|
|
218
220
|
|
|
219
|
-
# Extract tokens from response segments
|
|
220
|
-
response_token_list = []
|
|
221
|
-
|
|
221
|
+
# Extract tokens and logprobs from response segments
|
|
222
222
|
for segment in response_segments:
|
|
223
223
|
response_token_list.append(segment.token)
|
|
224
|
+
# Capture per-token logprob inline to avoid a redundant forward pass
|
|
225
|
+
if segment.logprobs is not None:
|
|
226
|
+
response_logprob_list.append(float(segment.logprobs[segment.token]))
|
|
224
227
|
# Check if this segment indicates natural stopping (EOS token)
|
|
225
228
|
if hasattr(segment, 'finish_reason') and segment.finish_reason == "stop":
|
|
226
229
|
break
|
|
227
|
-
|
|
230
|
+
|
|
228
231
|
# Convert to MLX array
|
|
229
232
|
response_tokens = mx.array(response_token_list) if response_token_list else mx.array([])
|
|
230
|
-
|
|
231
|
-
|
|
233
|
+
|
|
234
|
+
|
|
232
235
|
except ImportError:
|
|
233
236
|
# Fallback to original generate method if stream_generate unavailable
|
|
234
237
|
print("WARNING: stream_generate not available, using fallback generate method")
|
|
@@ -248,8 +251,12 @@ def generate_tokens(
|
|
|
248
251
|
|
|
249
252
|
response_tokens = _extract_response_tokens(response, prompt_list, tokenizer)
|
|
250
253
|
|
|
251
|
-
#
|
|
252
|
-
|
|
254
|
+
# Use inline logprobs captured during generation when available,
|
|
255
|
+
# falling back to a full forward pass only if logprobs were missing.
|
|
256
|
+
if response_logprob_list and len(response_logprob_list) == len(response_token_list):
|
|
257
|
+
logprobs = mx.array(response_logprob_list)
|
|
258
|
+
else:
|
|
259
|
+
logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
|
|
253
260
|
return response_tokens, {'logprob': logprobs}
|
|
254
261
|
|
|
255
262
|
|
|
@@ -292,35 +299,43 @@ def _simple_generate(
|
|
|
292
299
|
"""
|
|
293
300
|
current_tokens = prompt_tokens
|
|
294
301
|
generated = []
|
|
295
|
-
|
|
302
|
+
generated_logprobs = []
|
|
303
|
+
|
|
296
304
|
for _ in range(max_tokens):
|
|
297
305
|
# Model forward pass
|
|
298
306
|
logits = model(current_tokens[None]) # Add batch dimension
|
|
299
307
|
next_token_logits = logits[0, -1, :] # Last token logits
|
|
300
|
-
|
|
308
|
+
|
|
301
309
|
# Temperature scaling
|
|
302
310
|
if temperature > 0:
|
|
303
311
|
scaled_logits = next_token_logits / temperature
|
|
304
312
|
else:
|
|
305
313
|
scaled_logits = next_token_logits
|
|
306
|
-
|
|
314
|
+
|
|
307
315
|
# Sample next token
|
|
308
316
|
probs = mx.softmax(scaled_logits)
|
|
309
317
|
next_token = mx.random.categorical(probs[None])[0]
|
|
310
|
-
|
|
318
|
+
|
|
319
|
+
# Capture logprob inline: log_softmax of the *unscaled* logits at the selected token
|
|
320
|
+
log_probs = next_token_logits - mx.logsumexp(next_token_logits)
|
|
321
|
+
generated_logprobs.append(float(log_probs[next_token]))
|
|
322
|
+
|
|
311
323
|
# Add to sequence
|
|
312
324
|
generated.append(next_token)
|
|
313
325
|
current_tokens = mx.concatenate([current_tokens, next_token[None]])
|
|
314
|
-
|
|
326
|
+
|
|
315
327
|
# Stop on EOS (approximate) - avoid .item() calls
|
|
316
328
|
if len(generated) > 5 and next_token < 5: # Simple stop condition
|
|
317
329
|
break
|
|
318
|
-
|
|
330
|
+
|
|
319
331
|
response_tokens = mx.array(generated) if generated else mx.array([2])
|
|
320
|
-
|
|
321
|
-
#
|
|
322
|
-
|
|
323
|
-
|
|
332
|
+
|
|
333
|
+
# Use inline logprobs captured during generation (avoids redundant forward pass)
|
|
334
|
+
if generated_logprobs and len(generated_logprobs) == len(generated):
|
|
335
|
+
logprobs = mx.array(generated_logprobs)
|
|
336
|
+
else:
|
|
337
|
+
logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
|
|
338
|
+
|
|
324
339
|
return response_tokens, {'logprob': logprobs}
|
|
325
340
|
|
|
326
341
|
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Countdown Numbers Game task for TextPolicy.
|
|
3
|
+
|
|
4
|
+
Importing this module registers the 'countdown' reward function.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .evaluator import ExpressionError, EvalResult, evaluate_expression
|
|
8
|
+
from .prompt import format_countdown_prompt, extract_expression_from_completion
|
|
9
|
+
from .reward import countdown_reward
|
|
10
|
+
from .dataset import generate_countdown_problems, load_countdown_dataset
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ExpressionError",
|
|
14
|
+
"EvalResult",
|
|
15
|
+
"evaluate_expression",
|
|
16
|
+
"format_countdown_prompt",
|
|
17
|
+
"extract_expression_from_completion",
|
|
18
|
+
"countdown_reward",
|
|
19
|
+
"generate_countdown_problems",
|
|
20
|
+
"load_countdown_dataset",
|
|
21
|
+
]
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Problem generation and HuggingFace dataset loading for the Countdown task.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import ast
|
|
6
|
+
import itertools
|
|
7
|
+
import random
|
|
8
|
+
from typing import Dict, List, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def generate_countdown_problems(
|
|
12
|
+
num_problems: int,
|
|
13
|
+
num_numbers: int = 4,
|
|
14
|
+
number_range: Tuple[int, int] = (1, 25),
|
|
15
|
+
target_range: Tuple[int, int] = (10, 100),
|
|
16
|
+
ensure_solvable: bool = True,
|
|
17
|
+
seed: Optional[int] = None,
|
|
18
|
+
max_attempts: Optional[int] = None,
|
|
19
|
+
) -> List[Dict]:
|
|
20
|
+
"""
|
|
21
|
+
Generate Countdown Numbers Game problems.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
num_problems: Number of problems to generate.
|
|
25
|
+
num_numbers: How many numbers per problem (3 or 4 recommended).
|
|
26
|
+
number_range: (min, max) inclusive range for available numbers.
|
|
27
|
+
target_range: (min, max) inclusive range for target.
|
|
28
|
+
ensure_solvable: If True, only return problems with at least one solution.
|
|
29
|
+
seed: Random seed for reproducibility.
|
|
30
|
+
max_attempts: Maximum number of candidate problems to try before stopping.
|
|
31
|
+
Defaults to num_problems * 100 when ensure_solvable is True.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
List of dicts with keys 'target' and 'numbers'.
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
RuntimeError: If max_attempts is exhausted before generating enough problems.
|
|
38
|
+
"""
|
|
39
|
+
rng = random.Random(seed)
|
|
40
|
+
problems = []
|
|
41
|
+
|
|
42
|
+
if max_attempts is None:
|
|
43
|
+
max_attempts = num_problems * 100 if ensure_solvable else num_problems
|
|
44
|
+
|
|
45
|
+
attempts = 0
|
|
46
|
+
while len(problems) < num_problems:
|
|
47
|
+
if attempts >= max_attempts:
|
|
48
|
+
raise RuntimeError(
|
|
49
|
+
f"Could not generate {num_problems} problems within "
|
|
50
|
+
f"{max_attempts} attempts (got {len(problems)}). "
|
|
51
|
+
f"Try wider number_range/target_range or increase max_attempts."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
numbers = [rng.randint(*number_range) for _ in range(num_numbers)]
|
|
55
|
+
target = rng.randint(*target_range)
|
|
56
|
+
attempts += 1
|
|
57
|
+
|
|
58
|
+
if ensure_solvable and not _is_solvable(numbers, target):
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
problems.append({"target": target, "numbers": numbers})
|
|
62
|
+
|
|
63
|
+
return problems
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def load_countdown_dataset(
|
|
67
|
+
split: str = "train",
|
|
68
|
+
max_examples: Optional[int] = None,
|
|
69
|
+
) -> List[Dict]:
|
|
70
|
+
"""
|
|
71
|
+
Load the Countdown task dataset from HuggingFace.
|
|
72
|
+
|
|
73
|
+
Requires the `datasets` library (optional dependency).
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
split: Dataset split to load ('train', 'test', etc.).
|
|
77
|
+
max_examples: Maximum number of examples to return.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
List of dicts with keys 'target' and 'numbers'.
|
|
81
|
+
"""
|
|
82
|
+
try:
|
|
83
|
+
from datasets import load_dataset
|
|
84
|
+
except ImportError:
|
|
85
|
+
raise ImportError(
|
|
86
|
+
"The 'datasets' library is required to load HuggingFace datasets. "
|
|
87
|
+
"Install it with: pip install datasets"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
ds = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split=split)
|
|
91
|
+
|
|
92
|
+
examples = []
|
|
93
|
+
for item in ds:
|
|
94
|
+
target = item.get("target")
|
|
95
|
+
numbers = item.get("nums") or item.get("numbers")
|
|
96
|
+
if target is not None and numbers is not None:
|
|
97
|
+
if isinstance(numbers, str):
|
|
98
|
+
numbers = ast.literal_eval(numbers)
|
|
99
|
+
examples.append({"target": int(target), "numbers": list(numbers)})
|
|
100
|
+
if max_examples is not None and len(examples) >= max_examples:
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
return examples
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# ---------------------------------------------------------------------------
|
|
107
|
+
# Brute-force solvability check
|
|
108
|
+
# ---------------------------------------------------------------------------
|
|
109
|
+
|
|
110
|
+
# Commutative ops: only need (a, b), not (b, a)
|
|
111
|
+
_COMMUTATIVE_OPS = [
|
|
112
|
+
lambda a, b: a + b,
|
|
113
|
+
lambda a, b: a * b,
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
# Non-commutative ops: must try both orderings
|
|
117
|
+
_NON_COMMUTATIVE_OPS = [
|
|
118
|
+
lambda a, b: a - b,
|
|
119
|
+
lambda a, b: a / b if b != 0 else None,
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _is_solvable(numbers: List[int], target: int) -> bool:
|
|
124
|
+
"""Check if target is reachable using any subset and arrangement of numbers."""
|
|
125
|
+
return _solve(list(map(float, numbers)), float(target))
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _solve(nums: List[float], target: float) -> bool:
|
|
129
|
+
"""Recursively try all pairs of numbers with all operations.
|
|
130
|
+
|
|
131
|
+
Allows using a subset of numbers — if any single number in the
|
|
132
|
+
current list already equals the target, that counts as solved.
|
|
133
|
+
|
|
134
|
+
Uses combinations (not permutations) for pair selection and only
|
|
135
|
+
tries both orderings for non-commutative operations (-, /).
|
|
136
|
+
"""
|
|
137
|
+
# Any number in the current set already equals the target → solvable
|
|
138
|
+
for n in nums:
|
|
139
|
+
if abs(n - target) < 1e-9:
|
|
140
|
+
return True
|
|
141
|
+
|
|
142
|
+
if len(nums) < 2:
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
for i in range(len(nums)):
|
|
146
|
+
for j in range(i + 1, len(nums)):
|
|
147
|
+
a, b = nums[i], nums[j]
|
|
148
|
+
remaining = [nums[k] for k in range(len(nums)) if k != i and k != j]
|
|
149
|
+
|
|
150
|
+
# Commutative: a+b == b+a, a*b == b*a — one ordering suffices
|
|
151
|
+
for op in _COMMUTATIVE_OPS:
|
|
152
|
+
result = op(a, b)
|
|
153
|
+
if _solve(remaining + [result], target):
|
|
154
|
+
return True
|
|
155
|
+
|
|
156
|
+
# Non-commutative: try both (a,b) and (b,a)
|
|
157
|
+
for op in _NON_COMMUTATIVE_OPS:
|
|
158
|
+
for x, y in ((a, b), (b, a)):
|
|
159
|
+
result = op(x, y)
|
|
160
|
+
if result is not None and _solve(remaining + [result], target):
|
|
161
|
+
return True
|
|
162
|
+
|
|
163
|
+
return False
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Safe arithmetic expression evaluator using recursive descent parsing.
|
|
3
|
+
|
|
4
|
+
No eval(), no ast module. Handles +, -, *, /, parentheses, integers only.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import List, Optional
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ExpressionError(Exception):
|
|
13
|
+
"""Raised when an expression is invalid or cannot be evaluated."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class EvalResult:
|
|
19
|
+
"""Result of evaluating an arithmetic expression."""
|
|
20
|
+
value: float
|
|
21
|
+
numbers_used: List[int] = field(default_factory=list)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Allowed characters in expressions
|
|
25
|
+
_ALLOWED_CHARS = re.compile(r'^[0-9\s+\-*/()]+$')
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def evaluate_expression(
|
|
29
|
+
expression: str, available_numbers: Optional[List[int]] = None
|
|
30
|
+
) -> EvalResult:
|
|
31
|
+
"""
|
|
32
|
+
Safely evaluate an arithmetic expression using recursive descent parsing.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
expression: Arithmetic expression string (e.g. "(2+3)*4")
|
|
36
|
+
available_numbers: If provided, validates that only these numbers are used
|
|
37
|
+
(each at most once).
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
EvalResult with the computed value and list of numbers used.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ExpressionError: On syntax errors, division by zero, disallowed chars,
|
|
44
|
+
or number reuse/unavailability.
|
|
45
|
+
"""
|
|
46
|
+
if not expression or not expression.strip():
|
|
47
|
+
raise ExpressionError("Empty expression")
|
|
48
|
+
|
|
49
|
+
expr = expression.strip()
|
|
50
|
+
|
|
51
|
+
if not _ALLOWED_CHARS.match(expr):
|
|
52
|
+
raise ExpressionError(
|
|
53
|
+
f"Expression contains disallowed characters: {expr!r}"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
tokens = _tokenize(expr)
|
|
57
|
+
if not tokens:
|
|
58
|
+
raise ExpressionError("Empty expression after tokenization")
|
|
59
|
+
|
|
60
|
+
parser = _Parser(tokens)
|
|
61
|
+
value = parser.parse_expression()
|
|
62
|
+
|
|
63
|
+
if parser.pos < len(parser.tokens):
|
|
64
|
+
raise ExpressionError(
|
|
65
|
+
f"Unexpected token after end of expression: "
|
|
66
|
+
f"{parser.tokens[parser.pos]!r}"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
numbers_used = parser.numbers_used
|
|
70
|
+
|
|
71
|
+
if available_numbers is not None:
|
|
72
|
+
_validate_numbers(numbers_used, available_numbers)
|
|
73
|
+
|
|
74
|
+
return EvalResult(value=value, numbers_used=numbers_used)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _tokenize(expr: str) -> List[str]:
|
|
78
|
+
"""Tokenize an expression into numbers, operators, and parentheses."""
|
|
79
|
+
tokens = []
|
|
80
|
+
i = 0
|
|
81
|
+
while i < len(expr):
|
|
82
|
+
ch = expr[i]
|
|
83
|
+
if ch.isspace():
|
|
84
|
+
i += 1
|
|
85
|
+
continue
|
|
86
|
+
if ch.isdigit():
|
|
87
|
+
j = i
|
|
88
|
+
while j < len(expr) and expr[j].isdigit():
|
|
89
|
+
j += 1
|
|
90
|
+
tokens.append(expr[i:j])
|
|
91
|
+
i = j
|
|
92
|
+
elif ch in '+-*/()':
|
|
93
|
+
tokens.append(ch)
|
|
94
|
+
i += 1
|
|
95
|
+
else:
|
|
96
|
+
raise ExpressionError(f"Unexpected character: {ch!r}")
|
|
97
|
+
return tokens
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class _Parser:
|
|
101
|
+
"""Recursive descent parser for arithmetic expressions.
|
|
102
|
+
|
|
103
|
+
Grammar:
|
|
104
|
+
expression := term (('+' | '-') term)*
|
|
105
|
+
term := factor (('*' | '/') factor)*
|
|
106
|
+
factor := NUMBER | '(' expression ')' | ('+' | '-') factor
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(self, tokens: List[str]):
|
|
110
|
+
self.tokens = tokens
|
|
111
|
+
self.pos = 0
|
|
112
|
+
self.numbers_used: List[int] = []
|
|
113
|
+
|
|
114
|
+
def _peek(self) -> Optional[str]:
|
|
115
|
+
if self.pos < len(self.tokens):
|
|
116
|
+
return self.tokens[self.pos]
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
def _consume(self) -> str:
|
|
120
|
+
token = self.tokens[self.pos]
|
|
121
|
+
self.pos += 1
|
|
122
|
+
return token
|
|
123
|
+
|
|
124
|
+
def parse_expression(self) -> float:
|
|
125
|
+
"""Parse an expression: term (('+' | '-') term)*"""
|
|
126
|
+
result = self._parse_term()
|
|
127
|
+
while self._peek() in ('+', '-'):
|
|
128
|
+
op = self._consume()
|
|
129
|
+
right = self._parse_term()
|
|
130
|
+
if op == '+':
|
|
131
|
+
result = result + right
|
|
132
|
+
else:
|
|
133
|
+
result = result - right
|
|
134
|
+
return result
|
|
135
|
+
|
|
136
|
+
def _parse_term(self) -> float:
|
|
137
|
+
"""Parse a term: factor (('*' | '/') factor)*"""
|
|
138
|
+
result = self._parse_factor()
|
|
139
|
+
while self._peek() in ('*', '/'):
|
|
140
|
+
op = self._consume()
|
|
141
|
+
right = self._parse_factor()
|
|
142
|
+
if op == '*':
|
|
143
|
+
result = result * right
|
|
144
|
+
else:
|
|
145
|
+
if right == 0:
|
|
146
|
+
raise ExpressionError("Division by zero")
|
|
147
|
+
result = result / right
|
|
148
|
+
return result
|
|
149
|
+
|
|
150
|
+
def _parse_factor(self) -> float:
|
|
151
|
+
"""Parse a factor: NUMBER | '(' expression ')' | unary +/-"""
|
|
152
|
+
token = self._peek()
|
|
153
|
+
if token is None:
|
|
154
|
+
raise ExpressionError("Unexpected end of expression")
|
|
155
|
+
|
|
156
|
+
# Unary plus/minus
|
|
157
|
+
if token in ('+', '-'):
|
|
158
|
+
op = self._consume()
|
|
159
|
+
value = self._parse_factor()
|
|
160
|
+
return value if op == '+' else -value
|
|
161
|
+
|
|
162
|
+
# Parenthesized expression
|
|
163
|
+
if token == '(':
|
|
164
|
+
self._consume() # eat '('
|
|
165
|
+
value = self.parse_expression()
|
|
166
|
+
if self._peek() != ')':
|
|
167
|
+
raise ExpressionError("Unmatched opening parenthesis")
|
|
168
|
+
self._consume() # eat ')'
|
|
169
|
+
return value
|
|
170
|
+
|
|
171
|
+
if token == ')':
|
|
172
|
+
raise ExpressionError("Unmatched closing parenthesis")
|
|
173
|
+
|
|
174
|
+
# Number — _tokenize only produces all-digit tokens, so isdigit()
|
|
175
|
+
# is sufficient. Avoid masking malformed tokens with a looser check.
|
|
176
|
+
if token.isdigit():
|
|
177
|
+
self._consume()
|
|
178
|
+
num = int(token)
|
|
179
|
+
self.numbers_used.append(num)
|
|
180
|
+
return float(num)
|
|
181
|
+
|
|
182
|
+
raise ExpressionError(f"Unexpected token: {token!r}")
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _validate_numbers(
|
|
186
|
+
numbers_used: List[int], available_numbers: List[int]
|
|
187
|
+
) -> None:
|
|
188
|
+
"""Validate that numbers_used is a valid subset of available_numbers."""
|
|
189
|
+
available_copy = list(available_numbers)
|
|
190
|
+
for num in numbers_used:
|
|
191
|
+
if num in available_copy:
|
|
192
|
+
available_copy.remove(num)
|
|
193
|
+
else:
|
|
194
|
+
raise ExpressionError(
|
|
195
|
+
f"Number {num} is not available or has been used too many times. "
|
|
196
|
+
f"Available: {available_numbers}"
|
|
197
|
+
)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Prompt formatting and expression extraction for the Countdown task.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def format_countdown_prompt(target: int, numbers: List[int]) -> str:
|
|
10
|
+
"""
|
|
11
|
+
Format a Countdown task prompt.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
target: The target number to reach.
|
|
15
|
+
numbers: The available numbers to use.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
A formatted prompt string.
|
|
19
|
+
"""
|
|
20
|
+
return (
|
|
21
|
+
f"Using the numbers {numbers}, create an arithmetic expression "
|
|
22
|
+
f"that equals {target}. You may use each number at most once. "
|
|
23
|
+
f"Use only +, -, *, / and parentheses. "
|
|
24
|
+
f"Provide your expression on its own line."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Pattern matching pure arithmetic expressions (digits, operators, parens, spaces)
|
|
29
|
+
_EXPR_PATTERN = re.compile(r'^[\d\s+\-*/()]+$')
|
|
30
|
+
|
|
31
|
+
# Pattern for lines with delimiters like "= ...", ": ...", "answer ..."
|
|
32
|
+
_DELIMITER_PATTERN = re.compile(
|
|
33
|
+
r'(?:=|:|answer\s*(?:is|:)?)\s*([\d\s+\-*/()]+)',
|
|
34
|
+
re.IGNORECASE,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Find longest arithmetic-like substring
|
|
38
|
+
_ARITH_SUBSTRING = re.compile(r'[\d\s+\-*/()]+')
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def extract_expression_from_completion(completion: str) -> str:
|
|
42
|
+
"""
|
|
43
|
+
Extract an arithmetic expression from model output.
|
|
44
|
+
|
|
45
|
+
Uses fallback strategies:
|
|
46
|
+
1. Lines that are pure arithmetic expressions
|
|
47
|
+
2. Text after =, :, or 'answer' delimiters
|
|
48
|
+
3. Longest arithmetic-like substring
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
completion: The raw model output.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The extracted expression string (may still be invalid).
|
|
55
|
+
"""
|
|
56
|
+
if not completion or not completion.strip():
|
|
57
|
+
return ""
|
|
58
|
+
|
|
59
|
+
text = completion.strip()
|
|
60
|
+
|
|
61
|
+
# Strategy 1: Find lines that are pure arithmetic expressions
|
|
62
|
+
for line in text.splitlines():
|
|
63
|
+
line = line.strip()
|
|
64
|
+
if line and _EXPR_PATTERN.match(line) and _has_digit_and_operator(line):
|
|
65
|
+
return line
|
|
66
|
+
|
|
67
|
+
# Strategy 2: Look for delimiters
|
|
68
|
+
match = _DELIMITER_PATTERN.search(text)
|
|
69
|
+
if match:
|
|
70
|
+
candidate = match.group(1).strip()
|
|
71
|
+
if candidate and _has_digit_and_operator(candidate):
|
|
72
|
+
return candidate
|
|
73
|
+
|
|
74
|
+
# Strategy 3: Longest arithmetic-like substring containing at least
|
|
75
|
+
# one digit and one operator
|
|
76
|
+
candidates = _ARITH_SUBSTRING.findall(text)
|
|
77
|
+
valid = [c.strip() for c in candidates if _has_digit_and_operator(c.strip())]
|
|
78
|
+
if valid:
|
|
79
|
+
return max(valid, key=len)
|
|
80
|
+
|
|
81
|
+
# Last resort: return the whole text stripped
|
|
82
|
+
return text
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _has_digit_and_operator(s: str) -> bool:
|
|
86
|
+
"""Check if string has at least one digit and one operator."""
|
|
87
|
+
has_digit = any(c.isdigit() for c in s)
|
|
88
|
+
has_op = any(c in '+-*/' for c in s)
|
|
89
|
+
return has_digit and has_op
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Countdown reward function for GRPO training.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, Dict
|
|
7
|
+
|
|
8
|
+
from textpolicy.rewards.registry import reward
|
|
9
|
+
from .evaluator import ExpressionError, evaluate_expression
|
|
10
|
+
from .prompt import extract_expression_from_completion
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@reward(name="countdown")
|
|
16
|
+
def countdown_reward(
|
|
17
|
+
prompt: str,
|
|
18
|
+
completion: str,
|
|
19
|
+
example: Dict[str, Any],
|
|
20
|
+
**kwargs,
|
|
21
|
+
) -> float:
|
|
22
|
+
"""
|
|
23
|
+
Reward function for the Countdown Numbers Game.
|
|
24
|
+
|
|
25
|
+
Scoring:
|
|
26
|
+
1.0 — expression equals target with valid numbers
|
|
27
|
+
0.0 — evaluates but wrong answer, or malformed example
|
|
28
|
+
-0.5 — syntax error, empty, unparseable, number reuse, or invalid numbers
|
|
29
|
+
|
|
30
|
+
The example dict must contain 'target' (int) and 'numbers' (list of int).
|
|
31
|
+
"""
|
|
32
|
+
# Extract task parameters
|
|
33
|
+
target = example.get("target")
|
|
34
|
+
numbers = example.get("numbers")
|
|
35
|
+
|
|
36
|
+
if target is None or numbers is None:
|
|
37
|
+
logger.warning("Malformed example: missing 'target' or 'numbers'")
|
|
38
|
+
return 0.0
|
|
39
|
+
|
|
40
|
+
# Extract expression from completion
|
|
41
|
+
expression = extract_expression_from_completion(completion)
|
|
42
|
+
if not expression:
|
|
43
|
+
return -0.5
|
|
44
|
+
|
|
45
|
+
# Evaluate
|
|
46
|
+
try:
|
|
47
|
+
result = evaluate_expression(expression, available_numbers=numbers)
|
|
48
|
+
except ExpressionError as e:
|
|
49
|
+
logger.debug(f"Expression error: {e}")
|
|
50
|
+
return -0.5
|
|
51
|
+
|
|
52
|
+
# Check if result matches target (use tolerance for float comparison)
|
|
53
|
+
if abs(result.value - target) < 1e-9:
|
|
54
|
+
return 1.0
|
|
55
|
+
|
|
56
|
+
return 0.0
|