textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,369 @@
1
+ # textpolicy/rewards/verifiers.py
2
+ """
3
+ Text quality verifiers following retrain's decorator-based pattern.
4
+
5
+ Verifiers provide boolean pre-filtering for reward functions,
6
+ following retrain's philosophy of efficient quality control.
7
+
8
+ All verifiers follow the signature: (prompt: str, completion: str, example: Dict[str, Any]) -> bool
9
+ """
10
+
11
+ import warnings
12
+ from typing import Dict, List, Optional, Any
13
+ import re
14
+ from dataclasses import dataclass
15
+ from enum import Enum
16
+ from .registry import verifier
17
+
18
+
19
+ warnings.warn(
20
+ "VerificationResult and VerificationReport are deprecated; use boolean verifiers instead.",
21
+ DeprecationWarning,
22
+ stacklevel=2,
23
+ )
24
+
25
+ # Legacy types for backward compatibility (deprecated)
26
+ class VerificationResult(Enum):
27
+ """Result of verification check (deprecated - use boolean verifiers now)."""
28
+ PASS = "pass"
29
+ FAIL = "fail"
30
+ WARNING = "warning"
31
+
32
+
33
+ @dataclass
34
+ class VerificationReport:
35
+ """Report from verification check (deprecated - use boolean verifiers now)."""
36
+ result: VerificationResult
37
+ score: float # 0.0 to 1.0
38
+ details: Dict[str, Any]
39
+ message: str
40
+
41
+
42
+ @verifier
43
+ def length_verifier(
44
+ prompt: str,
45
+ completion: str,
46
+ example: Dict[str, Any],
47
+ min_length: int = 10,
48
+ max_length: int = 500,
49
+ **kwargs
50
+ ) -> bool:
51
+ """
52
+ Verifies response length appropriateness following retrain's pattern.
53
+
54
+ Args:
55
+ prompt: Input prompt
56
+ completion: Generated response text
57
+ example: Example data context (may contain length constraints)
58
+ min_length: Minimum required word count
59
+ max_length: Maximum allowed word count
60
+ **kwargs: Additional parameters
61
+
62
+ Returns:
63
+ True if length is appropriate, False otherwise
64
+ """
65
+ # Get constraints from example if not provided
66
+ if 'min_length' in example:
67
+ min_length = example['min_length']
68
+ if 'max_length' in example:
69
+ max_length = example['max_length']
70
+
71
+ word_count = len(completion.split())
72
+ return min_length <= word_count <= max_length
73
+
74
+
75
+ @verifier
76
+ def toxicity_verifier(
77
+ prompt: str,
78
+ completion: str,
79
+ example: Dict[str, Any],
80
+ **kwargs
81
+ ) -> bool:
82
+ """
83
+ Verifies response is not toxic or inappropriate following retrain's pattern.
84
+
85
+ Args:
86
+ prompt: Input prompt
87
+ completion: Generated response text
88
+ example: Example data context (may contain custom toxic patterns)
89
+ **kwargs: Additional parameters
90
+
91
+ Returns:
92
+ True if content is non-toxic, False if toxic content detected
93
+ """
94
+ # Default toxic patterns - simple keyword-based detection
95
+ # In practice, use a proper toxicity classifier
96
+ toxic_patterns = [
97
+ r'\b(hate|kill|die|stupid|idiot|racist|sexist)\b',
98
+ r'\b(fuck|shit|damn|hell)\b',
99
+ r'\b(violence|abuse|harassment)\b'
100
+ ]
101
+
102
+ # Allow custom patterns from example
103
+ if 'toxic_patterns' in example:
104
+ toxic_patterns.extend(example['toxic_patterns'])
105
+
106
+ toxic_regex = re.compile('|'.join(toxic_patterns), re.IGNORECASE)
107
+ matches = toxic_regex.findall(completion.lower())
108
+
109
+ return len(matches) == 0
110
+
111
+
112
+ @verifier
113
+ def coherence_verifier(
114
+ prompt: str,
115
+ completion: str,
116
+ example: Dict[str, Any],
117
+ min_coherence_score: float = 0.5,
118
+ **kwargs
119
+ ) -> bool:
120
+ """
121
+ Verifies response coherence and logical flow following retrain's pattern.
122
+
123
+ Args:
124
+ prompt: Input prompt
125
+ completion: Generated response text
126
+ example: Example data context (may contain coherence requirements)
127
+ min_coherence_score: Minimum coherence score threshold
128
+ **kwargs: Additional parameters
129
+
130
+ Returns:
131
+ True if coherent enough, False otherwise
132
+ """
133
+ # Get threshold from example if specified
134
+ if 'min_coherence_score' in example:
135
+ min_coherence_score = example['min_coherence_score']
136
+
137
+ # Simple coherence heuristics
138
+ coherence_indicators = [
139
+ r'\b(therefore|thus|however|moreover|furthermore)\b',
140
+ r'\b(first|second|third|finally|in conclusion)\b',
141
+ r'\b(because|since|as a result|consequently)\b'
142
+ ]
143
+
144
+ coherence_regex = re.compile('|'.join(coherence_indicators), re.IGNORECASE)
145
+
146
+ # Check for logical connectors
147
+ connectors = len(coherence_regex.findall(completion))
148
+
149
+ # Check sentence structure
150
+ sentences = re.split(r'[.!?]+', completion)
151
+ valid_sentences = [s for s in sentences if s.strip()]
152
+ if not valid_sentences:
153
+ return False
154
+
155
+ avg_sentence_length = sum(len(s.split()) for s in valid_sentences) / len(valid_sentences)
156
+
157
+ # Simple coherence score
158
+ connector_score = min(1.0, connectors / 2.0) # 2+ connectors is good
159
+ sentence_score = 1.0 if 5 <= avg_sentence_length <= 25 else 0.5
160
+
161
+ coherence_score = (connector_score + sentence_score) / 2.0
162
+
163
+ return coherence_score >= min_coherence_score
164
+
165
+
166
+ @verifier
167
+ def factual_verifier(
168
+ prompt: str,
169
+ completion: str,
170
+ example: Dict[str, Any],
171
+ min_factual_score: float = 0.6,
172
+ max_uncertainty_count: int = 2,
173
+ **kwargs
174
+ ) -> bool:
175
+ """
176
+ Verifies factual accuracy and consistency following retrain's pattern.
177
+
178
+ Args:
179
+ prompt: Input prompt
180
+ completion: Generated response text
181
+ example: Example data context (may contain factual requirements)
182
+ min_factual_score: Minimum factual confidence score
183
+ max_uncertainty_count: Maximum allowed uncertainty phrases
184
+ **kwargs: Additional parameters
185
+
186
+ Returns:
187
+ True if factually sound, False otherwise
188
+ """
189
+ # Get thresholds from example if specified
190
+ if 'min_factual_score' in example:
191
+ min_factual_score = example['min_factual_score']
192
+ if 'max_uncertainty_count' in example:
193
+ max_uncertainty_count = example['max_uncertainty_count']
194
+
195
+ # Uncertainty indicators that suggest factual issues
196
+ uncertainty_phrases = [
197
+ r'\b(i think|maybe|perhaps|possibly|not sure)\b',
198
+ r'\b(might be|could be|i believe|seems like)\b',
199
+ r'\b(probably|likely|unclear|unknown)\b'
200
+ ]
201
+
202
+ uncertainty_regex = re.compile('|'.join(uncertainty_phrases), re.IGNORECASE)
203
+
204
+ # Count uncertainty phrases
205
+ uncertainty_count = len(uncertainty_regex.findall(completion.lower()))
206
+
207
+ # Check for contradictory statements (simple heuristic)
208
+ contradictions = 0
209
+ if 'yes' in completion.lower() and 'no' in completion.lower():
210
+ contradictions += 1
211
+ if 'true' in completion.lower() and 'false' in completion.lower():
212
+ contradictions += 1
213
+
214
+ # Calculate factual confidence
215
+ uncertainty_penalty = min(1.0, uncertainty_count * 0.2)
216
+ contradiction_penalty = contradictions * 0.3
217
+
218
+ factual_score = max(0.0, 1.0 - uncertainty_penalty - contradiction_penalty)
219
+
220
+ return factual_score >= min_factual_score and uncertainty_count <= max_uncertainty_count
221
+
222
+
223
+ # Additional verifiers following retrain's patterns
224
+
225
+ @verifier
226
+ def has_greeting(
227
+ prompt: str,
228
+ completion: str,
229
+ example: Dict[str, Any],
230
+ required_greeting: str = "hello",
231
+ **kwargs
232
+ ) -> bool:
233
+ """
234
+ Verifies that the completion contains a greeting.
235
+
236
+ Args:
237
+ prompt: Input prompt
238
+ completion: Generated response text
239
+ example: Example data context (may specify required greeting)
240
+ required_greeting: The greeting phrase to look for
241
+ **kwargs: Additional parameters
242
+
243
+ Returns:
244
+ True if greeting is present, False otherwise
245
+ """
246
+ if 'required_greeting' in example:
247
+ required_greeting = example['required_greeting']
248
+
249
+ return required_greeting.lower() in completion.lower()
250
+
251
+
252
+ @verifier
253
+ def no_empty_response(
254
+ prompt: str,
255
+ completion: str,
256
+ example: Dict[str, Any],
257
+ **kwargs
258
+ ) -> bool:
259
+ """
260
+ Verifies that the completion is not empty or whitespace-only.
261
+
262
+ Args:
263
+ prompt: Input prompt
264
+ completion: Generated response text
265
+ example: Example data context
266
+ **kwargs: Additional parameters
267
+
268
+ Returns:
269
+ True if completion has content, False if empty
270
+ """
271
+ return bool(completion.strip())
272
+
273
+
274
+ @verifier
275
+ def contains_keywords(
276
+ prompt: str,
277
+ completion: str,
278
+ example: Dict[str, Any],
279
+ required_keywords: Optional[List[str]] = None,
280
+ **kwargs
281
+ ) -> bool:
282
+ """
283
+ Verifies that the completion contains required keywords.
284
+
285
+ Args:
286
+ prompt: Input prompt
287
+ completion: Generated response text
288
+ example: Example data context (may contain required_keywords)
289
+ required_keywords: List of keywords that must be present
290
+ **kwargs: Additional parameters
291
+
292
+ Returns:
293
+ True if all required keywords are present, False otherwise
294
+ """
295
+ if required_keywords is None:
296
+ required_keywords = example.get('required_keywords', [])
297
+
298
+ if not required_keywords:
299
+ return True # No requirements, always pass
300
+
301
+ completion_lower = completion.lower()
302
+ return all(keyword.lower() in completion_lower for keyword in required_keywords)
303
+
304
+
305
+ # Legacy compatibility functions (deprecated)
306
+ def create_default_verifier_pipeline():
307
+ """Create default verification pipeline (deprecated - use registry-based verifiers)."""
308
+ import logging
309
+ logger = logging.getLogger(__name__)
310
+ logger.warning("create_default_verifier_pipeline is deprecated. Use registry-based verifiers with apply_verifiers_to_reward instead.")
311
+ return None
312
+
313
+
314
+ def create_custom_verifier_pipeline(verifier_configs: List[Dict[str, Any]]):
315
+ """Create custom verification pipeline (deprecated - use registry-based verifiers)."""
316
+ import logging
317
+ logger = logging.getLogger(__name__)
318
+ logger.warning("create_custom_verifier_pipeline is deprecated. Use registry-based verifiers with apply_verifiers_to_reward instead.")
319
+ return None
320
+
321
+
322
+ # Legacy class aliases for backward compatibility
323
+ class TextVerifier:
324
+ """Base class for text quality verifiers (deprecated)."""
325
+ def verify(self, prompt: str, response: str):
326
+ # Legacy method: deprecated in favour of registry-based boolean verifiers
327
+ warnings.warn(
328
+ "TextVerifier.verify is deprecated; use registry-based @verifier functions instead.",
329
+ DeprecationWarning,
330
+ stacklevel=2,
331
+ )
332
+ return False
333
+
334
+
335
+ class LengthVerifier(TextVerifier):
336
+ """Legacy length verifier (deprecated - use length_verifier function)."""
337
+ pass
338
+
339
+
340
+ class ToxicityVerifier(TextVerifier):
341
+ """Legacy toxicity verifier (deprecated - use toxicity_verifier function)."""
342
+ pass
343
+
344
+
345
+ class CoherenceVerifier(TextVerifier):
346
+ """Legacy coherence verifier (deprecated - use coherence_verifier function)."""
347
+ pass
348
+
349
+
350
+ class FactualVerifier(TextVerifier):
351
+ """Legacy factual verifier (deprecated - use factual_verifier function)."""
352
+ pass
353
+
354
+
355
+ class VerificationPipeline:
356
+ """Legacy verification pipeline (deprecated - use registry-based approach)."""
357
+ def __init__(self, verifiers):
358
+ import logging
359
+ logger = logging.getLogger(__name__)
360
+ logger.warning("VerificationPipeline is deprecated. Use registry-based verifiers with apply_verifiers_to_reward instead.")
361
+
362
+ def verify_batch(self, prompts, responses):
363
+ # Legacy method: deprecated in favour of registry-based verification pipeline
364
+ warnings.warn(
365
+ "VerificationPipeline.verify_batch is deprecated; use registry-based verifiers instead.",
366
+ DeprecationWarning,
367
+ stacklevel=2,
368
+ )
369
+ return []
@@ -0,0 +1,44 @@
1
+ # textpolicy/rollout/__init__.py
2
+ """
3
+ Modular rollout system for TextPolicy.
4
+
5
+ Main components:
6
+ - RolloutCoordinator: Interface for rollout collection
7
+ - RolloutStrategy: Algorithm-specific rollout behavior
8
+ - RolloutWorker: Multi-process worker management
9
+ - BufferAggregator: Multi-worker data coordination
10
+ """
11
+
12
+ from .rollout import RolloutCoordinator, create_rollout_coordinator
13
+ from .base import RolloutStrategy
14
+ from .worker import RolloutWorker
15
+ from .runner import RolloutRunner
16
+ from .aggregator import BufferAggregator
17
+ from .strategy import PPOStrategy, GRPOStrategy, create_strategy
18
+
19
+ # Backwards compatibility exports
20
+ from .runner import RolloutRunner as RolloutRunner_Legacy
21
+ from .aggregator import BufferAggregator as BufferAggregator_Legacy
22
+ from .worker import RolloutWorker as RolloutWorker_Legacy
23
+
24
+ __all__ = [
25
+ # Main public interface
26
+ 'RolloutCoordinator',
27
+ 'create_rollout_coordinator',
28
+
29
+ # Core components
30
+ 'RolloutStrategy',
31
+ 'RolloutWorker',
32
+ 'RolloutRunner',
33
+ 'BufferAggregator',
34
+
35
+ # Strategies
36
+ 'PPOStrategy',
37
+ 'GRPOStrategy',
38
+ 'create_strategy',
39
+
40
+ # Legacy compatibility
41
+ 'RolloutRunner_Legacy',
42
+ 'BufferAggregator_Legacy',
43
+ 'RolloutWorker_Legacy',
44
+ ]
@@ -0,0 +1,145 @@
1
+ # textpolicy/rollout/aggregator.py
2
+ """
3
+ Multi-worker buffer aggregation and coordination.
4
+ """
5
+
6
+ from typing import List, Optional
7
+ import multiprocessing as mp
8
+ from textpolicy.buffer import Buffer
9
+ from .worker import RolloutWorker
10
+ from .base import DEFAULT_MAX_EPISODES
11
+
12
+
13
+ class BufferAggregator:
14
+ """
15
+ Aggregates episodes from multiple RolloutWorkers.
16
+
17
+ Coordinates data collection from multiple processes:
18
+ - Maintains bounded buffer of complete episodes
19
+ - Provides sampling interface for trainer
20
+ - Handles queue management and data consumption
21
+ - Thread-safe operation for async collection
22
+ """
23
+
24
+ def __init__(self, num_workers: int, max_episodes: int = DEFAULT_MAX_EPISODES):
25
+ """
26
+ Initialize aggregator for multi-worker coordination.
27
+
28
+ Args:
29
+ num_workers: Number of RolloutWorker processes to coordinate
30
+ max_episodes: Maximum episodes to keep (oldest are dropped)
31
+ """
32
+ self.num_workers = num_workers
33
+ self.buffer = Buffer(max_episodes=max_episodes)
34
+ self._worker_queues: List[Optional[mp.Queue]] = [None] * num_workers
35
+
36
+ def add_worker(self, worker: RolloutWorker, worker_id: int):
37
+ """
38
+ Register a worker's send queue for data consumption.
39
+
40
+ Args:
41
+ worker: RolloutWorker instance
42
+ worker_id: Unique worker ID (0 <= id < num_workers)
43
+
44
+ Raises:
45
+ ValueError: If worker_id is out of bounds
46
+ """
47
+ if not (0 <= worker_id < self.num_workers):
48
+ raise ValueError(f"worker_id must be in [0, {self.num_workers - 1}]")
49
+ self._worker_queues[worker_id] = worker.send_queue
50
+
51
+ def consume_from_worker(self, worker_id: int) -> bool:
52
+ """
53
+ Try to consume new episodes from specific worker.
54
+
55
+ Args:
56
+ worker_id: ID of the worker to consume from
57
+
58
+ Returns:
59
+ True if data was consumed, False if no data available
60
+
61
+ Raises:
62
+ ValueError: If worker_id has no registered queue
63
+ """
64
+ queue = self._worker_queues[worker_id]
65
+ if queue is None:
66
+ raise ValueError(f"No queue registered for worker {worker_id}")
67
+
68
+ if queue.empty():
69
+ return False
70
+
71
+ # Get serialized episodes from worker
72
+ episodes_data = queue.get()
73
+
74
+ # Add episodes to aggregated buffer
75
+ for ep_dict in episodes_data:
76
+ self.buffer.add_episode_from_dict(ep_dict)
77
+
78
+ return True
79
+
80
+ def consume_all(self) -> int:
81
+ """
82
+ Consume available data from all workers.
83
+
84
+ Non-blocking operation that checks all worker queues.
85
+
86
+ Returns:
87
+ Number of workers that had data ready
88
+ """
89
+ count = 0
90
+ for wid in range(self.num_workers):
91
+ queue = self._worker_queues[wid]
92
+ if queue is not None and not queue.empty():
93
+ if self.consume_from_worker(wid):
94
+ count += 1
95
+ return count
96
+
97
+ def ready(self, min_episodes: int = 1) -> bool:
98
+ """
99
+ Check if enough episodes have been collected for training.
100
+
101
+ Args:
102
+ min_episodes: Minimum number of episodes required
103
+
104
+ Returns:
105
+ True if buffer has at least min_episodes
106
+ """
107
+ return self.buffer.ready(min_episodes)
108
+
109
+ def sample_latest_steps(self, n: int) -> dict:
110
+ """
111
+ Sample the N most recent steps for training.
112
+
113
+ Args:
114
+ n: Number of steps to sample
115
+
116
+ Returns:
117
+ Dict of MLX arrays (obs, act, rew, done, etc.)
118
+ """
119
+ return self.buffer.sample_latest_steps(n)
120
+
121
+ def sample_episodes(self, k: int, order: str = 'desc') -> dict:
122
+ """
123
+ Sample up to k episodes for training.
124
+
125
+ Args:
126
+ k: Number of episodes to sample
127
+ order: 'asc' (oldest first) or 'desc' (newest first)
128
+
129
+ Returns:
130
+ Dict of MLX arrays containing episode data
131
+ """
132
+ return self.buffer.sample_episodes(k, order)
133
+
134
+ def clear(self):
135
+ """Clear all collected episodes from the buffer."""
136
+ self.buffer.clear()
137
+
138
+ def __len__(self) -> int:
139
+ """Total number of steps across all episodes."""
140
+ return len(self.buffer)
141
+
142
+ @property
143
+ def episode_count(self) -> int:
144
+ """Number of complete episodes currently stored."""
145
+ return len(self.buffer.episodes)
@@ -0,0 +1,108 @@
1
+ # textpolicy/rollout/base.py
2
+ """
3
+ Base classes and protocols for the rollout system.
4
+ """
5
+
6
+ from typing import Callable, Dict, Any, Tuple, Protocol
7
+ import mlx.core as mx # type: ignore
8
+ from textpolicy.buffer import Buffer
9
+
10
+
11
+ class RolloutStrategy(Protocol):
12
+ """
13
+ Protocol for rollout strategies.
14
+
15
+ Defines the interface used by RolloutRunner for algorithm-specific behavior.
16
+ Each strategy encapsulates how to:
17
+ - Select actions from policy outputs
18
+ - Store transition data in buffers
19
+ - Handle algorithm-specific requirements
20
+ """
21
+
22
+ def select_action(self, policy: Callable, obs: mx.array) -> Tuple[mx.array, Dict[str, Any]]:
23
+ """
24
+ Select an action using the policy.
25
+
26
+ Args:
27
+ policy: Function that takes obs and returns (action, extras)
28
+ obs: MLX array observation
29
+
30
+ Returns:
31
+ action: mx.array (scalar or tensor)
32
+ extras: Dict of additional data (e.g. logprob, value)
33
+ """
34
+ ...
35
+
36
+ def store_transition(self, buffer: Buffer, **data) -> None:
37
+ """
38
+ Store a transition in the buffer.
39
+
40
+ Args:
41
+ buffer: Buffer instance
42
+ **data: Transition data (obs, act, rew, next_obs, done, logprob, value, etc.)
43
+ """
44
+ ...
45
+
46
+
47
+ # Common constants and configurations
48
+ DEFAULT_MAX_STEPS = 1000
49
+ DEFAULT_MAX_EPISODES = 100
50
+ DEFAULT_WORKER_TIMEOUT = 1.0
51
+
52
+ # Supported transition data keys for validation
53
+ REQUIRED_TRANSITION_KEYS = {'obs', 'act', 'rew', 'next_obs', 'done'}
54
+ OPTIONAL_TRANSITION_KEYS = {'timeout', 'logprob', 'value', 'entropy'}
55
+ ALL_TRANSITION_KEYS = REQUIRED_TRANSITION_KEYS | OPTIONAL_TRANSITION_KEYS
56
+
57
+
58
+ def validate_transition_data(data: Dict[str, Any]) -> Dict[str, Any]:
59
+ """
60
+ Validate and filter transition data.
61
+
62
+ Args:
63
+ data: Dictionary containing transition data
64
+
65
+ Returns:
66
+ Filtered dictionary with only valid keys
67
+
68
+ Raises:
69
+ ValueError: If required keys are missing
70
+ """
71
+ # Check required keys
72
+ missing_keys = REQUIRED_TRANSITION_KEYS - set(data.keys())
73
+ if missing_keys:
74
+ raise ValueError(f"Missing required transition keys: {missing_keys}")
75
+
76
+ # Filter to only valid keys
77
+ valid_data = {k: v for k, v in data.items() if k in ALL_TRANSITION_KEYS}
78
+
79
+ return valid_data
80
+
81
+
82
+ def serialize_mx_array(arr: mx.array) -> Any:
83
+ """
84
+ Serialize MLX array for multiprocessing communication.
85
+
86
+ Args:
87
+ arr: MLX array to serialize
88
+
89
+ Returns:
90
+ Python scalar or list suitable for queue transmission
91
+ """
92
+ if arr.ndim == 0:
93
+ return arr.item() # Scalar
94
+ else:
95
+ return arr.tolist() # List
96
+
97
+
98
+ def deserialize_to_mx_array(data: Any) -> mx.array:
99
+ """
100
+ Deserialize data back to MLX array.
101
+
102
+ Args:
103
+ data: Python scalar or list
104
+
105
+ Returns:
106
+ MLX array
107
+ """
108
+ return mx.array(data)