textpolicy 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,8 +27,8 @@ try:
27
27
  from mlx_lm.sample_utils import make_sampler, make_logits_processors
28
28
  # sampling utilities fallback when sample_utils is unavailable
29
29
  except ImportError:
30
- _make_sampler = None
31
- _make_logits_processors = None
30
+ make_sampler = None
31
+ make_logits_processors = None
32
32
 
33
33
 
34
34
  def _get_eos_configs_for_model(
@@ -194,7 +194,9 @@ def generate_tokens(
194
194
  return _simple_generate(model, prompt_tokens, max_tokens, temperature)
195
195
 
196
196
  prompt_list = prompt_tokens.tolist()
197
-
197
+ response_token_list: list = []
198
+ response_logprob_list: list = []
199
+
198
200
  # Use stream_generate instead of generate to get proper EOS token handling
199
201
  # This is the core fix - stream_generate respects EOS tokens, generate() does not
200
202
  try:
@@ -216,19 +218,20 @@ def generate_tokens(
216
218
  logits_processors=logits_processors,
217
219
  ))
218
220
 
219
- # Extract tokens from response segments and detect natural EOS stopping
220
- response_token_list = []
221
-
221
+ # Extract tokens and logprobs from response segments
222
222
  for segment in response_segments:
223
223
  response_token_list.append(segment.token)
224
+ # Capture per-token logprob inline to avoid a redundant forward pass
225
+ if segment.logprobs is not None:
226
+ response_logprob_list.append(float(segment.logprobs[segment.token]))
224
227
  # Check if this segment indicates natural stopping (EOS token)
225
228
  if hasattr(segment, 'finish_reason') and segment.finish_reason == "stop":
226
229
  break
227
-
230
+
228
231
  # Convert to MLX array
229
232
  response_tokens = mx.array(response_token_list) if response_token_list else mx.array([])
230
-
231
-
233
+
234
+
232
235
  except ImportError:
233
236
  # Fallback to original generate method if stream_generate unavailable
234
237
  print("WARNING: stream_generate not available, using fallback generate method")
@@ -248,8 +251,12 @@ def generate_tokens(
248
251
 
249
252
  response_tokens = _extract_response_tokens(response, prompt_list, tokenizer)
250
253
 
251
- # Compute logprobs for the response tokens
252
- logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
254
+ # Use inline logprobs captured during generation when available,
255
+ # falling back to a full forward pass only if logprobs were missing.
256
+ if response_logprob_list and len(response_logprob_list) == len(response_token_list):
257
+ logprobs = mx.array(response_logprob_list)
258
+ else:
259
+ logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
253
260
  return response_tokens, {'logprob': logprobs}
254
261
 
255
262
 
@@ -292,35 +299,43 @@ def _simple_generate(
292
299
  """
293
300
  current_tokens = prompt_tokens
294
301
  generated = []
295
-
302
+ generated_logprobs = []
303
+
296
304
  for _ in range(max_tokens):
297
305
  # Model forward pass
298
306
  logits = model(current_tokens[None]) # Add batch dimension
299
307
  next_token_logits = logits[0, -1, :] # Last token logits
300
-
308
+
301
309
  # Temperature scaling
302
310
  if temperature > 0:
303
311
  scaled_logits = next_token_logits / temperature
304
312
  else:
305
313
  scaled_logits = next_token_logits
306
-
314
+
307
315
  # Sample next token
308
316
  probs = mx.softmax(scaled_logits)
309
317
  next_token = mx.random.categorical(probs[None])[0]
310
-
318
+
319
+ # Capture logprob inline: log_softmax of the *unscaled* logits at the selected token
320
+ log_probs = next_token_logits - mx.logsumexp(next_token_logits)
321
+ generated_logprobs.append(float(log_probs[next_token]))
322
+
311
323
  # Add to sequence
312
324
  generated.append(next_token)
313
325
  current_tokens = mx.concatenate([current_tokens, next_token[None]])
314
-
326
+
315
327
  # Stop on EOS (approximate) - avoid .item() calls
316
328
  if len(generated) > 5 and next_token < 5: # Simple stop condition
317
329
  break
318
-
330
+
319
331
  response_tokens = mx.array(generated) if generated else mx.array([2])
320
-
321
- # Compute simple logprobs
322
- logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
323
-
332
+
333
+ # Use inline logprobs captured during generation (avoids redundant forward pass)
334
+ if generated_logprobs and len(generated_logprobs) == len(generated):
335
+ logprobs = mx.array(generated_logprobs)
336
+ else:
337
+ logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
338
+
324
339
  return response_tokens, {'logprob': logprobs}
325
340
 
326
341
 
@@ -0,0 +1,7 @@
1
+ """
2
+ Task implementations for TextPolicy.
3
+
4
+ Importing this package triggers auto-registration of task reward functions.
5
+ """
6
+
7
+ from . import countdown # noqa: F401 — triggers @reward registration
@@ -0,0 +1,21 @@
1
+ """
2
+ Countdown Numbers Game task for TextPolicy.
3
+
4
+ Importing this module registers the 'countdown' reward function.
5
+ """
6
+
7
+ from .evaluator import ExpressionError, EvalResult, evaluate_expression
8
+ from .prompt import format_countdown_prompt, extract_expression_from_completion
9
+ from .reward import countdown_reward
10
+ from .dataset import generate_countdown_problems, load_countdown_dataset
11
+
12
+ __all__ = [
13
+ "ExpressionError",
14
+ "EvalResult",
15
+ "evaluate_expression",
16
+ "format_countdown_prompt",
17
+ "extract_expression_from_completion",
18
+ "countdown_reward",
19
+ "generate_countdown_problems",
20
+ "load_countdown_dataset",
21
+ ]
@@ -0,0 +1,163 @@
1
+ """
2
+ Problem generation and HuggingFace dataset loading for the Countdown task.
3
+ """
4
+
5
+ import ast
6
+ import itertools
7
+ import random
8
+ from typing import Dict, List, Optional, Tuple
9
+
10
+
11
+ def generate_countdown_problems(
12
+ num_problems: int,
13
+ num_numbers: int = 4,
14
+ number_range: Tuple[int, int] = (1, 25),
15
+ target_range: Tuple[int, int] = (10, 100),
16
+ ensure_solvable: bool = True,
17
+ seed: Optional[int] = None,
18
+ max_attempts: Optional[int] = None,
19
+ ) -> List[Dict]:
20
+ """
21
+ Generate Countdown Numbers Game problems.
22
+
23
+ Args:
24
+ num_problems: Number of problems to generate.
25
+ num_numbers: How many numbers per problem (3 or 4 recommended).
26
+ number_range: (min, max) inclusive range for available numbers.
27
+ target_range: (min, max) inclusive range for target.
28
+ ensure_solvable: If True, only return problems with at least one solution.
29
+ seed: Random seed for reproducibility.
30
+ max_attempts: Maximum number of candidate problems to try before stopping.
31
+ Defaults to num_problems * 100 when ensure_solvable is True.
32
+
33
+ Returns:
34
+ List of dicts with keys 'target' and 'numbers'.
35
+
36
+ Raises:
37
+ RuntimeError: If max_attempts is exhausted before generating enough problems.
38
+ """
39
+ rng = random.Random(seed)
40
+ problems = []
41
+
42
+ if max_attempts is None:
43
+ max_attempts = num_problems * 100 if ensure_solvable else num_problems
44
+
45
+ attempts = 0
46
+ while len(problems) < num_problems:
47
+ if attempts >= max_attempts:
48
+ raise RuntimeError(
49
+ f"Could not generate {num_problems} problems within "
50
+ f"{max_attempts} attempts (got {len(problems)}). "
51
+ f"Try wider number_range/target_range or increase max_attempts."
52
+ )
53
+
54
+ numbers = [rng.randint(*number_range) for _ in range(num_numbers)]
55
+ target = rng.randint(*target_range)
56
+ attempts += 1
57
+
58
+ if ensure_solvable and not _is_solvable(numbers, target):
59
+ continue
60
+
61
+ problems.append({"target": target, "numbers": numbers})
62
+
63
+ return problems
64
+
65
+
66
+ def load_countdown_dataset(
67
+ split: str = "train",
68
+ max_examples: Optional[int] = None,
69
+ ) -> List[Dict]:
70
+ """
71
+ Load the Countdown task dataset from HuggingFace.
72
+
73
+ Requires the `datasets` library (optional dependency).
74
+
75
+ Args:
76
+ split: Dataset split to load ('train', 'test', etc.).
77
+ max_examples: Maximum number of examples to return.
78
+
79
+ Returns:
80
+ List of dicts with keys 'target' and 'numbers'.
81
+ """
82
+ try:
83
+ from datasets import load_dataset
84
+ except ImportError:
85
+ raise ImportError(
86
+ "The 'datasets' library is required to load HuggingFace datasets. "
87
+ "Install it with: pip install datasets"
88
+ )
89
+
90
+ ds = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split=split)
91
+
92
+ examples = []
93
+ for item in ds:
94
+ target = item.get("target")
95
+ numbers = item.get("nums") or item.get("numbers")
96
+ if target is not None and numbers is not None:
97
+ if isinstance(numbers, str):
98
+ numbers = ast.literal_eval(numbers)
99
+ examples.append({"target": int(target), "numbers": list(numbers)})
100
+ if max_examples is not None and len(examples) >= max_examples:
101
+ break
102
+
103
+ return examples
104
+
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # Brute-force solvability check
108
+ # ---------------------------------------------------------------------------
109
+
110
+ # Commutative ops: only need (a, b), not (b, a)
111
+ _COMMUTATIVE_OPS = [
112
+ lambda a, b: a + b,
113
+ lambda a, b: a * b,
114
+ ]
115
+
116
+ # Non-commutative ops: must try both orderings
117
+ _NON_COMMUTATIVE_OPS = [
118
+ lambda a, b: a - b,
119
+ lambda a, b: a / b if b != 0 else None,
120
+ ]
121
+
122
+
123
+ def _is_solvable(numbers: List[int], target: int) -> bool:
124
+ """Check if target is reachable using any subset and arrangement of numbers."""
125
+ return _solve(list(map(float, numbers)), float(target))
126
+
127
+
128
+ def _solve(nums: List[float], target: float) -> bool:
129
+ """Recursively try all pairs of numbers with all operations.
130
+
131
+ Allows using a subset of numbers — if any single number in the
132
+ current list already equals the target, that counts as solved.
133
+
134
+ Uses combinations (not permutations) for pair selection and only
135
+ tries both orderings for non-commutative operations (-, /).
136
+ """
137
+ # Any number in the current set already equals the target → solvable
138
+ for n in nums:
139
+ if abs(n - target) < 1e-9:
140
+ return True
141
+
142
+ if len(nums) < 2:
143
+ return False
144
+
145
+ for i in range(len(nums)):
146
+ for j in range(i + 1, len(nums)):
147
+ a, b = nums[i], nums[j]
148
+ remaining = [nums[k] for k in range(len(nums)) if k != i and k != j]
149
+
150
+ # Commutative: a+b == b+a, a*b == b*a — one ordering suffices
151
+ for op in _COMMUTATIVE_OPS:
152
+ result = op(a, b)
153
+ if _solve(remaining + [result], target):
154
+ return True
155
+
156
+ # Non-commutative: try both (a,b) and (b,a)
157
+ for op in _NON_COMMUTATIVE_OPS:
158
+ for x, y in ((a, b), (b, a)):
159
+ result = op(x, y)
160
+ if result is not None and _solve(remaining + [result], target):
161
+ return True
162
+
163
+ return False
@@ -0,0 +1,197 @@
1
+ """
2
+ Safe arithmetic expression evaluator using recursive descent parsing.
3
+
4
+ No eval(), no ast module. Handles +, -, *, /, parentheses, integers only.
5
+ """
6
+
7
+ import re
8
+ from dataclasses import dataclass, field
9
+ from typing import List, Optional
10
+
11
+
12
+ class ExpressionError(Exception):
13
+ """Raised when an expression is invalid or cannot be evaluated."""
14
+ pass
15
+
16
+
17
+ @dataclass
18
+ class EvalResult:
19
+ """Result of evaluating an arithmetic expression."""
20
+ value: float
21
+ numbers_used: List[int] = field(default_factory=list)
22
+
23
+
24
+ # Allowed characters in expressions
25
+ _ALLOWED_CHARS = re.compile(r'^[0-9\s+\-*/()]+$')
26
+
27
+
28
+ def evaluate_expression(
29
+ expression: str, available_numbers: Optional[List[int]] = None
30
+ ) -> EvalResult:
31
+ """
32
+ Safely evaluate an arithmetic expression using recursive descent parsing.
33
+
34
+ Args:
35
+ expression: Arithmetic expression string (e.g. "(2+3)*4")
36
+ available_numbers: If provided, validates that only these numbers are used
37
+ (each at most once).
38
+
39
+ Returns:
40
+ EvalResult with the computed value and list of numbers used.
41
+
42
+ Raises:
43
+ ExpressionError: On syntax errors, division by zero, disallowed chars,
44
+ or number reuse/unavailability.
45
+ """
46
+ if not expression or not expression.strip():
47
+ raise ExpressionError("Empty expression")
48
+
49
+ expr = expression.strip()
50
+
51
+ if not _ALLOWED_CHARS.match(expr):
52
+ raise ExpressionError(
53
+ f"Expression contains disallowed characters: {expr!r}"
54
+ )
55
+
56
+ tokens = _tokenize(expr)
57
+ if not tokens:
58
+ raise ExpressionError("Empty expression after tokenization")
59
+
60
+ parser = _Parser(tokens)
61
+ value = parser.parse_expression()
62
+
63
+ if parser.pos < len(parser.tokens):
64
+ raise ExpressionError(
65
+ f"Unexpected token after end of expression: "
66
+ f"{parser.tokens[parser.pos]!r}"
67
+ )
68
+
69
+ numbers_used = parser.numbers_used
70
+
71
+ if available_numbers is not None:
72
+ _validate_numbers(numbers_used, available_numbers)
73
+
74
+ return EvalResult(value=value, numbers_used=numbers_used)
75
+
76
+
77
+ def _tokenize(expr: str) -> List[str]:
78
+ """Tokenize an expression into numbers, operators, and parentheses."""
79
+ tokens = []
80
+ i = 0
81
+ while i < len(expr):
82
+ ch = expr[i]
83
+ if ch.isspace():
84
+ i += 1
85
+ continue
86
+ if ch.isdigit():
87
+ j = i
88
+ while j < len(expr) and expr[j].isdigit():
89
+ j += 1
90
+ tokens.append(expr[i:j])
91
+ i = j
92
+ elif ch in '+-*/()':
93
+ tokens.append(ch)
94
+ i += 1
95
+ else:
96
+ raise ExpressionError(f"Unexpected character: {ch!r}")
97
+ return tokens
98
+
99
+
100
+ class _Parser:
101
+ """Recursive descent parser for arithmetic expressions.
102
+
103
+ Grammar:
104
+ expression := term (('+' | '-') term)*
105
+ term := factor (('*' | '/') factor)*
106
+ factor := NUMBER | '(' expression ')' | ('+' | '-') factor
107
+ """
108
+
109
+ def __init__(self, tokens: List[str]):
110
+ self.tokens = tokens
111
+ self.pos = 0
112
+ self.numbers_used: List[int] = []
113
+
114
+ def _peek(self) -> Optional[str]:
115
+ if self.pos < len(self.tokens):
116
+ return self.tokens[self.pos]
117
+ return None
118
+
119
+ def _consume(self) -> str:
120
+ token = self.tokens[self.pos]
121
+ self.pos += 1
122
+ return token
123
+
124
+ def parse_expression(self) -> float:
125
+ """Parse an expression: term (('+' | '-') term)*"""
126
+ result = self._parse_term()
127
+ while self._peek() in ('+', '-'):
128
+ op = self._consume()
129
+ right = self._parse_term()
130
+ if op == '+':
131
+ result = result + right
132
+ else:
133
+ result = result - right
134
+ return result
135
+
136
+ def _parse_term(self) -> float:
137
+ """Parse a term: factor (('*' | '/') factor)*"""
138
+ result = self._parse_factor()
139
+ while self._peek() in ('*', '/'):
140
+ op = self._consume()
141
+ right = self._parse_factor()
142
+ if op == '*':
143
+ result = result * right
144
+ else:
145
+ if right == 0:
146
+ raise ExpressionError("Division by zero")
147
+ result = result / right
148
+ return result
149
+
150
+ def _parse_factor(self) -> float:
151
+ """Parse a factor: NUMBER | '(' expression ')' | unary +/-"""
152
+ token = self._peek()
153
+ if token is None:
154
+ raise ExpressionError("Unexpected end of expression")
155
+
156
+ # Unary plus/minus
157
+ if token in ('+', '-'):
158
+ op = self._consume()
159
+ value = self._parse_factor()
160
+ return value if op == '+' else -value
161
+
162
+ # Parenthesized expression
163
+ if token == '(':
164
+ self._consume() # eat '('
165
+ value = self.parse_expression()
166
+ if self._peek() != ')':
167
+ raise ExpressionError("Unmatched opening parenthesis")
168
+ self._consume() # eat ')'
169
+ return value
170
+
171
+ if token == ')':
172
+ raise ExpressionError("Unmatched closing parenthesis")
173
+
174
+ # Number — _tokenize only produces all-digit tokens, so isdigit()
175
+ # is sufficient. Avoid masking malformed tokens with a looser check.
176
+ if token.isdigit():
177
+ self._consume()
178
+ num = int(token)
179
+ self.numbers_used.append(num)
180
+ return float(num)
181
+
182
+ raise ExpressionError(f"Unexpected token: {token!r}")
183
+
184
+
185
+ def _validate_numbers(
186
+ numbers_used: List[int], available_numbers: List[int]
187
+ ) -> None:
188
+ """Validate that numbers_used is a valid subset of available_numbers."""
189
+ available_copy = list(available_numbers)
190
+ for num in numbers_used:
191
+ if num in available_copy:
192
+ available_copy.remove(num)
193
+ else:
194
+ raise ExpressionError(
195
+ f"Number {num} is not available or has been used too many times. "
196
+ f"Available: {available_numbers}"
197
+ )
@@ -0,0 +1,89 @@
1
+ """
2
+ Prompt formatting and expression extraction for the Countdown task.
3
+ """
4
+
5
+ import re
6
+ from typing import List
7
+
8
+
9
+ def format_countdown_prompt(target: int, numbers: List[int]) -> str:
10
+ """
11
+ Format a Countdown task prompt.
12
+
13
+ Args:
14
+ target: The target number to reach.
15
+ numbers: The available numbers to use.
16
+
17
+ Returns:
18
+ A formatted prompt string.
19
+ """
20
+ return (
21
+ f"Using the numbers {numbers}, create an arithmetic expression "
22
+ f"that equals {target}. You may use each number at most once. "
23
+ f"Use only +, -, *, / and parentheses. "
24
+ f"Provide your expression on its own line."
25
+ )
26
+
27
+
28
+ # Pattern matching pure arithmetic expressions (digits, operators, parens, spaces)
29
+ _EXPR_PATTERN = re.compile(r'^[\d\s+\-*/()]+$')
30
+
31
+ # Pattern for lines with delimiters like "= ...", ": ...", "answer ..."
32
+ _DELIMITER_PATTERN = re.compile(
33
+ r'(?:=|:|answer\s*(?:is|:)?)\s*([\d\s+\-*/()]+)',
34
+ re.IGNORECASE,
35
+ )
36
+
37
+ # Find longest arithmetic-like substring
38
+ _ARITH_SUBSTRING = re.compile(r'[\d\s+\-*/()]+')
39
+
40
+
41
+ def extract_expression_from_completion(completion: str) -> str:
42
+ """
43
+ Extract an arithmetic expression from model output.
44
+
45
+ Uses fallback strategies:
46
+ 1. Lines that are pure arithmetic expressions
47
+ 2. Text after =, :, or 'answer' delimiters
48
+ 3. Longest arithmetic-like substring
49
+
50
+ Args:
51
+ completion: The raw model output.
52
+
53
+ Returns:
54
+ The extracted expression string (may still be invalid).
55
+ """
56
+ if not completion or not completion.strip():
57
+ return ""
58
+
59
+ text = completion.strip()
60
+
61
+ # Strategy 1: Find lines that are pure arithmetic expressions
62
+ for line in text.splitlines():
63
+ line = line.strip()
64
+ if line and _EXPR_PATTERN.match(line) and _has_digit_and_operator(line):
65
+ return line
66
+
67
+ # Strategy 2: Look for delimiters
68
+ match = _DELIMITER_PATTERN.search(text)
69
+ if match:
70
+ candidate = match.group(1).strip()
71
+ if candidate and _has_digit_and_operator(candidate):
72
+ return candidate
73
+
74
+ # Strategy 3: Longest arithmetic-like substring containing at least
75
+ # one digit and one operator
76
+ candidates = _ARITH_SUBSTRING.findall(text)
77
+ valid = [c.strip() for c in candidates if _has_digit_and_operator(c.strip())]
78
+ if valid:
79
+ return max(valid, key=len)
80
+
81
+ # Last resort: return the whole text stripped
82
+ return text
83
+
84
+
85
+ def _has_digit_and_operator(s: str) -> bool:
86
+ """Check if string has at least one digit and one operator."""
87
+ has_digit = any(c.isdigit() for c in s)
88
+ has_op = any(c in '+-*/' for c in s)
89
+ return has_digit and has_op
@@ -0,0 +1,56 @@
1
+ """
2
+ Countdown reward function for GRPO training.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any, Dict
7
+
8
+ from textpolicy.rewards.registry import reward
9
+ from .evaluator import ExpressionError, evaluate_expression
10
+ from .prompt import extract_expression_from_completion
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @reward(name="countdown")
16
+ def countdown_reward(
17
+ prompt: str,
18
+ completion: str,
19
+ example: Dict[str, Any],
20
+ **kwargs,
21
+ ) -> float:
22
+ """
23
+ Reward function for the Countdown Numbers Game.
24
+
25
+ Scoring:
26
+ 1.0 — expression equals target with valid numbers
27
+ 0.0 — evaluates but wrong answer, or malformed example
28
+ -0.5 — syntax error, empty, unparseable, number reuse, or invalid numbers
29
+
30
+ The example dict must contain 'target' (int) and 'numbers' (list of int).
31
+ """
32
+ # Extract task parameters
33
+ target = example.get("target")
34
+ numbers = example.get("numbers")
35
+
36
+ if target is None or numbers is None:
37
+ logger.warning("Malformed example: missing 'target' or 'numbers'")
38
+ return 0.0
39
+
40
+ # Extract expression from completion
41
+ expression = extract_expression_from_completion(completion)
42
+ if not expression:
43
+ return -0.5
44
+
45
+ # Evaluate
46
+ try:
47
+ result = evaluate_expression(expression, available_numbers=numbers)
48
+ except ExpressionError as e:
49
+ logger.debug(f"Expression error: {e}")
50
+ return -0.5
51
+
52
+ # Check if result matches target (use tolerance for float comparison)
53
+ if abs(result.value - target) < 1e-9:
54
+ return 1.0
55
+
56
+ return 0.0