textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +53 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +797 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.1.dist-info/METADATA +109 -0
  62. textpolicy-0.1.1.dist-info/RECORD +66 -0
  63. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
  64. textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
  65. textpolicy-0.0.1.dist-info/METADATA +0 -10
  66. textpolicy-0.0.1.dist-info/RECORD +0 -6
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,447 @@
1
+ # textpolicy/rewards/mlx_batch_processor.py
2
+ """
3
+ MLX-optimized batch processing system following DESIGN_GUIDELINES.md principles.
4
+
5
+ This module implements pure function composition with zero abstraction cost,
6
+ designed for maximum efficiency on Apple Silicon using MLX compilation.
7
+
8
+ Features:
9
+ - Pure function composition (no classes, no dispatch)
10
+ - MLX compilation for optimal performance
11
+ - Vectorized batch processing
12
+ - Single interface for all reward/verifier combinations
13
+ - Integrates with retrain's philosophy
14
+ """
15
+
16
+ import mlx.core as mx
17
+ from typing import List, Dict, Any, Callable, Coroutine, Union
18
+ import asyncio
19
+
20
+ from .registry import (
21
+ RewardFunction, RewardConfig,
22
+ get_reward_function,
23
+ apply_verifiers_to_reward, REWARD_REGISTRY, VERIFIER_REGISTRY
24
+ )
25
+
26
+
27
+ # Pure function composition following DESIGN_GUIDELINES.md Option 3
28
+ def create_batch_reward_processor(
29
+ reward_configs: List[RewardConfig],
30
+ enable_mlx_compilation: bool = True
31
+ ) -> Callable[[List[str], List[str], List[Dict[str, Any]]], mx.array]:
32
+ """
33
+ Pure function factory for creating MLX-optimized batch processors.
34
+
35
+ Following DESIGN_GUIDELINES.md:
36
+ - Pure function composition over class hierarchies
37
+ - Zero abstraction cost
38
+ - Single training loop for all algorithms
39
+ - MLX compilation ready
40
+
41
+ Args:
42
+ reward_configs: List of reward configurations
43
+ enable_mlx_compilation: Whether to enable MLX compilation
44
+
45
+ Returns:
46
+ Pure function: (prompts, completions, examples) -> rewards [batch_size]
47
+ """
48
+ # Pre-load all reward functions (zero cost at runtime)
49
+ loaded_functions: List[RewardFunction] = []
50
+ weights: List[float] = []
51
+
52
+ for config in reward_configs:
53
+ # Create configured reward function with verifiers
54
+ reward_func = create_configured_reward_function(config)
55
+ loaded_functions.append(reward_func)
56
+ weights.append(config.weight)
57
+
58
+ # Convert weights to MLX array for efficient computation
59
+ weights_array = mx.array(weights)
60
+
61
+ # Pure function implementation
62
+ def batch_processor(
63
+ prompts: List[str],
64
+ completions: List[str],
65
+ examples: List[Dict[str, Any]]
66
+ ) -> mx.array:
67
+ """
68
+ Process batch of episodes with pure function composition.
69
+
70
+ Args:
71
+ prompts: List of input prompts
72
+ completions: List of generated completions
73
+ examples: List of example contexts
74
+
75
+ Returns:
76
+ Combined rewards [batch_size]
77
+ """
78
+ batch_size = len(prompts)
79
+
80
+ # Compute rewards for each function
81
+ all_rewards = []
82
+
83
+ for func_idx, reward_func in enumerate(loaded_functions):
84
+ batch_rewards = []
85
+
86
+ # Process each sample
87
+ for i in range(batch_size):
88
+ try:
89
+ reward = reward_func(prompts[i], completions[i], examples[i])
90
+ batch_rewards.append(float(reward))
91
+ except Exception:
92
+ batch_rewards.append(0.0)
93
+
94
+ all_rewards.append(mx.array(batch_rewards))
95
+
96
+ # Stack rewards and apply weights
97
+ if all_rewards:
98
+ reward_matrix = mx.stack(all_rewards, axis=1) # [batch_size, num_functions]
99
+ return mx.sum(reward_matrix * weights_array, axis=1)
100
+ else:
101
+ return mx.zeros(batch_size)
102
+
103
+ # Optionally compile for maximum performance
104
+ if enable_mlx_compilation:
105
+ # Create compiled version of core computation
106
+ @mx.compile
107
+ def compiled_weighted_sum(reward_matrix: mx.array, weights: mx.array) -> mx.array:
108
+ return mx.sum(reward_matrix * weights, axis=1)
109
+
110
+ # Update function to use compiled computation
111
+ def optimized_processor(prompts, completions, examples):
112
+ # Use original function for reward computation
113
+ reward_matrix = mx.zeros((len(prompts), len(loaded_functions)))
114
+
115
+ for func_idx, reward_func in enumerate(loaded_functions):
116
+ batch_rewards = []
117
+ for i in range(len(prompts)):
118
+ try:
119
+ reward = reward_func(prompts[i], completions[i], examples[i])
120
+ batch_rewards.append(float(reward))
121
+ except Exception:
122
+ batch_rewards.append(0.0)
123
+
124
+ # Calculate difference between new values and current values at the specified indices
125
+ current_values = reward_matrix[:, func_idx]
126
+ new_values = mx.array(batch_rewards)
127
+ diff = new_values - current_values
128
+
129
+ # Use add method to update values at specified indices
130
+ reward_matrix = reward_matrix.at[:, func_idx].add(diff)
131
+
132
+ # Use compiled function for final computation
133
+ return compiled_weighted_sum(reward_matrix, weights_array)
134
+
135
+ return optimized_processor
136
+
137
+ return batch_processor
138
+
139
+
140
+ def create_configured_reward_function(config: RewardConfig) -> RewardFunction:
141
+ """
142
+ Create a configured reward function following retrain's patterns.
143
+
144
+ This is a pure function that creates other pure functions,
145
+ maintaining zero abstraction cost.
146
+ """
147
+ # Get base reward function
148
+ base_reward_func = get_reward_function(config.name)
149
+ if base_reward_func is None:
150
+ raise ValueError(f"Reward function '{config.name}' not found in registry")
151
+
152
+ # Create parameter-injected function
153
+ def reward_with_params(prompt: str, completion: str, example: Dict[str, Any]) -> float:
154
+ # Merge config params with example
155
+ merged_example = {**example, **config.params}
156
+ return base_reward_func(prompt, completion, merged_example)
157
+
158
+ # Apply verifiers if specified (following retrain's pre-filtering)
159
+ if config.verifiers:
160
+ reward_with_params = apply_verifiers_to_reward(
161
+ reward_with_params,
162
+ config.verifiers,
163
+ config.verifier_penalty
164
+ )
165
+
166
+ return reward_with_params
167
+
168
+
169
+ # MLX-compiled vectorized operations for maximum performance
170
+ @mx.compile
171
+ def compute_length_rewards_vectorized(
172
+ completion_lengths: mx.array,
173
+ target_length: float,
174
+ tolerance: float
175
+ ) -> mx.array:
176
+ """
177
+ MLX-compiled vectorized length reward computation.
178
+
179
+ Args:
180
+ completion_lengths: Word counts [batch_size]
181
+ target_length: Target word count
182
+ tolerance: Tolerance fraction
183
+
184
+ Returns:
185
+ Length rewards [batch_size]
186
+ """
187
+ deviations = mx.abs(completion_lengths - target_length) / target_length
188
+
189
+ # Vectorized conditional computation
190
+ within_tolerance = deviations <= tolerance
191
+ beyond_tolerance = ~within_tolerance
192
+
193
+ # Linear decay for beyond tolerance
194
+ decay_rewards = mx.maximum(
195
+ 0.0,
196
+ 1.0 - (deviations - tolerance) / (1.0 - tolerance)
197
+ )
198
+
199
+ # Combine results
200
+ rewards = within_tolerance * 1.0 + beyond_tolerance * decay_rewards
201
+ return rewards
202
+
203
+
204
+ @mx.compile
205
+ def compute_keyword_rewards_vectorized(
206
+ keyword_matches: mx.array,
207
+ total_keywords: mx.array,
208
+ bonus_matches: mx.array,
209
+ total_bonus_keywords: mx.array,
210
+ bonus_multiplier: float
211
+ ) -> mx.array:
212
+ """
213
+ MLX-compiled vectorized keyword reward computation.
214
+
215
+ Args:
216
+ keyword_matches: Number of keyword matches [batch_size]
217
+ total_keywords: Total keywords for each sample [batch_size]
218
+ bonus_matches: Bonus keyword matches [batch_size]
219
+ total_bonus_keywords: Total bonus keywords [batch_size]
220
+ bonus_multiplier: Bonus multiplier
221
+
222
+ Returns:
223
+ Keyword rewards [batch_size]
224
+ """
225
+ # Avoid division by zero
226
+ safe_total_keywords = mx.maximum(total_keywords, 1.0)
227
+ safe_total_bonus = mx.maximum(total_bonus_keywords, 1.0)
228
+
229
+ base_rewards = keyword_matches / safe_total_keywords
230
+ bonus_rewards = (bonus_matches / safe_total_bonus) * bonus_multiplier
231
+
232
+ # Clip to reasonable range
233
+ total_rewards = mx.minimum(base_rewards + bonus_rewards, 2.0)
234
+ return total_rewards
235
+
236
+
237
+ def create_mlx_optimized_batch_processor(
238
+ reward_configs: List[RewardConfig]
239
+ ) -> Callable[[List[str], List[str], List[Dict[str, Any]]], mx.array]:
240
+ """
241
+ Create fully MLX-optimized batch processor for maximum Apple Silicon performance.
242
+
243
+ This implementation follows DESIGN_GUIDELINES.md by:
244
+ 1. Using pure function composition
245
+ 2. Maximizing MLX compilation opportunities
246
+ 3. Minimizing memory allocations
247
+ 4. Utilizing unified memory efficiently
248
+ """
249
+
250
+ # Pre-compile all possible reward components
251
+ compiled_functions = {}
252
+
253
+ # Check which reward types we need and pre-compile them
254
+ for config in reward_configs:
255
+ if config.name == 'length_reward':
256
+ compiled_functions['length'] = compute_length_rewards_vectorized
257
+ elif config.name == 'keyword_reward':
258
+ compiled_functions['keyword'] = compute_keyword_rewards_vectorized
259
+
260
+ weights = mx.array([config.weight for config in reward_configs])
261
+
262
+ def optimized_batch_processor(
263
+ prompts: List[str],
264
+ completions: List[str],
265
+ examples: List[Dict[str, Any]]
266
+ ) -> mx.array:
267
+ """
268
+ Fully optimized batch processor using MLX compilation.
269
+ """
270
+ batch_size = len(prompts)
271
+
272
+ # Collect all reward arrays
273
+ all_rewards = []
274
+
275
+ # Process each reward type
276
+ for config_idx, config in enumerate(reward_configs):
277
+ if config.name == 'length_reward' and 'length' in compiled_functions:
278
+ # Use vectorized length computation
279
+ lengths = mx.array([len(comp.split()) for comp in completions])
280
+ target = config.params.get('target_length', 50)
281
+ tolerance = config.params.get('tolerance', 0.2)
282
+
283
+ rewards = compiled_functions['length'](lengths, float(target), tolerance)
284
+ all_rewards.append(rewards)
285
+
286
+ elif config.name == 'keyword_reward' and 'keyword' in compiled_functions:
287
+ # Use vectorized keyword computation
288
+ keywords = config.params.get('keywords', [])
289
+ if keywords:
290
+ # Preprocess keyword matches
291
+ keyword_matches = []
292
+ bonus_matches = []
293
+
294
+ for i, (prompt, completion) in enumerate(zip(prompts, completions)):
295
+ comp_lower = completion.lower()
296
+ prompt_lower = prompt.lower()
297
+
298
+ matches = sum(1 for kw in keywords if kw.lower() in comp_lower)
299
+ bonus_kws = [kw for kw in keywords if kw.lower() not in prompt_lower]
300
+ bonus = sum(1 for kw in bonus_kws if kw.lower() in comp_lower)
301
+
302
+ keyword_matches.append(matches)
303
+ bonus_matches.append(bonus)
304
+
305
+ # Vectorized computation
306
+ match_array = mx.array(keyword_matches)
307
+ bonus_array = mx.array(bonus_matches)
308
+ total_kw = mx.full((batch_size,), len(keywords))
309
+ total_bonus = mx.array([len([kw for kw in keywords if kw.lower() not in prompts[i].lower()]) for i in range(batch_size)])
310
+ multiplier = config.params.get('bonus_multiplier', 1.0)
311
+
312
+ rewards = compiled_functions['keyword'](
313
+ match_array, total_kw, bonus_array, total_bonus, multiplier
314
+ )
315
+ all_rewards.append(rewards)
316
+ else:
317
+ # No keywords specified
318
+ all_rewards.append(mx.zeros(batch_size))
319
+ else:
320
+ # Fallback to individual function calls
321
+ reward_func = create_configured_reward_function(config)
322
+ batch_rewards = []
323
+
324
+ for i in range(batch_size):
325
+ try:
326
+ reward = reward_func(prompts[i], completions[i], examples[i])
327
+ batch_rewards.append(float(reward))
328
+ except Exception:
329
+ batch_rewards.append(0.0)
330
+
331
+ all_rewards.append(mx.array(batch_rewards))
332
+
333
+ # Combine all rewards
334
+ if all_rewards:
335
+ reward_matrix = mx.stack(all_rewards, axis=1) # [batch_size, num_rewards]
336
+ return mx.sum(reward_matrix * weights, axis=1)
337
+ else:
338
+ return mx.zeros(batch_size)
339
+
340
+ return optimized_batch_processor
341
+
342
+
343
+ # Async processing for external reward models
344
+ async def create_async_batch_processor(
345
+ reward_configs: List[RewardConfig],
346
+ max_workers: int = 4
347
+ ) -> Callable[[List[str], List[str], List[Dict[str, Any]]], Coroutine[Any, Any, mx.array]]:
348
+ """
349
+ Create async batch processor for external reward models.
350
+
351
+ Maintains pure function composition while enabling async operations.
352
+ """
353
+
354
+ # Separate local and external reward configs
355
+ local_configs = []
356
+ external_configs = []
357
+
358
+ for config in reward_configs:
359
+ if config.params.get('external_url'):
360
+ external_configs.append(config)
361
+ else:
362
+ local_configs.append(config)
363
+
364
+ # Create local processor
365
+ local_processor = create_mlx_optimized_batch_processor(local_configs) if local_configs else None
366
+
367
+ async def async_batch_processor(
368
+ prompts: List[str],
369
+ completions: List[str],
370
+ examples: List[Dict[str, Any]]
371
+ ) -> mx.array:
372
+ """Async batch processor combining local and external rewards."""
373
+ batch_size = len(prompts)
374
+
375
+ # Process local rewards synchronously
376
+ local_rewards = None
377
+ if local_processor:
378
+ local_rewards = local_processor(prompts, completions, examples)
379
+
380
+ # Process external rewards asynchronously
381
+ external_rewards = None
382
+ if external_configs:
383
+ # Placeholder for external API calls
384
+ # In practice, this would make HTTP requests to external reward models
385
+ external_rewards = mx.zeros((batch_size, len(external_configs)))
386
+
387
+ # Combine results
388
+ if local_rewards is not None and external_rewards is not None:
389
+ all_rewards = mx.concatenate([local_rewards.reshape(-1, len(local_configs)), external_rewards], axis=1)
390
+ elif local_rewards is not None:
391
+ all_rewards = local_rewards.reshape(-1, len(local_configs))
392
+ elif external_rewards is not None:
393
+ all_rewards = external_rewards
394
+ else:
395
+ all_rewards = mx.zeros((batch_size, 1))
396
+
397
+ # Apply final weights
398
+ all_weights = mx.array([config.weight for config in reward_configs])
399
+ return mx.sum(all_rewards * all_weights, axis=1)
400
+
401
+ return async_batch_processor
402
+
403
+
404
+ # Utility functions for integration
405
+ def list_available_processors() -> Dict[str, List[str]]:
406
+ """List all available reward and verifier functions."""
407
+ return {
408
+ "rewards": list(REWARD_REGISTRY.keys()),
409
+ "verifiers": list(VERIFIER_REGISTRY.keys()),
410
+ "compiled_optimizations": ["length_reward", "keyword_reward"]
411
+ }
412
+
413
+
414
+ async def create_processor_from_config(
415
+ config_dict: Dict[str, Any]
416
+ ) -> Union[Callable[[List[str], List[str], List[Dict[str, Any]]], mx.array],
417
+ Callable[[List[str], List[str], List[Dict[str, Any]]], Coroutine[Any, Any, mx.array]]]:
418
+ """
419
+ Create processor from configuration dictionary following retrain's patterns.
420
+
421
+ Args:
422
+ config_dict: Configuration with reward specifications
423
+
424
+ Returns:
425
+ Batch processor function
426
+ """
427
+ reward_configs = []
428
+
429
+ for name, config in config_dict.items():
430
+ reward_config = RewardConfig(
431
+ name=name,
432
+ weight=config.get('weight', 1.0),
433
+ params=config.get('params', {}),
434
+ verifiers=config.get('verifiers', []),
435
+ verifier_penalty=config.get('verifier_penalty', 0.0)
436
+ )
437
+ reward_configs.append(reward_config)
438
+
439
+ # Choose optimal processor based on configuration
440
+ has_external = any(cfg.params.get('external_url') for cfg in reward_configs)
441
+
442
+ if has_external:
443
+ # Use async processor for external models
444
+ return await create_async_batch_processor(reward_configs)
445
+ else:
446
+ # Use MLX-optimized processor for local computation
447
+ return create_mlx_optimized_batch_processor(reward_configs)