textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,557 @@
1
+ # textpolicy/generation/mlx_generation.py
2
+ """
3
+ Complete MLX-LM text generation functions for RL training.
4
+
5
+ This module provides proper integration with MLX-LM for text generation RL,
6
+ including correct logprob extraction for policy gradient training.
7
+
8
+ Key functions:
9
+ - load_model: Load MLX model and tokenizer
10
+ - generate_tokens: Generate text with logprob tracking
11
+ - compute_logprobs: Extract logprobs for RL training
12
+ - create_policy: Create policy function for rollout collection
13
+ """
14
+
15
+ from __future__ import annotations
16
+ from typing import Dict, Optional, Tuple, Any, Callable
17
+ import mlx.core as mx
18
+ import mlx.nn as nn
19
+ try:
20
+ from mlx_lm import load, generate
21
+ HAS_MLX_LM = True
22
+ except ImportError:
23
+ HAS_MLX_LM = False
24
+ print("Warning: mlx_lm not found. Using fallback implementations.")
25
+
26
+ try:
27
+ from mlx_lm.sample_utils import make_sampler, make_logits_processors
28
+ # sampling utilities fallback when sample_utils is unavailable
29
+ except ImportError:
30
+ _make_sampler = None
31
+ _make_logits_processors = None
32
+
33
+
34
+ def _get_eos_configs_for_model(
35
+ model_path: str,
36
+ tokenizer_config: Optional[Dict]
37
+ ) -> Tuple[Optional[Dict], Dict[str, Any]]:
38
+ """
39
+ Determine tokenizer_config and model_config for proper EOS handling based on model type.
40
+ """
41
+ model_config: Dict[str, Any] = {}
42
+ if tokenizer_config is None and "Qwen" in model_path:
43
+ tokenizer_config = {}
44
+ if "Qwen" in model_path:
45
+ # For Qwen Instruct variants, let tokenizer.eos_token_id (<|im_end|>) prevail;
46
+ # override only for base Qwen to use <|endoftext|> (151643) as EOS.
47
+ if "Instruct" not in model_path:
48
+ eos_id = 151643
49
+ model_config["eos_token_id"] = eos_id
50
+ return tokenizer_config, model_config
51
+
52
+
53
+ def _prepare_tokenizer(tokenizer: Any, verbose: bool) -> None:
54
+ """
55
+ Configure tokenizer verbosity and ensure EOS token IDs for stopping.
56
+ """
57
+ tokenizer.verbose = verbose
58
+ # Force tokenizer's EOS to Qwen's natural <|endoftext|> when available
59
+ eos_id = getattr(tokenizer, 'eos_token_id', None)
60
+ if eos_id is not None:
61
+ # Override tokenizer.eos_token to match eos_token_id for natural stopping
62
+ tokenizer.eos_token_id = eos_id
63
+ tokenizer.eos_token = tokenizer.convert_ids_to_tokens(eos_id)
64
+ tokenizer.eos_token_ids = [eos_id]
65
+ # Align pad_token to EOS to ensure MLX-LM uses EOS for padding/stopping
66
+ tokenizer.pad_token_id = eos_id
67
+ tokenizer.pad_token = tokenizer.eos_token
68
+
69
+
70
+ def _make_eos_safe_sampler(temp: float, top_p: float) -> Any:
71
+ """
72
+ Build a sampler that does not prune low-probability tokens (e.g., EOS) and encourages natural stopping.
73
+ """
74
+ if make_sampler is not None:
75
+ # Use more conservative sampling parameters to encourage natural EOS generation
76
+ # Lower min_p and ensure we keep more tokens in consideration
77
+ return make_sampler(
78
+ temp=temp,
79
+ top_p=top_p,
80
+ min_p=0.0, # Don't filter out low-probability tokens like EOS
81
+ min_tokens_to_keep=2, # Keep at least 2 tokens to ensure EOS has a chance
82
+ )
83
+ else:
84
+ # Fallback implementation when mlx_lm.sample_utils is not available
85
+ return None
86
+
87
+
88
+ def _make_logits_processors(repetition_penalty: float) -> Any:
89
+ """
90
+ Create logits processors to enforce repetition penalty.
91
+ """
92
+ if make_logits_processors is not None:
93
+ return make_logits_processors(
94
+ repetition_penalty=repetition_penalty,
95
+ repetition_context_size=20,
96
+ )
97
+ else:
98
+ # Fallback implementation when mlx_lm.sample_utils is not available
99
+ return None
100
+
101
+
102
+ def _extract_response_tokens(
103
+ response: Any,
104
+ prompt_list: Any,
105
+ tokenizer: Any,
106
+ ) -> mx.array:
107
+ """
108
+ Extract response token IDs from a raw generation output (string or list).
109
+ Enhanced to better handle EOS tokens and edge cases.
110
+ """
111
+ if isinstance(response, list):
112
+ return mx.array(response)
113
+ try:
114
+ full_tokens = tokenizer.encode(response)
115
+ eos_id = getattr(tokenizer, 'eos_token_id', None)
116
+
117
+ # First, try to find EOS token and include it in response
118
+ if eos_id is not None and eos_id in full_tokens:
119
+ idx = full_tokens.index(eos_id)
120
+ # Include EOS token in response for proper reward calculation
121
+ resp = full_tokens[len(prompt_list): idx + 1]
122
+ else:
123
+ # No EOS found - extract response portion without EOS
124
+ try:
125
+ prompt_text = tokenizer.decode(prompt_list)
126
+ if response.startswith(prompt_text):
127
+ tail = response[len(prompt_text):]
128
+ else:
129
+ tail = response
130
+ resp = tokenizer.encode(tail.strip()) if tail.strip() else []
131
+ except Exception:
132
+ # Fallback: encode the whole response and hope for the best
133
+ resp = tokenizer.encode(response) if response else []
134
+
135
+ return mx.array(resp) if resp else mx.array([])
136
+ except Exception:
137
+ return mx.array([])
138
+
139
+
140
+ def load_model(
141
+ model_path: str,
142
+ adapter_path: Optional[str] = None,
143
+ tokenizer_config: Optional[Dict] = None,
144
+ verbose: bool = False
145
+ ) -> Tuple[nn.Module, Any]:
146
+ """
147
+ Load MLX model and tokenizer for RL training.
148
+
149
+ This function properly loads MLX-LM models with support for LoRA adapters
150
+ and ensures compatibility with our training system. Automatically configures
151
+ proper EOS tokens for Qwen models to ensure correct generation stopping.
152
+
153
+ Args:
154
+ model_path: Path or HuggingFace model ID
155
+ adapter_path: Optional LoRA adapter path
156
+ tokenizer_config: Optional tokenizer configuration for EOS tokens
157
+ verbose: Enable debug logging for chat template application
158
+
159
+ Returns:
160
+ (model, tokenizer): MLX model and tokenizer instances
161
+ """
162
+ if not HAS_MLX_LM:
163
+ raise ImportError("mlx_lm is required. Install with: pip install mlx-lm")
164
+
165
+ print(f"Loading MLX model: {model_path}")
166
+ if adapter_path:
167
+ print(f"Loading with LoRA adapters: {adapter_path}")
168
+
169
+ # Configure model & tokenizer for EOS handling based on model type
170
+ tokenizer_config, model_config = _get_eos_configs_for_model(model_path, tokenizer_config)
171
+ model, tokenizer = load(
172
+ path_or_hf_repo=model_path,
173
+ adapter_path=adapter_path,
174
+ tokenizer_config=tokenizer_config,
175
+ model_config=model_config,
176
+ lazy=False,
177
+ )
178
+ _prepare_tokenizer(tokenizer, verbose)
179
+ print("✓ Model loaded successfully")
180
+ return model, tokenizer
181
+
182
+
183
+ def generate_tokens(
184
+ model: nn.Module,
185
+ tokenizer: Any,
186
+ prompt_tokens: mx.array,
187
+ max_tokens: int = 50,
188
+ temperature: float = 0.7, # Lower default temperature for more stable generation
189
+ top_p: float = 0.9, # Lower top_p for more focused sampling
190
+ repetition_penalty: float = 1.1 # Add repetition penalty to prevent loops
191
+ ) -> Tuple[mx.array, Dict[str, Any]]:
192
+ """Generate response tokens with proper MLX-LM integration and EOS token support."""
193
+ if not HAS_MLX_LM:
194
+ return _simple_generate(model, prompt_tokens, max_tokens, temperature)
195
+
196
+ prompt_list = prompt_tokens.tolist()
197
+
198
+ # Use stream_generate instead of generate to get proper EOS token handling
199
+ # This is the core fix - stream_generate respects EOS tokens, generate() does not
200
+ try:
201
+ from mlx_lm import stream_generate
202
+
203
+ # EOS-safe sampling with reduced temperature for more predictable stopping
204
+ optimized_temperature = min(temperature, 0.7)
205
+ sampler = _make_eos_safe_sampler(optimized_temperature, top_p)
206
+ logits_processors = _make_logits_processors(repetition_penalty) if _make_logits_processors is not None else None
207
+
208
+
209
+ # Use stream_generate to get token-by-token generation with EOS detection
210
+ response_segments = list(stream_generate(
211
+ model=model,
212
+ tokenizer=tokenizer,
213
+ prompt=prompt_list, # type: ignore
214
+ max_tokens=max_tokens,
215
+ sampler=sampler,
216
+ logits_processors=logits_processors,
217
+ ))
218
+
219
+ # Extract tokens from response segments and detect natural EOS stopping
220
+ response_token_list = []
221
+
222
+ for segment in response_segments:
223
+ response_token_list.append(segment.token)
224
+ # Check if this segment indicates natural stopping (EOS token)
225
+ if hasattr(segment, 'finish_reason') and segment.finish_reason == "stop":
226
+ break
227
+
228
+ # Convert to MLX array
229
+ response_tokens = mx.array(response_token_list) if response_token_list else mx.array([])
230
+
231
+
232
+ except ImportError:
233
+ # Fallback to original generate method if stream_generate unavailable
234
+ print("WARNING: stream_generate not available, using fallback generate method")
235
+ optimized_temperature = min(temperature, 0.7)
236
+ sampler = _make_eos_safe_sampler(optimized_temperature, top_p)
237
+ logits_processors = _make_logits_processors(repetition_penalty) if _make_logits_processors is not None else None
238
+
239
+ response = generate(
240
+ model=model,
241
+ tokenizer=tokenizer,
242
+ prompt=prompt_list, # type: ignore
243
+ max_tokens=max_tokens,
244
+ sampler=sampler,
245
+ logits_processors=logits_processors,
246
+ verbose=False,
247
+ )
248
+
249
+ response_tokens = _extract_response_tokens(response, prompt_list, tokenizer)
250
+
251
+ # Compute logprobs for the response tokens
252
+ logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
253
+ return response_tokens, {'logprob': logprobs}
254
+
255
+
256
+ def _truncate_repetitive_text(text: str, max_repetitions: int = 3) -> str:
257
+ """
258
+ Truncate text if it contains excessive repetitions.
259
+
260
+ This helps prevent the model from generating endless loops of the same tokens.
261
+ """
262
+ words = text.split()
263
+ if len(words) < 4:
264
+ return text
265
+
266
+ # Check for word repetition
267
+ for i in range(len(words) - max_repetitions):
268
+ if len(set(words[i:i+max_repetitions])) == 1:
269
+ # Found repetition, truncate here
270
+ return ' '.join(words[:i])
271
+
272
+ # Check for character repetition (like "5555555")
273
+ for i in range(len(text) - max_repetitions):
274
+ if len(set(text[i:i+max_repetitions])) == 1:
275
+ # Found character repetition, truncate here
276
+ return text[:i]
277
+
278
+ return text
279
+
280
+
281
+ def _simple_generate(
282
+ model: nn.Module,
283
+ prompt_tokens: mx.array,
284
+ max_tokens: int,
285
+ temperature: float
286
+ ) -> Tuple[mx.array, Dict[str, Any]]:
287
+ """
288
+ Simple fallback generation for development without MLX-LM.
289
+
290
+ This provides basic autoregressive generation for testing when
291
+ MLX-LM is not available.
292
+ """
293
+ current_tokens = prompt_tokens
294
+ generated = []
295
+
296
+ for _ in range(max_tokens):
297
+ # Model forward pass
298
+ logits = model(current_tokens[None]) # Add batch dimension
299
+ next_token_logits = logits[0, -1, :] # Last token logits
300
+
301
+ # Temperature scaling
302
+ if temperature > 0:
303
+ scaled_logits = next_token_logits / temperature
304
+ else:
305
+ scaled_logits = next_token_logits
306
+
307
+ # Sample next token
308
+ probs = mx.softmax(scaled_logits)
309
+ next_token = mx.random.categorical(probs[None])[0]
310
+
311
+ # Add to sequence
312
+ generated.append(next_token)
313
+ current_tokens = mx.concatenate([current_tokens, next_token[None]])
314
+
315
+ # Stop on EOS (approximate) - avoid .item() calls
316
+ if len(generated) > 5 and next_token < 5: # Simple stop condition
317
+ break
318
+
319
+ response_tokens = mx.array(generated) if generated else mx.array([2])
320
+
321
+ # Compute simple logprobs
322
+ logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
323
+
324
+ return response_tokens, {'logprob': logprobs}
325
+
326
+
327
+ def compute_logprobs(
328
+ model: nn.Module,
329
+ prompt_tokens: mx.array,
330
+ response_tokens: mx.array
331
+ ) -> mx.array:
332
+ """
333
+ Extract log-probabilities of response_tokens under model via teacher-forcing.
334
+ Raises on dimension mismatch or invalid (nan/inf/positive) values.
335
+ """
336
+ if len(response_tokens) == 0:
337
+ return mx.array([])
338
+
339
+ full_sequence = mx.concatenate([prompt_tokens, response_tokens])
340
+ model_input = full_sequence[None] if full_sequence.ndim == 1 else full_sequence
341
+ logits = model(model_input)
342
+ prompt_len, response_len = len(prompt_tokens), len(response_tokens)
343
+ prediction_logits = logits[0, prompt_len-1:prompt_len-1+response_len, :]
344
+ if prediction_logits.shape[0] != response_len:
345
+ raise ValueError(
346
+ f"Logits/tokens mismatch: {prediction_logits.shape[0]} vs {response_len}"
347
+ )
348
+
349
+ log_probs = prediction_logits - mx.logsumexp(prediction_logits, axis=-1, keepdims=True)
350
+ selected = log_probs[mx.arange(response_len), response_tokens]
351
+ if mx.any(mx.isnan(selected)) or mx.any(mx.isinf(selected)):
352
+ raise ValueError("Invalid logprobs (nan/inf)")
353
+ if mx.any(selected > 0):
354
+ print("Warning: positive logprobs detected")
355
+ return selected
356
+
357
+
358
+ def encode(tokenizer: Any, text: str) -> mx.array:
359
+ """
360
+ Convert text to MLX token array.
361
+
362
+ Args:
363
+ tokenizer: MLX tokenizer
364
+ text: Input text string
365
+
366
+ Returns:
367
+ Token array as MLX array
368
+ """
369
+ tokens = tokenizer.encode(text)
370
+ return mx.array(tokens, dtype=mx.int32)
371
+
372
+
373
+ def decode(tokenizer: Any, tokens: mx.array) -> str:
374
+ """
375
+ Convert MLX token array to text.
376
+
377
+ Args:
378
+ tokenizer: MLX tokenizer
379
+ tokens: Token array
380
+
381
+ Returns:
382
+ Decoded text string
383
+ """
384
+ token_list = tokens.tolist()
385
+ return tokenizer.decode(token_list)
386
+
387
+
388
+ def create_policy(
389
+ model: nn.Module,
390
+ tokenizer: Any,
391
+ generation_params: Optional[Dict[str, Any]] = None
392
+ ) -> Callable[[mx.array], Tuple[mx.array, Dict[str, Any]]]:
393
+ """
394
+ Create a policy function for RL training with automatic chat template support.
395
+
396
+ This returns a pure function that can be used by rollout systems
397
+ to generate responses and collect the data needed for training.
398
+
399
+ Automatically applies chat templates for instruction models
400
+ to enable proper EOS token generation and natural stopping behavior.
401
+
402
+ Args:
403
+ model: MLX model
404
+ tokenizer: MLX tokenizer
405
+ generation_params: Generation parameters (max_tokens, temperature, etc.)
406
+
407
+ Returns:
408
+ Policy function: (prompt_tokens) -> (response_tokens, info)
409
+ """
410
+ params = generation_params or {}
411
+ max_tokens = params.get('max_tokens', 50)
412
+ temperature = params.get('temperature', 0.8)
413
+ top_p = params.get('top_p', 0.95)
414
+
415
+ def policy_fn(prompt_tokens: mx.array, deterministic: bool = False) -> Tuple[mx.array, Dict[str, Any]]:
416
+ """
417
+ Policy function that generates responses for RL training with automatic chat template support.
418
+
419
+ Automatically applies chat templates for instruction models to enable
420
+ proper EOS token generation. This allows models to naturally end responses with
421
+ appropriate end-of-sequence tokens instead of being artificially truncated.
422
+
423
+ Args:
424
+ prompt_tokens: Input prompt tokens
425
+ deterministic: Whether to use deterministic generation
426
+
427
+ Returns:
428
+ (response_tokens, generation_info): Response and metadata for training
429
+ """
430
+ # Auto-apply chat template for instruction models
431
+ processed_tokens = prompt_tokens
432
+
433
+ try:
434
+ # Decode tokens to check if chat template is needed
435
+ if hasattr(tokenizer, 'decode'):
436
+ raw_prompt = tokenizer.decode(prompt_tokens.tolist())
437
+ else:
438
+ # Fallback for tokenizers without decode method
439
+ raw_prompt = str(prompt_tokens.tolist())
440
+
441
+ # Let the tokenizer decide if chat template is needed
442
+ # This works for ANY instruction model (Qwen, Llama, Mistral, etc.)
443
+ needs_formatting = (
444
+ hasattr(tokenizer, 'apply_chat_template') and
445
+ # Only apply if not already formatted (avoid double-formatting)
446
+ not any(marker in raw_prompt for marker in ['<|im_start|>', '<|endoftext|>', '<|assistant|>'])
447
+ )
448
+
449
+ if needs_formatting:
450
+ # Convert to messages format and apply chat template
451
+ # This uses the tokenizer's built-in knowledge of its own chat format
452
+ messages = [{"role": "user", "content": raw_prompt.strip()}]
453
+ formatted_prompt = tokenizer.apply_chat_template(
454
+ messages,
455
+ tokenize=False,
456
+ add_generation_prompt=True # Adds <|im_start|>assistant\n for response generation
457
+ )
458
+
459
+ # Re-encode with proper formatting for EOS generation
460
+ if hasattr(tokenizer, 'encode'):
461
+ processed_tokens = mx.array(tokenizer.encode(formatted_prompt))
462
+ else:
463
+ # Fallback if tokenizer doesn't have encode method
464
+ processed_tokens = prompt_tokens
465
+
466
+ # Debug logging (only in verbose mode to avoid noise)
467
+ if hasattr(tokenizer, 'verbose') and tokenizer.verbose:
468
+ print(f"Applied chat template: '{formatted_prompt[:100]}...'") # Show first 100 chars for debugging
469
+
470
+ except Exception:
471
+ # Fallback to original tokens if formatting fails
472
+ # This ensures robustness and backward compatibility
473
+ pass
474
+
475
+ # Generate response with processed tokens
476
+ temp = 0.0 if deterministic else temperature
477
+ return generate_tokens(
478
+ model=model,
479
+ tokenizer=tokenizer,
480
+ prompt_tokens=processed_tokens, # Use formatted tokens for proper EOS generation
481
+ max_tokens=max_tokens,
482
+ temperature=temp,
483
+ top_p=top_p
484
+ )
485
+
486
+ return policy_fn
487
+
488
+
489
+ def compute_reward(
490
+ prompt: str,
491
+ response: str,
492
+ reward_type: str = "length",
493
+ **kwargs
494
+ ) -> float:
495
+ """
496
+ Simple reward computation for RL training.
497
+
498
+ This provides basic reward functions for testing. In practice,
499
+ you would use the sophisticated reward system from textpolicy.rewards.
500
+
501
+ Args:
502
+ prompt: Input prompt text
503
+ response: Generated response text
504
+ reward_type: Type of reward to compute
505
+ **kwargs: Additional parameters for reward computation
506
+
507
+ Returns:
508
+ Reward score
509
+ """
510
+ if reward_type == "length":
511
+ target_length = kwargs.get('target_length', 30)
512
+ actual_length = len(response.split())
513
+ # Simple length-based reward
514
+ diff = abs(actual_length - target_length)
515
+ return max(0.0, 1.0 - diff / target_length)
516
+
517
+ elif reward_type == "keyword":
518
+ keywords = kwargs.get('keywords', ['good', 'great', 'excellent'])
519
+ count = sum(1 for kw in keywords if kw.lower() in response.lower())
520
+ return count / len(keywords)
521
+
522
+ else:
523
+ # Default: simple response quality heuristic
524
+ if len(response.strip()) == 0:
525
+ return 0.0
526
+ if len(response.split()) < 5:
527
+ return 0.2
528
+ return 0.5
529
+
530
+
531
+ # Convenience function for complete setup
532
+ def create_setup(
533
+ model_path: str,
534
+ generation_params: Optional[Dict[str, Any]] = None,
535
+ adapter_path: Optional[str] = None
536
+ ) -> Tuple[Callable, nn.Module, Any]:
537
+ """
538
+ Complete setup for MLX-LM RL training.
539
+
540
+ This function combines model loading and policy creation for
541
+ convenient setup of RL training systems.
542
+
543
+ Args:
544
+ model_path: Path or HuggingFace model ID
545
+ generation_params: Generation parameters
546
+ adapter_path: Optional LoRA adapter path
547
+
548
+ Returns:
549
+ (policy_fn, model, tokenizer): Complete setup for RL training
550
+ """
551
+ # Load model and tokenizer
552
+ model, tokenizer = load_model(model_path, adapter_path)
553
+
554
+ # Create policy function
555
+ policy_fn = create_policy(model, tokenizer, generation_params)
556
+
557
+ return policy_fn, model, tokenizer