textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
# textpolicy/generation/reload.py
|
|
2
|
+
"""
|
|
3
|
+
Pure functions for handling LoRA model reloading after training updates.
|
|
4
|
+
|
|
5
|
+
Critical for RL training: After each training step, the LoRA adapters are updated.
|
|
6
|
+
For the next rollout generation, we need to ensure the policy uses the updated model.
|
|
7
|
+
|
|
8
|
+
Following our design principles:
|
|
9
|
+
- Pure functions only
|
|
10
|
+
- No state management
|
|
11
|
+
- Zero abstraction cost
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from typing import Dict, Tuple, Any, Optional, Callable
|
|
15
|
+
import mlx.core as mx # type: ignore
|
|
16
|
+
import mlx.nn as nn # type: ignore
|
|
17
|
+
from .mlx_generation import load_model# type: ignore
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def save_adapters(
|
|
21
|
+
model: nn.Module,
|
|
22
|
+
adapter_path: str
|
|
23
|
+
) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Pure function to save LoRA adapter weights.
|
|
26
|
+
|
|
27
|
+
After training updates, save only the LoRA parameters to disk.
|
|
28
|
+
This is much faster than saving the full model.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: Model with trained LoRA adapters
|
|
32
|
+
adapter_path: Path to save adapters
|
|
33
|
+
"""
|
|
34
|
+
# Extract LoRA parameters
|
|
35
|
+
lora_params = {}
|
|
36
|
+
for name, param in model.named_parameters():
|
|
37
|
+
if 'lora_' in name.lower() and param.requires_grad:
|
|
38
|
+
lora_params[name] = param
|
|
39
|
+
|
|
40
|
+
# Save using MLX
|
|
41
|
+
mx.save_safetensors(adapter_path, lora_params)
|
|
42
|
+
print(f"✓ Saved LoRA adapters to {adapter_path}")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def reload_model(
|
|
46
|
+
base_model_path: str,
|
|
47
|
+
adapter_path: str,
|
|
48
|
+
tokenizer_config: Optional[Dict] = None
|
|
49
|
+
) -> Tuple[nn.Module, Any]:
|
|
50
|
+
"""
|
|
51
|
+
Pure function to reload model with updated LoRA adapters.
|
|
52
|
+
|
|
53
|
+
This is called after training to get an updated model for the next rollout.
|
|
54
|
+
Much more efficient than reloading the full model.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
base_model_path: Path to base model
|
|
58
|
+
adapter_path: Path to updated LoRA adapters
|
|
59
|
+
tokenizer_config: Optional tokenizer config
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
(updated_model, tokenizer): Model with updated adapters
|
|
63
|
+
"""
|
|
64
|
+
# Load fresh model with updated adapters
|
|
65
|
+
model, tokenizer = load_model(
|
|
66
|
+
model_path=base_model_path,
|
|
67
|
+
adapter_path=adapter_path,
|
|
68
|
+
tokenizer_config=tokenizer_config
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
print("✓ Reloaded model with updated adapters")
|
|
72
|
+
return model, tokenizer
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def create_reloadable_policy(
|
|
76
|
+
base_model_path: str,
|
|
77
|
+
initial_adapter_path: Optional[str],
|
|
78
|
+
generation_params: Dict[str, Any]
|
|
79
|
+
) -> Tuple[Callable, Callable]:
|
|
80
|
+
"""
|
|
81
|
+
Create a policy function that can be reloaded with updated LoRA adapters.
|
|
82
|
+
|
|
83
|
+
Returns both the policy function and a reload function for efficiency.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
base_model_path: Path to base model
|
|
87
|
+
initial_adapter_path: Initial LoRA adapter path (can be None)
|
|
88
|
+
generation_params: Generation parameters
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
(policy_fn, reload_fn): Policy function and reload function
|
|
92
|
+
"""
|
|
93
|
+
# Load initial model
|
|
94
|
+
current_model, tokenizer = load_model(
|
|
95
|
+
model_path=base_model_path,
|
|
96
|
+
adapter_path=initial_adapter_path
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Store current state
|
|
100
|
+
current_state = {
|
|
101
|
+
'model': current_model,
|
|
102
|
+
'tokenizer': tokenizer,
|
|
103
|
+
'base_path': base_model_path,
|
|
104
|
+
'generation_params': generation_params
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
def policy_fn(obs: mx.array, deterministic: bool = False) -> Tuple[mx.array, Dict[str, mx.array]]:
|
|
108
|
+
"""Policy function that always uses the current model."""
|
|
109
|
+
from .mlx_generation import generate_tokens
|
|
110
|
+
|
|
111
|
+
# Use current model (updated after each reload)
|
|
112
|
+
model = current_state['model']
|
|
113
|
+
tokenizer = current_state['tokenizer']
|
|
114
|
+
|
|
115
|
+
# Adjust temperature
|
|
116
|
+
temp = 0.0 if deterministic else generation_params.get("temperature", 0.8)
|
|
117
|
+
|
|
118
|
+
# Generate tokens
|
|
119
|
+
response_tokens, response_info = generate_tokens(
|
|
120
|
+
model=model,
|
|
121
|
+
tokenizer=tokenizer,
|
|
122
|
+
prompt_tokens=obs,
|
|
123
|
+
max_tokens=generation_params.get("max_tokens", 50),
|
|
124
|
+
temperature=temp,
|
|
125
|
+
top_p=generation_params.get("top_p", 0.95)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Extract logprobs from response info
|
|
129
|
+
logprobs = response_info.get('logprob', mx.array([]))
|
|
130
|
+
|
|
131
|
+
extras = {
|
|
132
|
+
"logprob": logprobs,
|
|
133
|
+
"entropy": mx.mean(logprobs) if len(logprobs) > 0 else mx.array(0.0)
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
return response_tokens, extras
|
|
137
|
+
|
|
138
|
+
def reload_fn(adapter_path: str) -> None:
|
|
139
|
+
"""Reload the model with updated adapters."""
|
|
140
|
+
updated_model, updated_tokenizer = reload_model(
|
|
141
|
+
base_model_path=current_state['base_path'],
|
|
142
|
+
adapter_path=adapter_path
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Update current state
|
|
146
|
+
current_state['model'] = updated_model
|
|
147
|
+
current_state['tokenizer'] = updated_tokenizer
|
|
148
|
+
|
|
149
|
+
print(f"✓ Policy reloaded with adapters from {adapter_path}")
|
|
150
|
+
|
|
151
|
+
return policy_fn, reload_fn
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def create_training_loop_with_reload(
|
|
155
|
+
base_model_path: str,
|
|
156
|
+
adapter_save_path: str,
|
|
157
|
+
generation_params: Dict[str, Any],
|
|
158
|
+
trainer
|
|
159
|
+
) -> Tuple[Callable, Callable]:
|
|
160
|
+
"""
|
|
161
|
+
Create a complete training loop that handles LoRA reloading.
|
|
162
|
+
|
|
163
|
+
This solves the LoRA update problem by automatically saving and reloading
|
|
164
|
+
adapters after each training step.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
base_model_path: Path to base model
|
|
168
|
+
adapter_save_path: Path to save LoRA adapters
|
|
169
|
+
generation_params: Generation parameters
|
|
170
|
+
trainer: GRPO trainer instance
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
(train_and_reload_fn, policy_fn): Training function and policy
|
|
174
|
+
"""
|
|
175
|
+
# Create reloadable policy
|
|
176
|
+
policy_fn, reload_fn = create_reloadable_policy(
|
|
177
|
+
base_model_path=base_model_path,
|
|
178
|
+
initial_adapter_path=None, # Start with base model
|
|
179
|
+
generation_params=generation_params
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def train_and_reload_fn(rollout_data) -> Dict[str, float]:
|
|
183
|
+
"""
|
|
184
|
+
Train the model and reload policy with updated adapters.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
rollout_data: Rollout data for training
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Training metrics
|
|
191
|
+
"""
|
|
192
|
+
# 1. Train the model (updates LoRA adapters in-place)
|
|
193
|
+
metrics = trainer.train(rollout_data)
|
|
194
|
+
|
|
195
|
+
# 2. Save updated LoRA adapters
|
|
196
|
+
save_adapters(trainer.model, adapter_save_path)
|
|
197
|
+
|
|
198
|
+
# 3. Reload policy with updated adapters
|
|
199
|
+
reload_fn(adapter_save_path)
|
|
200
|
+
|
|
201
|
+
print("✓ Training step complete, policy updated")
|
|
202
|
+
return metrics
|
|
203
|
+
|
|
204
|
+
return train_and_reload_fn, policy_fn
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# Simple wrapper for the most common use case
|
|
208
|
+
def create_auto_reload_setup(
|
|
209
|
+
model_path: str,
|
|
210
|
+
adapter_save_path: str = "./lora_adapters.safetensors",
|
|
211
|
+
**generation_params
|
|
212
|
+
) -> Tuple[Callable, Callable, nn.Module, Any]:
|
|
213
|
+
"""
|
|
214
|
+
Create complete auto-reloading setup for LoRA training.
|
|
215
|
+
|
|
216
|
+
This is the main function most users should call.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
model_path: Path to base model
|
|
220
|
+
adapter_save_path: Where to save LoRA adapters
|
|
221
|
+
**generation_params: Generation parameters
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
(policy_fn, reload_fn, model, tokenizer): Complete setup
|
|
225
|
+
"""
|
|
226
|
+
from .lora import create_lora_setup
|
|
227
|
+
|
|
228
|
+
# Load and setup LoRA model
|
|
229
|
+
base_model, tokenizer = load_model(model_path)
|
|
230
|
+
|
|
231
|
+
lora_config = {
|
|
232
|
+
"lora_layers": generation_params.pop("lora_layers", 8),
|
|
233
|
+
"lora_rank": generation_params.pop("lora_rank", 8),
|
|
234
|
+
"lora_scale": generation_params.pop("lora_scale", 20.0)
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
lora_model, memory_stats = create_lora_setup(
|
|
238
|
+
model=base_model,
|
|
239
|
+
lora_config=lora_config
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Create reloadable policy
|
|
243
|
+
policy_fn, reload_fn = create_reloadable_policy(
|
|
244
|
+
base_model_path=model_path,
|
|
245
|
+
initial_adapter_path=None,
|
|
246
|
+
generation_params=generation_params
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
print("✓ Auto-reload setup complete")
|
|
250
|
+
print(f" Memory savings: {memory_stats['memory_savings_percent']:.1f}%")
|
|
251
|
+
print(f" Adapter save path: {adapter_save_path}")
|
|
252
|
+
|
|
253
|
+
return policy_fn, reload_fn, lora_model, tokenizer
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# textpolicy/rewards/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Unified reward and verification system for MLX-optimized text generation.
|
|
4
|
+
|
|
5
|
+
Key principles:
|
|
6
|
+
- Decorator-based registration (@reward, @verifier)
|
|
7
|
+
- Pure function composition with zero abstraction cost
|
|
8
|
+
- MLX compilation for Apple Silicon optimization
|
|
9
|
+
- Modular system allowing custom rewards and verifiers
|
|
10
|
+
- Pre-filtering verification approach
|
|
11
|
+
- Signature consistency: (prompt, completion, example, **kwargs)
|
|
12
|
+
|
|
13
|
+
This system provides modular reward computation with MLX optimization.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
# Registry system
|
|
17
|
+
from .registry import (
|
|
18
|
+
reward, verifier, # Decorators for registration
|
|
19
|
+
RewardFunction, VerifierFunction, RewardConfig,
|
|
20
|
+
get_reward_function, get_verifier_function,
|
|
21
|
+
apply_verifiers_to_reward,
|
|
22
|
+
create_configured_reward_function,
|
|
23
|
+
list_registered_functions, clear_registries,
|
|
24
|
+
REWARD_REGISTRY, VERIFIER_REGISTRY
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Core reward functions (auto-registered via decorators)
|
|
28
|
+
from .basic import (
|
|
29
|
+
length_reward,
|
|
30
|
+
keyword_reward,
|
|
31
|
+
perplexity_reward,
|
|
32
|
+
accuracy_reward
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Core verifier functions (auto-registered via decorators)
|
|
36
|
+
from .verifiers import (
|
|
37
|
+
length_verifier,
|
|
38
|
+
toxicity_verifier,
|
|
39
|
+
coherence_verifier,
|
|
40
|
+
factual_verifier,
|
|
41
|
+
has_greeting,
|
|
42
|
+
no_empty_response,
|
|
43
|
+
contains_keywords,
|
|
44
|
+
# Legacy compatibility (deprecated)
|
|
45
|
+
create_default_verifier_pipeline,
|
|
46
|
+
create_custom_verifier_pipeline
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# MLX-optimized batch processing (following DESIGN_GUIDELINES.md)
|
|
50
|
+
from .mlx_batch_processor import (
|
|
51
|
+
create_batch_reward_processor,
|
|
52
|
+
create_mlx_optimized_batch_processor,
|
|
53
|
+
create_async_batch_processor,
|
|
54
|
+
create_processor_from_config,
|
|
55
|
+
list_available_processors,
|
|
56
|
+
compute_length_rewards_vectorized,
|
|
57
|
+
compute_keyword_rewards_vectorized
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Core adapters and utilities for MLX reward system
|
|
61
|
+
from .adapters import (
|
|
62
|
+
# Configuration models
|
|
63
|
+
MLXRewardConfig, MLXRewardSystemConfig,
|
|
64
|
+
# Sample types
|
|
65
|
+
MLXSample, MLXExternalRewardModel,
|
|
66
|
+
# Math utilities
|
|
67
|
+
extract_boxed_answer, normalize_math_answer, compute_f1_score,
|
|
68
|
+
# Bridge functions
|
|
69
|
+
create_mlx_batch_adapter, create_mlx_system_from_config,
|
|
70
|
+
samples_to_mlx_format, mlx_format_to_samples,
|
|
71
|
+
get_available_adapters,
|
|
72
|
+
# Reward functions
|
|
73
|
+
math_accuracy_reward, f1_score_reward
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Legacy systems (maintained for backward compatibility)
|
|
77
|
+
from .rollout_rewards import (
|
|
78
|
+
RolloutRewardProcessor,
|
|
79
|
+
create_rollout_reward_processor,
|
|
80
|
+
process_episode_batch_rewards,
|
|
81
|
+
compute_reward_vector
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
from .integrated_system import (
|
|
85
|
+
IntegratedRewardConfig,
|
|
86
|
+
IntegratedRolloutRewardSystem,
|
|
87
|
+
create_integrated_reward_system,
|
|
88
|
+
process_episodes_with_quality_control,
|
|
89
|
+
compute_integrated_rewards
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
__all__ = [
|
|
93
|
+
# Registry system (primary interface)
|
|
94
|
+
"reward", "verifier", # Decorators
|
|
95
|
+
"RewardFunction", "VerifierFunction", "RewardConfig",
|
|
96
|
+
"get_reward_function", "get_verifier_function",
|
|
97
|
+
"apply_verifiers_to_reward", "create_configured_reward_function",
|
|
98
|
+
"list_registered_functions", "clear_registries",
|
|
99
|
+
"REWARD_REGISTRY", "VERIFIER_REGISTRY",
|
|
100
|
+
|
|
101
|
+
# Core reward functions (auto-registered)
|
|
102
|
+
"length_reward", "keyword_reward", "perplexity_reward", "accuracy_reward",
|
|
103
|
+
|
|
104
|
+
# Core verifier functions (auto-registered)
|
|
105
|
+
"length_verifier", "toxicity_verifier", "coherence_verifier", "factual_verifier",
|
|
106
|
+
"has_greeting", "no_empty_response", "contains_keywords",
|
|
107
|
+
|
|
108
|
+
# MLX-optimized batch processing (primary interface for training)
|
|
109
|
+
"create_batch_reward_processor",
|
|
110
|
+
"create_mlx_optimized_batch_processor",
|
|
111
|
+
"create_async_batch_processor",
|
|
112
|
+
"create_processor_from_config",
|
|
113
|
+
"list_available_processors",
|
|
114
|
+
"compute_length_rewards_vectorized",
|
|
115
|
+
"compute_keyword_rewards_vectorized",
|
|
116
|
+
|
|
117
|
+
# Core adapters and utilities
|
|
118
|
+
"MLXRewardConfig", "MLXRewardSystemConfig",
|
|
119
|
+
"MLXSample", "MLXExternalRewardModel",
|
|
120
|
+
"extract_boxed_answer", "normalize_math_answer", "compute_f1_score",
|
|
121
|
+
"create_mlx_batch_adapter", "create_mlx_system_from_config",
|
|
122
|
+
"samples_to_mlx_format", "mlx_format_to_samples", "get_available_adapters",
|
|
123
|
+
"math_accuracy_reward", "f1_score_reward",
|
|
124
|
+
|
|
125
|
+
# Legacy systems (backward compatibility)
|
|
126
|
+
"RolloutRewardProcessor", "create_rollout_reward_processor",
|
|
127
|
+
"process_episode_batch_rewards", "compute_reward_vector",
|
|
128
|
+
"IntegratedRewardConfig", "IntegratedRolloutRewardSystem",
|
|
129
|
+
"create_integrated_reward_system", "process_episodes_with_quality_control",
|
|
130
|
+
"compute_integrated_rewards",
|
|
131
|
+
"create_default_verifier_pipeline", "create_custom_verifier_pipeline",
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
# Import basic rewards and verifiers to trigger registration
|
|
135
|
+
import textpolicy.rewards.basic
|
|
136
|
+
import textpolicy.rewards.verifiers
|
|
137
|
+
import textpolicy.rewards.adapters # Register adapted rewards
|