textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +53 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +797 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.1.dist-info/METADATA +109 -0
- textpolicy-0.1.1.dist-info/RECORD +66 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
- textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
# textpolicy/rewards/mlx_batch_processor.py
|
|
2
|
+
"""
|
|
3
|
+
MLX-optimized batch processing system following DESIGN_GUIDELINES.md principles.
|
|
4
|
+
|
|
5
|
+
This module implements pure function composition with zero abstraction cost,
|
|
6
|
+
designed for maximum efficiency on Apple Silicon using MLX compilation.
|
|
7
|
+
|
|
8
|
+
Features:
|
|
9
|
+
- Pure function composition (no classes, no dispatch)
|
|
10
|
+
- MLX compilation for optimal performance
|
|
11
|
+
- Vectorized batch processing
|
|
12
|
+
- Single interface for all reward/verifier combinations
|
|
13
|
+
- Integrates with retrain's philosophy
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import mlx.core as mx
|
|
17
|
+
from typing import List, Dict, Any, Callable, Coroutine, Union
|
|
18
|
+
import asyncio
|
|
19
|
+
|
|
20
|
+
from .registry import (
|
|
21
|
+
RewardFunction, RewardConfig,
|
|
22
|
+
get_reward_function,
|
|
23
|
+
apply_verifiers_to_reward, REWARD_REGISTRY, VERIFIER_REGISTRY
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# Pure function composition following DESIGN_GUIDELINES.md Option 3
|
|
28
|
+
def create_batch_reward_processor(
|
|
29
|
+
reward_configs: List[RewardConfig],
|
|
30
|
+
enable_mlx_compilation: bool = True
|
|
31
|
+
) -> Callable[[List[str], List[str], List[Dict[str, Any]]], mx.array]:
|
|
32
|
+
"""
|
|
33
|
+
Pure function factory for creating MLX-optimized batch processors.
|
|
34
|
+
|
|
35
|
+
Following DESIGN_GUIDELINES.md:
|
|
36
|
+
- Pure function composition over class hierarchies
|
|
37
|
+
- Zero abstraction cost
|
|
38
|
+
- Single training loop for all algorithms
|
|
39
|
+
- MLX compilation ready
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
reward_configs: List of reward configurations
|
|
43
|
+
enable_mlx_compilation: Whether to enable MLX compilation
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Pure function: (prompts, completions, examples) -> rewards [batch_size]
|
|
47
|
+
"""
|
|
48
|
+
# Pre-load all reward functions (zero cost at runtime)
|
|
49
|
+
loaded_functions: List[RewardFunction] = []
|
|
50
|
+
weights: List[float] = []
|
|
51
|
+
|
|
52
|
+
for config in reward_configs:
|
|
53
|
+
# Create configured reward function with verifiers
|
|
54
|
+
reward_func = create_configured_reward_function(config)
|
|
55
|
+
loaded_functions.append(reward_func)
|
|
56
|
+
weights.append(config.weight)
|
|
57
|
+
|
|
58
|
+
# Convert weights to MLX array for efficient computation
|
|
59
|
+
weights_array = mx.array(weights)
|
|
60
|
+
|
|
61
|
+
# Pure function implementation
|
|
62
|
+
def batch_processor(
|
|
63
|
+
prompts: List[str],
|
|
64
|
+
completions: List[str],
|
|
65
|
+
examples: List[Dict[str, Any]]
|
|
66
|
+
) -> mx.array:
|
|
67
|
+
"""
|
|
68
|
+
Process batch of episodes with pure function composition.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
prompts: List of input prompts
|
|
72
|
+
completions: List of generated completions
|
|
73
|
+
examples: List of example contexts
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Combined rewards [batch_size]
|
|
77
|
+
"""
|
|
78
|
+
batch_size = len(prompts)
|
|
79
|
+
|
|
80
|
+
# Compute rewards for each function
|
|
81
|
+
all_rewards = []
|
|
82
|
+
|
|
83
|
+
for func_idx, reward_func in enumerate(loaded_functions):
|
|
84
|
+
batch_rewards = []
|
|
85
|
+
|
|
86
|
+
# Process each sample
|
|
87
|
+
for i in range(batch_size):
|
|
88
|
+
try:
|
|
89
|
+
reward = reward_func(prompts[i], completions[i], examples[i])
|
|
90
|
+
batch_rewards.append(float(reward))
|
|
91
|
+
except Exception:
|
|
92
|
+
batch_rewards.append(0.0)
|
|
93
|
+
|
|
94
|
+
all_rewards.append(mx.array(batch_rewards))
|
|
95
|
+
|
|
96
|
+
# Stack rewards and apply weights
|
|
97
|
+
if all_rewards:
|
|
98
|
+
reward_matrix = mx.stack(all_rewards, axis=1) # [batch_size, num_functions]
|
|
99
|
+
return mx.sum(reward_matrix * weights_array, axis=1)
|
|
100
|
+
else:
|
|
101
|
+
return mx.zeros(batch_size)
|
|
102
|
+
|
|
103
|
+
# Optionally compile for maximum performance
|
|
104
|
+
if enable_mlx_compilation:
|
|
105
|
+
# Create compiled version of core computation
|
|
106
|
+
@mx.compile
|
|
107
|
+
def compiled_weighted_sum(reward_matrix: mx.array, weights: mx.array) -> mx.array:
|
|
108
|
+
return mx.sum(reward_matrix * weights, axis=1)
|
|
109
|
+
|
|
110
|
+
# Update function to use compiled computation
|
|
111
|
+
def optimized_processor(prompts, completions, examples):
|
|
112
|
+
# Use original function for reward computation
|
|
113
|
+
reward_matrix = mx.zeros((len(prompts), len(loaded_functions)))
|
|
114
|
+
|
|
115
|
+
for func_idx, reward_func in enumerate(loaded_functions):
|
|
116
|
+
batch_rewards = []
|
|
117
|
+
for i in range(len(prompts)):
|
|
118
|
+
try:
|
|
119
|
+
reward = reward_func(prompts[i], completions[i], examples[i])
|
|
120
|
+
batch_rewards.append(float(reward))
|
|
121
|
+
except Exception:
|
|
122
|
+
batch_rewards.append(0.0)
|
|
123
|
+
|
|
124
|
+
# Calculate difference between new values and current values at the specified indices
|
|
125
|
+
current_values = reward_matrix[:, func_idx]
|
|
126
|
+
new_values = mx.array(batch_rewards)
|
|
127
|
+
diff = new_values - current_values
|
|
128
|
+
|
|
129
|
+
# Use add method to update values at specified indices
|
|
130
|
+
reward_matrix = reward_matrix.at[:, func_idx].add(diff)
|
|
131
|
+
|
|
132
|
+
# Use compiled function for final computation
|
|
133
|
+
return compiled_weighted_sum(reward_matrix, weights_array)
|
|
134
|
+
|
|
135
|
+
return optimized_processor
|
|
136
|
+
|
|
137
|
+
return batch_processor
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def create_configured_reward_function(config: RewardConfig) -> RewardFunction:
|
|
141
|
+
"""
|
|
142
|
+
Create a configured reward function following retrain's patterns.
|
|
143
|
+
|
|
144
|
+
This is a pure function that creates other pure functions,
|
|
145
|
+
maintaining zero abstraction cost.
|
|
146
|
+
"""
|
|
147
|
+
# Get base reward function
|
|
148
|
+
base_reward_func = get_reward_function(config.name)
|
|
149
|
+
if base_reward_func is None:
|
|
150
|
+
raise ValueError(f"Reward function '{config.name}' not found in registry")
|
|
151
|
+
|
|
152
|
+
# Create parameter-injected function
|
|
153
|
+
def reward_with_params(prompt: str, completion: str, example: Dict[str, Any]) -> float:
|
|
154
|
+
# Merge config params with example
|
|
155
|
+
merged_example = {**example, **config.params}
|
|
156
|
+
return base_reward_func(prompt, completion, merged_example)
|
|
157
|
+
|
|
158
|
+
# Apply verifiers if specified (following retrain's pre-filtering)
|
|
159
|
+
if config.verifiers:
|
|
160
|
+
reward_with_params = apply_verifiers_to_reward(
|
|
161
|
+
reward_with_params,
|
|
162
|
+
config.verifiers,
|
|
163
|
+
config.verifier_penalty
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return reward_with_params
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# MLX-compiled vectorized operations for maximum performance
|
|
170
|
+
@mx.compile
|
|
171
|
+
def compute_length_rewards_vectorized(
|
|
172
|
+
completion_lengths: mx.array,
|
|
173
|
+
target_length: float,
|
|
174
|
+
tolerance: float
|
|
175
|
+
) -> mx.array:
|
|
176
|
+
"""
|
|
177
|
+
MLX-compiled vectorized length reward computation.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
completion_lengths: Word counts [batch_size]
|
|
181
|
+
target_length: Target word count
|
|
182
|
+
tolerance: Tolerance fraction
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Length rewards [batch_size]
|
|
186
|
+
"""
|
|
187
|
+
deviations = mx.abs(completion_lengths - target_length) / target_length
|
|
188
|
+
|
|
189
|
+
# Vectorized conditional computation
|
|
190
|
+
within_tolerance = deviations <= tolerance
|
|
191
|
+
beyond_tolerance = ~within_tolerance
|
|
192
|
+
|
|
193
|
+
# Linear decay for beyond tolerance
|
|
194
|
+
decay_rewards = mx.maximum(
|
|
195
|
+
0.0,
|
|
196
|
+
1.0 - (deviations - tolerance) / (1.0 - tolerance)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Combine results
|
|
200
|
+
rewards = within_tolerance * 1.0 + beyond_tolerance * decay_rewards
|
|
201
|
+
return rewards
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@mx.compile
|
|
205
|
+
def compute_keyword_rewards_vectorized(
|
|
206
|
+
keyword_matches: mx.array,
|
|
207
|
+
total_keywords: mx.array,
|
|
208
|
+
bonus_matches: mx.array,
|
|
209
|
+
total_bonus_keywords: mx.array,
|
|
210
|
+
bonus_multiplier: float
|
|
211
|
+
) -> mx.array:
|
|
212
|
+
"""
|
|
213
|
+
MLX-compiled vectorized keyword reward computation.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
keyword_matches: Number of keyword matches [batch_size]
|
|
217
|
+
total_keywords: Total keywords for each sample [batch_size]
|
|
218
|
+
bonus_matches: Bonus keyword matches [batch_size]
|
|
219
|
+
total_bonus_keywords: Total bonus keywords [batch_size]
|
|
220
|
+
bonus_multiplier: Bonus multiplier
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Keyword rewards [batch_size]
|
|
224
|
+
"""
|
|
225
|
+
# Avoid division by zero
|
|
226
|
+
safe_total_keywords = mx.maximum(total_keywords, 1.0)
|
|
227
|
+
safe_total_bonus = mx.maximum(total_bonus_keywords, 1.0)
|
|
228
|
+
|
|
229
|
+
base_rewards = keyword_matches / safe_total_keywords
|
|
230
|
+
bonus_rewards = (bonus_matches / safe_total_bonus) * bonus_multiplier
|
|
231
|
+
|
|
232
|
+
# Clip to reasonable range
|
|
233
|
+
total_rewards = mx.minimum(base_rewards + bonus_rewards, 2.0)
|
|
234
|
+
return total_rewards
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def create_mlx_optimized_batch_processor(
|
|
238
|
+
reward_configs: List[RewardConfig]
|
|
239
|
+
) -> Callable[[List[str], List[str], List[Dict[str, Any]]], mx.array]:
|
|
240
|
+
"""
|
|
241
|
+
Create fully MLX-optimized batch processor for maximum Apple Silicon performance.
|
|
242
|
+
|
|
243
|
+
This implementation follows DESIGN_GUIDELINES.md by:
|
|
244
|
+
1. Using pure function composition
|
|
245
|
+
2. Maximizing MLX compilation opportunities
|
|
246
|
+
3. Minimizing memory allocations
|
|
247
|
+
4. Utilizing unified memory efficiently
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
# Pre-compile all possible reward components
|
|
251
|
+
compiled_functions = {}
|
|
252
|
+
|
|
253
|
+
# Check which reward types we need and pre-compile them
|
|
254
|
+
for config in reward_configs:
|
|
255
|
+
if config.name == 'length_reward':
|
|
256
|
+
compiled_functions['length'] = compute_length_rewards_vectorized
|
|
257
|
+
elif config.name == 'keyword_reward':
|
|
258
|
+
compiled_functions['keyword'] = compute_keyword_rewards_vectorized
|
|
259
|
+
|
|
260
|
+
weights = mx.array([config.weight for config in reward_configs])
|
|
261
|
+
|
|
262
|
+
def optimized_batch_processor(
|
|
263
|
+
prompts: List[str],
|
|
264
|
+
completions: List[str],
|
|
265
|
+
examples: List[Dict[str, Any]]
|
|
266
|
+
) -> mx.array:
|
|
267
|
+
"""
|
|
268
|
+
Fully optimized batch processor using MLX compilation.
|
|
269
|
+
"""
|
|
270
|
+
batch_size = len(prompts)
|
|
271
|
+
|
|
272
|
+
# Collect all reward arrays
|
|
273
|
+
all_rewards = []
|
|
274
|
+
|
|
275
|
+
# Process each reward type
|
|
276
|
+
for config_idx, config in enumerate(reward_configs):
|
|
277
|
+
if config.name == 'length_reward' and 'length' in compiled_functions:
|
|
278
|
+
# Use vectorized length computation
|
|
279
|
+
lengths = mx.array([len(comp.split()) for comp in completions])
|
|
280
|
+
target = config.params.get('target_length', 50)
|
|
281
|
+
tolerance = config.params.get('tolerance', 0.2)
|
|
282
|
+
|
|
283
|
+
rewards = compiled_functions['length'](lengths, float(target), tolerance)
|
|
284
|
+
all_rewards.append(rewards)
|
|
285
|
+
|
|
286
|
+
elif config.name == 'keyword_reward' and 'keyword' in compiled_functions:
|
|
287
|
+
# Use vectorized keyword computation
|
|
288
|
+
keywords = config.params.get('keywords', [])
|
|
289
|
+
if keywords:
|
|
290
|
+
# Preprocess keyword matches
|
|
291
|
+
keyword_matches = []
|
|
292
|
+
bonus_matches = []
|
|
293
|
+
|
|
294
|
+
for i, (prompt, completion) in enumerate(zip(prompts, completions)):
|
|
295
|
+
comp_lower = completion.lower()
|
|
296
|
+
prompt_lower = prompt.lower()
|
|
297
|
+
|
|
298
|
+
matches = sum(1 for kw in keywords if kw.lower() in comp_lower)
|
|
299
|
+
bonus_kws = [kw for kw in keywords if kw.lower() not in prompt_lower]
|
|
300
|
+
bonus = sum(1 for kw in bonus_kws if kw.lower() in comp_lower)
|
|
301
|
+
|
|
302
|
+
keyword_matches.append(matches)
|
|
303
|
+
bonus_matches.append(bonus)
|
|
304
|
+
|
|
305
|
+
# Vectorized computation
|
|
306
|
+
match_array = mx.array(keyword_matches)
|
|
307
|
+
bonus_array = mx.array(bonus_matches)
|
|
308
|
+
total_kw = mx.full((batch_size,), len(keywords))
|
|
309
|
+
total_bonus = mx.array([len([kw for kw in keywords if kw.lower() not in prompts[i].lower()]) for i in range(batch_size)])
|
|
310
|
+
multiplier = config.params.get('bonus_multiplier', 1.0)
|
|
311
|
+
|
|
312
|
+
rewards = compiled_functions['keyword'](
|
|
313
|
+
match_array, total_kw, bonus_array, total_bonus, multiplier
|
|
314
|
+
)
|
|
315
|
+
all_rewards.append(rewards)
|
|
316
|
+
else:
|
|
317
|
+
# No keywords specified
|
|
318
|
+
all_rewards.append(mx.zeros(batch_size))
|
|
319
|
+
else:
|
|
320
|
+
# Fallback to individual function calls
|
|
321
|
+
reward_func = create_configured_reward_function(config)
|
|
322
|
+
batch_rewards = []
|
|
323
|
+
|
|
324
|
+
for i in range(batch_size):
|
|
325
|
+
try:
|
|
326
|
+
reward = reward_func(prompts[i], completions[i], examples[i])
|
|
327
|
+
batch_rewards.append(float(reward))
|
|
328
|
+
except Exception:
|
|
329
|
+
batch_rewards.append(0.0)
|
|
330
|
+
|
|
331
|
+
all_rewards.append(mx.array(batch_rewards))
|
|
332
|
+
|
|
333
|
+
# Combine all rewards
|
|
334
|
+
if all_rewards:
|
|
335
|
+
reward_matrix = mx.stack(all_rewards, axis=1) # [batch_size, num_rewards]
|
|
336
|
+
return mx.sum(reward_matrix * weights, axis=1)
|
|
337
|
+
else:
|
|
338
|
+
return mx.zeros(batch_size)
|
|
339
|
+
|
|
340
|
+
return optimized_batch_processor
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
# Async processing for external reward models
|
|
344
|
+
async def create_async_batch_processor(
|
|
345
|
+
reward_configs: List[RewardConfig],
|
|
346
|
+
max_workers: int = 4
|
|
347
|
+
) -> Callable[[List[str], List[str], List[Dict[str, Any]]], Coroutine[Any, Any, mx.array]]:
|
|
348
|
+
"""
|
|
349
|
+
Create async batch processor for external reward models.
|
|
350
|
+
|
|
351
|
+
Maintains pure function composition while enabling async operations.
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
# Separate local and external reward configs
|
|
355
|
+
local_configs = []
|
|
356
|
+
external_configs = []
|
|
357
|
+
|
|
358
|
+
for config in reward_configs:
|
|
359
|
+
if config.params.get('external_url'):
|
|
360
|
+
external_configs.append(config)
|
|
361
|
+
else:
|
|
362
|
+
local_configs.append(config)
|
|
363
|
+
|
|
364
|
+
# Create local processor
|
|
365
|
+
local_processor = create_mlx_optimized_batch_processor(local_configs) if local_configs else None
|
|
366
|
+
|
|
367
|
+
async def async_batch_processor(
|
|
368
|
+
prompts: List[str],
|
|
369
|
+
completions: List[str],
|
|
370
|
+
examples: List[Dict[str, Any]]
|
|
371
|
+
) -> mx.array:
|
|
372
|
+
"""Async batch processor combining local and external rewards."""
|
|
373
|
+
batch_size = len(prompts)
|
|
374
|
+
|
|
375
|
+
# Process local rewards synchronously
|
|
376
|
+
local_rewards = None
|
|
377
|
+
if local_processor:
|
|
378
|
+
local_rewards = local_processor(prompts, completions, examples)
|
|
379
|
+
|
|
380
|
+
# Process external rewards asynchronously
|
|
381
|
+
external_rewards = None
|
|
382
|
+
if external_configs:
|
|
383
|
+
# Placeholder for external API calls
|
|
384
|
+
# In practice, this would make HTTP requests to external reward models
|
|
385
|
+
external_rewards = mx.zeros((batch_size, len(external_configs)))
|
|
386
|
+
|
|
387
|
+
# Combine results
|
|
388
|
+
if local_rewards is not None and external_rewards is not None:
|
|
389
|
+
all_rewards = mx.concatenate([local_rewards.reshape(-1, len(local_configs)), external_rewards], axis=1)
|
|
390
|
+
elif local_rewards is not None:
|
|
391
|
+
all_rewards = local_rewards.reshape(-1, len(local_configs))
|
|
392
|
+
elif external_rewards is not None:
|
|
393
|
+
all_rewards = external_rewards
|
|
394
|
+
else:
|
|
395
|
+
all_rewards = mx.zeros((batch_size, 1))
|
|
396
|
+
|
|
397
|
+
# Apply final weights
|
|
398
|
+
all_weights = mx.array([config.weight for config in reward_configs])
|
|
399
|
+
return mx.sum(all_rewards * all_weights, axis=1)
|
|
400
|
+
|
|
401
|
+
return async_batch_processor
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
# Utility functions for integration
|
|
405
|
+
def list_available_processors() -> Dict[str, List[str]]:
|
|
406
|
+
"""List all available reward and verifier functions."""
|
|
407
|
+
return {
|
|
408
|
+
"rewards": list(REWARD_REGISTRY.keys()),
|
|
409
|
+
"verifiers": list(VERIFIER_REGISTRY.keys()),
|
|
410
|
+
"compiled_optimizations": ["length_reward", "keyword_reward"]
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
async def create_processor_from_config(
|
|
415
|
+
config_dict: Dict[str, Any]
|
|
416
|
+
) -> Union[Callable[[List[str], List[str], List[Dict[str, Any]]], mx.array],
|
|
417
|
+
Callable[[List[str], List[str], List[Dict[str, Any]]], Coroutine[Any, Any, mx.array]]]:
|
|
418
|
+
"""
|
|
419
|
+
Create processor from configuration dictionary following retrain's patterns.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
config_dict: Configuration with reward specifications
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
Batch processor function
|
|
426
|
+
"""
|
|
427
|
+
reward_configs = []
|
|
428
|
+
|
|
429
|
+
for name, config in config_dict.items():
|
|
430
|
+
reward_config = RewardConfig(
|
|
431
|
+
name=name,
|
|
432
|
+
weight=config.get('weight', 1.0),
|
|
433
|
+
params=config.get('params', {}),
|
|
434
|
+
verifiers=config.get('verifiers', []),
|
|
435
|
+
verifier_penalty=config.get('verifier_penalty', 0.0)
|
|
436
|
+
)
|
|
437
|
+
reward_configs.append(reward_config)
|
|
438
|
+
|
|
439
|
+
# Choose optimal processor based on configuration
|
|
440
|
+
has_external = any(cfg.params.get('external_url') for cfg in reward_configs)
|
|
441
|
+
|
|
442
|
+
if has_external:
|
|
443
|
+
# Use async processor for external models
|
|
444
|
+
return await create_async_batch_processor(reward_configs)
|
|
445
|
+
else:
|
|
446
|
+
# Use MLX-optimized processor for local computation
|
|
447
|
+
return create_mlx_optimized_batch_processor(reward_configs)
|