textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
1
+ # textpolicy/rollout/rollout.py
2
+ """
3
+ Main rollout coordinator and public interface.
4
+ """
5
+
6
+ from typing import List, Callable, Any
7
+ from .worker import RolloutWorker
8
+ from .runner import RolloutRunner
9
+ from .aggregator import BufferAggregator
10
+ from .strategy import create_strategy
11
+ from textpolicy.buffer import Buffer
12
+
13
+
14
+ class RolloutCoordinator:
15
+ """
16
+ Coordinator for rollout collection.
17
+
18
+ Provides unified interface for:
19
+ - Single-process rollouts (for debugging/small-scale)
20
+ - Multi-process rollouts (for production/performance)
21
+ - Strategy management and worker coordination
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ env_fn: Callable[[], Any],
27
+ policy_fn: Callable[[], Any],
28
+ algorithm: str,
29
+ num_workers: int = 0,
30
+ max_steps: int = 1000,
31
+ max_episodes: int = 100
32
+ ):
33
+ """
34
+ Initialize rollout coordinator.
35
+
36
+ Args:
37
+ env_fn: Function that creates environment instances
38
+ policy_fn: Function that creates policy instances
39
+ algorithm: Algorithm name ('ppo', 'grpo', etc.)
40
+ num_workers: Number of worker processes (0 = single-process mode)
41
+ max_steps: Maximum steps per rollout
42
+ max_episodes: Maximum episodes to buffer
43
+ """
44
+ self.env_fn = env_fn
45
+ self.policy_fn = policy_fn
46
+ self.algorithm = algorithm
47
+ self.num_workers = num_workers
48
+ self.max_steps = max_steps
49
+
50
+ # Create strategy for algorithm
51
+ self.strategy = create_strategy(algorithm)
52
+
53
+ # Setup for multi-process or single-process mode
54
+ if num_workers > 0:
55
+ self._setup_multiprocess(max_episodes)
56
+ else:
57
+ self._setup_singleprocess()
58
+
59
+ def _setup_multiprocess(self, max_episodes: int):
60
+ """Setup multi-process rollout collection."""
61
+ self.aggregator = BufferAggregator(self.num_workers, max_episodes)
62
+ self.workers: List[RolloutWorker] = []
63
+
64
+ # Create and start worker processes
65
+ for i in range(self.num_workers):
66
+ worker = RolloutWorker(
67
+ env_fn=self.env_fn,
68
+ policy_fn=self.policy_fn,
69
+ strategy=self.strategy,
70
+ max_steps=self.max_steps
71
+ )
72
+ self.workers.append(worker)
73
+ self.aggregator.add_worker(worker, i)
74
+ worker.start()
75
+
76
+ def _setup_singleprocess(self):
77
+ """Setup single-process rollout collection."""
78
+ self.aggregator = None
79
+ self.workers = []
80
+
81
+ # Create single runner for direct use
82
+ env = self.env_fn()
83
+ policy = self.policy_fn()
84
+ self.runner = RolloutRunner(env, policy, self.strategy, self.max_steps)
85
+
86
+ def collect(self) -> Buffer:
87
+ """
88
+ Collect rollout data.
89
+
90
+ Returns:
91
+ Buffer containing collected episodes
92
+ """
93
+ if self.num_workers > 0:
94
+ return self._collect_multiprocess()
95
+ else:
96
+ return self._collect_singleprocess()
97
+
98
+ def _collect_multiprocess(self) -> Buffer:
99
+ """Collect data using multiple worker processes."""
100
+ # Request rollouts from all workers
101
+ for worker in self.workers:
102
+ worker.collect_async()
103
+
104
+ # Wait for and consume results
105
+ while not self.aggregator.ready(min_episodes=self.num_workers):
106
+ self.aggregator.consume_all()
107
+
108
+ # Return aggregated buffer
109
+ # Note: In practice, trainer would manage this more carefully
110
+ return self.aggregator.buffer
111
+
112
+ def _collect_singleprocess(self) -> Buffer:
113
+ """Collect data using single process."""
114
+ return self.runner.collect()
115
+
116
+ def close(self):
117
+ """Cleanup resources."""
118
+ if self.workers:
119
+ for worker in self.workers:
120
+ worker.close()
121
+
122
+
123
+ # Public API functions for external use
124
+ def create_rollout_coordinator(
125
+ env_fn: Callable[[], Any],
126
+ policy_fn: Callable[[], Any],
127
+ algorithm: str,
128
+ **kwargs
129
+ ) -> RolloutCoordinator:
130
+ """
131
+ Factory function for creating rollout coordinators.
132
+
133
+ Args:
134
+ env_fn: Environment factory function
135
+ policy_fn: Policy factory function
136
+ algorithm: Algorithm name
137
+ **kwargs: Additional configuration options
138
+
139
+ Returns:
140
+ RolloutCoordinator instance
141
+ """
142
+ return RolloutCoordinator(env_fn, policy_fn, algorithm, **kwargs)
@@ -0,0 +1,280 @@
1
+ # textpolicy/rollout/runner.py
2
+ """
3
+ Core rollout collection engine.
4
+ """
5
+
6
+ from typing import Callable, Dict, Optional, Tuple
7
+ import mlx.core as mx # type: ignore
8
+ from textpolicy.buffer import Buffer
9
+ from .base import RolloutStrategy, DEFAULT_MAX_STEPS
10
+
11
+
12
+ class RolloutRunner:
13
+ """
14
+ Collects rollouts from a single environment using a policy.
15
+
16
+ Designed for MLX and Apple Silicon:
17
+ - Minimized Python overhead in the collection loop
18
+ - MLX array conversions performed once per step
19
+ - Python lists for storage to reduce overhead
20
+ - Policy inference on GPU/ANE when available
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ env,
26
+ policy: Optional[Callable[[mx.array], Tuple[mx.array, Dict[str, mx.array]]]] = None,
27
+ strategy: Optional[RolloutStrategy] = None,
28
+ max_steps: int = DEFAULT_MAX_STEPS,
29
+ agent = None # Alternative API: pass agent instead of policy/strategy
30
+ ) -> None:
31
+ """
32
+ Initialize rollout runner.
33
+
34
+ Args:
35
+ env: Environment instance (must implement gym interface)
36
+ policy: Policy function that takes obs and returns (action, extras)
37
+ strategy: RolloutStrategy defining algorithm-specific behavior
38
+ max_steps: Maximum steps per rollout collection
39
+ agent: Alternative API - Agent object containing policy and rollout_strategy
40
+ """
41
+ self.env = env
42
+
43
+ # Support both direct policy/strategy and agent-based initialization
44
+ if agent is not None:
45
+ # Extract policy and strategy from agent for backward compatibility with tests
46
+ self.policy = agent.policy
47
+ self.strategy = agent.rollout_strategy
48
+ else:
49
+ # Direct policy/strategy initialization (current API)
50
+ if policy is None or strategy is None:
51
+ raise ValueError("Must provide either agent or both policy and strategy")
52
+ self.policy = policy
53
+ self.strategy = strategy
54
+
55
+ self.max_steps = max_steps
56
+ self.buffer = Buffer(max_episodes=10)
57
+ self.step_count = 0 # Track total steps collected for test compatibility
58
+
59
+ def _normalize_step_result(self, step_result):
60
+ """
61
+ Normalize Environment.step results to a tuple
62
+ (next_obs, reward, terminated, truncated, info).
63
+
64
+ Enforces dict-shaped step results per Environment contract.
65
+ Raises TypeError for tuple-based results.
66
+ """
67
+ if not isinstance(step_result, dict):
68
+ raise TypeError(
69
+ "Environment.step must return a dict with keys: observation, reward, "
70
+ "terminated, truncated, info. Tuple returns are not supported."
71
+ )
72
+ return (
73
+ step_result.get("observation"),
74
+ step_result.get("reward"),
75
+ step_result.get("terminated"),
76
+ step_result.get("truncated"),
77
+ step_result.get("info", {}),
78
+ )
79
+
80
+ def collect(self) -> Buffer:
81
+ """
82
+ Run one complete rollout collection.
83
+
84
+ MLX efficiency guidelines:
85
+ - Policy runs on GPU/ANE via MLX arrays
86
+ - Buffer stores as Python lists
87
+ - Batch array conversions reduce memory transfers
88
+ - Strategy handles algorithm differences
89
+
90
+ Returns:
91
+ Buffer containing collected episodes
92
+ """
93
+ obs, _ = self.env.reset()
94
+
95
+ # Batch observations for efficient MLX conversion
96
+ # Collect multiple observations before converting to MLX arrays
97
+ batch_size = min(32, self.max_steps) # Batch size for array conversion
98
+ obs_batch = []
99
+
100
+ for step in range(self.max_steps):
101
+ obs_batch.append(obs)
102
+
103
+ # Process batch when full or on final step
104
+ if len(obs_batch) == batch_size or step == self.max_steps - 1:
105
+ # Single MLX array conversion for the entire batch
106
+ try:
107
+ import numpy as np
108
+ obs_batch_np = np.array(obs_batch)
109
+ obs_batch_mx = mx.array(obs_batch_np)
110
+ except (ValueError, TypeError):
111
+ # Fallback to individual conversion if batch conversion fails
112
+ obs_batch_mx = mx.stack([mx.array(o) for o in obs_batch])
113
+
114
+ # Process each observation in the batch
115
+ for i, obs_single in enumerate(obs_batch):
116
+ # Extract single observation from batch
117
+ obs_mx = obs_batch_mx[i] if len(obs_batch) > 1 else obs_batch_mx
118
+
119
+ # Policy forward pass - runs on GPU/ANE
120
+ action_mx, extra = self.strategy.select_action(self.policy, obs_mx)
121
+
122
+ # Convert action back to Python format (once per step)
123
+ if action_mx.ndim == 0:
124
+ action = action_mx.item() # Scalar action (discrete environments)
125
+ else:
126
+ action = action_mx.tolist() # Vector action (continuous environments)
127
+
128
+ # Environment step (CPU-bound). Normalize to a standard tuple.
129
+ step_result = self.env.step(action)
130
+ next_obs, reward, done, trunc, info = self._normalize_step_result(step_result)
131
+
132
+ # Store transition using strategy-specific logic
133
+ # Strategy handles filtering and algorithm-specific data
134
+ self.strategy.store_transition(
135
+ buffer=self.buffer,
136
+ obs=obs_single, # Use original observation
137
+ act=action,
138
+ rew=reward,
139
+ next_obs=next_obs,
140
+ done=done,
141
+ timeout=trunc,
142
+ **extra # Algorithm-specific data (logprob, value, etc.)
143
+ )
144
+
145
+ # Handle episode boundaries
146
+ if done or trunc:
147
+ obs, _ = self.env.reset()
148
+ else:
149
+ obs = next_obs
150
+
151
+ # Break if we've reached max steps
152
+ if step >= self.max_steps - 1:
153
+ break
154
+
155
+ # Clear batch for next iteration
156
+ obs_batch = []
157
+
158
+ return self.buffer
159
+
160
+ def collect_episode(self, deterministic=False):
161
+ """
162
+ Collect a single episode and return as trajectory (backward compatibility).
163
+
164
+ Args:
165
+ deterministic: Whether to use deterministic policy actions
166
+
167
+ Returns:
168
+ List of transition dictionaries for test compatibility
169
+ """
170
+ # Reset buffer for clean episode collection
171
+ self.buffer.clear()
172
+
173
+ # Handle both old-style (obs only) and new-style (obs, info) reset
174
+ reset_result = self.env.reset()
175
+ if isinstance(reset_result, tuple):
176
+ obs, _ = reset_result
177
+ else:
178
+ obs = reset_result
179
+ trajectory = []
180
+
181
+ for step in range(self.max_steps):
182
+ # Convert observation to MLX array
183
+ obs_mx = mx.array(obs)
184
+
185
+ # Policy forward pass - pass through deterministic flag to policy
186
+ if hasattr(self.policy, '__call__'):
187
+ # If policy supports deterministic parameter, use it
188
+ try:
189
+ action_mx, extra = self.policy(obs_mx, deterministic=deterministic)
190
+ except TypeError:
191
+ # Fallback: use strategy select_action which handles deterministic behavior
192
+ action_mx, extra = self.strategy.select_action(self.policy, obs_mx)
193
+ else:
194
+ action_mx, extra = self.strategy.select_action(self.policy, obs_mx)
195
+
196
+ # Convert action back to Python format
197
+ if action_mx.ndim == 0:
198
+ action = action_mx.item() # Scalar action (discrete environments)
199
+ else:
200
+ action = action_mx.tolist() # Vector action (continuous environments)
201
+
202
+ # Environment step: normalize dict/tuple to a standard tuple
203
+ step_result = self.env.step(action)
204
+ next_obs, reward, done, trunc, info = self._normalize_step_result(step_result)
205
+
206
+ # Build transition dictionary for test compatibility
207
+ transition = {
208
+ 'obs': obs,
209
+ 'act': action,
210
+ 'rew': reward,
211
+ 'next_obs': next_obs,
212
+ 'done': done or trunc, # Combine terminated/truncated for backward compatibility
213
+ **extra # Include logprob, value, entropy from strategy
214
+ }
215
+ trajectory.append(transition)
216
+
217
+ # Track step count for tests
218
+ self.step_count += 1
219
+
220
+ # Handle episode boundaries
221
+ if done or trunc:
222
+ break
223
+ else:
224
+ obs = next_obs
225
+
226
+ return trajectory
227
+
228
+ def collect_rollout(self, num_episodes: int, collect_stats=False):
229
+ """
230
+ Collect multiple episodes (backward compatibility).
231
+
232
+ Args:
233
+ num_episodes: Number of episodes to collect
234
+ collect_stats: Whether to collect statistics (for test compatibility)
235
+
236
+ Returns:
237
+ List of episode trajectories
238
+ """
239
+ trajectories = []
240
+ for _ in range(num_episodes):
241
+ trajectory = self.collect_episode()
242
+ trajectories.append(trajectory)
243
+
244
+ # Store trajectories for statistics if requested
245
+ if collect_stats:
246
+ self._last_trajectories = trajectories
247
+
248
+ return trajectories
249
+
250
+ def get_statistics(self):
251
+ """
252
+ Get rollout statistics (for test compatibility).
253
+
254
+ Returns:
255
+ Dictionary of rollout statistics
256
+ """
257
+ if not hasattr(self, '_last_trajectories'):
258
+ return {
259
+ 'total_episodes': 0,
260
+ 'total_steps': self.step_count,
261
+ 'avg_episode_length': 0,
262
+ 'avg_episode_reward': 0
263
+ }
264
+
265
+ total_episodes = len(self._last_trajectories)
266
+ total_steps = sum(len(traj) for traj in self._last_trajectories)
267
+ avg_episode_length = total_steps / max(total_episodes, 1)
268
+
269
+ total_reward = sum(
270
+ sum(transition['rew'] for transition in traj)
271
+ for traj in self._last_trajectories
272
+ )
273
+ avg_episode_reward = total_reward / max(total_episodes, 1)
274
+
275
+ return {
276
+ 'total_episodes': total_episodes,
277
+ 'total_steps': total_steps,
278
+ 'avg_episode_length': avg_episode_length,
279
+ 'avg_episode_reward': avg_episode_reward
280
+ }
@@ -0,0 +1,208 @@
1
+ # textpolicy/rollout/strategy.py
2
+ """
3
+ Algorithm-specific rollout strategies.
4
+ """
5
+
6
+ from typing import Callable, Dict, Any, Tuple
7
+ import mlx.core as mx # type: ignore
8
+ from textpolicy.buffer import Buffer
9
+ from .base import RolloutStrategy, validate_transition_data
10
+
11
+
12
+ class PPOStrategy(RolloutStrategy):
13
+ """
14
+ Rollout strategy for Proximal Policy Optimization (PPO).
15
+
16
+ PPO requires:
17
+ - Action probabilities (logprob) for policy gradient
18
+ - Value function estimates for advantage calculation
19
+ - Policy and value function trained together
20
+
21
+ Expected policy output:
22
+ (action, {"logprob": mx.array, "value": mx.array, "entropy": mx.array})
23
+
24
+ Stored in buffer: obs, act, rew, next_obs, done, timeout, logprob, value
25
+ Filtered out: entropy (computed during training, not stored)
26
+ """
27
+
28
+ def select_action(self, policy: Callable, obs: mx.array) -> Tuple[mx.array, Dict[str, Any]]:
29
+ """
30
+ Select action using PPO policy.
31
+
32
+ PPO TRAINING BEHAVIOR: Uses stochastic policy (deterministic=False) for exploration
33
+ during training data collection. This provides the variety of experiences needed
34
+ for robust policy learning. Evaluation uses deterministic=True for consistent
35
+ performance measurement.
36
+
37
+ Args:
38
+ policy: Policy function returning (action, extras)
39
+ obs: MLX array observation
40
+
41
+ Returns:
42
+ action: Selected action as MLX array (sampled stochastically)
43
+ extras: Dict with logprob, value, entropy
44
+ """
45
+ # Use stochastic policy for training data collection (explore during training, evaluate deterministically)
46
+ return policy(obs, deterministic=False)
47
+
48
+ def store_transition(self, buffer: Buffer, **data) -> None:
49
+ """
50
+ Store PPO transition data in buffer.
51
+
52
+ Filters data to include only what the buffer supports and PPO needs.
53
+ Validates required fields are present.
54
+
55
+ Args:
56
+ buffer: Buffer instance to store data
57
+ **data: Transition data including obs, act, rew, etc.
58
+ """
59
+ # Validate and filter transition data
60
+ filtered_data = validate_transition_data(data)
61
+
62
+ # Remove entropy if present (not stored, computed during training)
63
+ filtered_data.pop('entropy', None)
64
+
65
+ # Store in buffer
66
+ buffer.add(**filtered_data) # type: ignore
67
+
68
+
69
+ class GRPOStrategy(RolloutStrategy):
70
+ """
71
+ Rollout strategy for Group Relative Policy Optimization (GRPO).
72
+
73
+ GRPO characteristics:
74
+ - No value function required (uses group-relative advantages)
75
+ - Only needs action probabilities for policy gradient
76
+ - Advantage computed relative to group performance
77
+
78
+ Expected policy output:
79
+ (action, {"logprob": mx.array, "entropy": mx.array})
80
+
81
+ Stored in buffer: obs, act, rew, next_obs, done, timeout, logprob
82
+ Filtered out: value (not used), entropy (computed during training)
83
+ """
84
+
85
+ def select_action(self, policy: Callable, obs: mx.array) -> Tuple[mx.array, Dict[str, Any]]:
86
+ """
87
+ Select action using GRPO policy.
88
+
89
+ GRPO TRAINING BEHAVIOR: Uses stochastic policy (deterministic=False) for exploration
90
+ during training data collection, consistent with PPO approach.
91
+
92
+ Args:
93
+ policy: Policy function returning (action, extras)
94
+ obs: MLX array observation
95
+
96
+ Returns:
97
+ action: Selected action as MLX array (sampled stochastically)
98
+ extras: Dict with logprob, entropy (no value function)
99
+ """
100
+ # Use stochastic policy for training data collection
101
+ return policy(obs, deterministic=False)
102
+
103
+ def store_transition(self, buffer: Buffer, **data) -> None:
104
+ """
105
+ Store GRPO transition data in buffer.
106
+
107
+ GRPO doesn't use value functions, so value data is filtered out.
108
+ Only stores what's needed for group-relative advantage computation.
109
+
110
+ Args:
111
+ buffer: Buffer instance to store data
112
+ **data: Transition data including obs, act, rew, etc.
113
+ """
114
+ # Validate and filter transition data
115
+ filtered_data = validate_transition_data(data)
116
+
117
+ # Remove fields not used by GRPO
118
+ filtered_data.pop('value', None) # No value function in GRPO
119
+ filtered_data.pop('entropy', None) # Computed during training
120
+
121
+ # Store in buffer
122
+ buffer.add(**filtered_data) # type: ignore
123
+
124
+
125
+ class DreamerV3Strategy(RolloutStrategy):
126
+ """
127
+ Rollout strategy for DreamerV3.
128
+
129
+ DreamerV3 is model-based RL that requires:
130
+ - Real environment interactions for world model learning
131
+ - Action probabilities for actor training
132
+ - State representations for RSSM dynamics learning
133
+
134
+ Expected policy output:
135
+ (action, {"logprob": mx.array, "value": mx.array, "entropy": mx.array,
136
+ "state": mx.array, "embed": mx.array})
137
+
138
+ Stored in buffer: obs, act, rew, next_obs, done, timeout, logprob, value, state, embed
139
+ Filtered out: entropy (computed during training)
140
+ """
141
+
142
+ def select_action(self, policy: Callable, obs: mx.array) -> Tuple[mx.array, Dict[str, Any]]:
143
+ """
144
+ Select action using DreamerV3 policy.
145
+
146
+ DreamerV3 uses the world model's posterior state for action selection.
147
+ During training, we use stochastic policy for exploration.
148
+
149
+ Args:
150
+ policy: DreamerV3 policy function returning (action, extras)
151
+ obs: MLX array observation
152
+
153
+ Returns:
154
+ action: Selected action as MLX array (sampled stochastically)
155
+ extras: Dict with logprob, value, entropy, state, embed
156
+ """
157
+ # Use stochastic policy during training for exploration
158
+ return policy(obs, deterministic=False)
159
+
160
+ def store_transition(self, buffer: Buffer, **data) -> None:
161
+ """
162
+ Store DreamerV3 transition data in buffer.
163
+
164
+ DreamerV3 stores additional state representations and embeddings
165
+ needed for world model learning and imagination.
166
+
167
+ Args:
168
+ buffer: Buffer instance to store data
169
+ **data: Transition data including obs, act, rew, state, embed, etc.
170
+ """
171
+ # Validate and filter transition data
172
+ filtered_data = validate_transition_data(data)
173
+
174
+ # Remove entropy if present (not stored, computed during training)
175
+ filtered_data.pop('entropy', None)
176
+
177
+ # Store in buffer (including DreamerV3-specific fields like state, embed)
178
+ buffer.add(**filtered_data) # type: ignore
179
+
180
+
181
+ # Strategy registry for factory pattern
182
+ STRATEGY_REGISTRY = {
183
+ 'ppo': PPOStrategy,
184
+ 'grpo': GRPOStrategy,
185
+ 'gspo': GRPOStrategy, # Add alias for backwards compatibility
186
+ 'dreamerv3': DreamerV3Strategy,
187
+ }
188
+
189
+
190
+ def create_strategy(algorithm: str) -> RolloutStrategy:
191
+ """
192
+ Factory function for creating rollout strategies.
193
+
194
+ Args:
195
+ algorithm: Algorithm name ('ppo', 'grpo', etc.)
196
+
197
+ Returns:
198
+ RolloutStrategy instance
199
+
200
+ Raises:
201
+ ValueError: If algorithm is not supported
202
+ """
203
+ if algorithm not in STRATEGY_REGISTRY:
204
+ available = list(STRATEGY_REGISTRY.keys())
205
+ raise ValueError(f"Unknown algorithm '{algorithm}'. Available: {available}")
206
+
207
+ strategy_class = STRATEGY_REGISTRY[algorithm]
208
+ return strategy_class()