textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +53 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +797 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.1.dist-info/METADATA +109 -0
- textpolicy-0.1.1.dist-info/RECORD +66 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
- textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
# textpolicy/rollout/worker.py
|
|
2
|
+
"""
|
|
3
|
+
Worker process management and queue communication.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import multiprocessing as mp
|
|
7
|
+
import queue
|
|
8
|
+
from typing import Callable, Any, Optional
|
|
9
|
+
from .runner import RolloutRunner
|
|
10
|
+
from .base import DEFAULT_WORKER_TIMEOUT, DEFAULT_MAX_STEPS
|
|
11
|
+
from textpolicy.buffer import Buffer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RolloutWorker:
|
|
15
|
+
"""
|
|
16
|
+
Manages a rollout runner in a separate process.
|
|
17
|
+
|
|
18
|
+
Handles:
|
|
19
|
+
- Process lifecycle (start, stop, cleanup)
|
|
20
|
+
- Queue-based communication with trainer
|
|
21
|
+
- Data serialization for multiprocessing
|
|
22
|
+
- Async rollout collection
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
env_fn: Callable[[], Any],
|
|
28
|
+
policy_fn: Callable[[], Any],
|
|
29
|
+
strategy: Any,
|
|
30
|
+
max_steps: int = DEFAULT_MAX_STEPS,
|
|
31
|
+
send_queue: Optional[mp.Queue] = None,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Initialize worker for separate process execution.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
env_fn: Function that returns a fresh environment
|
|
38
|
+
policy_fn: Function that returns a fresh policy (for multiprocessing compatibility)
|
|
39
|
+
strategy: RolloutStrategy instance (e.g., PPOStrategy)
|
|
40
|
+
max_steps: Maximum steps per rollout collection
|
|
41
|
+
send_queue: Queue to send data back to trainer
|
|
42
|
+
"""
|
|
43
|
+
self.env_fn = env_fn
|
|
44
|
+
self.policy_fn = policy_fn
|
|
45
|
+
self.strategy = strategy
|
|
46
|
+
self.max_steps = max_steps
|
|
47
|
+
|
|
48
|
+
# Communication queues
|
|
49
|
+
self.send_queue: mp.Queue = send_queue if send_queue is not None else mp.Queue()
|
|
50
|
+
self.control_queue: mp.Queue = mp.Queue() # Commands: "collect", "exit"
|
|
51
|
+
|
|
52
|
+
# Process handle and status
|
|
53
|
+
self.process: Optional[mp.Process] = None
|
|
54
|
+
self.is_closed = False # Track cleanup status
|
|
55
|
+
|
|
56
|
+
def run(self):
|
|
57
|
+
"""
|
|
58
|
+
Target function for the worker process.
|
|
59
|
+
|
|
60
|
+
Creates fresh environment and policy instances, then runs collection loop.
|
|
61
|
+
Communicates via queues to avoid shared memory issues.
|
|
62
|
+
"""
|
|
63
|
+
# Create fresh instances in worker process (avoid pickle issues)
|
|
64
|
+
env = self.env_fn()
|
|
65
|
+
policy = self.policy_fn()
|
|
66
|
+
runner = RolloutRunner(env, policy, self.strategy, self.max_steps)
|
|
67
|
+
|
|
68
|
+
while True:
|
|
69
|
+
try:
|
|
70
|
+
# Wait for commands from trainer
|
|
71
|
+
msg = self.control_queue.get(timeout=DEFAULT_WORKER_TIMEOUT)
|
|
72
|
+
|
|
73
|
+
if msg == "collect":
|
|
74
|
+
# Run rollout collection
|
|
75
|
+
buffer = runner.collect()
|
|
76
|
+
|
|
77
|
+
# Serialize buffer to pure Python types for queue transmission
|
|
78
|
+
episodes_data = [ep.to_dict() for ep in buffer.episodes]
|
|
79
|
+
self.send_queue.put(episodes_data)
|
|
80
|
+
|
|
81
|
+
elif msg == "exit":
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
except queue.Empty:
|
|
85
|
+
# Timeout - continue listening for commands
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
def start(self):
|
|
89
|
+
"""Launch the worker process."""
|
|
90
|
+
if self.process is not None:
|
|
91
|
+
raise RuntimeError("Worker process already started")
|
|
92
|
+
|
|
93
|
+
self.process = mp.Process(target=self.run, daemon=True)
|
|
94
|
+
self.process.start()
|
|
95
|
+
|
|
96
|
+
def collect_async(self):
|
|
97
|
+
"""Request a rollout without blocking."""
|
|
98
|
+
if self.process is None:
|
|
99
|
+
raise RuntimeError("Worker process not started")
|
|
100
|
+
|
|
101
|
+
self.control_queue.put("collect")
|
|
102
|
+
|
|
103
|
+
def has_data(self) -> bool:
|
|
104
|
+
"""Check if rollout result is ready."""
|
|
105
|
+
return not self.send_queue.empty()
|
|
106
|
+
|
|
107
|
+
def get_buffer(self) -> Buffer:
|
|
108
|
+
"""
|
|
109
|
+
Retrieve the collected buffer.
|
|
110
|
+
|
|
111
|
+
Call only after has_data() returns True.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Buffer instance with collected episodes
|
|
115
|
+
"""
|
|
116
|
+
episodes_data = self.send_queue.get()
|
|
117
|
+
|
|
118
|
+
# Reconstruct buffer from serialized data
|
|
119
|
+
buffer = Buffer(max_episodes=len(episodes_data))
|
|
120
|
+
for ep_dict in episodes_data:
|
|
121
|
+
buffer.add_episode_from_dict(ep_dict)
|
|
122
|
+
|
|
123
|
+
return buffer
|
|
124
|
+
|
|
125
|
+
def close(self):
|
|
126
|
+
"""Shut down the worker process gracefully."""
|
|
127
|
+
if self.process is None or self.is_closed:
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
# Step 1: Signal shutdown to worker process
|
|
132
|
+
self.control_queue.put("exit")
|
|
133
|
+
|
|
134
|
+
# Step 2: Wait for graceful shutdown
|
|
135
|
+
self.process.join(timeout=3)
|
|
136
|
+
|
|
137
|
+
# Step 3: Force termination if needed
|
|
138
|
+
if self.process.is_alive():
|
|
139
|
+
self.process.terminate()
|
|
140
|
+
self.process.join(timeout=1)
|
|
141
|
+
|
|
142
|
+
# Step 4: Properly drain and close queues to prevent file descriptor errors
|
|
143
|
+
# This prevents the background feeder threads from trying to close already-closed FDs
|
|
144
|
+
self._cleanup_queues()
|
|
145
|
+
|
|
146
|
+
except Exception:
|
|
147
|
+
# Ignore cleanup errors during testing - they're just noise
|
|
148
|
+
pass
|
|
149
|
+
finally:
|
|
150
|
+
# Step 5: Mark as closed
|
|
151
|
+
self.is_closed = True
|
|
152
|
+
self.process = None
|
|
153
|
+
|
|
154
|
+
def _cleanup_queues(self):
|
|
155
|
+
"""
|
|
156
|
+
Properly cleanup multiprocessing queues to prevent file descriptor errors.
|
|
157
|
+
|
|
158
|
+
The issue: Multiprocessing queues have background "feeder" threads that can
|
|
159
|
+
try to close file descriptors after the main process has already closed them.
|
|
160
|
+
|
|
161
|
+
Solution: Explicitly drain and close queues before process termination.
|
|
162
|
+
"""
|
|
163
|
+
try:
|
|
164
|
+
# Cancel the queue's join thread to prevent hanging on exit
|
|
165
|
+
# This tells the queue not to wait for the feeder thread
|
|
166
|
+
if hasattr(self.send_queue, 'cancel_join_thread'):
|
|
167
|
+
self.send_queue.cancel_join_thread()
|
|
168
|
+
if hasattr(self.control_queue, 'cancel_join_thread'):
|
|
169
|
+
self.control_queue.cancel_join_thread()
|
|
170
|
+
|
|
171
|
+
# Drain any remaining items from queues to prevent deadlock
|
|
172
|
+
# Empty queues close more cleanly
|
|
173
|
+
try:
|
|
174
|
+
while not self.send_queue.empty():
|
|
175
|
+
self.send_queue.get_nowait()
|
|
176
|
+
except queue.Empty:
|
|
177
|
+
pass # Queue might be closed already
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
while not self.control_queue.empty():
|
|
181
|
+
self.control_queue.get_nowait()
|
|
182
|
+
except queue.Empty:
|
|
183
|
+
pass # Queue might be closed already
|
|
184
|
+
|
|
185
|
+
# Now close the queues properly
|
|
186
|
+
if hasattr(self.send_queue, 'close'):
|
|
187
|
+
self.send_queue.close()
|
|
188
|
+
if hasattr(self.control_queue, 'close'):
|
|
189
|
+
self.control_queue.close()
|
|
190
|
+
|
|
191
|
+
except Exception:
|
|
192
|
+
# During testing, ignore queue cleanup errors
|
|
193
|
+
# They're artifacts of rapid process termination
|
|
194
|
+
pass
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# textpolicy/training/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Unified training infrastructure for all RL algorithms.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .trainer import Trainer
|
|
7
|
+
from .rollout_manager import RolloutManager
|
|
8
|
+
from .metrics import TrainingMetrics
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Trainer",
|
|
12
|
+
"RolloutManager",
|
|
13
|
+
"TrainingMetrics"
|
|
14
|
+
]
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
# textpolicy/training/metrics.py
|
|
2
|
+
"""
|
|
3
|
+
Training metrics collection and analysis for all RL algorithms.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict, Any, Optional, cast
|
|
7
|
+
import mlx.core as mx # type: ignore
|
|
8
|
+
from collections import defaultdict, deque
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TrainingMetrics:
|
|
12
|
+
"""
|
|
13
|
+
Lightweight metrics collector optimized for MLX training.
|
|
14
|
+
|
|
15
|
+
Tracks algorithm-agnostic and algorithm-specific metrics
|
|
16
|
+
with minimal overhead during training.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, history_length: int = 100):
|
|
20
|
+
"""
|
|
21
|
+
Initialize metrics collector.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
history_length: Number of recent values to keep for rolling averages
|
|
25
|
+
"""
|
|
26
|
+
self.history_length = history_length
|
|
27
|
+
self.metrics = defaultdict(lambda: deque(maxlen=history_length))
|
|
28
|
+
self.total_steps = 0
|
|
29
|
+
|
|
30
|
+
def update(self, metrics_dict: Dict[str, float]):
|
|
31
|
+
"""
|
|
32
|
+
Update metrics with new values.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
metrics_dict: Dictionary of metric_name -> value
|
|
36
|
+
"""
|
|
37
|
+
for name, value in metrics_dict.items():
|
|
38
|
+
self.metrics[name].append(value)
|
|
39
|
+
|
|
40
|
+
if 'step' in metrics_dict:
|
|
41
|
+
self.total_steps = metrics_dict['step']
|
|
42
|
+
|
|
43
|
+
def get_latest(self, metric_name: str) -> Optional[float]:
|
|
44
|
+
"""Get the most recent value for a metric."""
|
|
45
|
+
if metric_name in self.metrics and self.metrics[metric_name]:
|
|
46
|
+
return self.metrics[metric_name][-1]
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
def get_mean(self, metric_name: str, last_n: Optional[int] = None) -> Optional[float]:
|
|
50
|
+
"""
|
|
51
|
+
Get mean of recent metric values.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
metric_name: Name of the metric
|
|
55
|
+
last_n: Number of recent values to average (None for all)
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Mean value or None if metric doesn't exist
|
|
59
|
+
"""
|
|
60
|
+
if metric_name not in self.metrics or not self.metrics[metric_name]:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
values = list(self.metrics[metric_name])
|
|
64
|
+
if last_n is not None:
|
|
65
|
+
values = values[-last_n:]
|
|
66
|
+
|
|
67
|
+
return sum(values) / len(values) if values else None
|
|
68
|
+
|
|
69
|
+
def get_summary(self) -> Dict[str, Any]:
|
|
70
|
+
"""
|
|
71
|
+
Get comprehensive metrics summary.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Dictionary with latest, mean, and other statistics
|
|
75
|
+
"""
|
|
76
|
+
summary = {
|
|
77
|
+
'total_steps': self.total_steps,
|
|
78
|
+
'metrics': {}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
for metric_name, values in self.metrics.items():
|
|
82
|
+
if not values:
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
values_list = list(values)
|
|
86
|
+
summary['metrics'][metric_name] = {
|
|
87
|
+
'latest': values_list[-1],
|
|
88
|
+
'mean': sum(values_list) / len(values_list),
|
|
89
|
+
'min': min(values_list),
|
|
90
|
+
'max': max(values_list),
|
|
91
|
+
'count': len(values_list)
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
return summary
|
|
95
|
+
|
|
96
|
+
def reset(self):
|
|
97
|
+
"""Reset all metrics."""
|
|
98
|
+
self.metrics.clear()
|
|
99
|
+
self.total_steps = 0
|
|
100
|
+
|
|
101
|
+
def __len__(self) -> int:
|
|
102
|
+
"""Return number of metrics being tracked."""
|
|
103
|
+
return len(self.metrics)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class RolloutMetrics:
|
|
107
|
+
"""
|
|
108
|
+
Metrics specific to rollout collection phase.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(self):
|
|
112
|
+
self.episodes_collected = 0
|
|
113
|
+
self.total_reward = 0.0
|
|
114
|
+
self.episode_lengths = []
|
|
115
|
+
self.episode_rewards = []
|
|
116
|
+
|
|
117
|
+
def add_episode(self, reward: float, length: int):
|
|
118
|
+
"""Add metrics from a completed episode."""
|
|
119
|
+
self.episodes_collected += 1
|
|
120
|
+
self.total_reward += reward
|
|
121
|
+
self.episode_lengths.append(length)
|
|
122
|
+
self.episode_rewards.append(reward)
|
|
123
|
+
|
|
124
|
+
def get_summary(self) -> Dict[str, float]:
|
|
125
|
+
"""Get rollout metrics summary."""
|
|
126
|
+
if not self.episode_rewards:
|
|
127
|
+
return {
|
|
128
|
+
'episodes_collected': 0,
|
|
129
|
+
'mean_reward': 0.0,
|
|
130
|
+
'mean_length': 0.0,
|
|
131
|
+
'total_reward': 0.0
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
return {
|
|
135
|
+
'episodes_collected': self.episodes_collected,
|
|
136
|
+
'mean_reward': sum(self.episode_rewards) / len(self.episode_rewards),
|
|
137
|
+
'mean_length': sum(self.episode_lengths) / len(self.episode_lengths),
|
|
138
|
+
'total_reward': self.total_reward,
|
|
139
|
+
'min_reward': min(self.episode_rewards),
|
|
140
|
+
'max_reward': max(self.episode_rewards),
|
|
141
|
+
'min_length': min(self.episode_lengths),
|
|
142
|
+
'max_length': max(self.episode_lengths)
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
def reset(self):
|
|
146
|
+
"""Reset rollout metrics."""
|
|
147
|
+
self.episodes_collected = 0
|
|
148
|
+
self.total_reward = 0.0
|
|
149
|
+
self.episode_lengths.clear()
|
|
150
|
+
self.episode_rewards.clear()
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def log_metrics(
|
|
154
|
+
metrics: Dict[str, float],
|
|
155
|
+
step: int,
|
|
156
|
+
logger: Optional[Any] = None,
|
|
157
|
+
prefix: str = ""
|
|
158
|
+
):
|
|
159
|
+
"""
|
|
160
|
+
Log metrics to console and optionally to external logger.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
metrics: Metrics dictionary
|
|
164
|
+
step: Training step number
|
|
165
|
+
logger: Optional external logger (wandb, tensorboard, etc.)
|
|
166
|
+
prefix: Prefix for metric names
|
|
167
|
+
"""
|
|
168
|
+
# Console logging
|
|
169
|
+
metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
|
170
|
+
print(f"Step {step} | {prefix}{metrics_str}")
|
|
171
|
+
|
|
172
|
+
# External logger
|
|
173
|
+
if logger is not None and hasattr(logger, 'log'):
|
|
174
|
+
prefixed_metrics = {f"{prefix}{k}": v for k, v in metrics.items()}
|
|
175
|
+
logger.log(prefixed_metrics, step=step)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def compute_explained_variance(predicted: mx.array, targets: mx.array) -> float:
|
|
179
|
+
"""
|
|
180
|
+
Compute explained variance for value function evaluation.
|
|
181
|
+
|
|
182
|
+
Explained variance = 1 - Var(targets - predicted) / Var(targets)
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
predicted: Predicted values
|
|
186
|
+
targets: Target values
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Explained variance (1.0 = perfect, 0.0 = no better than mean)
|
|
190
|
+
"""
|
|
191
|
+
target_var = mx.var(targets)
|
|
192
|
+
if target_var == 0:
|
|
193
|
+
return 0.0
|
|
194
|
+
|
|
195
|
+
residual_var = mx.var(targets - predicted)
|
|
196
|
+
explained_var = 1.0 - (residual_var / target_var)
|
|
197
|
+
|
|
198
|
+
return cast(float, explained_var.item()) # MLX scalar array .item() returns Python float for float dtypes
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def compute_policy_metrics(
|
|
202
|
+
old_logprobs: mx.array,
|
|
203
|
+
new_logprobs: mx.array,
|
|
204
|
+
clip_ratio: float = 0.2
|
|
205
|
+
) -> Dict[str, float]:
|
|
206
|
+
"""
|
|
207
|
+
Compute standard policy optimization metrics.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
old_logprobs: Log probabilities from rollout
|
|
211
|
+
new_logprobs: Log probabilities from current policy
|
|
212
|
+
clip_ratio: Clipping ratio used in loss
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Dictionary of policy metrics
|
|
216
|
+
"""
|
|
217
|
+
# Importance ratio
|
|
218
|
+
ratio = mx.exp(new_logprobs - old_logprobs)
|
|
219
|
+
|
|
220
|
+
# Clipping statistics
|
|
221
|
+
clip_lower = 1 - clip_ratio
|
|
222
|
+
clip_upper = 1 + clip_ratio
|
|
223
|
+
clipped = (ratio < clip_lower) | (ratio > clip_upper)
|
|
224
|
+
clip_fraction = mx.mean(clipped.astype(mx.float32))
|
|
225
|
+
|
|
226
|
+
# KL divergence approximation
|
|
227
|
+
kl_div = mx.mean(old_logprobs - new_logprobs)
|
|
228
|
+
|
|
229
|
+
# Policy change magnitude
|
|
230
|
+
ratio_mean = mx.mean(ratio)
|
|
231
|
+
ratio_std = mx.std(ratio)
|
|
232
|
+
|
|
233
|
+
# Policy entropy (negative log probability mean)
|
|
234
|
+
entropy_mean = mx.mean(new_logprobs)
|
|
235
|
+
|
|
236
|
+
return {
|
|
237
|
+
'policy/ratio_mean': cast(float, ratio_mean.item()),
|
|
238
|
+
'policy/ratio_std': cast(float, ratio_std.item()),
|
|
239
|
+
'policy/clip_fraction': cast(float, clip_fraction.item()),
|
|
240
|
+
'policy/kl_divergence': cast(float, kl_div.item()),
|
|
241
|
+
'policy/entropy': -cast(float, entropy_mean.item()) # Negative entropy: -E[log(p)]
|
|
242
|
+
}
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# textpolicy/training/rollout_manager.py
|
|
2
|
+
"""
|
|
3
|
+
Lightweight rollout manager that integrates with existing rollout system.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Callable, Any
|
|
7
|
+
from textpolicy.rollout import RolloutCoordinator
|
|
8
|
+
from textpolicy.buffer import Buffer
|
|
9
|
+
from .metrics import RolloutMetrics
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RolloutManager:
|
|
13
|
+
"""
|
|
14
|
+
Simple manager that wraps the existing RolloutCoordinator
|
|
15
|
+
with metrics tracking and convenient interface.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
env_fn: Callable[[], Any],
|
|
21
|
+
policy_fn: Callable[[], Any],
|
|
22
|
+
algorithm: str = 'grpo',
|
|
23
|
+
num_workers: int = 0,
|
|
24
|
+
max_steps: int = 1000,
|
|
25
|
+
max_episodes: int = 100
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize rollout manager.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
env_fn: Function that creates environment instances
|
|
32
|
+
policy_fn: Function that creates policy instances
|
|
33
|
+
algorithm: Algorithm name ('grpo', 'ppo', etc.)
|
|
34
|
+
num_workers: Number of worker processes (0 = single-process)
|
|
35
|
+
max_steps: Maximum steps per rollout
|
|
36
|
+
max_episodes: Maximum episodes to buffer
|
|
37
|
+
"""
|
|
38
|
+
self.coordinator = RolloutCoordinator(
|
|
39
|
+
env_fn=env_fn,
|
|
40
|
+
policy_fn=policy_fn,
|
|
41
|
+
algorithm=algorithm,
|
|
42
|
+
num_workers=num_workers,
|
|
43
|
+
max_steps=max_steps,
|
|
44
|
+
max_episodes=max_episodes
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
self.metrics = RolloutMetrics()
|
|
48
|
+
|
|
49
|
+
def collect(self) -> Buffer:
|
|
50
|
+
"""
|
|
51
|
+
Collect rollout data and update metrics.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Buffer containing collected episodes
|
|
55
|
+
"""
|
|
56
|
+
# Use existing rollout system
|
|
57
|
+
buffer = self.coordinator.collect()
|
|
58
|
+
|
|
59
|
+
# Update metrics
|
|
60
|
+
for episode in buffer.episodes:
|
|
61
|
+
episode_data = episode.to_tensor_dict()
|
|
62
|
+
reward = episode_data['rew'].sum().item()
|
|
63
|
+
length = len(episode_data['obs'])
|
|
64
|
+
self.metrics.add_episode(reward, length)
|
|
65
|
+
|
|
66
|
+
return buffer
|
|
67
|
+
|
|
68
|
+
def get_metrics(self) -> dict:
|
|
69
|
+
"""Get rollout collection metrics."""
|
|
70
|
+
return self.metrics.get_summary()
|
|
71
|
+
|
|
72
|
+
def reset_metrics(self):
|
|
73
|
+
"""Reset rollout metrics."""
|
|
74
|
+
self.metrics.reset()
|
|
75
|
+
|
|
76
|
+
def close(self):
|
|
77
|
+
"""Cleanup resources."""
|
|
78
|
+
self.coordinator.close()
|