textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,293 @@
1
+ # textpolicy/rewards/registry.py
2
+ """
3
+ Unified reward and verifier registry system following retrain's philosophy.
4
+
5
+ This module provides decorator-based registration for rewards and verifiers,
6
+ maintaining compatibility with MLX optimization and pure function composition.
7
+
8
+ Key principles:
9
+ - Decorator-based registration (@reward, @verifier)
10
+ - Function signature consistency: (prompt, completion, example, **kwargs)
11
+ - Pre-filtering verification approach
12
+ - Global registries for modularity
13
+ - MLX compilation support
14
+ """
15
+
16
+ from typing import Callable, Dict, Any, List, Optional, Union
17
+ import inspect
18
+ import functools
19
+ import mlx.core as mx
20
+ from dataclasses import dataclass
21
+
22
+ # Type definitions following retrain's patterns
23
+ RewardFunction = Callable[[str, str, Dict[str, Any]], float]
24
+ VerifierFunction = Callable[[str, str, Dict[str, Any]], bool]
25
+
26
+ # Global registries following retrain's architecture
27
+ REWARD_REGISTRY: Dict[str, RewardFunction] = {}
28
+ VERIFIER_REGISTRY: Dict[str, VerifierFunction] = {}
29
+
30
+ # Simple logging
31
+ import logging
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def reward(_func: Optional[RewardFunction] = None, *, name: Optional[str] = None) -> Union[Callable[[RewardFunction], RewardFunction], RewardFunction]:
36
+ """
37
+ Decorator to register reward functions following retrain's pattern.
38
+
39
+ Usage:
40
+ @reward
41
+ def my_reward(prompt: str, completion: str, example: Dict[str, Any]) -> float:
42
+ return 1.0
43
+
44
+ @reward(name="custom_name")
45
+ def another_reward(prompt: str, completion: str, example: Dict[str, Any]) -> float:
46
+ return 0.5
47
+ """
48
+ def decorator_reward(func: RewardFunction) -> RewardFunction:
49
+ if not callable(func):
50
+ raise TypeError(f"Object {getattr(func, '__name__', '<unknown>')} must be callable to be registered as reward.")
51
+
52
+ registration_name = name if name is not None else func.__name__
53
+
54
+ # Validate function signature for consistency
55
+ sig = inspect.signature(func)
56
+ expected_params = ['prompt', 'completion', 'example']
57
+
58
+ if len(sig.parameters) < 3:
59
+ logger.warning(f"Reward function '{registration_name}' has fewer than 3 expected parameters. Ensure signature compatibility.")
60
+
61
+ param_names = list(sig.parameters.keys())
62
+ for i, expected in enumerate(expected_params):
63
+ if i < len(param_names) and param_names[i] != expected:
64
+ logger.warning(f"Reward function '{registration_name}' parameter {i} is '{param_names[i]}', expected '{expected}'.")
65
+
66
+ if registration_name in REWARD_REGISTRY:
67
+ logger.warning(f"Reward function '{registration_name}' already registered. Overwriting.")
68
+
69
+ REWARD_REGISTRY[registration_name] = func
70
+ logger.info(f"Registered reward function: '{registration_name}' -> {func.__name__}")
71
+
72
+ return func
73
+
74
+ if _func is None:
75
+ # Called with parentheses: @reward() or @reward(name=...)
76
+ return decorator_reward
77
+ elif callable(_func):
78
+ # Called without parentheses: @reward
79
+ if name is not None:
80
+ raise TypeError("Cannot specify 'name' when using @reward without parentheses. Use @reward(name='...') instead.")
81
+ return decorator_reward(_func)
82
+ else:
83
+ raise TypeError("Invalid arguments supplied to @reward decorator.")
84
+
85
+
86
+ def verifier(_func: Optional[VerifierFunction] = None, *, name: Optional[str] = None) -> Union[Callable[[VerifierFunction], VerifierFunction], VerifierFunction]:
87
+ """
88
+ Decorator to register verifier functions following retrain's pattern.
89
+
90
+ Usage:
91
+ @verifier
92
+ def has_greeting(prompt: str, completion: str, example: Dict[str, Any]) -> bool:
93
+ return completion.lower().startswith("hello")
94
+
95
+ @verifier(name="custom_check")
96
+ def custom_verifier(prompt: str, completion: str, example: Dict[str, Any]) -> bool:
97
+ return len(completion) > 10
98
+ """
99
+ def decorator_verifier(func: VerifierFunction) -> VerifierFunction:
100
+ if not callable(func):
101
+ raise TypeError(f"Object {getattr(func, '__name__', '<unknown>')} must be callable to be registered as verifier.")
102
+
103
+ registration_name = name if name is not None else func.__name__
104
+
105
+ # Validate function signature
106
+ sig = inspect.signature(func)
107
+ if len(sig.parameters) < 3:
108
+ logger.warning(f"Verifier function '{registration_name}' has fewer than 3 expected parameters (prompt, completion, example).")
109
+
110
+ if registration_name in VERIFIER_REGISTRY:
111
+ logger.warning(f"Verifier function '{registration_name}' already registered. Overwriting.")
112
+
113
+ VERIFIER_REGISTRY[registration_name] = func
114
+ logger.info(f"Registered verifier function: '{registration_name}' -> {func.__name__}")
115
+
116
+ return func
117
+
118
+ if _func is None:
119
+ return decorator_verifier
120
+ elif callable(_func):
121
+ if name is not None:
122
+ raise TypeError("Cannot specify 'name' when using @verifier without parentheses. Use @verifier(name='...') instead.")
123
+ return decorator_verifier(_func)
124
+ else:
125
+ raise TypeError("Invalid arguments supplied to @verifier decorator.")
126
+
127
+
128
+ def get_reward_function(name: str) -> Optional[RewardFunction]:
129
+ """Retrieve a registered reward function by name."""
130
+ func = REWARD_REGISTRY.get(name)
131
+ if func is None:
132
+ available = list(REWARD_REGISTRY.keys())
133
+ logger.error(f"Reward function '{name}' not found. Available: {available}")
134
+ return func
135
+
136
+
137
+ def get_verifier_function(name: str) -> Optional[VerifierFunction]:
138
+ """Retrieve a registered verifier function by name."""
139
+ func = VERIFIER_REGISTRY.get(name)
140
+ if func is None:
141
+ available = list(VERIFIER_REGISTRY.keys())
142
+ logger.error(f"Verifier function '{name}' not found. Available: {available}")
143
+ return func
144
+
145
+
146
+ def apply_verifiers_to_reward(
147
+ original_reward_func: RewardFunction,
148
+ verifier_names: List[str],
149
+ penalty_on_failure: float = 0.0
150
+ ) -> RewardFunction:
151
+ """
152
+ Apply verifiers to a reward function following retrain's pre-filtering approach.
153
+
154
+ If any verifier fails, returns the penalty value without executing the reward function.
155
+ This follows retrain's philosophy of efficient pre-filtering.
156
+
157
+ Args:
158
+ original_reward_func: The reward function to wrap
159
+ verifier_names: List of verifier names to apply
160
+ penalty_on_failure: Value to return if any verifier fails
161
+
162
+ Returns:
163
+ Wrapped reward function that applies verifiers first
164
+ """
165
+ # Load verifier functions
166
+ loaded_verifiers: List[VerifierFunction] = []
167
+ missing_verifiers = []
168
+
169
+ for verifier_name in verifier_names:
170
+ verifier_func = get_verifier_function(verifier_name)
171
+ if verifier_func is None:
172
+ missing_verifiers.append(verifier_name)
173
+ else:
174
+ loaded_verifiers.append(verifier_func)
175
+
176
+ if missing_verifiers:
177
+ available = list(VERIFIER_REGISTRY.keys())
178
+ raise ValueError(f"Verifiers {missing_verifiers} not found for reward '{original_reward_func.__name__}'. Available: {available}")
179
+
180
+ @functools.wraps(original_reward_func)
181
+ def reward_with_verifiers(prompt: str, completion: str, example: Dict[str, Any], **kwargs) -> float:
182
+ # Apply verifiers first (pre-filtering approach)
183
+ for i, verifier_func in enumerate(loaded_verifiers):
184
+ verifier_name = verifier_names[i]
185
+ try:
186
+ if not verifier_func(prompt, completion, example):
187
+ logger.debug(f"Verifier '{verifier_name}' failed for reward '{original_reward_func.__name__}'. Applying penalty: {penalty_on_failure}")
188
+ return penalty_on_failure
189
+ except Exception as e:
190
+ logger.error(f"Verifier '{verifier_name}' errored: {e}. Applying penalty: {penalty_on_failure}")
191
+ return penalty_on_failure
192
+
193
+ # All verifiers passed, execute reward function
194
+ try:
195
+ result = original_reward_func(prompt, completion, example, **kwargs)
196
+ return float(result)
197
+ except Exception as e:
198
+ logger.error(f"Reward function '{original_reward_func.__name__}' errored after verifiers passed: {e}. Returning 0.0")
199
+ return 0.0
200
+
201
+ # Set descriptive name
202
+ if verifier_names:
203
+ verifier_suffix = '_and_'.join(verifier_names)
204
+ reward_with_verifiers.__name__ = f"{original_reward_func.__name__}_verified_by_{verifier_suffix}"
205
+ else:
206
+ reward_with_verifiers.__name__ = original_reward_func.__name__
207
+
208
+ return reward_with_verifiers
209
+
210
+
211
+ @dataclass
212
+ class RewardConfig:
213
+ """Configuration for a reward function following retrain's patterns."""
214
+ name: str # Name in REWARD_REGISTRY
215
+ weight: float = 1.0
216
+ params: Optional[Dict[str, Any]] = None
217
+ verifiers: Optional[List[str]] = None
218
+ verifier_penalty: float = 0.0
219
+
220
+ def __post_init__(self):
221
+ if self.params is None:
222
+ self.params = {}
223
+ if self.verifiers is None:
224
+ self.verifiers = []
225
+
226
+
227
+ def create_configured_reward_function(config: RewardConfig) -> RewardFunction:
228
+ """
229
+ Create a reward function from configuration following retrain's approach.
230
+
231
+ This function:
232
+ 1. Loads the base reward function from registry
233
+ 2. Applies verifiers if specified
234
+ 3. Handles parameter passing
235
+ 4. Returns a configured function ready for use
236
+ """
237
+ # Get base reward function
238
+ base_reward_func = get_reward_function(config.name)
239
+ if base_reward_func is None:
240
+ raise ValueError(f"Reward function '{config.name}' not found in registry")
241
+
242
+ # Create wrapper that handles params
243
+ def reward_with_params(prompt: str, completion: str, example: Dict[str, Any], **kwargs) -> float:
244
+ # Merge config params with runtime kwargs
245
+ merged_kwargs = {**config.params, **kwargs}
246
+ return base_reward_func(prompt, completion, example, **merged_kwargs)
247
+
248
+ reward_with_params.__name__ = f"{base_reward_func.__name__}_with_params"
249
+
250
+ # Apply verifiers if specified
251
+ if config.verifiers:
252
+ reward_with_params = apply_verifiers_to_reward(
253
+ reward_with_params,
254
+ config.verifiers,
255
+ config.verifier_penalty
256
+ )
257
+
258
+ return reward_with_params
259
+
260
+
261
+ # MLX optimization support
262
+ @mx.compile
263
+ def batch_reward_computation(
264
+ base_rewards: mx.array,
265
+ weights: mx.array
266
+ ) -> mx.array:
267
+ """
268
+ MLX-compiled function for efficient batch reward computation.
269
+
270
+ Args:
271
+ base_rewards: Individual reward scores [batch_size, num_rewards]
272
+ weights: Reward weights [num_rewards]
273
+
274
+ Returns:
275
+ Weighted combined rewards [batch_size]
276
+ """
277
+ return mx.sum(base_rewards * weights, axis=1)
278
+
279
+
280
+ def list_registered_functions() -> Dict[str, List[str]]:
281
+ """List all registered reward and verifier functions."""
282
+ return {
283
+ "rewards": list(REWARD_REGISTRY.keys()),
284
+ "verifiers": list(VERIFIER_REGISTRY.keys())
285
+ }
286
+
287
+
288
+ def clear_registries():
289
+ """Clear all registries (useful for testing)."""
290
+ global REWARD_REGISTRY, VERIFIER_REGISTRY
291
+ REWARD_REGISTRY.clear()
292
+ VERIFIER_REGISTRY.clear()
293
+ logger.info("Cleared all reward and verifier registries")
@@ -0,0 +1,410 @@
1
+ # textpolicy/rewards/rollout_rewards.py
2
+ """
3
+ Rollout-level reward processing system for efficient MLX training.
4
+
5
+ This system processes rewards at the episode/rollout level rather than
6
+ per-transition, enabling vectorized operations and batch processing
7
+ for optimal MLX performance.
8
+
9
+ Key features:
10
+ - Batch reward computation for entire episodes
11
+ - Vectorized operations using MLX
12
+ - Integration with rollout buffer system
13
+ - Support for async external reward models
14
+ - Pure function composition
15
+ """
16
+
17
+ from typing import Dict, List, Optional, Any
18
+ import mlx.core as mx
19
+ from dataclasses import dataclass
20
+ import asyncio
21
+ from concurrent.futures import ThreadPoolExecutor
22
+
23
+ # Optional dependency
24
+ try:
25
+ import aiohttp # type: ignore
26
+ HAS_AIOHTTP = True
27
+ except ImportError:
28
+ HAS_AIOHTTP = False
29
+ aiohttp = None
30
+
31
+ # Import reward functions used in this module
32
+ # These are registered via the @reward decorator and provide the standard (prompt, completion, example) signature
33
+ from .basic import length_reward, keyword_reward, perplexity_reward, accuracy_reward
34
+
35
+
36
+ @dataclass
37
+ class RewardConfig:
38
+ """Configuration for rollout reward processing."""
39
+ # Basic reward weights
40
+ length_weight: float = 0.1
41
+ keyword_weight: float = 0.2
42
+ perplexity_weight: float = 0.3
43
+ accuracy_weight: float = 0.4
44
+
45
+ # Target parameters
46
+ target_length: int = 50
47
+ keywords: Optional[List[str]] = None
48
+
49
+ # External reward model
50
+ external_rm_url: Optional[str] = None
51
+ external_rm_timeout: float = 30.0
52
+
53
+ # Batch processing
54
+ batch_size: int = 32
55
+ max_workers: int = 4
56
+
57
+ def __post_init__(self):
58
+ if self.keywords is None:
59
+ self.keywords = []
60
+
61
+
62
+ class RolloutRewardProcessor:
63
+ """
64
+ Efficient rollout-level reward processor for MLX training.
65
+
66
+ Processes entire episodes in batches, enabling vectorized operations
67
+ and optimal memory usage on Apple Silicon.
68
+ """
69
+
70
+ def __init__(self, config: RewardConfig):
71
+ """
72
+ Initialize reward processor.
73
+
74
+ Args:
75
+ config: Reward processing configuration
76
+ """
77
+ self.config = config
78
+ self.executor = ThreadPoolExecutor(max_workers=config.max_workers)
79
+
80
+ # Pre-compile reward functions for MLX
81
+ self._compile_reward_functions()
82
+
83
+ def _compile_reward_functions(self):
84
+ """Pre-compile reward functions for MLX optimization."""
85
+ # These will be compiled when first called
86
+ self._compiled_functions = {}
87
+
88
+ def process_rollout(
89
+ self,
90
+ prompt: str,
91
+ completion: str,
92
+ example: Dict[str, Any]
93
+ ) -> mx.array:
94
+ """
95
+ Process a single rollout with prompt, completion, and example context.
96
+
97
+ This method provides the standard interface expected by tests and other components.
98
+ It follows the same signature as individual reward functions for consistency.
99
+
100
+ Args:
101
+ prompt: Input prompt text
102
+ completion: Generated completion text
103
+ example: Example context with target parameters and metadata
104
+
105
+ Returns:
106
+ MLX array with single reward value [1]
107
+ """
108
+ # Convert single rollout to episode format for consistency
109
+ episode = {
110
+ 'prompt': prompt,
111
+ 'response': completion,
112
+ 'metadata': example
113
+ }
114
+
115
+ # Process as single episode and return scalar reward
116
+ rewards = self.process_episode_rewards([episode])
117
+ return rewards[0] if rewards.size > 0 else mx.array(0.0)
118
+
119
+ def process_batch_rollouts(
120
+ self,
121
+ prompts: List[str],
122
+ completions: List[str],
123
+ examples: List[Dict[str, Any]]
124
+ ) -> mx.array:
125
+ """
126
+ Process multiple rollouts in batch for efficient processing.
127
+
128
+ This method provides batch processing capability expected by tests and training loops.
129
+ It maintains the same interface as individual reward functions but processes multiple
130
+ examples simultaneously for better performance.
131
+
132
+ Args:
133
+ prompts: List of input prompt texts
134
+ completions: List of generated completion texts
135
+ examples: List of example contexts with target parameters
136
+
137
+ Returns:
138
+ MLX array of rewards [num_rollouts]
139
+ """
140
+ if len(prompts) != len(completions) or len(completions) != len(examples):
141
+ raise ValueError(f"Mismatched input lengths: prompts={len(prompts)}, "
142
+ f"completions={len(completions)}, examples={len(examples)}")
143
+
144
+ # Convert to episode format for consistency with existing implementation
145
+ episodes = []
146
+ for prompt, completion, example in zip(prompts, completions, examples):
147
+ episode = {
148
+ 'prompt': prompt,
149
+ 'response': completion,
150
+ 'metadata': example
151
+ }
152
+ episodes.append(episode)
153
+
154
+ # Process batch using existing episode processing logic
155
+ return self.process_episode_rewards(episodes)
156
+
157
+ def process_episode_rewards(
158
+ self,
159
+ episodes: List[Dict[str, Any]]
160
+ ) -> mx.array:
161
+ """
162
+ Process rewards for a batch of episodes.
163
+
164
+ Args:
165
+ episodes: List of episode dictionaries with 'prompt' and 'response' fields
166
+
167
+ Returns:
168
+ MLX array of rewards [num_episodes]
169
+ """
170
+ if not episodes:
171
+ return mx.array([])
172
+
173
+ # Extract prompts and responses
174
+ prompts = [ep.get('prompt', '') for ep in episodes]
175
+ responses = [ep.get('response', '') for ep in episodes]
176
+
177
+ # Compute basic rewards (vectorized where possible)
178
+ rewards = self._compute_basic_rewards(prompts, responses)
179
+
180
+ # Add external reward model scores if configured
181
+ if self.config.external_rm_url:
182
+ external_rewards = self._get_external_rewards(episodes)
183
+ # Blend with basic rewards
184
+ alpha = 0.7 # Weight for external rewards
185
+ rewards = alpha * external_rewards + (1 - alpha) * rewards
186
+
187
+ return rewards
188
+
189
+ def _compute_basic_rewards(
190
+ self,
191
+ prompts: List[str],
192
+ responses: List[str]
193
+ ) -> mx.array:
194
+ """
195
+ Compute basic rewards using vectorized operations.
196
+
197
+ Args:
198
+ prompts: List of input prompts
199
+ responses: List of generated responses
200
+
201
+ Returns:
202
+ MLX array of combined rewards
203
+ """
204
+ # Initialize reward array
205
+ num_episodes = len(prompts)
206
+ rewards = mx.zeros(num_episodes)
207
+
208
+ # Length rewards - pass config as example dict for unified interface
209
+ # The @reward decorator enforces (prompt, completion, example) signature
210
+ if self.config.length_weight > 0:
211
+ length_rewards = mx.array([
212
+ length_reward(p, r, {"target_length": self.config.target_length}) # type: ignore
213
+ for p, r in zip(prompts, responses)
214
+ ])
215
+ rewards += self.config.length_weight * length_rewards
216
+
217
+ # Keyword rewards - pass config as example dict for unified interface
218
+ # The @reward decorator enforces (prompt, completion, example) signature
219
+ if self.config.keyword_weight > 0 and self.config.keywords:
220
+ keyword_rewards = mx.array([
221
+ keyword_reward(p, r, {"keywords": self.config.keywords}) # type: ignore
222
+ for p, r in zip(prompts, responses)
223
+ ])
224
+ rewards += self.config.keyword_weight * keyword_rewards
225
+
226
+ # Perplexity rewards - pass empty example dict for unified interface
227
+ # The @reward decorator enforces (prompt, completion, example) signature
228
+ if self.config.perplexity_weight > 0:
229
+ perplexity_rewards = mx.array([
230
+ perplexity_reward(p, r, {}) # type: ignore
231
+ for p, r in zip(prompts, responses)
232
+ ])
233
+ rewards += self.config.perplexity_weight * perplexity_rewards
234
+
235
+ # Accuracy rewards - pass empty example dict for unified interface
236
+ # The @reward decorator enforces (prompt, completion, example) signature
237
+ if self.config.accuracy_weight > 0:
238
+ accuracy_rewards = mx.array([
239
+ accuracy_reward(p, r, {}) # type: ignore
240
+ for p, r in zip(prompts, responses)
241
+ ])
242
+ rewards += self.config.accuracy_weight * accuracy_rewards
243
+
244
+ return rewards
245
+
246
+ def _get_external_rewards(
247
+ self,
248
+ episodes: List[Dict[str, Any]]
249
+ ) -> mx.array:
250
+ """
251
+ Get rewards from external reward model.
252
+
253
+ Args:
254
+ episodes: List of episode dictionaries
255
+
256
+ Returns:
257
+ MLX array of external rewards
258
+ """
259
+ # For now, return default rewards
260
+ # In practice, this would call external API or model
261
+ return mx.ones(len(episodes)) * 0.5
262
+
263
+ async def _async_external_rewards(
264
+ self,
265
+ episodes: List[Dict[str, Any]]
266
+ ) -> List[float]:
267
+ """
268
+ Async version for external reward model calls.
269
+
270
+ Args:
271
+ episodes: List of episode dictionaries
272
+
273
+ Returns:
274
+ List of reward scores
275
+ """
276
+ if not self.config.external_rm_url:
277
+ return [0.5] * len(episodes)
278
+
279
+ if not HAS_AIOHTTP or aiohttp is None:
280
+ return [0.5] * len(episodes)
281
+
282
+ # Type guard: aiohttp is guaranteed to be available from this point
283
+ # Assert for type checker that aiohttp is not None after the above checks
284
+ assert aiohttp is not None, "aiohttp should be available when HAS_AIOHTTP is True"
285
+
286
+ async def get_reward(session, episode):
287
+ payload = {
288
+ "prompt": episode.get("prompt", ""),
289
+ "response": episode.get("response", ""),
290
+ "metadata": episode.get("metadata", {})
291
+ }
292
+
293
+ try:
294
+ async with session.post(
295
+ self.config.external_rm_url,
296
+ json=payload,
297
+ timeout=aiohttp.ClientTimeout(total=self.config.external_rm_timeout)
298
+ ) as resp:
299
+ if resp.status == 200:
300
+ result = await resp.json()
301
+ return result.get("reward", 0.5)
302
+ else:
303
+ return 0.5
304
+ except Exception:
305
+ return 0.5
306
+
307
+ async with aiohttp.ClientSession() as session:
308
+ tasks = [get_reward(session, ep) for ep in episodes]
309
+ results = await asyncio.gather(*tasks)
310
+ return list(results)
311
+
312
+ def process_buffer_rewards(self, buffer) -> mx.array:
313
+ """
314
+ Process rewards for all episodes in a buffer.
315
+
316
+ Args:
317
+ buffer: Buffer instance containing episodes
318
+
319
+ Returns:
320
+ MLX array of episode rewards
321
+ """
322
+ # Extract episodes from buffer
323
+ episodes = []
324
+ for episode in buffer.storage.episodes:
325
+ # Convert episode to dict format
326
+ episode_dict = {
327
+ 'prompt': episode.obs[0] if episode.obs else '',
328
+ 'response': episode.act[-1] if episode.act else '',
329
+ 'metadata': {
330
+ 'length': len(episode.obs),
331
+ 'logprobs': episode.logprob,
332
+ 'values': episode.value
333
+ }
334
+ }
335
+ episodes.append(episode_dict)
336
+
337
+ return self.process_episode_rewards(episodes)
338
+
339
+ def close(self):
340
+ """Cleanup resources."""
341
+ self.executor.shutdown(wait=True)
342
+
343
+
344
+ # Pure function interface for integration with rollout system
345
+ def create_rollout_reward_processor(config: RewardConfig) -> RolloutRewardProcessor:
346
+ """
347
+ Factory function for creating reward processors.
348
+
349
+ Args:
350
+ config: Reward processing configuration
351
+
352
+ Returns:
353
+ RolloutRewardProcessor instance
354
+ """
355
+ return RolloutRewardProcessor(config)
356
+
357
+
358
+ def process_episode_batch_rewards(
359
+ episodes: List[Dict[str, Any]],
360
+ config: RewardConfig
361
+ ) -> mx.array:
362
+ """
363
+ Pure function for processing episode rewards.
364
+
365
+ Args:
366
+ episodes: List of episode dictionaries
367
+ config: Reward configuration
368
+
369
+ Returns:
370
+ MLX array of rewards
371
+ """
372
+ processor = RolloutRewardProcessor(config)
373
+ try:
374
+ return processor.process_episode_rewards(episodes)
375
+ finally:
376
+ processor.close()
377
+
378
+
379
+ # MLX-compiled reward computation for high-performance training
380
+ @mx.compile
381
+ def compute_reward_vector(
382
+ response_lengths: mx.array,
383
+ keyword_matches: mx.array,
384
+ fluency_scores: mx.array,
385
+ accuracy_scores: mx.array,
386
+ weights: mx.array
387
+ ) -> mx.array:
388
+ """
389
+ MLX-compiled function for vectorized reward computation.
390
+
391
+ Args:
392
+ response_lengths: Normalized length scores [batch_size]
393
+ keyword_matches: Keyword match scores [batch_size]
394
+ fluency_scores: Fluency scores [batch_size]
395
+ accuracy_scores: Accuracy scores [batch_size]
396
+ weights: Reward weights [4] (length, keyword, fluency, accuracy)
397
+
398
+ Returns:
399
+ Combined rewards [batch_size]
400
+ """
401
+ # Weighted combination of reward components
402
+ rewards = (
403
+ weights[0] * response_lengths +
404
+ weights[1] * keyword_matches +
405
+ weights[2] * fluency_scores +
406
+ weights[3] * accuracy_scores
407
+ )
408
+
409
+ return rewards
410
+