textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
# textpolicy/rewards/registry.py
|
|
2
|
+
"""
|
|
3
|
+
Unified reward and verifier registry system following retrain's philosophy.
|
|
4
|
+
|
|
5
|
+
This module provides decorator-based registration for rewards and verifiers,
|
|
6
|
+
maintaining compatibility with MLX optimization and pure function composition.
|
|
7
|
+
|
|
8
|
+
Key principles:
|
|
9
|
+
- Decorator-based registration (@reward, @verifier)
|
|
10
|
+
- Function signature consistency: (prompt, completion, example, **kwargs)
|
|
11
|
+
- Pre-filtering verification approach
|
|
12
|
+
- Global registries for modularity
|
|
13
|
+
- MLX compilation support
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from typing import Callable, Dict, Any, List, Optional, Union
|
|
17
|
+
import inspect
|
|
18
|
+
import functools
|
|
19
|
+
import mlx.core as mx
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
# Type definitions following retrain's patterns
|
|
23
|
+
RewardFunction = Callable[[str, str, Dict[str, Any]], float]
|
|
24
|
+
VerifierFunction = Callable[[str, str, Dict[str, Any]], bool]
|
|
25
|
+
|
|
26
|
+
# Global registries following retrain's architecture
|
|
27
|
+
REWARD_REGISTRY: Dict[str, RewardFunction] = {}
|
|
28
|
+
VERIFIER_REGISTRY: Dict[str, VerifierFunction] = {}
|
|
29
|
+
|
|
30
|
+
# Simple logging
|
|
31
|
+
import logging
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def reward(_func: Optional[RewardFunction] = None, *, name: Optional[str] = None) -> Union[Callable[[RewardFunction], RewardFunction], RewardFunction]:
|
|
36
|
+
"""
|
|
37
|
+
Decorator to register reward functions following retrain's pattern.
|
|
38
|
+
|
|
39
|
+
Usage:
|
|
40
|
+
@reward
|
|
41
|
+
def my_reward(prompt: str, completion: str, example: Dict[str, Any]) -> float:
|
|
42
|
+
return 1.0
|
|
43
|
+
|
|
44
|
+
@reward(name="custom_name")
|
|
45
|
+
def another_reward(prompt: str, completion: str, example: Dict[str, Any]) -> float:
|
|
46
|
+
return 0.5
|
|
47
|
+
"""
|
|
48
|
+
def decorator_reward(func: RewardFunction) -> RewardFunction:
|
|
49
|
+
if not callable(func):
|
|
50
|
+
raise TypeError(f"Object {getattr(func, '__name__', '<unknown>')} must be callable to be registered as reward.")
|
|
51
|
+
|
|
52
|
+
registration_name = name if name is not None else func.__name__
|
|
53
|
+
|
|
54
|
+
# Validate function signature for consistency
|
|
55
|
+
sig = inspect.signature(func)
|
|
56
|
+
expected_params = ['prompt', 'completion', 'example']
|
|
57
|
+
|
|
58
|
+
if len(sig.parameters) < 3:
|
|
59
|
+
logger.warning(f"Reward function '{registration_name}' has fewer than 3 expected parameters. Ensure signature compatibility.")
|
|
60
|
+
|
|
61
|
+
param_names = list(sig.parameters.keys())
|
|
62
|
+
for i, expected in enumerate(expected_params):
|
|
63
|
+
if i < len(param_names) and param_names[i] != expected:
|
|
64
|
+
logger.warning(f"Reward function '{registration_name}' parameter {i} is '{param_names[i]}', expected '{expected}'.")
|
|
65
|
+
|
|
66
|
+
if registration_name in REWARD_REGISTRY:
|
|
67
|
+
logger.warning(f"Reward function '{registration_name}' already registered. Overwriting.")
|
|
68
|
+
|
|
69
|
+
REWARD_REGISTRY[registration_name] = func
|
|
70
|
+
logger.info(f"Registered reward function: '{registration_name}' -> {func.__name__}")
|
|
71
|
+
|
|
72
|
+
return func
|
|
73
|
+
|
|
74
|
+
if _func is None:
|
|
75
|
+
# Called with parentheses: @reward() or @reward(name=...)
|
|
76
|
+
return decorator_reward
|
|
77
|
+
elif callable(_func):
|
|
78
|
+
# Called without parentheses: @reward
|
|
79
|
+
if name is not None:
|
|
80
|
+
raise TypeError("Cannot specify 'name' when using @reward without parentheses. Use @reward(name='...') instead.")
|
|
81
|
+
return decorator_reward(_func)
|
|
82
|
+
else:
|
|
83
|
+
raise TypeError("Invalid arguments supplied to @reward decorator.")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def verifier(_func: Optional[VerifierFunction] = None, *, name: Optional[str] = None) -> Union[Callable[[VerifierFunction], VerifierFunction], VerifierFunction]:
|
|
87
|
+
"""
|
|
88
|
+
Decorator to register verifier functions following retrain's pattern.
|
|
89
|
+
|
|
90
|
+
Usage:
|
|
91
|
+
@verifier
|
|
92
|
+
def has_greeting(prompt: str, completion: str, example: Dict[str, Any]) -> bool:
|
|
93
|
+
return completion.lower().startswith("hello")
|
|
94
|
+
|
|
95
|
+
@verifier(name="custom_check")
|
|
96
|
+
def custom_verifier(prompt: str, completion: str, example: Dict[str, Any]) -> bool:
|
|
97
|
+
return len(completion) > 10
|
|
98
|
+
"""
|
|
99
|
+
def decorator_verifier(func: VerifierFunction) -> VerifierFunction:
|
|
100
|
+
if not callable(func):
|
|
101
|
+
raise TypeError(f"Object {getattr(func, '__name__', '<unknown>')} must be callable to be registered as verifier.")
|
|
102
|
+
|
|
103
|
+
registration_name = name if name is not None else func.__name__
|
|
104
|
+
|
|
105
|
+
# Validate function signature
|
|
106
|
+
sig = inspect.signature(func)
|
|
107
|
+
if len(sig.parameters) < 3:
|
|
108
|
+
logger.warning(f"Verifier function '{registration_name}' has fewer than 3 expected parameters (prompt, completion, example).")
|
|
109
|
+
|
|
110
|
+
if registration_name in VERIFIER_REGISTRY:
|
|
111
|
+
logger.warning(f"Verifier function '{registration_name}' already registered. Overwriting.")
|
|
112
|
+
|
|
113
|
+
VERIFIER_REGISTRY[registration_name] = func
|
|
114
|
+
logger.info(f"Registered verifier function: '{registration_name}' -> {func.__name__}")
|
|
115
|
+
|
|
116
|
+
return func
|
|
117
|
+
|
|
118
|
+
if _func is None:
|
|
119
|
+
return decorator_verifier
|
|
120
|
+
elif callable(_func):
|
|
121
|
+
if name is not None:
|
|
122
|
+
raise TypeError("Cannot specify 'name' when using @verifier without parentheses. Use @verifier(name='...') instead.")
|
|
123
|
+
return decorator_verifier(_func)
|
|
124
|
+
else:
|
|
125
|
+
raise TypeError("Invalid arguments supplied to @verifier decorator.")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_reward_function(name: str) -> Optional[RewardFunction]:
|
|
129
|
+
"""Retrieve a registered reward function by name."""
|
|
130
|
+
func = REWARD_REGISTRY.get(name)
|
|
131
|
+
if func is None:
|
|
132
|
+
available = list(REWARD_REGISTRY.keys())
|
|
133
|
+
logger.error(f"Reward function '{name}' not found. Available: {available}")
|
|
134
|
+
return func
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_verifier_function(name: str) -> Optional[VerifierFunction]:
|
|
138
|
+
"""Retrieve a registered verifier function by name."""
|
|
139
|
+
func = VERIFIER_REGISTRY.get(name)
|
|
140
|
+
if func is None:
|
|
141
|
+
available = list(VERIFIER_REGISTRY.keys())
|
|
142
|
+
logger.error(f"Verifier function '{name}' not found. Available: {available}")
|
|
143
|
+
return func
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def apply_verifiers_to_reward(
|
|
147
|
+
original_reward_func: RewardFunction,
|
|
148
|
+
verifier_names: List[str],
|
|
149
|
+
penalty_on_failure: float = 0.0
|
|
150
|
+
) -> RewardFunction:
|
|
151
|
+
"""
|
|
152
|
+
Apply verifiers to a reward function following retrain's pre-filtering approach.
|
|
153
|
+
|
|
154
|
+
If any verifier fails, returns the penalty value without executing the reward function.
|
|
155
|
+
This follows retrain's philosophy of efficient pre-filtering.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
original_reward_func: The reward function to wrap
|
|
159
|
+
verifier_names: List of verifier names to apply
|
|
160
|
+
penalty_on_failure: Value to return if any verifier fails
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Wrapped reward function that applies verifiers first
|
|
164
|
+
"""
|
|
165
|
+
# Load verifier functions
|
|
166
|
+
loaded_verifiers: List[VerifierFunction] = []
|
|
167
|
+
missing_verifiers = []
|
|
168
|
+
|
|
169
|
+
for verifier_name in verifier_names:
|
|
170
|
+
verifier_func = get_verifier_function(verifier_name)
|
|
171
|
+
if verifier_func is None:
|
|
172
|
+
missing_verifiers.append(verifier_name)
|
|
173
|
+
else:
|
|
174
|
+
loaded_verifiers.append(verifier_func)
|
|
175
|
+
|
|
176
|
+
if missing_verifiers:
|
|
177
|
+
available = list(VERIFIER_REGISTRY.keys())
|
|
178
|
+
raise ValueError(f"Verifiers {missing_verifiers} not found for reward '{original_reward_func.__name__}'. Available: {available}")
|
|
179
|
+
|
|
180
|
+
@functools.wraps(original_reward_func)
|
|
181
|
+
def reward_with_verifiers(prompt: str, completion: str, example: Dict[str, Any], **kwargs) -> float:
|
|
182
|
+
# Apply verifiers first (pre-filtering approach)
|
|
183
|
+
for i, verifier_func in enumerate(loaded_verifiers):
|
|
184
|
+
verifier_name = verifier_names[i]
|
|
185
|
+
try:
|
|
186
|
+
if not verifier_func(prompt, completion, example):
|
|
187
|
+
logger.debug(f"Verifier '{verifier_name}' failed for reward '{original_reward_func.__name__}'. Applying penalty: {penalty_on_failure}")
|
|
188
|
+
return penalty_on_failure
|
|
189
|
+
except Exception as e:
|
|
190
|
+
logger.error(f"Verifier '{verifier_name}' errored: {e}. Applying penalty: {penalty_on_failure}")
|
|
191
|
+
return penalty_on_failure
|
|
192
|
+
|
|
193
|
+
# All verifiers passed, execute reward function
|
|
194
|
+
try:
|
|
195
|
+
result = original_reward_func(prompt, completion, example, **kwargs)
|
|
196
|
+
return float(result)
|
|
197
|
+
except Exception as e:
|
|
198
|
+
logger.error(f"Reward function '{original_reward_func.__name__}' errored after verifiers passed: {e}. Returning 0.0")
|
|
199
|
+
return 0.0
|
|
200
|
+
|
|
201
|
+
# Set descriptive name
|
|
202
|
+
if verifier_names:
|
|
203
|
+
verifier_suffix = '_and_'.join(verifier_names)
|
|
204
|
+
reward_with_verifiers.__name__ = f"{original_reward_func.__name__}_verified_by_{verifier_suffix}"
|
|
205
|
+
else:
|
|
206
|
+
reward_with_verifiers.__name__ = original_reward_func.__name__
|
|
207
|
+
|
|
208
|
+
return reward_with_verifiers
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@dataclass
|
|
212
|
+
class RewardConfig:
|
|
213
|
+
"""Configuration for a reward function following retrain's patterns."""
|
|
214
|
+
name: str # Name in REWARD_REGISTRY
|
|
215
|
+
weight: float = 1.0
|
|
216
|
+
params: Optional[Dict[str, Any]] = None
|
|
217
|
+
verifiers: Optional[List[str]] = None
|
|
218
|
+
verifier_penalty: float = 0.0
|
|
219
|
+
|
|
220
|
+
def __post_init__(self):
|
|
221
|
+
if self.params is None:
|
|
222
|
+
self.params = {}
|
|
223
|
+
if self.verifiers is None:
|
|
224
|
+
self.verifiers = []
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def create_configured_reward_function(config: RewardConfig) -> RewardFunction:
|
|
228
|
+
"""
|
|
229
|
+
Create a reward function from configuration following retrain's approach.
|
|
230
|
+
|
|
231
|
+
This function:
|
|
232
|
+
1. Loads the base reward function from registry
|
|
233
|
+
2. Applies verifiers if specified
|
|
234
|
+
3. Handles parameter passing
|
|
235
|
+
4. Returns a configured function ready for use
|
|
236
|
+
"""
|
|
237
|
+
# Get base reward function
|
|
238
|
+
base_reward_func = get_reward_function(config.name)
|
|
239
|
+
if base_reward_func is None:
|
|
240
|
+
raise ValueError(f"Reward function '{config.name}' not found in registry")
|
|
241
|
+
|
|
242
|
+
# Create wrapper that handles params
|
|
243
|
+
def reward_with_params(prompt: str, completion: str, example: Dict[str, Any], **kwargs) -> float:
|
|
244
|
+
# Merge config params with runtime kwargs
|
|
245
|
+
merged_kwargs = {**config.params, **kwargs}
|
|
246
|
+
return base_reward_func(prompt, completion, example, **merged_kwargs)
|
|
247
|
+
|
|
248
|
+
reward_with_params.__name__ = f"{base_reward_func.__name__}_with_params"
|
|
249
|
+
|
|
250
|
+
# Apply verifiers if specified
|
|
251
|
+
if config.verifiers:
|
|
252
|
+
reward_with_params = apply_verifiers_to_reward(
|
|
253
|
+
reward_with_params,
|
|
254
|
+
config.verifiers,
|
|
255
|
+
config.verifier_penalty
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return reward_with_params
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
# MLX optimization support
|
|
262
|
+
@mx.compile
|
|
263
|
+
def batch_reward_computation(
|
|
264
|
+
base_rewards: mx.array,
|
|
265
|
+
weights: mx.array
|
|
266
|
+
) -> mx.array:
|
|
267
|
+
"""
|
|
268
|
+
MLX-compiled function for efficient batch reward computation.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
base_rewards: Individual reward scores [batch_size, num_rewards]
|
|
272
|
+
weights: Reward weights [num_rewards]
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
Weighted combined rewards [batch_size]
|
|
276
|
+
"""
|
|
277
|
+
return mx.sum(base_rewards * weights, axis=1)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def list_registered_functions() -> Dict[str, List[str]]:
|
|
281
|
+
"""List all registered reward and verifier functions."""
|
|
282
|
+
return {
|
|
283
|
+
"rewards": list(REWARD_REGISTRY.keys()),
|
|
284
|
+
"verifiers": list(VERIFIER_REGISTRY.keys())
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def clear_registries():
|
|
289
|
+
"""Clear all registries (useful for testing)."""
|
|
290
|
+
global REWARD_REGISTRY, VERIFIER_REGISTRY
|
|
291
|
+
REWARD_REGISTRY.clear()
|
|
292
|
+
VERIFIER_REGISTRY.clear()
|
|
293
|
+
logger.info("Cleared all reward and verifier registries")
|
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
# textpolicy/rewards/rollout_rewards.py
|
|
2
|
+
"""
|
|
3
|
+
Rollout-level reward processing system for efficient MLX training.
|
|
4
|
+
|
|
5
|
+
This system processes rewards at the episode/rollout level rather than
|
|
6
|
+
per-transition, enabling vectorized operations and batch processing
|
|
7
|
+
for optimal MLX performance.
|
|
8
|
+
|
|
9
|
+
Key features:
|
|
10
|
+
- Batch reward computation for entire episodes
|
|
11
|
+
- Vectorized operations using MLX
|
|
12
|
+
- Integration with rollout buffer system
|
|
13
|
+
- Support for async external reward models
|
|
14
|
+
- Pure function composition
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from typing import Dict, List, Optional, Any
|
|
18
|
+
import mlx.core as mx
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
import asyncio
|
|
21
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
22
|
+
|
|
23
|
+
# Optional dependency
|
|
24
|
+
try:
|
|
25
|
+
import aiohttp # type: ignore
|
|
26
|
+
HAS_AIOHTTP = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
HAS_AIOHTTP = False
|
|
29
|
+
aiohttp = None
|
|
30
|
+
|
|
31
|
+
# Import reward functions used in this module
|
|
32
|
+
# These are registered via the @reward decorator and provide the standard (prompt, completion, example) signature
|
|
33
|
+
from .basic import length_reward, keyword_reward, perplexity_reward, accuracy_reward
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class RewardConfig:
|
|
38
|
+
"""Configuration for rollout reward processing."""
|
|
39
|
+
# Basic reward weights
|
|
40
|
+
length_weight: float = 0.1
|
|
41
|
+
keyword_weight: float = 0.2
|
|
42
|
+
perplexity_weight: float = 0.3
|
|
43
|
+
accuracy_weight: float = 0.4
|
|
44
|
+
|
|
45
|
+
# Target parameters
|
|
46
|
+
target_length: int = 50
|
|
47
|
+
keywords: Optional[List[str]] = None
|
|
48
|
+
|
|
49
|
+
# External reward model
|
|
50
|
+
external_rm_url: Optional[str] = None
|
|
51
|
+
external_rm_timeout: float = 30.0
|
|
52
|
+
|
|
53
|
+
# Batch processing
|
|
54
|
+
batch_size: int = 32
|
|
55
|
+
max_workers: int = 4
|
|
56
|
+
|
|
57
|
+
def __post_init__(self):
|
|
58
|
+
if self.keywords is None:
|
|
59
|
+
self.keywords = []
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class RolloutRewardProcessor:
|
|
63
|
+
"""
|
|
64
|
+
Efficient rollout-level reward processor for MLX training.
|
|
65
|
+
|
|
66
|
+
Processes entire episodes in batches, enabling vectorized operations
|
|
67
|
+
and optimal memory usage on Apple Silicon.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(self, config: RewardConfig):
|
|
71
|
+
"""
|
|
72
|
+
Initialize reward processor.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
config: Reward processing configuration
|
|
76
|
+
"""
|
|
77
|
+
self.config = config
|
|
78
|
+
self.executor = ThreadPoolExecutor(max_workers=config.max_workers)
|
|
79
|
+
|
|
80
|
+
# Pre-compile reward functions for MLX
|
|
81
|
+
self._compile_reward_functions()
|
|
82
|
+
|
|
83
|
+
def _compile_reward_functions(self):
|
|
84
|
+
"""Pre-compile reward functions for MLX optimization."""
|
|
85
|
+
# These will be compiled when first called
|
|
86
|
+
self._compiled_functions = {}
|
|
87
|
+
|
|
88
|
+
def process_rollout(
|
|
89
|
+
self,
|
|
90
|
+
prompt: str,
|
|
91
|
+
completion: str,
|
|
92
|
+
example: Dict[str, Any]
|
|
93
|
+
) -> mx.array:
|
|
94
|
+
"""
|
|
95
|
+
Process a single rollout with prompt, completion, and example context.
|
|
96
|
+
|
|
97
|
+
This method provides the standard interface expected by tests and other components.
|
|
98
|
+
It follows the same signature as individual reward functions for consistency.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
prompt: Input prompt text
|
|
102
|
+
completion: Generated completion text
|
|
103
|
+
example: Example context with target parameters and metadata
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
MLX array with single reward value [1]
|
|
107
|
+
"""
|
|
108
|
+
# Convert single rollout to episode format for consistency
|
|
109
|
+
episode = {
|
|
110
|
+
'prompt': prompt,
|
|
111
|
+
'response': completion,
|
|
112
|
+
'metadata': example
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
# Process as single episode and return scalar reward
|
|
116
|
+
rewards = self.process_episode_rewards([episode])
|
|
117
|
+
return rewards[0] if rewards.size > 0 else mx.array(0.0)
|
|
118
|
+
|
|
119
|
+
def process_batch_rollouts(
|
|
120
|
+
self,
|
|
121
|
+
prompts: List[str],
|
|
122
|
+
completions: List[str],
|
|
123
|
+
examples: List[Dict[str, Any]]
|
|
124
|
+
) -> mx.array:
|
|
125
|
+
"""
|
|
126
|
+
Process multiple rollouts in batch for efficient processing.
|
|
127
|
+
|
|
128
|
+
This method provides batch processing capability expected by tests and training loops.
|
|
129
|
+
It maintains the same interface as individual reward functions but processes multiple
|
|
130
|
+
examples simultaneously for better performance.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
prompts: List of input prompt texts
|
|
134
|
+
completions: List of generated completion texts
|
|
135
|
+
examples: List of example contexts with target parameters
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
MLX array of rewards [num_rollouts]
|
|
139
|
+
"""
|
|
140
|
+
if len(prompts) != len(completions) or len(completions) != len(examples):
|
|
141
|
+
raise ValueError(f"Mismatched input lengths: prompts={len(prompts)}, "
|
|
142
|
+
f"completions={len(completions)}, examples={len(examples)}")
|
|
143
|
+
|
|
144
|
+
# Convert to episode format for consistency with existing implementation
|
|
145
|
+
episodes = []
|
|
146
|
+
for prompt, completion, example in zip(prompts, completions, examples):
|
|
147
|
+
episode = {
|
|
148
|
+
'prompt': prompt,
|
|
149
|
+
'response': completion,
|
|
150
|
+
'metadata': example
|
|
151
|
+
}
|
|
152
|
+
episodes.append(episode)
|
|
153
|
+
|
|
154
|
+
# Process batch using existing episode processing logic
|
|
155
|
+
return self.process_episode_rewards(episodes)
|
|
156
|
+
|
|
157
|
+
def process_episode_rewards(
|
|
158
|
+
self,
|
|
159
|
+
episodes: List[Dict[str, Any]]
|
|
160
|
+
) -> mx.array:
|
|
161
|
+
"""
|
|
162
|
+
Process rewards for a batch of episodes.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
episodes: List of episode dictionaries with 'prompt' and 'response' fields
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
MLX array of rewards [num_episodes]
|
|
169
|
+
"""
|
|
170
|
+
if not episodes:
|
|
171
|
+
return mx.array([])
|
|
172
|
+
|
|
173
|
+
# Extract prompts and responses
|
|
174
|
+
prompts = [ep.get('prompt', '') for ep in episodes]
|
|
175
|
+
responses = [ep.get('response', '') for ep in episodes]
|
|
176
|
+
|
|
177
|
+
# Compute basic rewards (vectorized where possible)
|
|
178
|
+
rewards = self._compute_basic_rewards(prompts, responses)
|
|
179
|
+
|
|
180
|
+
# Add external reward model scores if configured
|
|
181
|
+
if self.config.external_rm_url:
|
|
182
|
+
external_rewards = self._get_external_rewards(episodes)
|
|
183
|
+
# Blend with basic rewards
|
|
184
|
+
alpha = 0.7 # Weight for external rewards
|
|
185
|
+
rewards = alpha * external_rewards + (1 - alpha) * rewards
|
|
186
|
+
|
|
187
|
+
return rewards
|
|
188
|
+
|
|
189
|
+
def _compute_basic_rewards(
|
|
190
|
+
self,
|
|
191
|
+
prompts: List[str],
|
|
192
|
+
responses: List[str]
|
|
193
|
+
) -> mx.array:
|
|
194
|
+
"""
|
|
195
|
+
Compute basic rewards using vectorized operations.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
prompts: List of input prompts
|
|
199
|
+
responses: List of generated responses
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
MLX array of combined rewards
|
|
203
|
+
"""
|
|
204
|
+
# Initialize reward array
|
|
205
|
+
num_episodes = len(prompts)
|
|
206
|
+
rewards = mx.zeros(num_episodes)
|
|
207
|
+
|
|
208
|
+
# Length rewards - pass config as example dict for unified interface
|
|
209
|
+
# The @reward decorator enforces (prompt, completion, example) signature
|
|
210
|
+
if self.config.length_weight > 0:
|
|
211
|
+
length_rewards = mx.array([
|
|
212
|
+
length_reward(p, r, {"target_length": self.config.target_length}) # type: ignore
|
|
213
|
+
for p, r in zip(prompts, responses)
|
|
214
|
+
])
|
|
215
|
+
rewards += self.config.length_weight * length_rewards
|
|
216
|
+
|
|
217
|
+
# Keyword rewards - pass config as example dict for unified interface
|
|
218
|
+
# The @reward decorator enforces (prompt, completion, example) signature
|
|
219
|
+
if self.config.keyword_weight > 0 and self.config.keywords:
|
|
220
|
+
keyword_rewards = mx.array([
|
|
221
|
+
keyword_reward(p, r, {"keywords": self.config.keywords}) # type: ignore
|
|
222
|
+
for p, r in zip(prompts, responses)
|
|
223
|
+
])
|
|
224
|
+
rewards += self.config.keyword_weight * keyword_rewards
|
|
225
|
+
|
|
226
|
+
# Perplexity rewards - pass empty example dict for unified interface
|
|
227
|
+
# The @reward decorator enforces (prompt, completion, example) signature
|
|
228
|
+
if self.config.perplexity_weight > 0:
|
|
229
|
+
perplexity_rewards = mx.array([
|
|
230
|
+
perplexity_reward(p, r, {}) # type: ignore
|
|
231
|
+
for p, r in zip(prompts, responses)
|
|
232
|
+
])
|
|
233
|
+
rewards += self.config.perplexity_weight * perplexity_rewards
|
|
234
|
+
|
|
235
|
+
# Accuracy rewards - pass empty example dict for unified interface
|
|
236
|
+
# The @reward decorator enforces (prompt, completion, example) signature
|
|
237
|
+
if self.config.accuracy_weight > 0:
|
|
238
|
+
accuracy_rewards = mx.array([
|
|
239
|
+
accuracy_reward(p, r, {}) # type: ignore
|
|
240
|
+
for p, r in zip(prompts, responses)
|
|
241
|
+
])
|
|
242
|
+
rewards += self.config.accuracy_weight * accuracy_rewards
|
|
243
|
+
|
|
244
|
+
return rewards
|
|
245
|
+
|
|
246
|
+
def _get_external_rewards(
|
|
247
|
+
self,
|
|
248
|
+
episodes: List[Dict[str, Any]]
|
|
249
|
+
) -> mx.array:
|
|
250
|
+
"""
|
|
251
|
+
Get rewards from external reward model.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
episodes: List of episode dictionaries
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
MLX array of external rewards
|
|
258
|
+
"""
|
|
259
|
+
# For now, return default rewards
|
|
260
|
+
# In practice, this would call external API or model
|
|
261
|
+
return mx.ones(len(episodes)) * 0.5
|
|
262
|
+
|
|
263
|
+
async def _async_external_rewards(
|
|
264
|
+
self,
|
|
265
|
+
episodes: List[Dict[str, Any]]
|
|
266
|
+
) -> List[float]:
|
|
267
|
+
"""
|
|
268
|
+
Async version for external reward model calls.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
episodes: List of episode dictionaries
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
List of reward scores
|
|
275
|
+
"""
|
|
276
|
+
if not self.config.external_rm_url:
|
|
277
|
+
return [0.5] * len(episodes)
|
|
278
|
+
|
|
279
|
+
if not HAS_AIOHTTP or aiohttp is None:
|
|
280
|
+
return [0.5] * len(episodes)
|
|
281
|
+
|
|
282
|
+
# Type guard: aiohttp is guaranteed to be available from this point
|
|
283
|
+
# Assert for type checker that aiohttp is not None after the above checks
|
|
284
|
+
assert aiohttp is not None, "aiohttp should be available when HAS_AIOHTTP is True"
|
|
285
|
+
|
|
286
|
+
async def get_reward(session, episode):
|
|
287
|
+
payload = {
|
|
288
|
+
"prompt": episode.get("prompt", ""),
|
|
289
|
+
"response": episode.get("response", ""),
|
|
290
|
+
"metadata": episode.get("metadata", {})
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
async with session.post(
|
|
295
|
+
self.config.external_rm_url,
|
|
296
|
+
json=payload,
|
|
297
|
+
timeout=aiohttp.ClientTimeout(total=self.config.external_rm_timeout)
|
|
298
|
+
) as resp:
|
|
299
|
+
if resp.status == 200:
|
|
300
|
+
result = await resp.json()
|
|
301
|
+
return result.get("reward", 0.5)
|
|
302
|
+
else:
|
|
303
|
+
return 0.5
|
|
304
|
+
except Exception:
|
|
305
|
+
return 0.5
|
|
306
|
+
|
|
307
|
+
async with aiohttp.ClientSession() as session:
|
|
308
|
+
tasks = [get_reward(session, ep) for ep in episodes]
|
|
309
|
+
results = await asyncio.gather(*tasks)
|
|
310
|
+
return list(results)
|
|
311
|
+
|
|
312
|
+
def process_buffer_rewards(self, buffer) -> mx.array:
|
|
313
|
+
"""
|
|
314
|
+
Process rewards for all episodes in a buffer.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
buffer: Buffer instance containing episodes
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
MLX array of episode rewards
|
|
321
|
+
"""
|
|
322
|
+
# Extract episodes from buffer
|
|
323
|
+
episodes = []
|
|
324
|
+
for episode in buffer.storage.episodes:
|
|
325
|
+
# Convert episode to dict format
|
|
326
|
+
episode_dict = {
|
|
327
|
+
'prompt': episode.obs[0] if episode.obs else '',
|
|
328
|
+
'response': episode.act[-1] if episode.act else '',
|
|
329
|
+
'metadata': {
|
|
330
|
+
'length': len(episode.obs),
|
|
331
|
+
'logprobs': episode.logprob,
|
|
332
|
+
'values': episode.value
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
episodes.append(episode_dict)
|
|
336
|
+
|
|
337
|
+
return self.process_episode_rewards(episodes)
|
|
338
|
+
|
|
339
|
+
def close(self):
|
|
340
|
+
"""Cleanup resources."""
|
|
341
|
+
self.executor.shutdown(wait=True)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
# Pure function interface for integration with rollout system
|
|
345
|
+
def create_rollout_reward_processor(config: RewardConfig) -> RolloutRewardProcessor:
|
|
346
|
+
"""
|
|
347
|
+
Factory function for creating reward processors.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
config: Reward processing configuration
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
RolloutRewardProcessor instance
|
|
354
|
+
"""
|
|
355
|
+
return RolloutRewardProcessor(config)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def process_episode_batch_rewards(
|
|
359
|
+
episodes: List[Dict[str, Any]],
|
|
360
|
+
config: RewardConfig
|
|
361
|
+
) -> mx.array:
|
|
362
|
+
"""
|
|
363
|
+
Pure function for processing episode rewards.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
episodes: List of episode dictionaries
|
|
367
|
+
config: Reward configuration
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
MLX array of rewards
|
|
371
|
+
"""
|
|
372
|
+
processor = RolloutRewardProcessor(config)
|
|
373
|
+
try:
|
|
374
|
+
return processor.process_episode_rewards(episodes)
|
|
375
|
+
finally:
|
|
376
|
+
processor.close()
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
# MLX-compiled reward computation for high-performance training
|
|
380
|
+
@mx.compile
|
|
381
|
+
def compute_reward_vector(
|
|
382
|
+
response_lengths: mx.array,
|
|
383
|
+
keyword_matches: mx.array,
|
|
384
|
+
fluency_scores: mx.array,
|
|
385
|
+
accuracy_scores: mx.array,
|
|
386
|
+
weights: mx.array
|
|
387
|
+
) -> mx.array:
|
|
388
|
+
"""
|
|
389
|
+
MLX-compiled function for vectorized reward computation.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
response_lengths: Normalized length scores [batch_size]
|
|
393
|
+
keyword_matches: Keyword match scores [batch_size]
|
|
394
|
+
fluency_scores: Fluency scores [batch_size]
|
|
395
|
+
accuracy_scores: Accuracy scores [batch_size]
|
|
396
|
+
weights: Reward weights [4] (length, keyword, fluency, accuracy)
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
Combined rewards [batch_size]
|
|
400
|
+
"""
|
|
401
|
+
# Weighted combination of reward components
|
|
402
|
+
rewards = (
|
|
403
|
+
weights[0] * response_lengths +
|
|
404
|
+
weights[1] * keyword_matches +
|
|
405
|
+
weights[2] * fluency_scores +
|
|
406
|
+
weights[3] * accuracy_scores
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
return rewards
|
|
410
|
+
|