textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +53 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +797 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.1.dist-info/METADATA +109 -0
- textpolicy-0.1.1.dist-info/RECORD +66 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
- textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# textpolicy/rollout/rollout.py
|
|
2
|
+
"""
|
|
3
|
+
Main rollout coordinator and public interface.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import List, Callable, Any
|
|
7
|
+
from .worker import RolloutWorker
|
|
8
|
+
from .runner import RolloutRunner
|
|
9
|
+
from .aggregator import BufferAggregator
|
|
10
|
+
from .strategy import create_strategy
|
|
11
|
+
from textpolicy.buffer import Buffer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RolloutCoordinator:
|
|
15
|
+
"""
|
|
16
|
+
Coordinator for rollout collection.
|
|
17
|
+
|
|
18
|
+
Provides unified interface for:
|
|
19
|
+
- Single-process rollouts (for debugging/small-scale)
|
|
20
|
+
- Multi-process rollouts (for production/performance)
|
|
21
|
+
- Strategy management and worker coordination
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
env_fn: Callable[[], Any],
|
|
27
|
+
policy_fn: Callable[[], Any],
|
|
28
|
+
algorithm: str,
|
|
29
|
+
num_workers: int = 0,
|
|
30
|
+
max_steps: int = 1000,
|
|
31
|
+
max_episodes: int = 100
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Initialize rollout coordinator.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
env_fn: Function that creates environment instances
|
|
38
|
+
policy_fn: Function that creates policy instances
|
|
39
|
+
algorithm: Algorithm name ('ppo', 'grpo', etc.)
|
|
40
|
+
num_workers: Number of worker processes (0 = single-process mode)
|
|
41
|
+
max_steps: Maximum steps per rollout
|
|
42
|
+
max_episodes: Maximum episodes to buffer
|
|
43
|
+
"""
|
|
44
|
+
self.env_fn = env_fn
|
|
45
|
+
self.policy_fn = policy_fn
|
|
46
|
+
self.algorithm = algorithm
|
|
47
|
+
self.num_workers = num_workers
|
|
48
|
+
self.max_steps = max_steps
|
|
49
|
+
|
|
50
|
+
# Create strategy for algorithm
|
|
51
|
+
self.strategy = create_strategy(algorithm)
|
|
52
|
+
|
|
53
|
+
# Setup for multi-process or single-process mode
|
|
54
|
+
if num_workers > 0:
|
|
55
|
+
self._setup_multiprocess(max_episodes)
|
|
56
|
+
else:
|
|
57
|
+
self._setup_singleprocess()
|
|
58
|
+
|
|
59
|
+
def _setup_multiprocess(self, max_episodes: int):
|
|
60
|
+
"""Setup multi-process rollout collection."""
|
|
61
|
+
self.aggregator = BufferAggregator(self.num_workers, max_episodes)
|
|
62
|
+
self.workers: List[RolloutWorker] = []
|
|
63
|
+
|
|
64
|
+
# Create and start worker processes
|
|
65
|
+
for i in range(self.num_workers):
|
|
66
|
+
worker = RolloutWorker(
|
|
67
|
+
env_fn=self.env_fn,
|
|
68
|
+
policy_fn=self.policy_fn,
|
|
69
|
+
strategy=self.strategy,
|
|
70
|
+
max_steps=self.max_steps
|
|
71
|
+
)
|
|
72
|
+
self.workers.append(worker)
|
|
73
|
+
self.aggregator.add_worker(worker, i)
|
|
74
|
+
worker.start()
|
|
75
|
+
|
|
76
|
+
def _setup_singleprocess(self):
|
|
77
|
+
"""Setup single-process rollout collection."""
|
|
78
|
+
self.aggregator = None
|
|
79
|
+
self.workers = []
|
|
80
|
+
|
|
81
|
+
# Create single runner for direct use
|
|
82
|
+
env = self.env_fn()
|
|
83
|
+
policy = self.policy_fn()
|
|
84
|
+
self.runner = RolloutRunner(env, policy, self.strategy, self.max_steps)
|
|
85
|
+
|
|
86
|
+
def collect(self) -> Buffer:
|
|
87
|
+
"""
|
|
88
|
+
Collect rollout data.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Buffer containing collected episodes
|
|
92
|
+
"""
|
|
93
|
+
if self.num_workers > 0:
|
|
94
|
+
return self._collect_multiprocess()
|
|
95
|
+
else:
|
|
96
|
+
return self._collect_singleprocess()
|
|
97
|
+
|
|
98
|
+
def _collect_multiprocess(self) -> Buffer:
|
|
99
|
+
"""Collect data using multiple worker processes."""
|
|
100
|
+
# Request rollouts from all workers
|
|
101
|
+
for worker in self.workers:
|
|
102
|
+
worker.collect_async()
|
|
103
|
+
|
|
104
|
+
# Wait for and consume results
|
|
105
|
+
while not self.aggregator.ready(min_episodes=self.num_workers):
|
|
106
|
+
self.aggregator.consume_all()
|
|
107
|
+
|
|
108
|
+
# Return aggregated buffer
|
|
109
|
+
# Note: In practice, trainer would manage this more carefully
|
|
110
|
+
return self.aggregator.buffer
|
|
111
|
+
|
|
112
|
+
def _collect_singleprocess(self) -> Buffer:
|
|
113
|
+
"""Collect data using single process."""
|
|
114
|
+
return self.runner.collect()
|
|
115
|
+
|
|
116
|
+
def close(self):
|
|
117
|
+
"""Cleanup resources."""
|
|
118
|
+
if self.workers:
|
|
119
|
+
for worker in self.workers:
|
|
120
|
+
worker.close()
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# Public API functions for external use
|
|
124
|
+
def create_rollout_coordinator(
|
|
125
|
+
env_fn: Callable[[], Any],
|
|
126
|
+
policy_fn: Callable[[], Any],
|
|
127
|
+
algorithm: str,
|
|
128
|
+
**kwargs
|
|
129
|
+
) -> RolloutCoordinator:
|
|
130
|
+
"""
|
|
131
|
+
Factory function for creating rollout coordinators.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
env_fn: Environment factory function
|
|
135
|
+
policy_fn: Policy factory function
|
|
136
|
+
algorithm: Algorithm name
|
|
137
|
+
**kwargs: Additional configuration options
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
RolloutCoordinator instance
|
|
141
|
+
"""
|
|
142
|
+
return RolloutCoordinator(env_fn, policy_fn, algorithm, **kwargs)
|
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
# textpolicy/rollout/runner.py
|
|
2
|
+
"""
|
|
3
|
+
Core rollout collection engine.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Callable, Dict, Optional, Tuple
|
|
7
|
+
import mlx.core as mx # type: ignore
|
|
8
|
+
from textpolicy.buffer import Buffer
|
|
9
|
+
from .base import RolloutStrategy, DEFAULT_MAX_STEPS
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RolloutRunner:
|
|
13
|
+
"""
|
|
14
|
+
Collects rollouts from a single environment using a policy.
|
|
15
|
+
|
|
16
|
+
Designed for MLX and Apple Silicon:
|
|
17
|
+
- Minimized Python overhead in the collection loop
|
|
18
|
+
- MLX array conversions performed once per step
|
|
19
|
+
- Python lists for storage to reduce overhead
|
|
20
|
+
- Policy inference on GPU/ANE when available
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
env,
|
|
26
|
+
policy: Optional[Callable[[mx.array], Tuple[mx.array, Dict[str, mx.array]]]] = None,
|
|
27
|
+
strategy: Optional[RolloutStrategy] = None,
|
|
28
|
+
max_steps: int = DEFAULT_MAX_STEPS,
|
|
29
|
+
agent = None # Alternative API: pass agent instead of policy/strategy
|
|
30
|
+
) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Initialize rollout runner.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
env: Environment instance (must implement gym interface)
|
|
36
|
+
policy: Policy function that takes obs and returns (action, extras)
|
|
37
|
+
strategy: RolloutStrategy defining algorithm-specific behavior
|
|
38
|
+
max_steps: Maximum steps per rollout collection
|
|
39
|
+
agent: Alternative API - Agent object containing policy and rollout_strategy
|
|
40
|
+
"""
|
|
41
|
+
self.env = env
|
|
42
|
+
|
|
43
|
+
# Support both direct policy/strategy and agent-based initialization
|
|
44
|
+
if agent is not None:
|
|
45
|
+
# Extract policy and strategy from agent for backward compatibility with tests
|
|
46
|
+
self.policy = agent.policy
|
|
47
|
+
self.strategy = agent.rollout_strategy
|
|
48
|
+
else:
|
|
49
|
+
# Direct policy/strategy initialization (current API)
|
|
50
|
+
if policy is None or strategy is None:
|
|
51
|
+
raise ValueError("Must provide either agent or both policy and strategy")
|
|
52
|
+
self.policy = policy
|
|
53
|
+
self.strategy = strategy
|
|
54
|
+
|
|
55
|
+
self.max_steps = max_steps
|
|
56
|
+
self.buffer = Buffer(max_episodes=10)
|
|
57
|
+
self.step_count = 0 # Track total steps collected for test compatibility
|
|
58
|
+
|
|
59
|
+
def _normalize_step_result(self, step_result):
|
|
60
|
+
"""
|
|
61
|
+
Normalize Environment.step results to a tuple
|
|
62
|
+
(next_obs, reward, terminated, truncated, info).
|
|
63
|
+
|
|
64
|
+
Enforces dict-shaped step results per Environment contract.
|
|
65
|
+
Raises TypeError for tuple-based results.
|
|
66
|
+
"""
|
|
67
|
+
if not isinstance(step_result, dict):
|
|
68
|
+
raise TypeError(
|
|
69
|
+
"Environment.step must return a dict with keys: observation, reward, "
|
|
70
|
+
"terminated, truncated, info. Tuple returns are not supported."
|
|
71
|
+
)
|
|
72
|
+
return (
|
|
73
|
+
step_result.get("observation"),
|
|
74
|
+
step_result.get("reward"),
|
|
75
|
+
step_result.get("terminated"),
|
|
76
|
+
step_result.get("truncated"),
|
|
77
|
+
step_result.get("info", {}),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def collect(self) -> Buffer:
|
|
81
|
+
"""
|
|
82
|
+
Run one complete rollout collection.
|
|
83
|
+
|
|
84
|
+
MLX efficiency guidelines:
|
|
85
|
+
- Policy runs on GPU/ANE via MLX arrays
|
|
86
|
+
- Buffer stores as Python lists
|
|
87
|
+
- Batch array conversions reduce memory transfers
|
|
88
|
+
- Strategy handles algorithm differences
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Buffer containing collected episodes
|
|
92
|
+
"""
|
|
93
|
+
obs, _ = self.env.reset()
|
|
94
|
+
|
|
95
|
+
# Batch observations for efficient MLX conversion
|
|
96
|
+
# Collect multiple observations before converting to MLX arrays
|
|
97
|
+
batch_size = min(32, self.max_steps) # Batch size for array conversion
|
|
98
|
+
obs_batch = []
|
|
99
|
+
|
|
100
|
+
for step in range(self.max_steps):
|
|
101
|
+
obs_batch.append(obs)
|
|
102
|
+
|
|
103
|
+
# Process batch when full or on final step
|
|
104
|
+
if len(obs_batch) == batch_size or step == self.max_steps - 1:
|
|
105
|
+
# Single MLX array conversion for the entire batch
|
|
106
|
+
try:
|
|
107
|
+
import numpy as np
|
|
108
|
+
obs_batch_np = np.array(obs_batch)
|
|
109
|
+
obs_batch_mx = mx.array(obs_batch_np)
|
|
110
|
+
except (ValueError, TypeError):
|
|
111
|
+
# Fallback to individual conversion if batch conversion fails
|
|
112
|
+
obs_batch_mx = mx.stack([mx.array(o) for o in obs_batch])
|
|
113
|
+
|
|
114
|
+
# Process each observation in the batch
|
|
115
|
+
for i, obs_single in enumerate(obs_batch):
|
|
116
|
+
# Extract single observation from batch
|
|
117
|
+
obs_mx = obs_batch_mx[i] if len(obs_batch) > 1 else obs_batch_mx
|
|
118
|
+
|
|
119
|
+
# Policy forward pass - runs on GPU/ANE
|
|
120
|
+
action_mx, extra = self.strategy.select_action(self.policy, obs_mx)
|
|
121
|
+
|
|
122
|
+
# Convert action back to Python format (once per step)
|
|
123
|
+
if action_mx.ndim == 0:
|
|
124
|
+
action = action_mx.item() # Scalar action (discrete environments)
|
|
125
|
+
else:
|
|
126
|
+
action = action_mx.tolist() # Vector action (continuous environments)
|
|
127
|
+
|
|
128
|
+
# Environment step (CPU-bound). Normalize to a standard tuple.
|
|
129
|
+
step_result = self.env.step(action)
|
|
130
|
+
next_obs, reward, done, trunc, info = self._normalize_step_result(step_result)
|
|
131
|
+
|
|
132
|
+
# Store transition using strategy-specific logic
|
|
133
|
+
# Strategy handles filtering and algorithm-specific data
|
|
134
|
+
self.strategy.store_transition(
|
|
135
|
+
buffer=self.buffer,
|
|
136
|
+
obs=obs_single, # Use original observation
|
|
137
|
+
act=action,
|
|
138
|
+
rew=reward,
|
|
139
|
+
next_obs=next_obs,
|
|
140
|
+
done=done,
|
|
141
|
+
timeout=trunc,
|
|
142
|
+
**extra # Algorithm-specific data (logprob, value, etc.)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Handle episode boundaries
|
|
146
|
+
if done or trunc:
|
|
147
|
+
obs, _ = self.env.reset()
|
|
148
|
+
else:
|
|
149
|
+
obs = next_obs
|
|
150
|
+
|
|
151
|
+
# Break if we've reached max steps
|
|
152
|
+
if step >= self.max_steps - 1:
|
|
153
|
+
break
|
|
154
|
+
|
|
155
|
+
# Clear batch for next iteration
|
|
156
|
+
obs_batch = []
|
|
157
|
+
|
|
158
|
+
return self.buffer
|
|
159
|
+
|
|
160
|
+
def collect_episode(self, deterministic=False):
|
|
161
|
+
"""
|
|
162
|
+
Collect a single episode and return as trajectory (backward compatibility).
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
deterministic: Whether to use deterministic policy actions
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
List of transition dictionaries for test compatibility
|
|
169
|
+
"""
|
|
170
|
+
# Reset buffer for clean episode collection
|
|
171
|
+
self.buffer.clear()
|
|
172
|
+
|
|
173
|
+
# Handle both old-style (obs only) and new-style (obs, info) reset
|
|
174
|
+
reset_result = self.env.reset()
|
|
175
|
+
if isinstance(reset_result, tuple):
|
|
176
|
+
obs, _ = reset_result
|
|
177
|
+
else:
|
|
178
|
+
obs = reset_result
|
|
179
|
+
trajectory = []
|
|
180
|
+
|
|
181
|
+
for step in range(self.max_steps):
|
|
182
|
+
# Convert observation to MLX array
|
|
183
|
+
obs_mx = mx.array(obs)
|
|
184
|
+
|
|
185
|
+
# Policy forward pass - pass through deterministic flag to policy
|
|
186
|
+
if hasattr(self.policy, '__call__'):
|
|
187
|
+
# If policy supports deterministic parameter, use it
|
|
188
|
+
try:
|
|
189
|
+
action_mx, extra = self.policy(obs_mx, deterministic=deterministic)
|
|
190
|
+
except TypeError:
|
|
191
|
+
# Fallback: use strategy select_action which handles deterministic behavior
|
|
192
|
+
action_mx, extra = self.strategy.select_action(self.policy, obs_mx)
|
|
193
|
+
else:
|
|
194
|
+
action_mx, extra = self.strategy.select_action(self.policy, obs_mx)
|
|
195
|
+
|
|
196
|
+
# Convert action back to Python format
|
|
197
|
+
if action_mx.ndim == 0:
|
|
198
|
+
action = action_mx.item() # Scalar action (discrete environments)
|
|
199
|
+
else:
|
|
200
|
+
action = action_mx.tolist() # Vector action (continuous environments)
|
|
201
|
+
|
|
202
|
+
# Environment step: normalize dict/tuple to a standard tuple
|
|
203
|
+
step_result = self.env.step(action)
|
|
204
|
+
next_obs, reward, done, trunc, info = self._normalize_step_result(step_result)
|
|
205
|
+
|
|
206
|
+
# Build transition dictionary for test compatibility
|
|
207
|
+
transition = {
|
|
208
|
+
'obs': obs,
|
|
209
|
+
'act': action,
|
|
210
|
+
'rew': reward,
|
|
211
|
+
'next_obs': next_obs,
|
|
212
|
+
'done': done or trunc, # Combine terminated/truncated for backward compatibility
|
|
213
|
+
**extra # Include logprob, value, entropy from strategy
|
|
214
|
+
}
|
|
215
|
+
trajectory.append(transition)
|
|
216
|
+
|
|
217
|
+
# Track step count for tests
|
|
218
|
+
self.step_count += 1
|
|
219
|
+
|
|
220
|
+
# Handle episode boundaries
|
|
221
|
+
if done or trunc:
|
|
222
|
+
break
|
|
223
|
+
else:
|
|
224
|
+
obs = next_obs
|
|
225
|
+
|
|
226
|
+
return trajectory
|
|
227
|
+
|
|
228
|
+
def collect_rollout(self, num_episodes: int, collect_stats=False):
|
|
229
|
+
"""
|
|
230
|
+
Collect multiple episodes (backward compatibility).
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
num_episodes: Number of episodes to collect
|
|
234
|
+
collect_stats: Whether to collect statistics (for test compatibility)
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
List of episode trajectories
|
|
238
|
+
"""
|
|
239
|
+
trajectories = []
|
|
240
|
+
for _ in range(num_episodes):
|
|
241
|
+
trajectory = self.collect_episode()
|
|
242
|
+
trajectories.append(trajectory)
|
|
243
|
+
|
|
244
|
+
# Store trajectories for statistics if requested
|
|
245
|
+
if collect_stats:
|
|
246
|
+
self._last_trajectories = trajectories
|
|
247
|
+
|
|
248
|
+
return trajectories
|
|
249
|
+
|
|
250
|
+
def get_statistics(self):
|
|
251
|
+
"""
|
|
252
|
+
Get rollout statistics (for test compatibility).
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Dictionary of rollout statistics
|
|
256
|
+
"""
|
|
257
|
+
if not hasattr(self, '_last_trajectories'):
|
|
258
|
+
return {
|
|
259
|
+
'total_episodes': 0,
|
|
260
|
+
'total_steps': self.step_count,
|
|
261
|
+
'avg_episode_length': 0,
|
|
262
|
+
'avg_episode_reward': 0
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
total_episodes = len(self._last_trajectories)
|
|
266
|
+
total_steps = sum(len(traj) for traj in self._last_trajectories)
|
|
267
|
+
avg_episode_length = total_steps / max(total_episodes, 1)
|
|
268
|
+
|
|
269
|
+
total_reward = sum(
|
|
270
|
+
sum(transition['rew'] for transition in traj)
|
|
271
|
+
for traj in self._last_trajectories
|
|
272
|
+
)
|
|
273
|
+
avg_episode_reward = total_reward / max(total_episodes, 1)
|
|
274
|
+
|
|
275
|
+
return {
|
|
276
|
+
'total_episodes': total_episodes,
|
|
277
|
+
'total_steps': total_steps,
|
|
278
|
+
'avg_episode_length': avg_episode_length,
|
|
279
|
+
'avg_episode_reward': avg_episode_reward
|
|
280
|
+
}
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# textpolicy/rollout/strategy.py
|
|
2
|
+
"""
|
|
3
|
+
Algorithm-specific rollout strategies.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Callable, Dict, Any, Tuple
|
|
7
|
+
import mlx.core as mx # type: ignore
|
|
8
|
+
from textpolicy.buffer import Buffer
|
|
9
|
+
from .base import RolloutStrategy, validate_transition_data
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PPOStrategy(RolloutStrategy):
|
|
13
|
+
"""
|
|
14
|
+
Rollout strategy for Proximal Policy Optimization (PPO).
|
|
15
|
+
|
|
16
|
+
PPO requires:
|
|
17
|
+
- Action probabilities (logprob) for policy gradient
|
|
18
|
+
- Value function estimates for advantage calculation
|
|
19
|
+
- Policy and value function trained together
|
|
20
|
+
|
|
21
|
+
Expected policy output:
|
|
22
|
+
(action, {"logprob": mx.array, "value": mx.array, "entropy": mx.array})
|
|
23
|
+
|
|
24
|
+
Stored in buffer: obs, act, rew, next_obs, done, timeout, logprob, value
|
|
25
|
+
Filtered out: entropy (computed during training, not stored)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def select_action(self, policy: Callable, obs: mx.array) -> Tuple[mx.array, Dict[str, Any]]:
|
|
29
|
+
"""
|
|
30
|
+
Select action using PPO policy.
|
|
31
|
+
|
|
32
|
+
PPO TRAINING BEHAVIOR: Uses stochastic policy (deterministic=False) for exploration
|
|
33
|
+
during training data collection. This provides the variety of experiences needed
|
|
34
|
+
for robust policy learning. Evaluation uses deterministic=True for consistent
|
|
35
|
+
performance measurement.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
policy: Policy function returning (action, extras)
|
|
39
|
+
obs: MLX array observation
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
action: Selected action as MLX array (sampled stochastically)
|
|
43
|
+
extras: Dict with logprob, value, entropy
|
|
44
|
+
"""
|
|
45
|
+
# Use stochastic policy for training data collection (explore during training, evaluate deterministically)
|
|
46
|
+
return policy(obs, deterministic=False)
|
|
47
|
+
|
|
48
|
+
def store_transition(self, buffer: Buffer, **data) -> None:
|
|
49
|
+
"""
|
|
50
|
+
Store PPO transition data in buffer.
|
|
51
|
+
|
|
52
|
+
Filters data to include only what the buffer supports and PPO needs.
|
|
53
|
+
Validates required fields are present.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
buffer: Buffer instance to store data
|
|
57
|
+
**data: Transition data including obs, act, rew, etc.
|
|
58
|
+
"""
|
|
59
|
+
# Validate and filter transition data
|
|
60
|
+
filtered_data = validate_transition_data(data)
|
|
61
|
+
|
|
62
|
+
# Remove entropy if present (not stored, computed during training)
|
|
63
|
+
filtered_data.pop('entropy', None)
|
|
64
|
+
|
|
65
|
+
# Store in buffer
|
|
66
|
+
buffer.add(**filtered_data) # type: ignore
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class GRPOStrategy(RolloutStrategy):
|
|
70
|
+
"""
|
|
71
|
+
Rollout strategy for Group Relative Policy Optimization (GRPO).
|
|
72
|
+
|
|
73
|
+
GRPO characteristics:
|
|
74
|
+
- No value function required (uses group-relative advantages)
|
|
75
|
+
- Only needs action probabilities for policy gradient
|
|
76
|
+
- Advantage computed relative to group performance
|
|
77
|
+
|
|
78
|
+
Expected policy output:
|
|
79
|
+
(action, {"logprob": mx.array, "entropy": mx.array})
|
|
80
|
+
|
|
81
|
+
Stored in buffer: obs, act, rew, next_obs, done, timeout, logprob
|
|
82
|
+
Filtered out: value (not used), entropy (computed during training)
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def select_action(self, policy: Callable, obs: mx.array) -> Tuple[mx.array, Dict[str, Any]]:
|
|
86
|
+
"""
|
|
87
|
+
Select action using GRPO policy.
|
|
88
|
+
|
|
89
|
+
GRPO TRAINING BEHAVIOR: Uses stochastic policy (deterministic=False) for exploration
|
|
90
|
+
during training data collection, consistent with PPO approach.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
policy: Policy function returning (action, extras)
|
|
94
|
+
obs: MLX array observation
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
action: Selected action as MLX array (sampled stochastically)
|
|
98
|
+
extras: Dict with logprob, entropy (no value function)
|
|
99
|
+
"""
|
|
100
|
+
# Use stochastic policy for training data collection
|
|
101
|
+
return policy(obs, deterministic=False)
|
|
102
|
+
|
|
103
|
+
def store_transition(self, buffer: Buffer, **data) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Store GRPO transition data in buffer.
|
|
106
|
+
|
|
107
|
+
GRPO doesn't use value functions, so value data is filtered out.
|
|
108
|
+
Only stores what's needed for group-relative advantage computation.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
buffer: Buffer instance to store data
|
|
112
|
+
**data: Transition data including obs, act, rew, etc.
|
|
113
|
+
"""
|
|
114
|
+
# Validate and filter transition data
|
|
115
|
+
filtered_data = validate_transition_data(data)
|
|
116
|
+
|
|
117
|
+
# Remove fields not used by GRPO
|
|
118
|
+
filtered_data.pop('value', None) # No value function in GRPO
|
|
119
|
+
filtered_data.pop('entropy', None) # Computed during training
|
|
120
|
+
|
|
121
|
+
# Store in buffer
|
|
122
|
+
buffer.add(**filtered_data) # type: ignore
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class DreamerV3Strategy(RolloutStrategy):
|
|
126
|
+
"""
|
|
127
|
+
Rollout strategy for DreamerV3.
|
|
128
|
+
|
|
129
|
+
DreamerV3 is model-based RL that requires:
|
|
130
|
+
- Real environment interactions for world model learning
|
|
131
|
+
- Action probabilities for actor training
|
|
132
|
+
- State representations for RSSM dynamics learning
|
|
133
|
+
|
|
134
|
+
Expected policy output:
|
|
135
|
+
(action, {"logprob": mx.array, "value": mx.array, "entropy": mx.array,
|
|
136
|
+
"state": mx.array, "embed": mx.array})
|
|
137
|
+
|
|
138
|
+
Stored in buffer: obs, act, rew, next_obs, done, timeout, logprob, value, state, embed
|
|
139
|
+
Filtered out: entropy (computed during training)
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def select_action(self, policy: Callable, obs: mx.array) -> Tuple[mx.array, Dict[str, Any]]:
|
|
143
|
+
"""
|
|
144
|
+
Select action using DreamerV3 policy.
|
|
145
|
+
|
|
146
|
+
DreamerV3 uses the world model's posterior state for action selection.
|
|
147
|
+
During training, we use stochastic policy for exploration.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
policy: DreamerV3 policy function returning (action, extras)
|
|
151
|
+
obs: MLX array observation
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
action: Selected action as MLX array (sampled stochastically)
|
|
155
|
+
extras: Dict with logprob, value, entropy, state, embed
|
|
156
|
+
"""
|
|
157
|
+
# Use stochastic policy during training for exploration
|
|
158
|
+
return policy(obs, deterministic=False)
|
|
159
|
+
|
|
160
|
+
def store_transition(self, buffer: Buffer, **data) -> None:
|
|
161
|
+
"""
|
|
162
|
+
Store DreamerV3 transition data in buffer.
|
|
163
|
+
|
|
164
|
+
DreamerV3 stores additional state representations and embeddings
|
|
165
|
+
needed for world model learning and imagination.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
buffer: Buffer instance to store data
|
|
169
|
+
**data: Transition data including obs, act, rew, state, embed, etc.
|
|
170
|
+
"""
|
|
171
|
+
# Validate and filter transition data
|
|
172
|
+
filtered_data = validate_transition_data(data)
|
|
173
|
+
|
|
174
|
+
# Remove entropy if present (not stored, computed during training)
|
|
175
|
+
filtered_data.pop('entropy', None)
|
|
176
|
+
|
|
177
|
+
# Store in buffer (including DreamerV3-specific fields like state, embed)
|
|
178
|
+
buffer.add(**filtered_data) # type: ignore
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# Strategy registry for factory pattern
|
|
182
|
+
STRATEGY_REGISTRY = {
|
|
183
|
+
'ppo': PPOStrategy,
|
|
184
|
+
'grpo': GRPOStrategy,
|
|
185
|
+
'gspo': GRPOStrategy, # Add alias for backwards compatibility
|
|
186
|
+
'dreamerv3': DreamerV3Strategy,
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def create_strategy(algorithm: str) -> RolloutStrategy:
|
|
191
|
+
"""
|
|
192
|
+
Factory function for creating rollout strategies.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
algorithm: Algorithm name ('ppo', 'grpo', etc.)
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
RolloutStrategy instance
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
ValueError: If algorithm is not supported
|
|
202
|
+
"""
|
|
203
|
+
if algorithm not in STRATEGY_REGISTRY:
|
|
204
|
+
available = list(STRATEGY_REGISTRY.keys())
|
|
205
|
+
raise ValueError(f"Unknown algorithm '{algorithm}'. Available: {available}")
|
|
206
|
+
|
|
207
|
+
strategy_class = STRATEGY_REGISTRY[algorithm]
|
|
208
|
+
return strategy_class()
|