textpolicy 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +3 -0
- textpolicy/algorithms/__init__.py +29 -4
- textpolicy/algorithms/grpo.py +771 -361
- textpolicy/algorithms/length_shaping.py +151 -0
- textpolicy/analysis/__init__.py +23 -0
- textpolicy/analysis/emergence_logger.py +248 -0
- textpolicy/analysis/planning_patterns.py +105 -0
- textpolicy/analysis/serialization.py +65 -0
- textpolicy/generation/mlx_generation.py +36 -21
- textpolicy/tasks/__init__.py +7 -0
- textpolicy/tasks/countdown/__init__.py +21 -0
- textpolicy/tasks/countdown/dataset.py +163 -0
- textpolicy/tasks/countdown/evaluator.py +197 -0
- textpolicy/tasks/countdown/prompt.py +89 -0
- textpolicy/tasks/countdown/reward.py +56 -0
- textpolicy/training/trainer.py +41 -21
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/METADATA +1 -1
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/RECORD +22 -11
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/WHEEL +0 -0
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/entry_points.txt +0 -0
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.1.2.dist-info → textpolicy-0.1.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# textpolicy/algorithms/length_shaping.py
|
|
2
|
+
"""
|
|
3
|
+
DAPO-style soft overlong penalties and length shaping utilities.
|
|
4
|
+
|
|
5
|
+
These utilities replace hard truncation with graduated penalties,
|
|
6
|
+
reducing training instability from length-based confusion.
|
|
7
|
+
|
|
8
|
+
References:
|
|
9
|
+
DAPO: An Open-Source LLM Reinforcement Learning System at Scale
|
|
10
|
+
https://arxiv.org/abs/2503.14476
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from typing import List, Dict, Union
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import mlx.core as mx # type: ignore
|
|
19
|
+
except ImportError:
|
|
20
|
+
mx = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def compute_length_penalty(
|
|
24
|
+
sequence_length: int,
|
|
25
|
+
max_length: int,
|
|
26
|
+
cache_length: int = 100,
|
|
27
|
+
max_penalty: float = 0.5
|
|
28
|
+
) -> float:
|
|
29
|
+
"""
|
|
30
|
+
Compute soft penalty for sequences approaching max length.
|
|
31
|
+
|
|
32
|
+
Instead of hard cutoffs for max sequence length (which cause truncation
|
|
33
|
+
that looks like failure to the model), use graduated penalties within
|
|
34
|
+
an interval before max_length.
|
|
35
|
+
|
|
36
|
+
This reduces training instability from length-based confusion and helps
|
|
37
|
+
the model learn to be concise without hard punishment.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
sequence_length: Current sequence length
|
|
41
|
+
max_length: Maximum allowed sequence length
|
|
42
|
+
cache_length: Start penalizing this many tokens before max_length.
|
|
43
|
+
Must be positive.
|
|
44
|
+
max_penalty: Maximum penalty at max_length (default 0.5)
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Penalty value (0.0 for normal lengths, up to -max_penalty at max_length)
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
With max_length=512, cache_length=100 (threshold=412):
|
|
51
|
+
- length=400: penalty=0.0 (below threshold)
|
|
52
|
+
- length=412: penalty=0.0 (at threshold, progress=0)
|
|
53
|
+
- length=462: penalty=-0.25 (50/100 * 0.5)
|
|
54
|
+
- length=512: penalty=-0.5 (at max)
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError: If cache_length <= 0
|
|
58
|
+
"""
|
|
59
|
+
if cache_length <= 0:
|
|
60
|
+
raise ValueError(f"cache_length must be positive, got {cache_length}")
|
|
61
|
+
|
|
62
|
+
threshold = max_length - cache_length
|
|
63
|
+
|
|
64
|
+
if sequence_length < threshold:
|
|
65
|
+
return 0.0
|
|
66
|
+
|
|
67
|
+
# Linear penalty from 0 to max_penalty as we approach max
|
|
68
|
+
progress = (sequence_length - threshold) / cache_length
|
|
69
|
+
progress = min(1.0, progress) # Clamp at 1.0
|
|
70
|
+
|
|
71
|
+
return -max_penalty * progress
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def apply_length_shaping(
|
|
75
|
+
rewards: "mx.array",
|
|
76
|
+
sequence_lengths: List[int],
|
|
77
|
+
max_length: int,
|
|
78
|
+
cache_length: int = 100,
|
|
79
|
+
max_penalty: float = 0.5
|
|
80
|
+
) -> "mx.array":
|
|
81
|
+
"""
|
|
82
|
+
Apply soft length penalties to rewards.
|
|
83
|
+
|
|
84
|
+
Modifies rewards by adding graduated penalties for sequences that
|
|
85
|
+
approach the maximum length. This provides a smoother learning signal
|
|
86
|
+
than hard truncation.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
rewards: Original rewards array [batch_size]
|
|
90
|
+
sequence_lengths: List of sequence lengths for each episode
|
|
91
|
+
max_length: Maximum allowed sequence length
|
|
92
|
+
cache_length: Start penalizing this many tokens before max_length
|
|
93
|
+
max_penalty: Maximum penalty at max_length
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Rewards with length penalties applied
|
|
97
|
+
|
|
98
|
+
Example:
|
|
99
|
+
>>> rewards = mx.array([1.0, 0.5, 0.0])
|
|
100
|
+
>>> lengths = [400, 500, 520] # max_length=512, cache_length=100
|
|
101
|
+
>>> shaped = apply_length_shaping(rewards, lengths, 512)
|
|
102
|
+
>>> # shaped ≈ [1.0, 0.06, -0.5] # last one gets max penalty
|
|
103
|
+
"""
|
|
104
|
+
penalties = mx.array([
|
|
105
|
+
compute_length_penalty(length, max_length, cache_length, max_penalty)
|
|
106
|
+
for length in sequence_lengths
|
|
107
|
+
], dtype=mx.float32)
|
|
108
|
+
|
|
109
|
+
return rewards + penalties
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def compute_length_shaping_stats(
|
|
113
|
+
sequence_lengths: List[int],
|
|
114
|
+
max_length: int,
|
|
115
|
+
cache_length: int = 100
|
|
116
|
+
) -> Dict[str, Union[int, float]]:
|
|
117
|
+
"""
|
|
118
|
+
Compute statistics about length penalties for monitoring.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
sequence_lengths: List of sequence lengths
|
|
122
|
+
max_length: Maximum allowed sequence length
|
|
123
|
+
cache_length: Penalty threshold offset
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Dictionary with length penalty statistics:
|
|
127
|
+
- mean_length: Average sequence length
|
|
128
|
+
- max_length_observed: Maximum observed sequence length
|
|
129
|
+
- truncation_rate: Fraction of sequences at or past max_length
|
|
130
|
+
- penalty_zone_rate: Fraction of sequences in penalty zone
|
|
131
|
+
"""
|
|
132
|
+
threshold = max_length - cache_length
|
|
133
|
+
total = len(sequence_lengths)
|
|
134
|
+
|
|
135
|
+
if total == 0:
|
|
136
|
+
return {
|
|
137
|
+
'mean_length': 0.0,
|
|
138
|
+
'max_length_observed': 0,
|
|
139
|
+
'truncation_rate': 0.0,
|
|
140
|
+
'penalty_zone_rate': 0.0,
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
truncated = sum(1 for l in sequence_lengths if l >= max_length)
|
|
144
|
+
in_penalty_zone = sum(1 for l in sequence_lengths if threshold <= l < max_length)
|
|
145
|
+
|
|
146
|
+
return {
|
|
147
|
+
'mean_length': sum(sequence_lengths) / total,
|
|
148
|
+
'max_length_observed': max(sequence_lengths),
|
|
149
|
+
'truncation_rate': truncated / total,
|
|
150
|
+
'penalty_zone_rate': in_penalty_zone / total,
|
|
151
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# textpolicy/analysis/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Post-hoc analysis tooling for TextPolicy training runs.
|
|
4
|
+
|
|
5
|
+
Main components:
|
|
6
|
+
- EmergenceLogger: Captures all generations during GRPO training
|
|
7
|
+
- PlanningPatternDetector: Configurable planning-phrase detection
|
|
8
|
+
- PlanningPatternConfig: Pattern configuration dataclass
|
|
9
|
+
- StreamingJSONLWriter: Append-only JSONL writer
|
|
10
|
+
- to_json_safe: MLX/numpy → JSON-native conversion
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from .emergence_logger import EmergenceLogger
|
|
14
|
+
from .planning_patterns import PlanningPatternConfig, PlanningPatternDetector
|
|
15
|
+
from .serialization import StreamingJSONLWriter, to_json_safe
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"EmergenceLogger",
|
|
19
|
+
"PlanningPatternDetector",
|
|
20
|
+
"PlanningPatternConfig",
|
|
21
|
+
"StreamingJSONLWriter",
|
|
22
|
+
"to_json_safe",
|
|
23
|
+
]
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
# textpolicy/analysis/emergence_logger.py
|
|
2
|
+
"""
|
|
3
|
+
Generation logging for emergence analysis during GRPO training.
|
|
4
|
+
|
|
5
|
+
Captures every generation produced during training and writes two JSONL
|
|
6
|
+
streams: per-generation records (``generations.jsonl``) and per-step
|
|
7
|
+
aggregate statistics (``steps.jsonl``).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
13
|
+
|
|
14
|
+
from .planning_patterns import PlanningPatternConfig, PlanningPatternDetector
|
|
15
|
+
from .serialization import StreamingJSONLWriter, to_json_safe
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _flatten(value: Any) -> list:
|
|
19
|
+
"""Flatten an MLX array, Python list, or scalar to a plain Python list."""
|
|
20
|
+
if hasattr(value, "tolist"):
|
|
21
|
+
result = value.tolist()
|
|
22
|
+
if isinstance(result, list):
|
|
23
|
+
return result
|
|
24
|
+
return [result]
|
|
25
|
+
if isinstance(value, list):
|
|
26
|
+
return value
|
|
27
|
+
return [value]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _default_metadata_extractor(
|
|
31
|
+
example: Optional[dict],
|
|
32
|
+
reward: float,
|
|
33
|
+
) -> dict:
|
|
34
|
+
"""Extract countdown-task metadata from an example dict.
|
|
35
|
+
|
|
36
|
+
Returns ``target``, ``numbers``, and ``correctness`` (reward >= 0.99).
|
|
37
|
+
"""
|
|
38
|
+
if example is None:
|
|
39
|
+
return {}
|
|
40
|
+
meta: Dict[str, Any] = {}
|
|
41
|
+
if "target" in example:
|
|
42
|
+
meta["target"] = to_json_safe(example["target"])
|
|
43
|
+
if "numbers" in example:
|
|
44
|
+
meta["numbers"] = to_json_safe(example["numbers"])
|
|
45
|
+
meta["correctness"] = reward >= 0.99
|
|
46
|
+
return meta
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class EmergenceLogger:
|
|
50
|
+
"""Logs every generation during training for post-hoc emergence analysis.
|
|
51
|
+
|
|
52
|
+
Writes two JSONL files under *output_dir*:
|
|
53
|
+
|
|
54
|
+
* ``generations.jsonl`` — one record per generation
|
|
55
|
+
* ``steps.jsonl`` — one record per training step (aggregated stats)
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
output_dir: Directory for JSONL output files (created if needed).
|
|
59
|
+
planning_config: Optional :class:`PlanningPatternConfig`.
|
|
60
|
+
metadata_extractor: Optional callable ``(example, reward) -> dict``.
|
|
61
|
+
Defaults to countdown-task extractor.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
output_dir: Union[str, Path],
|
|
67
|
+
planning_config: Optional[PlanningPatternConfig] = None,
|
|
68
|
+
metadata_extractor: Optional[Callable] = None,
|
|
69
|
+
) -> None:
|
|
70
|
+
self._output_dir = Path(output_dir)
|
|
71
|
+
self._output_dir.mkdir(parents=True, exist_ok=True)
|
|
72
|
+
|
|
73
|
+
self._gen_writer = StreamingJSONLWriter(self._output_dir / "generations.jsonl")
|
|
74
|
+
self._step_writer = StreamingJSONLWriter(self._output_dir / "steps.jsonl")
|
|
75
|
+
|
|
76
|
+
self._detector = PlanningPatternDetector(planning_config)
|
|
77
|
+
self._extract_metadata = metadata_extractor or _default_metadata_extractor
|
|
78
|
+
|
|
79
|
+
# ------------------------------------------------------------------
|
|
80
|
+
# Public API
|
|
81
|
+
# ------------------------------------------------------------------
|
|
82
|
+
|
|
83
|
+
def log_step(
|
|
84
|
+
self,
|
|
85
|
+
step: int,
|
|
86
|
+
episodes: list,
|
|
87
|
+
tokenizer: Any,
|
|
88
|
+
examples: Optional[list] = None,
|
|
89
|
+
) -> dict:
|
|
90
|
+
"""Log all generations for a single training step.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
step: Current training step index.
|
|
94
|
+
episodes: List of :class:`Episode` objects (or dicts with the
|
|
95
|
+
same fields: ``obs``, ``act``, ``rew``, ``logprob``).
|
|
96
|
+
tokenizer: Tokenizer with a ``decode`` method.
|
|
97
|
+
examples: Optional parallel list of example dicts (same length
|
|
98
|
+
as *episodes*). Used by the metadata extractor.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Aggregated step statistics dict (also written to ``steps.jsonl``).
|
|
102
|
+
"""
|
|
103
|
+
t0 = time.perf_counter()
|
|
104
|
+
|
|
105
|
+
rewards: List[float] = []
|
|
106
|
+
completion_lengths: List[int] = []
|
|
107
|
+
planning_ratios: List[float] = []
|
|
108
|
+
entropy_values: List[float] = []
|
|
109
|
+
correct_count = 0
|
|
110
|
+
|
|
111
|
+
for idx, ep in enumerate(episodes):
|
|
112
|
+
record = self._process_episode(
|
|
113
|
+
step=step,
|
|
114
|
+
episode=ep,
|
|
115
|
+
tokenizer=tokenizer,
|
|
116
|
+
example=examples[idx] if examples and idx < len(examples) else None,
|
|
117
|
+
)
|
|
118
|
+
self._gen_writer.write(record)
|
|
119
|
+
|
|
120
|
+
# Accumulate for step aggregate
|
|
121
|
+
rewards.append(record["reward"])
|
|
122
|
+
completion_lengths.append(len(record["tokens"]))
|
|
123
|
+
planning_ratios.append(record["planning_token_ratio"])
|
|
124
|
+
if record["entropy_per_token"]:
|
|
125
|
+
entropy_values.extend(record["entropy_per_token"])
|
|
126
|
+
if record.get("metadata", {}).get("correctness", False):
|
|
127
|
+
correct_count += 1
|
|
128
|
+
|
|
129
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000.0
|
|
130
|
+
total = len(episodes)
|
|
131
|
+
|
|
132
|
+
step_record = self._build_step_record(
|
|
133
|
+
step=step,
|
|
134
|
+
rewards=rewards,
|
|
135
|
+
completion_lengths=completion_lengths,
|
|
136
|
+
planning_ratios=planning_ratios,
|
|
137
|
+
entropy_values=entropy_values,
|
|
138
|
+
correct_count=correct_count,
|
|
139
|
+
total_count=total,
|
|
140
|
+
elapsed_ms=elapsed_ms,
|
|
141
|
+
)
|
|
142
|
+
self._step_writer.write(step_record)
|
|
143
|
+
return step_record
|
|
144
|
+
|
|
145
|
+
def finish(self) -> None:
|
|
146
|
+
"""Close underlying file handles."""
|
|
147
|
+
self._gen_writer.close()
|
|
148
|
+
self._step_writer.close()
|
|
149
|
+
|
|
150
|
+
# ------------------------------------------------------------------
|
|
151
|
+
# Internal helpers
|
|
152
|
+
# ------------------------------------------------------------------
|
|
153
|
+
|
|
154
|
+
def _process_episode(
|
|
155
|
+
self,
|
|
156
|
+
step: int,
|
|
157
|
+
episode: Any,
|
|
158
|
+
tokenizer: Any,
|
|
159
|
+
example: Optional[dict],
|
|
160
|
+
) -> dict:
|
|
161
|
+
"""Build a per-generation record from a single episode."""
|
|
162
|
+
# Support both Episode objects and plain dicts.
|
|
163
|
+
# Use isinstance check instead of `or` to avoid falsy empty-list fallthrough.
|
|
164
|
+
if isinstance(episode, dict):
|
|
165
|
+
obs = episode.get("obs", [])
|
|
166
|
+
act = episode.get("act", [])
|
|
167
|
+
rew = episode.get("rew", [])
|
|
168
|
+
logprob_raw = episode.get("logprob")
|
|
169
|
+
else:
|
|
170
|
+
obs = episode.obs
|
|
171
|
+
act = episode.act
|
|
172
|
+
rew = episode.rew
|
|
173
|
+
logprob_raw = episode.logprob
|
|
174
|
+
|
|
175
|
+
# Flatten to plain Python lists
|
|
176
|
+
prompt_tokens = _flatten(obs[0]) if obs else []
|
|
177
|
+
completion_tokens = _flatten(act[0]) if act else []
|
|
178
|
+
reward_val = float(_flatten(rew)[0]) if rew else 0.0
|
|
179
|
+
|
|
180
|
+
# Logprobs (may be None)
|
|
181
|
+
logprobs: List[float] = []
|
|
182
|
+
if logprob_raw is not None and len(logprob_raw) > 0:
|
|
183
|
+
logprobs = [float(v) for v in _flatten(logprob_raw[0])]
|
|
184
|
+
|
|
185
|
+
# Entropy proxy: -logprob per token
|
|
186
|
+
entropy_per_token = [-lp for lp in logprobs] if logprobs else []
|
|
187
|
+
|
|
188
|
+
# Decode text for pattern detection
|
|
189
|
+
prompt_text = tokenizer.decode(prompt_tokens) if prompt_tokens else ""
|
|
190
|
+
completion_text = tokenizer.decode(completion_tokens) if completion_tokens else ""
|
|
191
|
+
|
|
192
|
+
# Planning pattern detection
|
|
193
|
+
planning_phrases = self._detector.detect(completion_text)
|
|
194
|
+
planning_ratio = self._detector.planning_token_ratio(
|
|
195
|
+
completion_text, len(completion_tokens)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Metadata
|
|
199
|
+
metadata = self._extract_metadata(example, reward_val)
|
|
200
|
+
|
|
201
|
+
return {
|
|
202
|
+
"step": step,
|
|
203
|
+
"prompt": prompt_text,
|
|
204
|
+
"completion": completion_text,
|
|
205
|
+
"reward": reward_val,
|
|
206
|
+
"tokens": completion_tokens,
|
|
207
|
+
"logprobs": logprobs,
|
|
208
|
+
"entropy_per_token": entropy_per_token,
|
|
209
|
+
"planning_phrases_found": planning_phrases,
|
|
210
|
+
"planning_token_ratio": planning_ratio,
|
|
211
|
+
"metadata": metadata,
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def _build_step_record(
|
|
216
|
+
step: int,
|
|
217
|
+
rewards: List[float],
|
|
218
|
+
completion_lengths: List[int],
|
|
219
|
+
planning_ratios: List[float],
|
|
220
|
+
entropy_values: List[float],
|
|
221
|
+
correct_count: int,
|
|
222
|
+
total_count: int,
|
|
223
|
+
elapsed_ms: float,
|
|
224
|
+
) -> dict:
|
|
225
|
+
"""Compute aggregate statistics for a training step."""
|
|
226
|
+
import math
|
|
227
|
+
|
|
228
|
+
def _mean(xs: list) -> float:
|
|
229
|
+
return sum(xs) / len(xs) if xs else 0.0
|
|
230
|
+
|
|
231
|
+
def _std(xs: list) -> float:
|
|
232
|
+
if len(xs) < 2:
|
|
233
|
+
return 0.0
|
|
234
|
+
m = _mean(xs)
|
|
235
|
+
return math.sqrt(sum((x - m) ** 2 for x in xs) / len(xs))
|
|
236
|
+
|
|
237
|
+
return {
|
|
238
|
+
"step": step,
|
|
239
|
+
"mean_reward": _mean(rewards),
|
|
240
|
+
"std_reward": _std(rewards),
|
|
241
|
+
"mean_completion_length": _mean([float(l) for l in completion_lengths]),
|
|
242
|
+
"planning_token_ratio": _mean(planning_ratios),
|
|
243
|
+
"entropy_mean": _mean(entropy_values),
|
|
244
|
+
"entropy_std": _std(entropy_values),
|
|
245
|
+
"correct_count": correct_count,
|
|
246
|
+
"total_count": total_count,
|
|
247
|
+
"logging_overhead_ms": round(elapsed_ms, 2),
|
|
248
|
+
}
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# textpolicy/analysis/planning_patterns.py
|
|
2
|
+
"""
|
|
3
|
+
Planning pattern detection for emergence analysis.
|
|
4
|
+
|
|
5
|
+
Provides configurable pattern matching to identify reasoning behaviors
|
|
6
|
+
(hesitation, verification, backtracking, etc.) in model generations.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class PlanningPatternConfig:
|
|
16
|
+
"""Configuration for planning pattern detection.
|
|
17
|
+
|
|
18
|
+
Each category maps to a list of literal phrases. The detector builds
|
|
19
|
+
a single compiled regex from all phrases for efficient matching.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
hesitation: List[str] = field(
|
|
23
|
+
default_factory=lambda: [
|
|
24
|
+
"wait",
|
|
25
|
+
"hmm",
|
|
26
|
+
"actually",
|
|
27
|
+
"let me think",
|
|
28
|
+
"on second thought",
|
|
29
|
+
]
|
|
30
|
+
)
|
|
31
|
+
verification: List[str] = field(
|
|
32
|
+
default_factory=lambda: [
|
|
33
|
+
"let me check",
|
|
34
|
+
"verify",
|
|
35
|
+
"double check",
|
|
36
|
+
"is this right",
|
|
37
|
+
]
|
|
38
|
+
)
|
|
39
|
+
backtracking: List[str] = field(
|
|
40
|
+
default_factory=lambda: [
|
|
41
|
+
"try another",
|
|
42
|
+
"different approach",
|
|
43
|
+
"go back",
|
|
44
|
+
"start over",
|
|
45
|
+
]
|
|
46
|
+
)
|
|
47
|
+
alternatives: List[str] = field(
|
|
48
|
+
default_factory=lambda: [
|
|
49
|
+
"alternatively",
|
|
50
|
+
"or we could",
|
|
51
|
+
"another way",
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
metacognition: List[str] = field(
|
|
55
|
+
default_factory=lambda: [
|
|
56
|
+
"notice that",
|
|
57
|
+
"the key is",
|
|
58
|
+
"importantly",
|
|
59
|
+
]
|
|
60
|
+
)
|
|
61
|
+
case_sensitive: bool = False
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def all_patterns(self) -> List[str]:
|
|
65
|
+
"""Return a flat list of all patterns across every category."""
|
|
66
|
+
patterns: List[str] = []
|
|
67
|
+
for cat in ("hesitation", "verification", "backtracking",
|
|
68
|
+
"alternatives", "metacognition"):
|
|
69
|
+
patterns.extend(getattr(self, cat))
|
|
70
|
+
return patterns
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class PlanningPatternDetector:
|
|
74
|
+
"""Efficient planning-phrase detector using a single compiled regex.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
config: Optional pattern configuration. Uses defaults if *None*.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self, config: Optional[PlanningPatternConfig] = None) -> None:
|
|
81
|
+
self.config = config or PlanningPatternConfig()
|
|
82
|
+
flags = 0 if self.config.case_sensitive else re.IGNORECASE
|
|
83
|
+
# Sort longest-first so greedy alternation prefers longer matches
|
|
84
|
+
patterns = sorted(self.config.all_patterns, key=len, reverse=True)
|
|
85
|
+
escaped = [re.escape(p) for p in patterns]
|
|
86
|
+
# Guard against empty pattern list — an empty regex matches every position
|
|
87
|
+
self._regex = re.compile("|".join(escaped), flags) if escaped else None
|
|
88
|
+
|
|
89
|
+
def detect(self, text: str) -> List[str]:
|
|
90
|
+
"""Return all matched planning phrases found in *text*."""
|
|
91
|
+
if not text or self._regex is None:
|
|
92
|
+
return []
|
|
93
|
+
return [m.group() for m in self._regex.finditer(text)]
|
|
94
|
+
|
|
95
|
+
def planning_token_ratio(self, text: str, total_tokens: int) -> float:
|
|
96
|
+
"""Ratio of planning-phrase words to *total_tokens*.
|
|
97
|
+
|
|
98
|
+
Uses whitespace word count of matched phrases as numerator.
|
|
99
|
+
Returns 0.0 when *total_tokens* is zero.
|
|
100
|
+
"""
|
|
101
|
+
if total_tokens == 0:
|
|
102
|
+
return 0.0
|
|
103
|
+
matches = self.detect(text)
|
|
104
|
+
planning_words = sum(len(m.split()) for m in matches)
|
|
105
|
+
return planning_words / total_tokens
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# textpolicy/analysis/serialization.py
|
|
2
|
+
"""
|
|
3
|
+
JSON-safe conversion utilities and streaming JSONL writer.
|
|
4
|
+
|
|
5
|
+
Handles MLX arrays, numpy scalars, and nested structures for
|
|
6
|
+
serialization to JSONL format used by EmergenceLogger.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Union
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def to_json_safe(obj: Any) -> Any:
|
|
15
|
+
"""Recursively convert MLX arrays, numpy scalars, etc. to JSON-native types.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
obj: Any Python object that may contain MLX arrays or numpy types.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
JSON-serializable equivalent.
|
|
22
|
+
"""
|
|
23
|
+
# MLX array → list
|
|
24
|
+
if hasattr(obj, "tolist") and callable(obj.tolist):
|
|
25
|
+
return obj.tolist()
|
|
26
|
+
|
|
27
|
+
# numpy scalar → Python scalar
|
|
28
|
+
if hasattr(obj, "item") and callable(obj.item):
|
|
29
|
+
return obj.item()
|
|
30
|
+
|
|
31
|
+
if isinstance(obj, dict):
|
|
32
|
+
return {k: to_json_safe(v) for k, v in obj.items()}
|
|
33
|
+
|
|
34
|
+
if isinstance(obj, (list, tuple)):
|
|
35
|
+
return [to_json_safe(v) for v in obj]
|
|
36
|
+
|
|
37
|
+
# int, float, str, bool, None pass through
|
|
38
|
+
return obj
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class StreamingJSONLWriter:
|
|
42
|
+
"""Append-only JSONL writer with lazy file open and compact serialization.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
path: Destination file path. Parent directories are created on first write.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, path: Union[str, Path]) -> None:
|
|
49
|
+
self._path = Path(path)
|
|
50
|
+
self._file = None
|
|
51
|
+
|
|
52
|
+
def write(self, record: dict) -> None:
|
|
53
|
+
"""Serialize *record* as one compact JSON line, then flush."""
|
|
54
|
+
if self._file is None:
|
|
55
|
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
|
56
|
+
self._file = open(self._path, "a")
|
|
57
|
+
line = json.dumps(to_json_safe(record), separators=(",", ":"))
|
|
58
|
+
self._file.write(line + "\n")
|
|
59
|
+
self._file.flush()
|
|
60
|
+
|
|
61
|
+
def close(self) -> None:
|
|
62
|
+
"""Close the underlying file handle (idempotent)."""
|
|
63
|
+
if self._file is not None:
|
|
64
|
+
self._file.close()
|
|
65
|
+
self._file = None
|