textpolicy 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,151 @@
1
+ # textpolicy/algorithms/length_shaping.py
2
+ """
3
+ DAPO-style soft overlong penalties and length shaping utilities.
4
+
5
+ These utilities replace hard truncation with graduated penalties,
6
+ reducing training instability from length-based confusion.
7
+
8
+ References:
9
+ DAPO: An Open-Source LLM Reinforcement Learning System at Scale
10
+ https://arxiv.org/abs/2503.14476
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import List, Dict, Union
16
+
17
+ try:
18
+ import mlx.core as mx # type: ignore
19
+ except ImportError:
20
+ mx = None
21
+
22
+
23
+ def compute_length_penalty(
24
+ sequence_length: int,
25
+ max_length: int,
26
+ cache_length: int = 100,
27
+ max_penalty: float = 0.5
28
+ ) -> float:
29
+ """
30
+ Compute soft penalty for sequences approaching max length.
31
+
32
+ Instead of hard cutoffs for max sequence length (which cause truncation
33
+ that looks like failure to the model), use graduated penalties within
34
+ an interval before max_length.
35
+
36
+ This reduces training instability from length-based confusion and helps
37
+ the model learn to be concise without hard punishment.
38
+
39
+ Args:
40
+ sequence_length: Current sequence length
41
+ max_length: Maximum allowed sequence length
42
+ cache_length: Start penalizing this many tokens before max_length.
43
+ Must be positive.
44
+ max_penalty: Maximum penalty at max_length (default 0.5)
45
+
46
+ Returns:
47
+ Penalty value (0.0 for normal lengths, up to -max_penalty at max_length)
48
+
49
+ Example:
50
+ With max_length=512, cache_length=100 (threshold=412):
51
+ - length=400: penalty=0.0 (below threshold)
52
+ - length=412: penalty=0.0 (at threshold, progress=0)
53
+ - length=462: penalty=-0.25 (50/100 * 0.5)
54
+ - length=512: penalty=-0.5 (at max)
55
+
56
+ Raises:
57
+ ValueError: If cache_length <= 0
58
+ """
59
+ if cache_length <= 0:
60
+ raise ValueError(f"cache_length must be positive, got {cache_length}")
61
+
62
+ threshold = max_length - cache_length
63
+
64
+ if sequence_length < threshold:
65
+ return 0.0
66
+
67
+ # Linear penalty from 0 to max_penalty as we approach max
68
+ progress = (sequence_length - threshold) / cache_length
69
+ progress = min(1.0, progress) # Clamp at 1.0
70
+
71
+ return -max_penalty * progress
72
+
73
+
74
+ def apply_length_shaping(
75
+ rewards: "mx.array",
76
+ sequence_lengths: List[int],
77
+ max_length: int,
78
+ cache_length: int = 100,
79
+ max_penalty: float = 0.5
80
+ ) -> "mx.array":
81
+ """
82
+ Apply soft length penalties to rewards.
83
+
84
+ Modifies rewards by adding graduated penalties for sequences that
85
+ approach the maximum length. This provides a smoother learning signal
86
+ than hard truncation.
87
+
88
+ Args:
89
+ rewards: Original rewards array [batch_size]
90
+ sequence_lengths: List of sequence lengths for each episode
91
+ max_length: Maximum allowed sequence length
92
+ cache_length: Start penalizing this many tokens before max_length
93
+ max_penalty: Maximum penalty at max_length
94
+
95
+ Returns:
96
+ Rewards with length penalties applied
97
+
98
+ Example:
99
+ >>> rewards = mx.array([1.0, 0.5, 0.0])
100
+ >>> lengths = [400, 500, 520] # max_length=512, cache_length=100
101
+ >>> shaped = apply_length_shaping(rewards, lengths, 512)
102
+ >>> # shaped ≈ [1.0, 0.06, -0.5] # last one gets max penalty
103
+ """
104
+ penalties = mx.array([
105
+ compute_length_penalty(length, max_length, cache_length, max_penalty)
106
+ for length in sequence_lengths
107
+ ], dtype=mx.float32)
108
+
109
+ return rewards + penalties
110
+
111
+
112
+ def compute_length_shaping_stats(
113
+ sequence_lengths: List[int],
114
+ max_length: int,
115
+ cache_length: int = 100
116
+ ) -> Dict[str, Union[int, float]]:
117
+ """
118
+ Compute statistics about length penalties for monitoring.
119
+
120
+ Args:
121
+ sequence_lengths: List of sequence lengths
122
+ max_length: Maximum allowed sequence length
123
+ cache_length: Penalty threshold offset
124
+
125
+ Returns:
126
+ Dictionary with length penalty statistics:
127
+ - mean_length: Average sequence length
128
+ - max_length_observed: Maximum observed sequence length
129
+ - truncation_rate: Fraction of sequences at or past max_length
130
+ - penalty_zone_rate: Fraction of sequences in penalty zone
131
+ """
132
+ threshold = max_length - cache_length
133
+ total = len(sequence_lengths)
134
+
135
+ if total == 0:
136
+ return {
137
+ 'mean_length': 0.0,
138
+ 'max_length_observed': 0,
139
+ 'truncation_rate': 0.0,
140
+ 'penalty_zone_rate': 0.0,
141
+ }
142
+
143
+ truncated = sum(1 for l in sequence_lengths if l >= max_length)
144
+ in_penalty_zone = sum(1 for l in sequence_lengths if threshold <= l < max_length)
145
+
146
+ return {
147
+ 'mean_length': sum(sequence_lengths) / total,
148
+ 'max_length_observed': max(sequence_lengths),
149
+ 'truncation_rate': truncated / total,
150
+ 'penalty_zone_rate': in_penalty_zone / total,
151
+ }
@@ -0,0 +1,23 @@
1
+ # textpolicy/analysis/__init__.py
2
+ """
3
+ Post-hoc analysis tooling for TextPolicy training runs.
4
+
5
+ Main components:
6
+ - EmergenceLogger: Captures all generations during GRPO training
7
+ - PlanningPatternDetector: Configurable planning-phrase detection
8
+ - PlanningPatternConfig: Pattern configuration dataclass
9
+ - StreamingJSONLWriter: Append-only JSONL writer
10
+ - to_json_safe: MLX/numpy → JSON-native conversion
11
+ """
12
+
13
+ from .emergence_logger import EmergenceLogger
14
+ from .planning_patterns import PlanningPatternConfig, PlanningPatternDetector
15
+ from .serialization import StreamingJSONLWriter, to_json_safe
16
+
17
+ __all__ = [
18
+ "EmergenceLogger",
19
+ "PlanningPatternDetector",
20
+ "PlanningPatternConfig",
21
+ "StreamingJSONLWriter",
22
+ "to_json_safe",
23
+ ]
@@ -0,0 +1,248 @@
1
+ # textpolicy/analysis/emergence_logger.py
2
+ """
3
+ Generation logging for emergence analysis during GRPO training.
4
+
5
+ Captures every generation produced during training and writes two JSONL
6
+ streams: per-generation records (``generations.jsonl``) and per-step
7
+ aggregate statistics (``steps.jsonl``).
8
+ """
9
+
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Any, Callable, Dict, List, Optional, Union
13
+
14
+ from .planning_patterns import PlanningPatternConfig, PlanningPatternDetector
15
+ from .serialization import StreamingJSONLWriter, to_json_safe
16
+
17
+
18
+ def _flatten(value: Any) -> list:
19
+ """Flatten an MLX array, Python list, or scalar to a plain Python list."""
20
+ if hasattr(value, "tolist"):
21
+ result = value.tolist()
22
+ if isinstance(result, list):
23
+ return result
24
+ return [result]
25
+ if isinstance(value, list):
26
+ return value
27
+ return [value]
28
+
29
+
30
+ def _default_metadata_extractor(
31
+ example: Optional[dict],
32
+ reward: float,
33
+ ) -> dict:
34
+ """Extract countdown-task metadata from an example dict.
35
+
36
+ Returns ``target``, ``numbers``, and ``correctness`` (reward >= 0.99).
37
+ """
38
+ if example is None:
39
+ return {}
40
+ meta: Dict[str, Any] = {}
41
+ if "target" in example:
42
+ meta["target"] = to_json_safe(example["target"])
43
+ if "numbers" in example:
44
+ meta["numbers"] = to_json_safe(example["numbers"])
45
+ meta["correctness"] = reward >= 0.99
46
+ return meta
47
+
48
+
49
+ class EmergenceLogger:
50
+ """Logs every generation during training for post-hoc emergence analysis.
51
+
52
+ Writes two JSONL files under *output_dir*:
53
+
54
+ * ``generations.jsonl`` — one record per generation
55
+ * ``steps.jsonl`` — one record per training step (aggregated stats)
56
+
57
+ Args:
58
+ output_dir: Directory for JSONL output files (created if needed).
59
+ planning_config: Optional :class:`PlanningPatternConfig`.
60
+ metadata_extractor: Optional callable ``(example, reward) -> dict``.
61
+ Defaults to countdown-task extractor.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ output_dir: Union[str, Path],
67
+ planning_config: Optional[PlanningPatternConfig] = None,
68
+ metadata_extractor: Optional[Callable] = None,
69
+ ) -> None:
70
+ self._output_dir = Path(output_dir)
71
+ self._output_dir.mkdir(parents=True, exist_ok=True)
72
+
73
+ self._gen_writer = StreamingJSONLWriter(self._output_dir / "generations.jsonl")
74
+ self._step_writer = StreamingJSONLWriter(self._output_dir / "steps.jsonl")
75
+
76
+ self._detector = PlanningPatternDetector(planning_config)
77
+ self._extract_metadata = metadata_extractor or _default_metadata_extractor
78
+
79
+ # ------------------------------------------------------------------
80
+ # Public API
81
+ # ------------------------------------------------------------------
82
+
83
+ def log_step(
84
+ self,
85
+ step: int,
86
+ episodes: list,
87
+ tokenizer: Any,
88
+ examples: Optional[list] = None,
89
+ ) -> dict:
90
+ """Log all generations for a single training step.
91
+
92
+ Args:
93
+ step: Current training step index.
94
+ episodes: List of :class:`Episode` objects (or dicts with the
95
+ same fields: ``obs``, ``act``, ``rew``, ``logprob``).
96
+ tokenizer: Tokenizer with a ``decode`` method.
97
+ examples: Optional parallel list of example dicts (same length
98
+ as *episodes*). Used by the metadata extractor.
99
+
100
+ Returns:
101
+ Aggregated step statistics dict (also written to ``steps.jsonl``).
102
+ """
103
+ t0 = time.perf_counter()
104
+
105
+ rewards: List[float] = []
106
+ completion_lengths: List[int] = []
107
+ planning_ratios: List[float] = []
108
+ entropy_values: List[float] = []
109
+ correct_count = 0
110
+
111
+ for idx, ep in enumerate(episodes):
112
+ record = self._process_episode(
113
+ step=step,
114
+ episode=ep,
115
+ tokenizer=tokenizer,
116
+ example=examples[idx] if examples and idx < len(examples) else None,
117
+ )
118
+ self._gen_writer.write(record)
119
+
120
+ # Accumulate for step aggregate
121
+ rewards.append(record["reward"])
122
+ completion_lengths.append(len(record["tokens"]))
123
+ planning_ratios.append(record["planning_token_ratio"])
124
+ if record["entropy_per_token"]:
125
+ entropy_values.extend(record["entropy_per_token"])
126
+ if record.get("metadata", {}).get("correctness", False):
127
+ correct_count += 1
128
+
129
+ elapsed_ms = (time.perf_counter() - t0) * 1000.0
130
+ total = len(episodes)
131
+
132
+ step_record = self._build_step_record(
133
+ step=step,
134
+ rewards=rewards,
135
+ completion_lengths=completion_lengths,
136
+ planning_ratios=planning_ratios,
137
+ entropy_values=entropy_values,
138
+ correct_count=correct_count,
139
+ total_count=total,
140
+ elapsed_ms=elapsed_ms,
141
+ )
142
+ self._step_writer.write(step_record)
143
+ return step_record
144
+
145
+ def finish(self) -> None:
146
+ """Close underlying file handles."""
147
+ self._gen_writer.close()
148
+ self._step_writer.close()
149
+
150
+ # ------------------------------------------------------------------
151
+ # Internal helpers
152
+ # ------------------------------------------------------------------
153
+
154
+ def _process_episode(
155
+ self,
156
+ step: int,
157
+ episode: Any,
158
+ tokenizer: Any,
159
+ example: Optional[dict],
160
+ ) -> dict:
161
+ """Build a per-generation record from a single episode."""
162
+ # Support both Episode objects and plain dicts.
163
+ # Use isinstance check instead of `or` to avoid falsy empty-list fallthrough.
164
+ if isinstance(episode, dict):
165
+ obs = episode.get("obs", [])
166
+ act = episode.get("act", [])
167
+ rew = episode.get("rew", [])
168
+ logprob_raw = episode.get("logprob")
169
+ else:
170
+ obs = episode.obs
171
+ act = episode.act
172
+ rew = episode.rew
173
+ logprob_raw = episode.logprob
174
+
175
+ # Flatten to plain Python lists
176
+ prompt_tokens = _flatten(obs[0]) if obs else []
177
+ completion_tokens = _flatten(act[0]) if act else []
178
+ reward_val = float(_flatten(rew)[0]) if rew else 0.0
179
+
180
+ # Logprobs (may be None)
181
+ logprobs: List[float] = []
182
+ if logprob_raw is not None and len(logprob_raw) > 0:
183
+ logprobs = [float(v) for v in _flatten(logprob_raw[0])]
184
+
185
+ # Entropy proxy: -logprob per token
186
+ entropy_per_token = [-lp for lp in logprobs] if logprobs else []
187
+
188
+ # Decode text for pattern detection
189
+ prompt_text = tokenizer.decode(prompt_tokens) if prompt_tokens else ""
190
+ completion_text = tokenizer.decode(completion_tokens) if completion_tokens else ""
191
+
192
+ # Planning pattern detection
193
+ planning_phrases = self._detector.detect(completion_text)
194
+ planning_ratio = self._detector.planning_token_ratio(
195
+ completion_text, len(completion_tokens)
196
+ )
197
+
198
+ # Metadata
199
+ metadata = self._extract_metadata(example, reward_val)
200
+
201
+ return {
202
+ "step": step,
203
+ "prompt": prompt_text,
204
+ "completion": completion_text,
205
+ "reward": reward_val,
206
+ "tokens": completion_tokens,
207
+ "logprobs": logprobs,
208
+ "entropy_per_token": entropy_per_token,
209
+ "planning_phrases_found": planning_phrases,
210
+ "planning_token_ratio": planning_ratio,
211
+ "metadata": metadata,
212
+ }
213
+
214
+ @staticmethod
215
+ def _build_step_record(
216
+ step: int,
217
+ rewards: List[float],
218
+ completion_lengths: List[int],
219
+ planning_ratios: List[float],
220
+ entropy_values: List[float],
221
+ correct_count: int,
222
+ total_count: int,
223
+ elapsed_ms: float,
224
+ ) -> dict:
225
+ """Compute aggregate statistics for a training step."""
226
+ import math
227
+
228
+ def _mean(xs: list) -> float:
229
+ return sum(xs) / len(xs) if xs else 0.0
230
+
231
+ def _std(xs: list) -> float:
232
+ if len(xs) < 2:
233
+ return 0.0
234
+ m = _mean(xs)
235
+ return math.sqrt(sum((x - m) ** 2 for x in xs) / len(xs))
236
+
237
+ return {
238
+ "step": step,
239
+ "mean_reward": _mean(rewards),
240
+ "std_reward": _std(rewards),
241
+ "mean_completion_length": _mean([float(l) for l in completion_lengths]),
242
+ "planning_token_ratio": _mean(planning_ratios),
243
+ "entropy_mean": _mean(entropy_values),
244
+ "entropy_std": _std(entropy_values),
245
+ "correct_count": correct_count,
246
+ "total_count": total_count,
247
+ "logging_overhead_ms": round(elapsed_ms, 2),
248
+ }
@@ -0,0 +1,105 @@
1
+ # textpolicy/analysis/planning_patterns.py
2
+ """
3
+ Planning pattern detection for emergence analysis.
4
+
5
+ Provides configurable pattern matching to identify reasoning behaviors
6
+ (hesitation, verification, backtracking, etc.) in model generations.
7
+ """
8
+
9
+ import re
10
+ from dataclasses import dataclass, field
11
+ from typing import Dict, List, Optional
12
+
13
+
14
+ @dataclass
15
+ class PlanningPatternConfig:
16
+ """Configuration for planning pattern detection.
17
+
18
+ Each category maps to a list of literal phrases. The detector builds
19
+ a single compiled regex from all phrases for efficient matching.
20
+ """
21
+
22
+ hesitation: List[str] = field(
23
+ default_factory=lambda: [
24
+ "wait",
25
+ "hmm",
26
+ "actually",
27
+ "let me think",
28
+ "on second thought",
29
+ ]
30
+ )
31
+ verification: List[str] = field(
32
+ default_factory=lambda: [
33
+ "let me check",
34
+ "verify",
35
+ "double check",
36
+ "is this right",
37
+ ]
38
+ )
39
+ backtracking: List[str] = field(
40
+ default_factory=lambda: [
41
+ "try another",
42
+ "different approach",
43
+ "go back",
44
+ "start over",
45
+ ]
46
+ )
47
+ alternatives: List[str] = field(
48
+ default_factory=lambda: [
49
+ "alternatively",
50
+ "or we could",
51
+ "another way",
52
+ ]
53
+ )
54
+ metacognition: List[str] = field(
55
+ default_factory=lambda: [
56
+ "notice that",
57
+ "the key is",
58
+ "importantly",
59
+ ]
60
+ )
61
+ case_sensitive: bool = False
62
+
63
+ @property
64
+ def all_patterns(self) -> List[str]:
65
+ """Return a flat list of all patterns across every category."""
66
+ patterns: List[str] = []
67
+ for cat in ("hesitation", "verification", "backtracking",
68
+ "alternatives", "metacognition"):
69
+ patterns.extend(getattr(self, cat))
70
+ return patterns
71
+
72
+
73
+ class PlanningPatternDetector:
74
+ """Efficient planning-phrase detector using a single compiled regex.
75
+
76
+ Args:
77
+ config: Optional pattern configuration. Uses defaults if *None*.
78
+ """
79
+
80
+ def __init__(self, config: Optional[PlanningPatternConfig] = None) -> None:
81
+ self.config = config or PlanningPatternConfig()
82
+ flags = 0 if self.config.case_sensitive else re.IGNORECASE
83
+ # Sort longest-first so greedy alternation prefers longer matches
84
+ patterns = sorted(self.config.all_patterns, key=len, reverse=True)
85
+ escaped = [re.escape(p) for p in patterns]
86
+ # Guard against empty pattern list — an empty regex matches every position
87
+ self._regex = re.compile("|".join(escaped), flags) if escaped else None
88
+
89
+ def detect(self, text: str) -> List[str]:
90
+ """Return all matched planning phrases found in *text*."""
91
+ if not text or self._regex is None:
92
+ return []
93
+ return [m.group() for m in self._regex.finditer(text)]
94
+
95
+ def planning_token_ratio(self, text: str, total_tokens: int) -> float:
96
+ """Ratio of planning-phrase words to *total_tokens*.
97
+
98
+ Uses whitespace word count of matched phrases as numerator.
99
+ Returns 0.0 when *total_tokens* is zero.
100
+ """
101
+ if total_tokens == 0:
102
+ return 0.0
103
+ matches = self.detect(text)
104
+ planning_words = sum(len(m.split()) for m in matches)
105
+ return planning_words / total_tokens
@@ -0,0 +1,65 @@
1
+ # textpolicy/analysis/serialization.py
2
+ """
3
+ JSON-safe conversion utilities and streaming JSONL writer.
4
+
5
+ Handles MLX arrays, numpy scalars, and nested structures for
6
+ serialization to JSONL format used by EmergenceLogger.
7
+ """
8
+
9
+ import json
10
+ from pathlib import Path
11
+ from typing import Any, Union
12
+
13
+
14
+ def to_json_safe(obj: Any) -> Any:
15
+ """Recursively convert MLX arrays, numpy scalars, etc. to JSON-native types.
16
+
17
+ Args:
18
+ obj: Any Python object that may contain MLX arrays or numpy types.
19
+
20
+ Returns:
21
+ JSON-serializable equivalent.
22
+ """
23
+ # MLX array → list
24
+ if hasattr(obj, "tolist") and callable(obj.tolist):
25
+ return obj.tolist()
26
+
27
+ # numpy scalar → Python scalar
28
+ if hasattr(obj, "item") and callable(obj.item):
29
+ return obj.item()
30
+
31
+ if isinstance(obj, dict):
32
+ return {k: to_json_safe(v) for k, v in obj.items()}
33
+
34
+ if isinstance(obj, (list, tuple)):
35
+ return [to_json_safe(v) for v in obj]
36
+
37
+ # int, float, str, bool, None pass through
38
+ return obj
39
+
40
+
41
+ class StreamingJSONLWriter:
42
+ """Append-only JSONL writer with lazy file open and compact serialization.
43
+
44
+ Args:
45
+ path: Destination file path. Parent directories are created on first write.
46
+ """
47
+
48
+ def __init__(self, path: Union[str, Path]) -> None:
49
+ self._path = Path(path)
50
+ self._file = None
51
+
52
+ def write(self, record: dict) -> None:
53
+ """Serialize *record* as one compact JSON line, then flush."""
54
+ if self._file is None:
55
+ self._path.parent.mkdir(parents=True, exist_ok=True)
56
+ self._file = open(self._path, "a")
57
+ line = json.dumps(to_json_safe(record), separators=(",", ":"))
58
+ self._file.write(line + "\n")
59
+ self._file.flush()
60
+
61
+ def close(self) -> None:
62
+ """Close the underlying file handle (idempotent)."""
63
+ if self._file is not None:
64
+ self._file.close()
65
+ self._file = None