textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +53 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +797 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.1.dist-info/METADATA +109 -0
- textpolicy-0.1.1.dist-info/RECORD +66 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
- textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,557 @@
|
|
|
1
|
+
# textpolicy/generation/mlx_generation.py
|
|
2
|
+
"""
|
|
3
|
+
Complete MLX-LM text generation functions for RL training.
|
|
4
|
+
|
|
5
|
+
This module provides proper integration with MLX-LM for text generation RL,
|
|
6
|
+
including correct logprob extraction for policy gradient training.
|
|
7
|
+
|
|
8
|
+
Key functions:
|
|
9
|
+
- load_model: Load MLX model and tokenizer
|
|
10
|
+
- generate_tokens: Generate text with logprob tracking
|
|
11
|
+
- compute_logprobs: Extract logprobs for RL training
|
|
12
|
+
- create_policy: Create policy function for rollout collection
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
from typing import Dict, Optional, Tuple, Any, Callable
|
|
17
|
+
import mlx.core as mx
|
|
18
|
+
import mlx.nn as nn
|
|
19
|
+
try:
|
|
20
|
+
from mlx_lm import load, generate
|
|
21
|
+
HAS_MLX_LM = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
HAS_MLX_LM = False
|
|
24
|
+
print("Warning: mlx_lm not found. Using fallback implementations.")
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
from mlx_lm.sample_utils import make_sampler, make_logits_processors
|
|
28
|
+
# sampling utilities fallback when sample_utils is unavailable
|
|
29
|
+
except ImportError:
|
|
30
|
+
_make_sampler = None
|
|
31
|
+
_make_logits_processors = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_eos_configs_for_model(
|
|
35
|
+
model_path: str,
|
|
36
|
+
tokenizer_config: Optional[Dict]
|
|
37
|
+
) -> Tuple[Optional[Dict], Dict[str, Any]]:
|
|
38
|
+
"""
|
|
39
|
+
Determine tokenizer_config and model_config for proper EOS handling based on model type.
|
|
40
|
+
"""
|
|
41
|
+
model_config: Dict[str, Any] = {}
|
|
42
|
+
if tokenizer_config is None and "Qwen" in model_path:
|
|
43
|
+
tokenizer_config = {}
|
|
44
|
+
if "Qwen" in model_path:
|
|
45
|
+
# For Qwen Instruct variants, let tokenizer.eos_token_id (<|im_end|>) prevail;
|
|
46
|
+
# override only for base Qwen to use <|endoftext|> (151643) as EOS.
|
|
47
|
+
if "Instruct" not in model_path:
|
|
48
|
+
eos_id = 151643
|
|
49
|
+
model_config["eos_token_id"] = eos_id
|
|
50
|
+
return tokenizer_config, model_config
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _prepare_tokenizer(tokenizer: Any, verbose: bool) -> None:
|
|
54
|
+
"""
|
|
55
|
+
Configure tokenizer verbosity and ensure EOS token IDs for stopping.
|
|
56
|
+
"""
|
|
57
|
+
tokenizer.verbose = verbose
|
|
58
|
+
# Force tokenizer's EOS to Qwen's natural <|endoftext|> when available
|
|
59
|
+
eos_id = getattr(tokenizer, 'eos_token_id', None)
|
|
60
|
+
if eos_id is not None:
|
|
61
|
+
# Override tokenizer.eos_token to match eos_token_id for natural stopping
|
|
62
|
+
tokenizer.eos_token_id = eos_id
|
|
63
|
+
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(eos_id)
|
|
64
|
+
tokenizer.eos_token_ids = [eos_id]
|
|
65
|
+
# Align pad_token to EOS to ensure MLX-LM uses EOS for padding/stopping
|
|
66
|
+
tokenizer.pad_token_id = eos_id
|
|
67
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _make_eos_safe_sampler(temp: float, top_p: float) -> Any:
|
|
71
|
+
"""
|
|
72
|
+
Build a sampler that does not prune low-probability tokens (e.g., EOS) and encourages natural stopping.
|
|
73
|
+
"""
|
|
74
|
+
if make_sampler is not None:
|
|
75
|
+
# Use more conservative sampling parameters to encourage natural EOS generation
|
|
76
|
+
# Lower min_p and ensure we keep more tokens in consideration
|
|
77
|
+
return make_sampler(
|
|
78
|
+
temp=temp,
|
|
79
|
+
top_p=top_p,
|
|
80
|
+
min_p=0.0, # Don't filter out low-probability tokens like EOS
|
|
81
|
+
min_tokens_to_keep=2, # Keep at least 2 tokens to ensure EOS has a chance
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
# Fallback implementation when mlx_lm.sample_utils is not available
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _make_logits_processors(repetition_penalty: float) -> Any:
|
|
89
|
+
"""
|
|
90
|
+
Create logits processors to enforce repetition penalty.
|
|
91
|
+
"""
|
|
92
|
+
if make_logits_processors is not None:
|
|
93
|
+
return make_logits_processors(
|
|
94
|
+
repetition_penalty=repetition_penalty,
|
|
95
|
+
repetition_context_size=20,
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
# Fallback implementation when mlx_lm.sample_utils is not available
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _extract_response_tokens(
|
|
103
|
+
response: Any,
|
|
104
|
+
prompt_list: Any,
|
|
105
|
+
tokenizer: Any,
|
|
106
|
+
) -> mx.array:
|
|
107
|
+
"""
|
|
108
|
+
Extract response token IDs from a raw generation output (string or list).
|
|
109
|
+
Enhanced to better handle EOS tokens and edge cases.
|
|
110
|
+
"""
|
|
111
|
+
if isinstance(response, list):
|
|
112
|
+
return mx.array(response)
|
|
113
|
+
try:
|
|
114
|
+
full_tokens = tokenizer.encode(response)
|
|
115
|
+
eos_id = getattr(tokenizer, 'eos_token_id', None)
|
|
116
|
+
|
|
117
|
+
# First, try to find EOS token and include it in response
|
|
118
|
+
if eos_id is not None and eos_id in full_tokens:
|
|
119
|
+
idx = full_tokens.index(eos_id)
|
|
120
|
+
# Include EOS token in response for proper reward calculation
|
|
121
|
+
resp = full_tokens[len(prompt_list): idx + 1]
|
|
122
|
+
else:
|
|
123
|
+
# No EOS found - extract response portion without EOS
|
|
124
|
+
try:
|
|
125
|
+
prompt_text = tokenizer.decode(prompt_list)
|
|
126
|
+
if response.startswith(prompt_text):
|
|
127
|
+
tail = response[len(prompt_text):]
|
|
128
|
+
else:
|
|
129
|
+
tail = response
|
|
130
|
+
resp = tokenizer.encode(tail.strip()) if tail.strip() else []
|
|
131
|
+
except Exception:
|
|
132
|
+
# Fallback: encode the whole response and hope for the best
|
|
133
|
+
resp = tokenizer.encode(response) if response else []
|
|
134
|
+
|
|
135
|
+
return mx.array(resp) if resp else mx.array([])
|
|
136
|
+
except Exception:
|
|
137
|
+
return mx.array([])
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def load_model(
|
|
141
|
+
model_path: str,
|
|
142
|
+
adapter_path: Optional[str] = None,
|
|
143
|
+
tokenizer_config: Optional[Dict] = None,
|
|
144
|
+
verbose: bool = False
|
|
145
|
+
) -> Tuple[nn.Module, Any]:
|
|
146
|
+
"""
|
|
147
|
+
Load MLX model and tokenizer for RL training.
|
|
148
|
+
|
|
149
|
+
This function properly loads MLX-LM models with support for LoRA adapters
|
|
150
|
+
and ensures compatibility with our training system. Automatically configures
|
|
151
|
+
proper EOS tokens for Qwen models to ensure correct generation stopping.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
model_path: Path or HuggingFace model ID
|
|
155
|
+
adapter_path: Optional LoRA adapter path
|
|
156
|
+
tokenizer_config: Optional tokenizer configuration for EOS tokens
|
|
157
|
+
verbose: Enable debug logging for chat template application
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
(model, tokenizer): MLX model and tokenizer instances
|
|
161
|
+
"""
|
|
162
|
+
if not HAS_MLX_LM:
|
|
163
|
+
raise ImportError("mlx_lm is required. Install with: pip install mlx-lm")
|
|
164
|
+
|
|
165
|
+
print(f"Loading MLX model: {model_path}")
|
|
166
|
+
if adapter_path:
|
|
167
|
+
print(f"Loading with LoRA adapters: {adapter_path}")
|
|
168
|
+
|
|
169
|
+
# Configure model & tokenizer for EOS handling based on model type
|
|
170
|
+
tokenizer_config, model_config = _get_eos_configs_for_model(model_path, tokenizer_config)
|
|
171
|
+
model, tokenizer = load(
|
|
172
|
+
path_or_hf_repo=model_path,
|
|
173
|
+
adapter_path=adapter_path,
|
|
174
|
+
tokenizer_config=tokenizer_config,
|
|
175
|
+
model_config=model_config,
|
|
176
|
+
lazy=False,
|
|
177
|
+
)
|
|
178
|
+
_prepare_tokenizer(tokenizer, verbose)
|
|
179
|
+
print("✓ Model loaded successfully")
|
|
180
|
+
return model, tokenizer
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def generate_tokens(
|
|
184
|
+
model: nn.Module,
|
|
185
|
+
tokenizer: Any,
|
|
186
|
+
prompt_tokens: mx.array,
|
|
187
|
+
max_tokens: int = 50,
|
|
188
|
+
temperature: float = 0.7, # Lower default temperature for more stable generation
|
|
189
|
+
top_p: float = 0.9, # Lower top_p for more focused sampling
|
|
190
|
+
repetition_penalty: float = 1.1 # Add repetition penalty to prevent loops
|
|
191
|
+
) -> Tuple[mx.array, Dict[str, Any]]:
|
|
192
|
+
"""Generate response tokens with proper MLX-LM integration and EOS token support."""
|
|
193
|
+
if not HAS_MLX_LM:
|
|
194
|
+
return _simple_generate(model, prompt_tokens, max_tokens, temperature)
|
|
195
|
+
|
|
196
|
+
prompt_list = prompt_tokens.tolist()
|
|
197
|
+
|
|
198
|
+
# Use stream_generate instead of generate to get proper EOS token handling
|
|
199
|
+
# This is the core fix - stream_generate respects EOS tokens, generate() does not
|
|
200
|
+
try:
|
|
201
|
+
from mlx_lm import stream_generate
|
|
202
|
+
|
|
203
|
+
# EOS-safe sampling with reduced temperature for more predictable stopping
|
|
204
|
+
optimized_temperature = min(temperature, 0.7)
|
|
205
|
+
sampler = _make_eos_safe_sampler(optimized_temperature, top_p)
|
|
206
|
+
logits_processors = _make_logits_processors(repetition_penalty) if _make_logits_processors is not None else None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# Use stream_generate to get token-by-token generation with EOS detection
|
|
210
|
+
response_segments = list(stream_generate(
|
|
211
|
+
model=model,
|
|
212
|
+
tokenizer=tokenizer,
|
|
213
|
+
prompt=prompt_list, # type: ignore
|
|
214
|
+
max_tokens=max_tokens,
|
|
215
|
+
sampler=sampler,
|
|
216
|
+
logits_processors=logits_processors,
|
|
217
|
+
))
|
|
218
|
+
|
|
219
|
+
# Extract tokens from response segments and detect natural EOS stopping
|
|
220
|
+
response_token_list = []
|
|
221
|
+
|
|
222
|
+
for segment in response_segments:
|
|
223
|
+
response_token_list.append(segment.token)
|
|
224
|
+
# Check if this segment indicates natural stopping (EOS token)
|
|
225
|
+
if hasattr(segment, 'finish_reason') and segment.finish_reason == "stop":
|
|
226
|
+
break
|
|
227
|
+
|
|
228
|
+
# Convert to MLX array
|
|
229
|
+
response_tokens = mx.array(response_token_list) if response_token_list else mx.array([])
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
except ImportError:
|
|
233
|
+
# Fallback to original generate method if stream_generate unavailable
|
|
234
|
+
print("WARNING: stream_generate not available, using fallback generate method")
|
|
235
|
+
optimized_temperature = min(temperature, 0.7)
|
|
236
|
+
sampler = _make_eos_safe_sampler(optimized_temperature, top_p)
|
|
237
|
+
logits_processors = _make_logits_processors(repetition_penalty) if _make_logits_processors is not None else None
|
|
238
|
+
|
|
239
|
+
response = generate(
|
|
240
|
+
model=model,
|
|
241
|
+
tokenizer=tokenizer,
|
|
242
|
+
prompt=prompt_list, # type: ignore
|
|
243
|
+
max_tokens=max_tokens,
|
|
244
|
+
sampler=sampler,
|
|
245
|
+
logits_processors=logits_processors,
|
|
246
|
+
verbose=False,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
response_tokens = _extract_response_tokens(response, prompt_list, tokenizer)
|
|
250
|
+
|
|
251
|
+
# Compute logprobs for the response tokens
|
|
252
|
+
logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
|
|
253
|
+
return response_tokens, {'logprob': logprobs}
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _truncate_repetitive_text(text: str, max_repetitions: int = 3) -> str:
|
|
257
|
+
"""
|
|
258
|
+
Truncate text if it contains excessive repetitions.
|
|
259
|
+
|
|
260
|
+
This helps prevent the model from generating endless loops of the same tokens.
|
|
261
|
+
"""
|
|
262
|
+
words = text.split()
|
|
263
|
+
if len(words) < 4:
|
|
264
|
+
return text
|
|
265
|
+
|
|
266
|
+
# Check for word repetition
|
|
267
|
+
for i in range(len(words) - max_repetitions):
|
|
268
|
+
if len(set(words[i:i+max_repetitions])) == 1:
|
|
269
|
+
# Found repetition, truncate here
|
|
270
|
+
return ' '.join(words[:i])
|
|
271
|
+
|
|
272
|
+
# Check for character repetition (like "5555555")
|
|
273
|
+
for i in range(len(text) - max_repetitions):
|
|
274
|
+
if len(set(text[i:i+max_repetitions])) == 1:
|
|
275
|
+
# Found character repetition, truncate here
|
|
276
|
+
return text[:i]
|
|
277
|
+
|
|
278
|
+
return text
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _simple_generate(
|
|
282
|
+
model: nn.Module,
|
|
283
|
+
prompt_tokens: mx.array,
|
|
284
|
+
max_tokens: int,
|
|
285
|
+
temperature: float
|
|
286
|
+
) -> Tuple[mx.array, Dict[str, Any]]:
|
|
287
|
+
"""
|
|
288
|
+
Simple fallback generation for development without MLX-LM.
|
|
289
|
+
|
|
290
|
+
This provides basic autoregressive generation for testing when
|
|
291
|
+
MLX-LM is not available.
|
|
292
|
+
"""
|
|
293
|
+
current_tokens = prompt_tokens
|
|
294
|
+
generated = []
|
|
295
|
+
|
|
296
|
+
for _ in range(max_tokens):
|
|
297
|
+
# Model forward pass
|
|
298
|
+
logits = model(current_tokens[None]) # Add batch dimension
|
|
299
|
+
next_token_logits = logits[0, -1, :] # Last token logits
|
|
300
|
+
|
|
301
|
+
# Temperature scaling
|
|
302
|
+
if temperature > 0:
|
|
303
|
+
scaled_logits = next_token_logits / temperature
|
|
304
|
+
else:
|
|
305
|
+
scaled_logits = next_token_logits
|
|
306
|
+
|
|
307
|
+
# Sample next token
|
|
308
|
+
probs = mx.softmax(scaled_logits)
|
|
309
|
+
next_token = mx.random.categorical(probs[None])[0]
|
|
310
|
+
|
|
311
|
+
# Add to sequence
|
|
312
|
+
generated.append(next_token)
|
|
313
|
+
current_tokens = mx.concatenate([current_tokens, next_token[None]])
|
|
314
|
+
|
|
315
|
+
# Stop on EOS (approximate) - avoid .item() calls
|
|
316
|
+
if len(generated) > 5 and next_token < 5: # Simple stop condition
|
|
317
|
+
break
|
|
318
|
+
|
|
319
|
+
response_tokens = mx.array(generated) if generated else mx.array([2])
|
|
320
|
+
|
|
321
|
+
# Compute simple logprobs
|
|
322
|
+
logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
|
|
323
|
+
|
|
324
|
+
return response_tokens, {'logprob': logprobs}
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def compute_logprobs(
|
|
328
|
+
model: nn.Module,
|
|
329
|
+
prompt_tokens: mx.array,
|
|
330
|
+
response_tokens: mx.array
|
|
331
|
+
) -> mx.array:
|
|
332
|
+
"""
|
|
333
|
+
Extract log-probabilities of response_tokens under model via teacher-forcing.
|
|
334
|
+
Raises on dimension mismatch or invalid (nan/inf/positive) values.
|
|
335
|
+
"""
|
|
336
|
+
if len(response_tokens) == 0:
|
|
337
|
+
return mx.array([])
|
|
338
|
+
|
|
339
|
+
full_sequence = mx.concatenate([prompt_tokens, response_tokens])
|
|
340
|
+
model_input = full_sequence[None] if full_sequence.ndim == 1 else full_sequence
|
|
341
|
+
logits = model(model_input)
|
|
342
|
+
prompt_len, response_len = len(prompt_tokens), len(response_tokens)
|
|
343
|
+
prediction_logits = logits[0, prompt_len-1:prompt_len-1+response_len, :]
|
|
344
|
+
if prediction_logits.shape[0] != response_len:
|
|
345
|
+
raise ValueError(
|
|
346
|
+
f"Logits/tokens mismatch: {prediction_logits.shape[0]} vs {response_len}"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
log_probs = prediction_logits - mx.logsumexp(prediction_logits, axis=-1, keepdims=True)
|
|
350
|
+
selected = log_probs[mx.arange(response_len), response_tokens]
|
|
351
|
+
if mx.any(mx.isnan(selected)) or mx.any(mx.isinf(selected)):
|
|
352
|
+
raise ValueError("Invalid logprobs (nan/inf)")
|
|
353
|
+
if mx.any(selected > 0):
|
|
354
|
+
print("Warning: positive logprobs detected")
|
|
355
|
+
return selected
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def encode(tokenizer: Any, text: str) -> mx.array:
|
|
359
|
+
"""
|
|
360
|
+
Convert text to MLX token array.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
tokenizer: MLX tokenizer
|
|
364
|
+
text: Input text string
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
Token array as MLX array
|
|
368
|
+
"""
|
|
369
|
+
tokens = tokenizer.encode(text)
|
|
370
|
+
return mx.array(tokens, dtype=mx.int32)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def decode(tokenizer: Any, tokens: mx.array) -> str:
|
|
374
|
+
"""
|
|
375
|
+
Convert MLX token array to text.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
tokenizer: MLX tokenizer
|
|
379
|
+
tokens: Token array
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
Decoded text string
|
|
383
|
+
"""
|
|
384
|
+
token_list = tokens.tolist()
|
|
385
|
+
return tokenizer.decode(token_list)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def create_policy(
|
|
389
|
+
model: nn.Module,
|
|
390
|
+
tokenizer: Any,
|
|
391
|
+
generation_params: Optional[Dict[str, Any]] = None
|
|
392
|
+
) -> Callable[[mx.array], Tuple[mx.array, Dict[str, Any]]]:
|
|
393
|
+
"""
|
|
394
|
+
Create a policy function for RL training with automatic chat template support.
|
|
395
|
+
|
|
396
|
+
This returns a pure function that can be used by rollout systems
|
|
397
|
+
to generate responses and collect the data needed for training.
|
|
398
|
+
|
|
399
|
+
Automatically applies chat templates for instruction models
|
|
400
|
+
to enable proper EOS token generation and natural stopping behavior.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
model: MLX model
|
|
404
|
+
tokenizer: MLX tokenizer
|
|
405
|
+
generation_params: Generation parameters (max_tokens, temperature, etc.)
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
Policy function: (prompt_tokens) -> (response_tokens, info)
|
|
409
|
+
"""
|
|
410
|
+
params = generation_params or {}
|
|
411
|
+
max_tokens = params.get('max_tokens', 50)
|
|
412
|
+
temperature = params.get('temperature', 0.8)
|
|
413
|
+
top_p = params.get('top_p', 0.95)
|
|
414
|
+
|
|
415
|
+
def policy_fn(prompt_tokens: mx.array, deterministic: bool = False) -> Tuple[mx.array, Dict[str, Any]]:
|
|
416
|
+
"""
|
|
417
|
+
Policy function that generates responses for RL training with automatic chat template support.
|
|
418
|
+
|
|
419
|
+
Automatically applies chat templates for instruction models to enable
|
|
420
|
+
proper EOS token generation. This allows models to naturally end responses with
|
|
421
|
+
appropriate end-of-sequence tokens instead of being artificially truncated.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
prompt_tokens: Input prompt tokens
|
|
425
|
+
deterministic: Whether to use deterministic generation
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
(response_tokens, generation_info): Response and metadata for training
|
|
429
|
+
"""
|
|
430
|
+
# Auto-apply chat template for instruction models
|
|
431
|
+
processed_tokens = prompt_tokens
|
|
432
|
+
|
|
433
|
+
try:
|
|
434
|
+
# Decode tokens to check if chat template is needed
|
|
435
|
+
if hasattr(tokenizer, 'decode'):
|
|
436
|
+
raw_prompt = tokenizer.decode(prompt_tokens.tolist())
|
|
437
|
+
else:
|
|
438
|
+
# Fallback for tokenizers without decode method
|
|
439
|
+
raw_prompt = str(prompt_tokens.tolist())
|
|
440
|
+
|
|
441
|
+
# Let the tokenizer decide if chat template is needed
|
|
442
|
+
# This works for ANY instruction model (Qwen, Llama, Mistral, etc.)
|
|
443
|
+
needs_formatting = (
|
|
444
|
+
hasattr(tokenizer, 'apply_chat_template') and
|
|
445
|
+
# Only apply if not already formatted (avoid double-formatting)
|
|
446
|
+
not any(marker in raw_prompt for marker in ['<|im_start|>', '<|endoftext|>', '<|assistant|>'])
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if needs_formatting:
|
|
450
|
+
# Convert to messages format and apply chat template
|
|
451
|
+
# This uses the tokenizer's built-in knowledge of its own chat format
|
|
452
|
+
messages = [{"role": "user", "content": raw_prompt.strip()}]
|
|
453
|
+
formatted_prompt = tokenizer.apply_chat_template(
|
|
454
|
+
messages,
|
|
455
|
+
tokenize=False,
|
|
456
|
+
add_generation_prompt=True # Adds <|im_start|>assistant\n for response generation
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Re-encode with proper formatting for EOS generation
|
|
460
|
+
if hasattr(tokenizer, 'encode'):
|
|
461
|
+
processed_tokens = mx.array(tokenizer.encode(formatted_prompt))
|
|
462
|
+
else:
|
|
463
|
+
# Fallback if tokenizer doesn't have encode method
|
|
464
|
+
processed_tokens = prompt_tokens
|
|
465
|
+
|
|
466
|
+
# Debug logging (only in verbose mode to avoid noise)
|
|
467
|
+
if hasattr(tokenizer, 'verbose') and tokenizer.verbose:
|
|
468
|
+
print(f"Applied chat template: '{formatted_prompt[:100]}...'") # Show first 100 chars for debugging
|
|
469
|
+
|
|
470
|
+
except Exception:
|
|
471
|
+
# Fallback to original tokens if formatting fails
|
|
472
|
+
# This ensures robustness and backward compatibility
|
|
473
|
+
pass
|
|
474
|
+
|
|
475
|
+
# Generate response with processed tokens
|
|
476
|
+
temp = 0.0 if deterministic else temperature
|
|
477
|
+
return generate_tokens(
|
|
478
|
+
model=model,
|
|
479
|
+
tokenizer=tokenizer,
|
|
480
|
+
prompt_tokens=processed_tokens, # Use formatted tokens for proper EOS generation
|
|
481
|
+
max_tokens=max_tokens,
|
|
482
|
+
temperature=temp,
|
|
483
|
+
top_p=top_p
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
return policy_fn
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def compute_reward(
|
|
490
|
+
prompt: str,
|
|
491
|
+
response: str,
|
|
492
|
+
reward_type: str = "length",
|
|
493
|
+
**kwargs
|
|
494
|
+
) -> float:
|
|
495
|
+
"""
|
|
496
|
+
Simple reward computation for RL training.
|
|
497
|
+
|
|
498
|
+
This provides basic reward functions for testing. In practice,
|
|
499
|
+
you would use the sophisticated reward system from textpolicy.rewards.
|
|
500
|
+
|
|
501
|
+
Args:
|
|
502
|
+
prompt: Input prompt text
|
|
503
|
+
response: Generated response text
|
|
504
|
+
reward_type: Type of reward to compute
|
|
505
|
+
**kwargs: Additional parameters for reward computation
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
Reward score
|
|
509
|
+
"""
|
|
510
|
+
if reward_type == "length":
|
|
511
|
+
target_length = kwargs.get('target_length', 30)
|
|
512
|
+
actual_length = len(response.split())
|
|
513
|
+
# Simple length-based reward
|
|
514
|
+
diff = abs(actual_length - target_length)
|
|
515
|
+
return max(0.0, 1.0 - diff / target_length)
|
|
516
|
+
|
|
517
|
+
elif reward_type == "keyword":
|
|
518
|
+
keywords = kwargs.get('keywords', ['good', 'great', 'excellent'])
|
|
519
|
+
count = sum(1 for kw in keywords if kw.lower() in response.lower())
|
|
520
|
+
return count / len(keywords)
|
|
521
|
+
|
|
522
|
+
else:
|
|
523
|
+
# Default: simple response quality heuristic
|
|
524
|
+
if len(response.strip()) == 0:
|
|
525
|
+
return 0.0
|
|
526
|
+
if len(response.split()) < 5:
|
|
527
|
+
return 0.2
|
|
528
|
+
return 0.5
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
# Convenience function for complete setup
|
|
532
|
+
def create_setup(
|
|
533
|
+
model_path: str,
|
|
534
|
+
generation_params: Optional[Dict[str, Any]] = None,
|
|
535
|
+
adapter_path: Optional[str] = None
|
|
536
|
+
) -> Tuple[Callable, nn.Module, Any]:
|
|
537
|
+
"""
|
|
538
|
+
Complete setup for MLX-LM RL training.
|
|
539
|
+
|
|
540
|
+
This function combines model loading and policy creation for
|
|
541
|
+
convenient setup of RL training systems.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
model_path: Path or HuggingFace model ID
|
|
545
|
+
generation_params: Generation parameters
|
|
546
|
+
adapter_path: Optional LoRA adapter path
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
(policy_fn, model, tokenizer): Complete setup for RL training
|
|
550
|
+
"""
|
|
551
|
+
# Load model and tokenizer
|
|
552
|
+
model, tokenizer = load_model(model_path, adapter_path)
|
|
553
|
+
|
|
554
|
+
# Create policy function
|
|
555
|
+
policy_fn = create_policy(model, tokenizer, generation_params)
|
|
556
|
+
|
|
557
|
+
return policy_fn, model, tokenizer
|