textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,253 @@
1
+ # textpolicy/generation/reload.py
2
+ """
3
+ Pure functions for handling LoRA model reloading after training updates.
4
+
5
+ Critical for RL training: After each training step, the LoRA adapters are updated.
6
+ For the next rollout generation, we need to ensure the policy uses the updated model.
7
+
8
+ Following our design principles:
9
+ - Pure functions only
10
+ - No state management
11
+ - Zero abstraction cost
12
+ """
13
+
14
+ from typing import Dict, Tuple, Any, Optional, Callable
15
+ import mlx.core as mx # type: ignore
16
+ import mlx.nn as nn # type: ignore
17
+ from .mlx_generation import load_model# type: ignore
18
+
19
+
20
+ def save_adapters(
21
+ model: nn.Module,
22
+ adapter_path: str
23
+ ) -> None:
24
+ """
25
+ Pure function to save LoRA adapter weights.
26
+
27
+ After training updates, save only the LoRA parameters to disk.
28
+ This is much faster than saving the full model.
29
+
30
+ Args:
31
+ model: Model with trained LoRA adapters
32
+ adapter_path: Path to save adapters
33
+ """
34
+ # Extract LoRA parameters
35
+ lora_params = {}
36
+ for name, param in model.named_parameters():
37
+ if 'lora_' in name.lower() and param.requires_grad:
38
+ lora_params[name] = param
39
+
40
+ # Save using MLX
41
+ mx.save_safetensors(adapter_path, lora_params)
42
+ print(f"✓ Saved LoRA adapters to {adapter_path}")
43
+
44
+
45
+ def reload_model(
46
+ base_model_path: str,
47
+ adapter_path: str,
48
+ tokenizer_config: Optional[Dict] = None
49
+ ) -> Tuple[nn.Module, Any]:
50
+ """
51
+ Pure function to reload model with updated LoRA adapters.
52
+
53
+ This is called after training to get an updated model for the next rollout.
54
+ Much more efficient than reloading the full model.
55
+
56
+ Args:
57
+ base_model_path: Path to base model
58
+ adapter_path: Path to updated LoRA adapters
59
+ tokenizer_config: Optional tokenizer config
60
+
61
+ Returns:
62
+ (updated_model, tokenizer): Model with updated adapters
63
+ """
64
+ # Load fresh model with updated adapters
65
+ model, tokenizer = load_model(
66
+ model_path=base_model_path,
67
+ adapter_path=adapter_path,
68
+ tokenizer_config=tokenizer_config
69
+ )
70
+
71
+ print("✓ Reloaded model with updated adapters")
72
+ return model, tokenizer
73
+
74
+
75
+ def create_reloadable_policy(
76
+ base_model_path: str,
77
+ initial_adapter_path: Optional[str],
78
+ generation_params: Dict[str, Any]
79
+ ) -> Tuple[Callable, Callable]:
80
+ """
81
+ Create a policy function that can be reloaded with updated LoRA adapters.
82
+
83
+ Returns both the policy function and a reload function for efficiency.
84
+
85
+ Args:
86
+ base_model_path: Path to base model
87
+ initial_adapter_path: Initial LoRA adapter path (can be None)
88
+ generation_params: Generation parameters
89
+
90
+ Returns:
91
+ (policy_fn, reload_fn): Policy function and reload function
92
+ """
93
+ # Load initial model
94
+ current_model, tokenizer = load_model(
95
+ model_path=base_model_path,
96
+ adapter_path=initial_adapter_path
97
+ )
98
+
99
+ # Store current state
100
+ current_state = {
101
+ 'model': current_model,
102
+ 'tokenizer': tokenizer,
103
+ 'base_path': base_model_path,
104
+ 'generation_params': generation_params
105
+ }
106
+
107
+ def policy_fn(obs: mx.array, deterministic: bool = False) -> Tuple[mx.array, Dict[str, mx.array]]:
108
+ """Policy function that always uses the current model."""
109
+ from .mlx_generation import generate_tokens
110
+
111
+ # Use current model (updated after each reload)
112
+ model = current_state['model']
113
+ tokenizer = current_state['tokenizer']
114
+
115
+ # Adjust temperature
116
+ temp = 0.0 if deterministic else generation_params.get("temperature", 0.8)
117
+
118
+ # Generate tokens
119
+ response_tokens, response_info = generate_tokens(
120
+ model=model,
121
+ tokenizer=tokenizer,
122
+ prompt_tokens=obs,
123
+ max_tokens=generation_params.get("max_tokens", 50),
124
+ temperature=temp,
125
+ top_p=generation_params.get("top_p", 0.95)
126
+ )
127
+
128
+ # Extract logprobs from response info
129
+ logprobs = response_info.get('logprob', mx.array([]))
130
+
131
+ extras = {
132
+ "logprob": logprobs,
133
+ "entropy": mx.mean(logprobs) if len(logprobs) > 0 else mx.array(0.0)
134
+ }
135
+
136
+ return response_tokens, extras
137
+
138
+ def reload_fn(adapter_path: str) -> None:
139
+ """Reload the model with updated adapters."""
140
+ updated_model, updated_tokenizer = reload_model(
141
+ base_model_path=current_state['base_path'],
142
+ adapter_path=adapter_path
143
+ )
144
+
145
+ # Update current state
146
+ current_state['model'] = updated_model
147
+ current_state['tokenizer'] = updated_tokenizer
148
+
149
+ print(f"✓ Policy reloaded with adapters from {adapter_path}")
150
+
151
+ return policy_fn, reload_fn
152
+
153
+
154
+ def create_training_loop_with_reload(
155
+ base_model_path: str,
156
+ adapter_save_path: str,
157
+ generation_params: Dict[str, Any],
158
+ trainer
159
+ ) -> Tuple[Callable, Callable]:
160
+ """
161
+ Create a complete training loop that handles LoRA reloading.
162
+
163
+ This solves the LoRA update problem by automatically saving and reloading
164
+ adapters after each training step.
165
+
166
+ Args:
167
+ base_model_path: Path to base model
168
+ adapter_save_path: Path to save LoRA adapters
169
+ generation_params: Generation parameters
170
+ trainer: GRPO trainer instance
171
+
172
+ Returns:
173
+ (train_and_reload_fn, policy_fn): Training function and policy
174
+ """
175
+ # Create reloadable policy
176
+ policy_fn, reload_fn = create_reloadable_policy(
177
+ base_model_path=base_model_path,
178
+ initial_adapter_path=None, # Start with base model
179
+ generation_params=generation_params
180
+ )
181
+
182
+ def train_and_reload_fn(rollout_data) -> Dict[str, float]:
183
+ """
184
+ Train the model and reload policy with updated adapters.
185
+
186
+ Args:
187
+ rollout_data: Rollout data for training
188
+
189
+ Returns:
190
+ Training metrics
191
+ """
192
+ # 1. Train the model (updates LoRA adapters in-place)
193
+ metrics = trainer.train(rollout_data)
194
+
195
+ # 2. Save updated LoRA adapters
196
+ save_adapters(trainer.model, adapter_save_path)
197
+
198
+ # 3. Reload policy with updated adapters
199
+ reload_fn(adapter_save_path)
200
+
201
+ print("✓ Training step complete, policy updated")
202
+ return metrics
203
+
204
+ return train_and_reload_fn, policy_fn
205
+
206
+
207
+ # Simple wrapper for the most common use case
208
+ def create_auto_reload_setup(
209
+ model_path: str,
210
+ adapter_save_path: str = "./lora_adapters.safetensors",
211
+ **generation_params
212
+ ) -> Tuple[Callable, Callable, nn.Module, Any]:
213
+ """
214
+ Create complete auto-reloading setup for LoRA training.
215
+
216
+ This is the main function most users should call.
217
+
218
+ Args:
219
+ model_path: Path to base model
220
+ adapter_save_path: Where to save LoRA adapters
221
+ **generation_params: Generation parameters
222
+
223
+ Returns:
224
+ (policy_fn, reload_fn, model, tokenizer): Complete setup
225
+ """
226
+ from .lora import create_lora_setup
227
+
228
+ # Load and setup LoRA model
229
+ base_model, tokenizer = load_model(model_path)
230
+
231
+ lora_config = {
232
+ "lora_layers": generation_params.pop("lora_layers", 8),
233
+ "lora_rank": generation_params.pop("lora_rank", 8),
234
+ "lora_scale": generation_params.pop("lora_scale", 20.0)
235
+ }
236
+
237
+ lora_model, memory_stats = create_lora_setup(
238
+ model=base_model,
239
+ lora_config=lora_config
240
+ )
241
+
242
+ # Create reloadable policy
243
+ policy_fn, reload_fn = create_reloadable_policy(
244
+ base_model_path=model_path,
245
+ initial_adapter_path=None,
246
+ generation_params=generation_params
247
+ )
248
+
249
+ print("✓ Auto-reload setup complete")
250
+ print(f" Memory savings: {memory_stats['memory_savings_percent']:.1f}%")
251
+ print(f" Adapter save path: {adapter_save_path}")
252
+
253
+ return policy_fn, reload_fn, lora_model, tokenizer
@@ -0,0 +1,137 @@
1
+ # textpolicy/rewards/__init__.py
2
+ """
3
+ Unified reward and verification system for MLX-optimized text generation.
4
+
5
+ Key principles:
6
+ - Decorator-based registration (@reward, @verifier)
7
+ - Pure function composition with zero abstraction cost
8
+ - MLX compilation for Apple Silicon optimization
9
+ - Modular system allowing custom rewards and verifiers
10
+ - Pre-filtering verification approach
11
+ - Signature consistency: (prompt, completion, example, **kwargs)
12
+
13
+ This system provides modular reward computation with MLX optimization.
14
+ """
15
+
16
+ # Registry system
17
+ from .registry import (
18
+ reward, verifier, # Decorators for registration
19
+ RewardFunction, VerifierFunction, RewardConfig,
20
+ get_reward_function, get_verifier_function,
21
+ apply_verifiers_to_reward,
22
+ create_configured_reward_function,
23
+ list_registered_functions, clear_registries,
24
+ REWARD_REGISTRY, VERIFIER_REGISTRY
25
+ )
26
+
27
+ # Core reward functions (auto-registered via decorators)
28
+ from .basic import (
29
+ length_reward,
30
+ keyword_reward,
31
+ perplexity_reward,
32
+ accuracy_reward
33
+ )
34
+
35
+ # Core verifier functions (auto-registered via decorators)
36
+ from .verifiers import (
37
+ length_verifier,
38
+ toxicity_verifier,
39
+ coherence_verifier,
40
+ factual_verifier,
41
+ has_greeting,
42
+ no_empty_response,
43
+ contains_keywords,
44
+ # Legacy compatibility (deprecated)
45
+ create_default_verifier_pipeline,
46
+ create_custom_verifier_pipeline
47
+ )
48
+
49
+ # MLX-optimized batch processing (following DESIGN_GUIDELINES.md)
50
+ from .mlx_batch_processor import (
51
+ create_batch_reward_processor,
52
+ create_mlx_optimized_batch_processor,
53
+ create_async_batch_processor,
54
+ create_processor_from_config,
55
+ list_available_processors,
56
+ compute_length_rewards_vectorized,
57
+ compute_keyword_rewards_vectorized
58
+ )
59
+
60
+ # Core adapters and utilities for MLX reward system
61
+ from .adapters import (
62
+ # Configuration models
63
+ MLXRewardConfig, MLXRewardSystemConfig,
64
+ # Sample types
65
+ MLXSample, MLXExternalRewardModel,
66
+ # Math utilities
67
+ extract_boxed_answer, normalize_math_answer, compute_f1_score,
68
+ # Bridge functions
69
+ create_mlx_batch_adapter, create_mlx_system_from_config,
70
+ samples_to_mlx_format, mlx_format_to_samples,
71
+ get_available_adapters,
72
+ # Reward functions
73
+ math_accuracy_reward, f1_score_reward
74
+ )
75
+
76
+ # Legacy systems (maintained for backward compatibility)
77
+ from .rollout_rewards import (
78
+ RolloutRewardProcessor,
79
+ create_rollout_reward_processor,
80
+ process_episode_batch_rewards,
81
+ compute_reward_vector
82
+ )
83
+
84
+ from .integrated_system import (
85
+ IntegratedRewardConfig,
86
+ IntegratedRolloutRewardSystem,
87
+ create_integrated_reward_system,
88
+ process_episodes_with_quality_control,
89
+ compute_integrated_rewards
90
+ )
91
+
92
+ __all__ = [
93
+ # Registry system (primary interface)
94
+ "reward", "verifier", # Decorators
95
+ "RewardFunction", "VerifierFunction", "RewardConfig",
96
+ "get_reward_function", "get_verifier_function",
97
+ "apply_verifiers_to_reward", "create_configured_reward_function",
98
+ "list_registered_functions", "clear_registries",
99
+ "REWARD_REGISTRY", "VERIFIER_REGISTRY",
100
+
101
+ # Core reward functions (auto-registered)
102
+ "length_reward", "keyword_reward", "perplexity_reward", "accuracy_reward",
103
+
104
+ # Core verifier functions (auto-registered)
105
+ "length_verifier", "toxicity_verifier", "coherence_verifier", "factual_verifier",
106
+ "has_greeting", "no_empty_response", "contains_keywords",
107
+
108
+ # MLX-optimized batch processing (primary interface for training)
109
+ "create_batch_reward_processor",
110
+ "create_mlx_optimized_batch_processor",
111
+ "create_async_batch_processor",
112
+ "create_processor_from_config",
113
+ "list_available_processors",
114
+ "compute_length_rewards_vectorized",
115
+ "compute_keyword_rewards_vectorized",
116
+
117
+ # Core adapters and utilities
118
+ "MLXRewardConfig", "MLXRewardSystemConfig",
119
+ "MLXSample", "MLXExternalRewardModel",
120
+ "extract_boxed_answer", "normalize_math_answer", "compute_f1_score",
121
+ "create_mlx_batch_adapter", "create_mlx_system_from_config",
122
+ "samples_to_mlx_format", "mlx_format_to_samples", "get_available_adapters",
123
+ "math_accuracy_reward", "f1_score_reward",
124
+
125
+ # Legacy systems (backward compatibility)
126
+ "RolloutRewardProcessor", "create_rollout_reward_processor",
127
+ "process_episode_batch_rewards", "compute_reward_vector",
128
+ "IntegratedRewardConfig", "IntegratedRolloutRewardSystem",
129
+ "create_integrated_reward_system", "process_episodes_with_quality_control",
130
+ "compute_integrated_rewards",
131
+ "create_default_verifier_pipeline", "create_custom_verifier_pipeline",
132
+ ]
133
+
134
+ # Import basic rewards and verifiers to trigger registration
135
+ import textpolicy.rewards.basic
136
+ import textpolicy.rewards.verifiers
137
+ import textpolicy.rewards.adapters # Register adapted rewards