textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +53 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +797 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.1.dist-info/METADATA +109 -0
  62. textpolicy-0.1.1.dist-info/RECORD +66 -0
  63. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
  64. textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
  65. textpolicy-0.0.1.dist-info/METADATA +0 -10
  66. textpolicy-0.0.1.dist-info/RECORD +0 -6
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
1
+ # textpolicy/rollout/worker.py
2
+ """
3
+ Worker process management and queue communication.
4
+ """
5
+
6
+ import multiprocessing as mp
7
+ import queue
8
+ from typing import Callable, Any, Optional
9
+ from .runner import RolloutRunner
10
+ from .base import DEFAULT_WORKER_TIMEOUT, DEFAULT_MAX_STEPS
11
+ from textpolicy.buffer import Buffer
12
+
13
+
14
+ class RolloutWorker:
15
+ """
16
+ Manages a rollout runner in a separate process.
17
+
18
+ Handles:
19
+ - Process lifecycle (start, stop, cleanup)
20
+ - Queue-based communication with trainer
21
+ - Data serialization for multiprocessing
22
+ - Async rollout collection
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ env_fn: Callable[[], Any],
28
+ policy_fn: Callable[[], Any],
29
+ strategy: Any,
30
+ max_steps: int = DEFAULT_MAX_STEPS,
31
+ send_queue: Optional[mp.Queue] = None,
32
+ ):
33
+ """
34
+ Initialize worker for separate process execution.
35
+
36
+ Args:
37
+ env_fn: Function that returns a fresh environment
38
+ policy_fn: Function that returns a fresh policy (for multiprocessing compatibility)
39
+ strategy: RolloutStrategy instance (e.g., PPOStrategy)
40
+ max_steps: Maximum steps per rollout collection
41
+ send_queue: Queue to send data back to trainer
42
+ """
43
+ self.env_fn = env_fn
44
+ self.policy_fn = policy_fn
45
+ self.strategy = strategy
46
+ self.max_steps = max_steps
47
+
48
+ # Communication queues
49
+ self.send_queue: mp.Queue = send_queue if send_queue is not None else mp.Queue()
50
+ self.control_queue: mp.Queue = mp.Queue() # Commands: "collect", "exit"
51
+
52
+ # Process handle and status
53
+ self.process: Optional[mp.Process] = None
54
+ self.is_closed = False # Track cleanup status
55
+
56
+ def run(self):
57
+ """
58
+ Target function for the worker process.
59
+
60
+ Creates fresh environment and policy instances, then runs collection loop.
61
+ Communicates via queues to avoid shared memory issues.
62
+ """
63
+ # Create fresh instances in worker process (avoid pickle issues)
64
+ env = self.env_fn()
65
+ policy = self.policy_fn()
66
+ runner = RolloutRunner(env, policy, self.strategy, self.max_steps)
67
+
68
+ while True:
69
+ try:
70
+ # Wait for commands from trainer
71
+ msg = self.control_queue.get(timeout=DEFAULT_WORKER_TIMEOUT)
72
+
73
+ if msg == "collect":
74
+ # Run rollout collection
75
+ buffer = runner.collect()
76
+
77
+ # Serialize buffer to pure Python types for queue transmission
78
+ episodes_data = [ep.to_dict() for ep in buffer.episodes]
79
+ self.send_queue.put(episodes_data)
80
+
81
+ elif msg == "exit":
82
+ break
83
+
84
+ except queue.Empty:
85
+ # Timeout - continue listening for commands
86
+ continue
87
+
88
+ def start(self):
89
+ """Launch the worker process."""
90
+ if self.process is not None:
91
+ raise RuntimeError("Worker process already started")
92
+
93
+ self.process = mp.Process(target=self.run, daemon=True)
94
+ self.process.start()
95
+
96
+ def collect_async(self):
97
+ """Request a rollout without blocking."""
98
+ if self.process is None:
99
+ raise RuntimeError("Worker process not started")
100
+
101
+ self.control_queue.put("collect")
102
+
103
+ def has_data(self) -> bool:
104
+ """Check if rollout result is ready."""
105
+ return not self.send_queue.empty()
106
+
107
+ def get_buffer(self) -> Buffer:
108
+ """
109
+ Retrieve the collected buffer.
110
+
111
+ Call only after has_data() returns True.
112
+
113
+ Returns:
114
+ Buffer instance with collected episodes
115
+ """
116
+ episodes_data = self.send_queue.get()
117
+
118
+ # Reconstruct buffer from serialized data
119
+ buffer = Buffer(max_episodes=len(episodes_data))
120
+ for ep_dict in episodes_data:
121
+ buffer.add_episode_from_dict(ep_dict)
122
+
123
+ return buffer
124
+
125
+ def close(self):
126
+ """Shut down the worker process gracefully."""
127
+ if self.process is None or self.is_closed:
128
+ return
129
+
130
+ try:
131
+ # Step 1: Signal shutdown to worker process
132
+ self.control_queue.put("exit")
133
+
134
+ # Step 2: Wait for graceful shutdown
135
+ self.process.join(timeout=3)
136
+
137
+ # Step 3: Force termination if needed
138
+ if self.process.is_alive():
139
+ self.process.terminate()
140
+ self.process.join(timeout=1)
141
+
142
+ # Step 4: Properly drain and close queues to prevent file descriptor errors
143
+ # This prevents the background feeder threads from trying to close already-closed FDs
144
+ self._cleanup_queues()
145
+
146
+ except Exception:
147
+ # Ignore cleanup errors during testing - they're just noise
148
+ pass
149
+ finally:
150
+ # Step 5: Mark as closed
151
+ self.is_closed = True
152
+ self.process = None
153
+
154
+ def _cleanup_queues(self):
155
+ """
156
+ Properly cleanup multiprocessing queues to prevent file descriptor errors.
157
+
158
+ The issue: Multiprocessing queues have background "feeder" threads that can
159
+ try to close file descriptors after the main process has already closed them.
160
+
161
+ Solution: Explicitly drain and close queues before process termination.
162
+ """
163
+ try:
164
+ # Cancel the queue's join thread to prevent hanging on exit
165
+ # This tells the queue not to wait for the feeder thread
166
+ if hasattr(self.send_queue, 'cancel_join_thread'):
167
+ self.send_queue.cancel_join_thread()
168
+ if hasattr(self.control_queue, 'cancel_join_thread'):
169
+ self.control_queue.cancel_join_thread()
170
+
171
+ # Drain any remaining items from queues to prevent deadlock
172
+ # Empty queues close more cleanly
173
+ try:
174
+ while not self.send_queue.empty():
175
+ self.send_queue.get_nowait()
176
+ except queue.Empty:
177
+ pass # Queue might be closed already
178
+
179
+ try:
180
+ while not self.control_queue.empty():
181
+ self.control_queue.get_nowait()
182
+ except queue.Empty:
183
+ pass # Queue might be closed already
184
+
185
+ # Now close the queues properly
186
+ if hasattr(self.send_queue, 'close'):
187
+ self.send_queue.close()
188
+ if hasattr(self.control_queue, 'close'):
189
+ self.control_queue.close()
190
+
191
+ except Exception:
192
+ # During testing, ignore queue cleanup errors
193
+ # They're artifacts of rapid process termination
194
+ pass
@@ -0,0 +1,14 @@
1
+ # textpolicy/training/__init__.py
2
+ """
3
+ Unified training infrastructure for all RL algorithms.
4
+ """
5
+
6
+ from .trainer import Trainer
7
+ from .rollout_manager import RolloutManager
8
+ from .metrics import TrainingMetrics
9
+
10
+ __all__ = [
11
+ "Trainer",
12
+ "RolloutManager",
13
+ "TrainingMetrics"
14
+ ]
@@ -0,0 +1,242 @@
1
+ # textpolicy/training/metrics.py
2
+ """
3
+ Training metrics collection and analysis for all RL algorithms.
4
+ """
5
+
6
+ from typing import Dict, Any, Optional, cast
7
+ import mlx.core as mx # type: ignore
8
+ from collections import defaultdict, deque
9
+
10
+
11
+ class TrainingMetrics:
12
+ """
13
+ Lightweight metrics collector optimized for MLX training.
14
+
15
+ Tracks algorithm-agnostic and algorithm-specific metrics
16
+ with minimal overhead during training.
17
+ """
18
+
19
+ def __init__(self, history_length: int = 100):
20
+ """
21
+ Initialize metrics collector.
22
+
23
+ Args:
24
+ history_length: Number of recent values to keep for rolling averages
25
+ """
26
+ self.history_length = history_length
27
+ self.metrics = defaultdict(lambda: deque(maxlen=history_length))
28
+ self.total_steps = 0
29
+
30
+ def update(self, metrics_dict: Dict[str, float]):
31
+ """
32
+ Update metrics with new values.
33
+
34
+ Args:
35
+ metrics_dict: Dictionary of metric_name -> value
36
+ """
37
+ for name, value in metrics_dict.items():
38
+ self.metrics[name].append(value)
39
+
40
+ if 'step' in metrics_dict:
41
+ self.total_steps = metrics_dict['step']
42
+
43
+ def get_latest(self, metric_name: str) -> Optional[float]:
44
+ """Get the most recent value for a metric."""
45
+ if metric_name in self.metrics and self.metrics[metric_name]:
46
+ return self.metrics[metric_name][-1]
47
+ return None
48
+
49
+ def get_mean(self, metric_name: str, last_n: Optional[int] = None) -> Optional[float]:
50
+ """
51
+ Get mean of recent metric values.
52
+
53
+ Args:
54
+ metric_name: Name of the metric
55
+ last_n: Number of recent values to average (None for all)
56
+
57
+ Returns:
58
+ Mean value or None if metric doesn't exist
59
+ """
60
+ if metric_name not in self.metrics or not self.metrics[metric_name]:
61
+ return None
62
+
63
+ values = list(self.metrics[metric_name])
64
+ if last_n is not None:
65
+ values = values[-last_n:]
66
+
67
+ return sum(values) / len(values) if values else None
68
+
69
+ def get_summary(self) -> Dict[str, Any]:
70
+ """
71
+ Get comprehensive metrics summary.
72
+
73
+ Returns:
74
+ Dictionary with latest, mean, and other statistics
75
+ """
76
+ summary = {
77
+ 'total_steps': self.total_steps,
78
+ 'metrics': {}
79
+ }
80
+
81
+ for metric_name, values in self.metrics.items():
82
+ if not values:
83
+ continue
84
+
85
+ values_list = list(values)
86
+ summary['metrics'][metric_name] = {
87
+ 'latest': values_list[-1],
88
+ 'mean': sum(values_list) / len(values_list),
89
+ 'min': min(values_list),
90
+ 'max': max(values_list),
91
+ 'count': len(values_list)
92
+ }
93
+
94
+ return summary
95
+
96
+ def reset(self):
97
+ """Reset all metrics."""
98
+ self.metrics.clear()
99
+ self.total_steps = 0
100
+
101
+ def __len__(self) -> int:
102
+ """Return number of metrics being tracked."""
103
+ return len(self.metrics)
104
+
105
+
106
+ class RolloutMetrics:
107
+ """
108
+ Metrics specific to rollout collection phase.
109
+ """
110
+
111
+ def __init__(self):
112
+ self.episodes_collected = 0
113
+ self.total_reward = 0.0
114
+ self.episode_lengths = []
115
+ self.episode_rewards = []
116
+
117
+ def add_episode(self, reward: float, length: int):
118
+ """Add metrics from a completed episode."""
119
+ self.episodes_collected += 1
120
+ self.total_reward += reward
121
+ self.episode_lengths.append(length)
122
+ self.episode_rewards.append(reward)
123
+
124
+ def get_summary(self) -> Dict[str, float]:
125
+ """Get rollout metrics summary."""
126
+ if not self.episode_rewards:
127
+ return {
128
+ 'episodes_collected': 0,
129
+ 'mean_reward': 0.0,
130
+ 'mean_length': 0.0,
131
+ 'total_reward': 0.0
132
+ }
133
+
134
+ return {
135
+ 'episodes_collected': self.episodes_collected,
136
+ 'mean_reward': sum(self.episode_rewards) / len(self.episode_rewards),
137
+ 'mean_length': sum(self.episode_lengths) / len(self.episode_lengths),
138
+ 'total_reward': self.total_reward,
139
+ 'min_reward': min(self.episode_rewards),
140
+ 'max_reward': max(self.episode_rewards),
141
+ 'min_length': min(self.episode_lengths),
142
+ 'max_length': max(self.episode_lengths)
143
+ }
144
+
145
+ def reset(self):
146
+ """Reset rollout metrics."""
147
+ self.episodes_collected = 0
148
+ self.total_reward = 0.0
149
+ self.episode_lengths.clear()
150
+ self.episode_rewards.clear()
151
+
152
+
153
+ def log_metrics(
154
+ metrics: Dict[str, float],
155
+ step: int,
156
+ logger: Optional[Any] = None,
157
+ prefix: str = ""
158
+ ):
159
+ """
160
+ Log metrics to console and optionally to external logger.
161
+
162
+ Args:
163
+ metrics: Metrics dictionary
164
+ step: Training step number
165
+ logger: Optional external logger (wandb, tensorboard, etc.)
166
+ prefix: Prefix for metric names
167
+ """
168
+ # Console logging
169
+ metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
170
+ print(f"Step {step} | {prefix}{metrics_str}")
171
+
172
+ # External logger
173
+ if logger is not None and hasattr(logger, 'log'):
174
+ prefixed_metrics = {f"{prefix}{k}": v for k, v in metrics.items()}
175
+ logger.log(prefixed_metrics, step=step)
176
+
177
+
178
+ def compute_explained_variance(predicted: mx.array, targets: mx.array) -> float:
179
+ """
180
+ Compute explained variance for value function evaluation.
181
+
182
+ Explained variance = 1 - Var(targets - predicted) / Var(targets)
183
+
184
+ Args:
185
+ predicted: Predicted values
186
+ targets: Target values
187
+
188
+ Returns:
189
+ Explained variance (1.0 = perfect, 0.0 = no better than mean)
190
+ """
191
+ target_var = mx.var(targets)
192
+ if target_var == 0:
193
+ return 0.0
194
+
195
+ residual_var = mx.var(targets - predicted)
196
+ explained_var = 1.0 - (residual_var / target_var)
197
+
198
+ return cast(float, explained_var.item()) # MLX scalar array .item() returns Python float for float dtypes
199
+
200
+
201
+ def compute_policy_metrics(
202
+ old_logprobs: mx.array,
203
+ new_logprobs: mx.array,
204
+ clip_ratio: float = 0.2
205
+ ) -> Dict[str, float]:
206
+ """
207
+ Compute standard policy optimization metrics.
208
+
209
+ Args:
210
+ old_logprobs: Log probabilities from rollout
211
+ new_logprobs: Log probabilities from current policy
212
+ clip_ratio: Clipping ratio used in loss
213
+
214
+ Returns:
215
+ Dictionary of policy metrics
216
+ """
217
+ # Importance ratio
218
+ ratio = mx.exp(new_logprobs - old_logprobs)
219
+
220
+ # Clipping statistics
221
+ clip_lower = 1 - clip_ratio
222
+ clip_upper = 1 + clip_ratio
223
+ clipped = (ratio < clip_lower) | (ratio > clip_upper)
224
+ clip_fraction = mx.mean(clipped.astype(mx.float32))
225
+
226
+ # KL divergence approximation
227
+ kl_div = mx.mean(old_logprobs - new_logprobs)
228
+
229
+ # Policy change magnitude
230
+ ratio_mean = mx.mean(ratio)
231
+ ratio_std = mx.std(ratio)
232
+
233
+ # Policy entropy (negative log probability mean)
234
+ entropy_mean = mx.mean(new_logprobs)
235
+
236
+ return {
237
+ 'policy/ratio_mean': cast(float, ratio_mean.item()),
238
+ 'policy/ratio_std': cast(float, ratio_std.item()),
239
+ 'policy/clip_fraction': cast(float, clip_fraction.item()),
240
+ 'policy/kl_divergence': cast(float, kl_div.item()),
241
+ 'policy/entropy': -cast(float, entropy_mean.item()) # Negative entropy: -E[log(p)]
242
+ }
@@ -0,0 +1,78 @@
1
+ # textpolicy/training/rollout_manager.py
2
+ """
3
+ Lightweight rollout manager that integrates with existing rollout system.
4
+ """
5
+
6
+ from typing import Callable, Any
7
+ from textpolicy.rollout import RolloutCoordinator
8
+ from textpolicy.buffer import Buffer
9
+ from .metrics import RolloutMetrics
10
+
11
+
12
+ class RolloutManager:
13
+ """
14
+ Simple manager that wraps the existing RolloutCoordinator
15
+ with metrics tracking and convenient interface.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ env_fn: Callable[[], Any],
21
+ policy_fn: Callable[[], Any],
22
+ algorithm: str = 'grpo',
23
+ num_workers: int = 0,
24
+ max_steps: int = 1000,
25
+ max_episodes: int = 100
26
+ ):
27
+ """
28
+ Initialize rollout manager.
29
+
30
+ Args:
31
+ env_fn: Function that creates environment instances
32
+ policy_fn: Function that creates policy instances
33
+ algorithm: Algorithm name ('grpo', 'ppo', etc.)
34
+ num_workers: Number of worker processes (0 = single-process)
35
+ max_steps: Maximum steps per rollout
36
+ max_episodes: Maximum episodes to buffer
37
+ """
38
+ self.coordinator = RolloutCoordinator(
39
+ env_fn=env_fn,
40
+ policy_fn=policy_fn,
41
+ algorithm=algorithm,
42
+ num_workers=num_workers,
43
+ max_steps=max_steps,
44
+ max_episodes=max_episodes
45
+ )
46
+
47
+ self.metrics = RolloutMetrics()
48
+
49
+ def collect(self) -> Buffer:
50
+ """
51
+ Collect rollout data and update metrics.
52
+
53
+ Returns:
54
+ Buffer containing collected episodes
55
+ """
56
+ # Use existing rollout system
57
+ buffer = self.coordinator.collect()
58
+
59
+ # Update metrics
60
+ for episode in buffer.episodes:
61
+ episode_data = episode.to_tensor_dict()
62
+ reward = episode_data['rew'].sum().item()
63
+ length = len(episode_data['obs'])
64
+ self.metrics.add_episode(reward, length)
65
+
66
+ return buffer
67
+
68
+ def get_metrics(self) -> dict:
69
+ """Get rollout collection metrics."""
70
+ return self.metrics.get_summary()
71
+
72
+ def reset_metrics(self):
73
+ """Reset rollout metrics."""
74
+ self.metrics.reset()
75
+
76
+ def close(self):
77
+ """Cleanup resources."""
78
+ self.coordinator.close()