textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +53 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +797 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.1.dist-info/METADATA +109 -0
  62. textpolicy-0.1.1.dist-info/RECORD +66 -0
  63. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
  64. textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
  65. textpolicy-0.0.1.dist-info/METADATA +0 -10
  66. textpolicy-0.0.1.dist-info/RECORD +0 -6
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,797 @@
1
+ # textpolicy/environment/text_generation.py
2
+ """
3
+ Text Generation Environment for Testing MLX RL Training.
4
+
5
+ This environment provides measurable text generation tasks to validate that
6
+ models are actually learning through RL training, not just going through motions.
7
+
8
+ Key features:
9
+ - Consistent, reproducible text generation tasks
10
+ - Before/after learning validation metrics
11
+ - Integration with MLX generation system
12
+ - Support for various text generation benchmarks
13
+ """
14
+
15
+ from typing import Dict, List, Optional, Tuple, Any, Callable
16
+ import mlx.core as mx
17
+ import random
18
+ from dataclasses import dataclass
19
+ from .base import Environment
20
+ from .task_suites import register_task_suite, get_task_suite
21
+
22
+ # Import our generation functions
23
+ from ..generation.mlx_generation import encode, decode, generate_tokens
24
+
25
+
26
+ @dataclass
27
+ class TextGenerationTask:
28
+ """A single text generation task with validation criteria."""
29
+ prompt: str
30
+ target_keywords: List[str]
31
+ target_length_range: Tuple[int, int] # (min_words, max_words)
32
+ difficulty: float # 0.0 to 1.0
33
+ category: str
34
+ evaluation_criteria: Dict[str, Any]
35
+
36
+
37
+ # Default task suites for registration and internal use.
38
+ def _default_basic_tasks() -> List[TextGenerationTask]:
39
+ return [
40
+ TextGenerationTask(
41
+ prompt="Write a brief explanation of machine learning.",
42
+ target_keywords=["algorithm", "data", "learn"],
43
+ target_length_range=(20, 40),
44
+ difficulty=0.3,
45
+ category="length_control",
46
+ evaluation_criteria={"keyword_weight": 0.4, "length_weight": 0.6}
47
+ ),
48
+ TextGenerationTask(
49
+ prompt="Describe the benefits of renewable energy in one paragraph.",
50
+ target_keywords=["environment", "sustainable", "clean"],
51
+ target_length_range=(30, 50),
52
+ difficulty=0.4,
53
+ category="length_control",
54
+ evaluation_criteria={"keyword_weight": 0.5, "length_weight": 0.5}
55
+ ),
56
+ TextGenerationTask(
57
+ prompt="Explain how computers work.",
58
+ target_keywords=["processor", "memory", "software", "hardware"],
59
+ target_length_range=(25, 45),
60
+ difficulty=0.5,
61
+ category="keyword_inclusion",
62
+ evaluation_criteria={"keyword_weight": 0.7, "length_weight": 0.3}
63
+ ),
64
+ TextGenerationTask(
65
+ prompt="Write about the importance of education.",
66
+ target_keywords=["knowledge", "skills", "future", "learning"],
67
+ target_length_range=(20, 40),
68
+ difficulty=0.4,
69
+ category="keyword_inclusion",
70
+ evaluation_criteria={"keyword_weight": 0.6, "length_weight": 0.4}
71
+ ),
72
+ TextGenerationTask(
73
+ prompt="Explain the process of photosynthesis step by step.",
74
+ target_keywords=["sunlight", "carbon", "oxygen", "glucose"],
75
+ target_length_range=(35, 60),
76
+ difficulty=0.6,
77
+ category="coherence",
78
+ evaluation_criteria={"keyword_weight": 0.3, "length_weight": 0.3, "coherence_weight": 0.4}
79
+ ),
80
+ ]
81
+
82
+
83
+ def _default_challenging_tasks() -> List[TextGenerationTask]:
84
+ return [
85
+ TextGenerationTask(
86
+ prompt="Compare and contrast neural networks and traditional algorithms.",
87
+ target_keywords=["pattern", "weights", "training", "classification", "regression"],
88
+ target_length_range=(50, 80),
89
+ difficulty=0.8,
90
+ category="comparison",
91
+ evaluation_criteria={"keyword_weight": 0.4, "length_weight": 0.3, "coherence_weight": 0.3}
92
+ ),
93
+ TextGenerationTask(
94
+ prompt="Analyze the ethical implications of artificial intelligence.",
95
+ target_keywords=["bias", "privacy", "autonomy", "responsibility", "society"],
96
+ target_length_range=(60, 100),
97
+ difficulty=0.9,
98
+ category="analysis",
99
+ evaluation_criteria={"keyword_weight": 0.3, "length_weight": 0.2, "coherence_weight": 0.5}
100
+ ),
101
+ ]
102
+
103
+
104
+ # Register defaults at import time to make them discoverable via the registry.
105
+ register_task_suite("basic", _default_basic_tasks)
106
+ register_task_suite("challenging", _default_challenging_tasks)
107
+
108
+
109
+ class TextGenerationEnvironment(Environment):
110
+ """
111
+ Environment for testing text generation learning with MLX models.
112
+
113
+ This environment provides a suite of text generation tasks that allow
114
+ measuring model improvement through RL training. It integrates directly
115
+ with our MLX generation system and reward functions.
116
+
117
+ Key validation approach:
118
+ 1. Pre-training baseline: Measure model performance on task suite
119
+ 2. Post-training comparison: Measure same model after RL training
120
+ 3. Learning validation: Prove statistically significant improvement
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ model: Any,
126
+ tokenizer: Any,
127
+ task_suite: str = "basic",
128
+ num_episodes: int = 50,
129
+ generation_params: Optional[Dict[str, Any]] = None,
130
+ seed: int = 42
131
+ ):
132
+ """
133
+ Initialize text generation testing environment.
134
+
135
+ Args:
136
+ model: MLX model for text generation
137
+ tokenizer: MLX tokenizer
138
+ task_suite: Which task suite to use ("basic", "challenging", "custom")
139
+ num_episodes: Number of episodes per evaluation
140
+ generation_params: Parameters for text generation
141
+ seed: Random seed for reproducible evaluation
142
+ """
143
+ super().__init__()
144
+
145
+ # Validate critical dependencies early to produce clear, actionable errors.
146
+ # This environment integrates directly with MLX generation; both model and
147
+ # tokenizer are required. Without these, encode/decode/generate would fail
148
+ # later with obscure attribute errors.
149
+ if model is None:
150
+ raise ValueError("TextGenerationEnvironment requires a valid MLX model (got None)")
151
+ if tokenizer is None:
152
+ raise ValueError("TextGenerationEnvironment requires a valid tokenizer (got None)")
153
+
154
+ self.model = model
155
+ self.tokenizer = tokenizer
156
+ self.generation_params = generation_params or {
157
+ 'max_tokens': 50,
158
+ 'temperature': 0.8,
159
+ 'top_p': 0.95
160
+ }
161
+
162
+ # Create task suite for evaluation
163
+ self.tasks = self._create_task_suite(task_suite)
164
+ self.task_suite = task_suite # remember suite type for cloning
165
+ self.num_episodes = num_episodes
166
+ self.current_episode = 0
167
+ self.current_task = None
168
+
169
+ # Performance tracking for learning validation
170
+ self.baseline_scores = []
171
+ self.current_scores = []
172
+
173
+ # Environment state
174
+ random.seed(seed)
175
+ self._episode_data = []
176
+
177
+ # Initialization complete; environment ready for evaluation
178
+ # (Debug prints removed for production efficiency)
179
+
180
+ def _create_task_suite(self, suite_type: str) -> List[TextGenerationTask]:
181
+ """
182
+ Create a suite of text generation tasks for evaluation.
183
+
184
+ These tasks are designed to be:
185
+ - Measurable: Clear success criteria
186
+ - Diverse: Cover different generation challenges
187
+ - Reproducible: Same tasks for before/after comparison
188
+ """
189
+ # Prefer registry-based suites when available; fall back to defaults here.
190
+ # First, try registry-based loader (see environment.task_suites).
191
+ # This enables custom suites without hardcoding here.
192
+ registered = get_task_suite(suite_type)
193
+ if registered is not None:
194
+ return registered
195
+
196
+ if suite_type == "basic":
197
+ return _default_basic_tasks()
198
+
199
+ elif suite_type == "challenging":
200
+ return _default_challenging_tasks()
201
+
202
+ else: # custom or fallback
203
+ return [
204
+ TextGenerationTask(
205
+ prompt="Tell me about your favorite topic.",
206
+ target_keywords=["interesting", "because", "example"],
207
+ target_length_range=(15, 35),
208
+ difficulty=0.2,
209
+ category="open_ended",
210
+ evaluation_criteria={"keyword_weight": 0.5, "length_weight": 0.5}
211
+ )
212
+ ]
213
+
214
+ # Register default suites in the registry to enable external access (list/get).
215
+ # Done here to avoid import cycles: the loader closures capture TextGenerationTask.
216
+ def _default_basic_tasks() -> List[TextGenerationTask]:
217
+ return [
218
+ TextGenerationTask(
219
+ prompt="Write a brief explanation of machine learning.",
220
+ target_keywords=["algorithm", "data", "learn"],
221
+ target_length_range=(20, 40),
222
+ difficulty=0.3,
223
+ category="length_control",
224
+ evaluation_criteria={"keyword_weight": 0.4, "length_weight": 0.6}
225
+ ),
226
+ TextGenerationTask(
227
+ prompt="Describe the benefits of renewable energy in one paragraph.",
228
+ target_keywords=["environment", "sustainable", "clean"],
229
+ target_length_range=(30, 50),
230
+ difficulty=0.4,
231
+ category="length_control",
232
+ evaluation_criteria={"keyword_weight": 0.5, "length_weight": 0.5}
233
+ ),
234
+ TextGenerationTask(
235
+ prompt="Explain how computers work.",
236
+ target_keywords=["processor", "memory", "software", "hardware"],
237
+ target_length_range=(25, 45),
238
+ difficulty=0.5,
239
+ category="keyword_inclusion",
240
+ evaluation_criteria={"keyword_weight": 0.7, "length_weight": 0.3}
241
+ ),
242
+ TextGenerationTask(
243
+ prompt="Write about the importance of education.",
244
+ target_keywords=["knowledge", "skills", "future", "learning"],
245
+ target_length_range=(20, 40),
246
+ difficulty=0.4,
247
+ category="keyword_inclusion",
248
+ evaluation_criteria={"keyword_weight": 0.6, "length_weight": 0.4}
249
+ ),
250
+ TextGenerationTask(
251
+ prompt="Explain the process of photosynthesis step by step.",
252
+ target_keywords=["sunlight", "carbon", "oxygen", "glucose"],
253
+ target_length_range=(35, 60),
254
+ difficulty=0.6,
255
+ category="coherence",
256
+ evaluation_criteria={"keyword_weight": 0.3, "length_weight": 0.3, "coherence_weight": 0.4}
257
+ ),
258
+ ]
259
+
260
+
261
+ def reset(self) -> Tuple[Any, Dict[str, Any]]:
262
+ """
263
+ Reset environment to start a new episode.
264
+
265
+ Returns:
266
+ (observation, info): Initial prompt and episode metadata
267
+ """
268
+ # Select next task (cycle through tasks)
269
+ task_index = self.current_episode % len(self.tasks)
270
+ self.current_task = self.tasks[task_index]
271
+
272
+ # Reset episode state
273
+ self._episode_data = {
274
+ 'prompt': self.current_task.prompt,
275
+ 'task': self.current_task,
276
+ 'responses': [],
277
+ 'scores': []
278
+ }
279
+
280
+ # Return initial observation (the prompt to generate from)
281
+ observation = encode(self.tokenizer, self.current_task.prompt)
282
+
283
+ info = {
284
+ 'episode': self.current_episode,
285
+ 'task_category': self.current_task.category,
286
+ 'difficulty': self.current_task.difficulty,
287
+ 'target_keywords': self.current_task.target_keywords,
288
+ 'target_length_range': self.current_task.target_length_range
289
+ }
290
+
291
+ return observation, info
292
+
293
+ def step(self, action: Any) -> Dict[str, Any]:
294
+ """
295
+ Take a step in the environment by generating text response.
296
+
297
+ Args:
298
+ action: Generated response tokens (MLX array)
299
+
300
+ Returns:
301
+ Step result with observation, reward, termination status, and info
302
+ """
303
+ if self.current_task is None:
304
+ raise ValueError("Environment not reset - call reset() first")
305
+
306
+ # Decode response from tokens
307
+ if hasattr(action, 'tolist'):
308
+ # Action is MLX array of tokens
309
+ response_text = decode(self.tokenizer, action)
310
+ else:
311
+ # Action might already be text
312
+ response_text = str(action)
313
+
314
+ # Compute reward using our reward system
315
+ reward_score = self._evaluate_response(
316
+ prompt=self.current_task.prompt,
317
+ response=response_text,
318
+ task=self.current_task
319
+ )
320
+
321
+ # Store episode data for analysis
322
+ self._episode_data['responses'].append(response_text)
323
+ self._episode_data['scores'].append(reward_score)
324
+
325
+ # Episode terminates after each generation (single-turn tasks)
326
+ terminated = True
327
+ truncated = False
328
+
329
+ # Prepare next observation (empty since episode ended)
330
+ next_observation = mx.array([])
331
+
332
+ info = {
333
+ 'response': response_text,
334
+ 'reward_score': reward_score,
335
+ 'task_category': self.current_task.category,
336
+ 'target_keywords_found': [kw for kw in self.current_task.target_keywords
337
+ if kw.lower() in response_text.lower()],
338
+ 'response_length': len(response_text.split()),
339
+ 'target_length_range': self.current_task.target_length_range
340
+ }
341
+
342
+ # Move to next episode
343
+ self.current_episode += 1
344
+
345
+ return {
346
+ 'observation': next_observation,
347
+ 'reward': reward_score,
348
+ 'terminated': terminated,
349
+ 'truncated': truncated,
350
+ 'info': info
351
+ }
352
+
353
+ def _evaluate_response(self, prompt: str, response: str, task: TextGenerationTask) -> float:
354
+ """
355
+ Evaluate response quality using task-specific criteria.
356
+
357
+ This function integrates with our reward system to provide
358
+ consistent, measurable evaluation of text generation quality.
359
+ """
360
+ criteria = task.evaluation_criteria
361
+ total_score = 0.0
362
+
363
+ # Length-based scoring
364
+ if 'length_weight' in criteria:
365
+ word_count = len(response.split())
366
+ min_len, max_len = task.target_length_range
367
+ target_len = (min_len + max_len) / 2
368
+
369
+ # Score based on proximity to target length
370
+ if min_len <= word_count <= max_len:
371
+ length_score = 1.0
372
+ else:
373
+ # Penalty for being outside range
374
+ distance = min(abs(word_count - min_len), abs(word_count - max_len))
375
+ length_score = max(0.0, 1.0 - distance / target_len)
376
+
377
+ total_score += criteria['length_weight'] * length_score
378
+
379
+ # Keyword inclusion scoring
380
+ if 'keyword_weight' in criteria:
381
+ keywords_found = sum(1 for kw in task.target_keywords
382
+ if kw.lower() in response.lower())
383
+ keyword_score = keywords_found / len(task.target_keywords)
384
+ total_score += criteria['keyword_weight'] * keyword_score
385
+
386
+ # Coherence scoring (simple heuristic)
387
+ if 'coherence_weight' in criteria:
388
+ # Use our existing coherence evaluation
389
+ coherence_score = self._simple_coherence_score(response)
390
+ total_score += criteria['coherence_weight'] * coherence_score
391
+
392
+ return total_score
393
+
394
+ def _simple_coherence_score(self, text: str) -> float:
395
+ """Simple coherence scoring based on structure indicators."""
396
+ if not text.strip():
397
+ return 0.0
398
+
399
+ # Basic coherence indicators
400
+ sentences = [s.strip() for s in text.split('.') if s.strip()]
401
+ if len(sentences) < 2:
402
+ return 0.5 # Single sentence is moderately coherent
403
+
404
+ # Look for logical connectors
405
+ connectors = ['therefore', 'however', 'moreover', 'furthermore', 'because', 'since']
406
+ connector_count = sum(1 for conn in connectors if conn in text.lower())
407
+
408
+ # Coherence score based on structure
409
+ connector_score = min(1.0, connector_count / 2.0) # 2+ connectors is good
410
+ sentence_score = min(1.0, len(sentences) / 3.0) # 3+ sentences is good
411
+
412
+ return (connector_score + sentence_score) / 2.0
413
+
414
+ def evaluate_model(self, mode: str = "current") -> Dict[str, float]:
415
+ """
416
+ Evaluate model performance on the full task suite.
417
+
418
+ This function runs through all tasks and computes aggregate
419
+ performance metrics to measure learning progress.
420
+
421
+ Args:
422
+ mode: "baseline" (store baseline) or "current" (compare to baseline)
423
+
424
+ Returns:
425
+ Performance metrics dictionary
426
+ """
427
+ print(f"Running {mode} evaluation on {self.num_episodes} episodes...")
428
+
429
+ all_scores = []
430
+ category_scores = {}
431
+
432
+ # Reset episode counter for evaluation
433
+ original_episode = self.current_episode
434
+ self.current_episode = 0
435
+
436
+ try:
437
+ for episode in range(self.num_episodes):
438
+ # Reset environment
439
+ observation, info = self.reset()
440
+
441
+ # Generate response using current model
442
+ response_tokens, generation_info = generate_tokens(
443
+ model=self.model,
444
+ tokenizer=self.tokenizer,
445
+ prompt_tokens=observation,
446
+ **self.generation_params
447
+ )
448
+
449
+ # Take step to get reward
450
+ step_result = self.step(response_tokens)
451
+ score = step_result['reward']
452
+ category = step_result['info']['task_category']
453
+
454
+ all_scores.append(score)
455
+ if category not in category_scores:
456
+ category_scores[category] = []
457
+ category_scores[category].append(score)
458
+
459
+ finally:
460
+ # Restore episode counter
461
+ self.current_episode = original_episode
462
+
463
+ # Compute aggregate metrics
464
+ mean_score = float(mx.mean(mx.array(all_scores)))
465
+ std_score = float(mx.std(mx.array(all_scores)))
466
+
467
+ metrics = {
468
+ 'mean_score': mean_score,
469
+ 'std_score': std_score,
470
+ 'num_episodes': self.num_episodes,
471
+ 'category_breakdown': {
472
+ cat: float(mx.mean(mx.array(scores)))
473
+ for cat, scores in category_scores.items()
474
+ }
475
+ }
476
+
477
+ # Store results based on mode
478
+ if mode == "baseline":
479
+ self.baseline_scores = all_scores
480
+ print(f"✓ Baseline evaluation complete: {mean_score:.3f} ± {std_score:.3f}")
481
+ else:
482
+ self.current_scores = all_scores
483
+
484
+ # Compute learning improvement if we have baseline
485
+ if self.baseline_scores:
486
+ baseline_mean = float(mx.mean(mx.array(self.baseline_scores)))
487
+ improvement = mean_score - baseline_mean
488
+ improvement_pct = (improvement / baseline_mean) * 100 if baseline_mean > 0 else 0
489
+
490
+ metrics['baseline_score'] = baseline_mean
491
+ metrics['improvement'] = improvement
492
+ metrics['improvement_percent'] = improvement_pct
493
+
494
+ print(f"Current evaluation complete: {mean_score:.3f} ± {std_score:.3f}")
495
+ print(f" Improvement: {improvement:+.3f} ({improvement_pct:+.1f}%)")
496
+
497
+ # Statistical significance test (simple)
498
+ if improvement > 2 * std_score: # Rough 2-sigma test
499
+ print(" LEARNING DETECTED: Statistically significant improvement!")
500
+ else:
501
+ print(" Learning uncertain: Improvement not statistically significant")
502
+
503
+ return metrics
504
+
505
+ @property
506
+ def observation_space(self) -> Any:
507
+ """Observation space is tokenized text (variable length)."""
508
+ return "TokenizedText" # Placeholder - MLX doesn't need gym spaces
509
+
510
+ @property
511
+ def action_space(self) -> Any:
512
+ """Action space is generated text tokens (variable length)."""
513
+ return "GeneratedTokens" # Placeholder - MLX doesn't need gym spaces
514
+
515
+ def clone(self) -> 'TextGenerationEnvironment':
516
+ """Create a multiprocessing clone with the same configuration.
517
+
518
+ This returns a new environment instance that references the same
519
+ model/tokenizer objects. On some systems, MLX models are not picklable;
520
+ for process spawning, prefer passing an environment factory (env_fn)
521
+ so model/tokenizer can be constructed in each process. See rollout.coordinator.
522
+ """
523
+ # Delegate to the same constructor with preserved parameters
524
+ return TextGenerationEnvironment(
525
+ model=self.model,
526
+ tokenizer=self.tokenizer,
527
+ task_suite=self.task_suite,
528
+ num_episodes=self.num_episodes,
529
+ generation_params=self.generation_params,
530
+ seed=random.randint(0, 10000)
531
+ )
532
+
533
+
534
+ def create_text_generation_test_env(
535
+ model: Any,
536
+ tokenizer: Any,
537
+ task_suite: str = "basic",
538
+ num_episodes: int = 50,
539
+ **kwargs
540
+ ) -> TextGenerationEnvironment:
541
+ """
542
+ Factory function to create a text generation testing environment.
543
+
544
+ This is the main entry point for creating environments to test
545
+ whether RL training actually improves model performance.
546
+
547
+ Args:
548
+ model: MLX model for text generation
549
+ tokenizer: MLX tokenizer
550
+ task_suite: Which task suite to use for evaluation
551
+ num_episodes: Number of episodes per evaluation
552
+ **kwargs: Additional environment parameters
553
+
554
+ Returns:
555
+ Configured TextGenerationEnvironment ready for testing
556
+ """
557
+ return TextGenerationEnvironment(
558
+ model=model,
559
+ tokenizer=tokenizer,
560
+ task_suite=task_suite,
561
+ num_episodes=num_episodes,
562
+ **kwargs
563
+ )
564
+
565
+
566
+ def validate_learning_progress(
567
+ env: TextGenerationEnvironment,
568
+ pre_training_metrics: Dict[str, float],
569
+ post_training_metrics: Dict[str, float]
570
+ ) -> Dict[str, Any]:
571
+ """
572
+ Pure function to validate that learning actually occurred.
573
+
574
+ This function provides statistical analysis to prove that
575
+ RL training resulted in measurable improvement.
576
+
577
+ Args:
578
+ env: The environment used for testing
579
+ pre_training_metrics: Metrics before training
580
+ post_training_metrics: Metrics after training
581
+
582
+ Returns:
583
+ Learning validation report
584
+ """
585
+ improvement = post_training_metrics['mean_score'] - pre_training_metrics['mean_score']
586
+ improvement_pct = (improvement / pre_training_metrics['mean_score']) * 100
587
+
588
+ # Simple statistical significance test
589
+ pre_std = pre_training_metrics['std_score']
590
+ post_std = post_training_metrics['std_score']
591
+ pooled_std = (pre_std + post_std) / 2
592
+
593
+ significance_threshold = 2 * pooled_std # Rough 2-sigma test
594
+ is_significant = abs(improvement) > significance_threshold
595
+
596
+ validation_report = {
597
+ 'learning_detected': improvement > 0 and is_significant,
598
+ 'improvement_score': improvement,
599
+ 'improvement_percent': improvement_pct,
600
+ 'statistical_significance': is_significant,
601
+ 'significance_threshold': significance_threshold,
602
+ 'pre_training_score': pre_training_metrics['mean_score'],
603
+ 'post_training_score': post_training_metrics['mean_score'],
604
+ 'recommendation': (
605
+ "LEARNING CONFIRMED: Model shows statistically significant improvement"
606
+ if improvement > 0 and is_significant
607
+ else "LEARNING UNCERTAIN: No significant improvement detected"
608
+ )
609
+ }
610
+
611
+ return validation_report
612
+
613
+
614
+ class TextGenerationEnv(Environment):
615
+ """
616
+ Simple text generation environment for RL training.
617
+
618
+ This is a lightweight wrapper around TextGenerationEnvironment that provides
619
+ the simple interface expected by training examples. It's designed for:
620
+ - Simple prompt-based training tasks
621
+ - External reward function integration
622
+ - Basic RL training workflows
623
+
624
+ For comprehensive testing and validation, use TextGenerationEnvironment instead.
625
+ """
626
+
627
+ def __init__(
628
+ self,
629
+ prompts: List[str],
630
+ reward_fn: Callable[[str, str, dict], float],
631
+ max_tokens: int = 25,
632
+ seed: int = 42,
633
+ tokenizer: Any = None,
634
+ examples: Optional[List[dict]] = None
635
+ ):
636
+ """
637
+ Initialize simple text generation environment.
638
+
639
+ Args:
640
+ prompts: List of prompts to cycle through
641
+ reward_fn: Function that computes reward from (prompt, completion, example)
642
+ max_tokens: Maximum tokens to generate per response
643
+ seed: Random seed for reproducible behavior
644
+ tokenizer: Tokenizer for converting prompts to tokens (required for MLX compatibility)
645
+ examples: Optional list of example dicts to pass to reward function. If provided,
646
+ must have same length as prompts. examples[i] is passed when prompts[i] is used.
647
+ """
648
+ super().__init__()
649
+
650
+ if tokenizer is None:
651
+ raise ValueError("tokenizer is required for TextGenerationEnv to work with MLX rollout system")
652
+
653
+ if examples is not None and len(examples) != len(prompts):
654
+ raise ValueError(f"examples length ({len(examples)}) must match prompts length ({len(prompts)})")
655
+
656
+ self.prompts = prompts
657
+ self.examples = examples if examples is not None else [{} for _ in prompts]
658
+ self.reward_fn = reward_fn
659
+ self.max_tokens = max_tokens
660
+ self.tokenizer = tokenizer
661
+ self.current_episode = 0
662
+ self.current_prompt = None
663
+
664
+ # Environment state
665
+ random.seed(seed)
666
+
667
+ # Debug prints removed for production efficiency
668
+
669
+ def reset(self) -> Tuple[Any, Dict[str, Any]]:
670
+ """
671
+ Reset environment to start a new episode.
672
+
673
+ Returns:
674
+ (observation, info): Current prompt tokens and episode metadata
675
+ """
676
+ # Cycle through prompts
677
+ prompt_index = self.current_episode % len(self.prompts)
678
+ self.current_prompt = self.prompts[prompt_index]
679
+
680
+ # Tokenize prompt for MLX compatibility
681
+ # Import encode function from mlx_generation to avoid circular imports
682
+ from ..generation.mlx_generation import encode
683
+ observation = encode(self.tokenizer, self.current_prompt)
684
+
685
+ info = {
686
+ 'episode': self.current_episode,
687
+ 'prompt_index': prompt_index,
688
+ 'max_tokens': self.max_tokens,
689
+ 'prompt_text': self.current_prompt # Keep original text for reward computation
690
+ }
691
+
692
+ return observation, info
693
+
694
+ def step(self, action: Any) -> Dict[str, Any]:
695
+ """
696
+ Take a step in the environment by evaluating generated text.
697
+
698
+ Args:
699
+ action: Generated text response (string or token array)
700
+
701
+ Returns:
702
+ Dictionary with keys: observation, reward, terminated, truncated, info.
703
+ This matches the Environment base class contract. The rollout runner
704
+ normalizes both dict and tuple returns, so returning a dict here keeps
705
+ interfaces consistent and compatible with rollouts.
706
+ """
707
+ if self.current_prompt is None:
708
+ raise ValueError("Environment not reset - call reset() first")
709
+
710
+ # Handle different action types - properly decode token arrays to text
711
+ if hasattr(action, 'tolist'):
712
+ # Action is MLX array of tokens - decode to text using tokenizer
713
+ try:
714
+ from ..generation.mlx_generation import decode
715
+ response_text = decode(self.tokenizer, action)
716
+ except Exception as e:
717
+ print(f"WARNING: Failed to decode MLX action array: {e}")
718
+ # Fallback: try to handle as raw tokens
719
+ try:
720
+ response_text = self.tokenizer.decode(action.tolist())
721
+ except Exception as e2:
722
+ print(f"WARNING: Fallback decode also failed: {e2}")
723
+ response_text = "Generated response (decode failed)"
724
+ elif isinstance(action, list) and len(action) > 0 and isinstance(action[0], (int, float)):
725
+ # Action is a Python list of token IDs - decode to text
726
+ try:
727
+ response_text = self.tokenizer.decode(action)
728
+ except Exception as e:
729
+ print(f"WARNING: Failed to decode token list: {e}")
730
+ response_text = "Generated response (decode failed)"
731
+ else:
732
+ # Action is already text or something else
733
+ response_text = str(action)
734
+
735
+ # Detect if response was truncated by max_tokens limit
736
+ # This happens when the generation hits the token limit before naturally ending
737
+ response_tokens = len(response_text.split()) if response_text else 0
738
+ truncated = response_tokens >= (self.max_tokens * 0.95) # Consider 95% of limit as likely truncated
739
+
740
+ # Episode terminates after each generation (single-turn tasks)
741
+ terminated = True
742
+
743
+ # Compute reward using provided reward function
744
+ # Pass tokenizer for EOS token detection and truncation detection
745
+ prompt_index = self.current_episode % len(self.prompts)
746
+ reward = self.reward_fn(
747
+ prompt=self.current_prompt,
748
+ completion=response_text,
749
+ example=self.examples[prompt_index],
750
+ tokenizer=self.tokenizer, # Pass tokenizer for EOS detection
751
+ truncated=truncated # Pass truncation flag from environment
752
+ )
753
+
754
+ # Prepare next observation (empty MLX array since episode ended)
755
+ next_observation = mx.array([])
756
+
757
+ info = {
758
+ 'response': response_text,
759
+ 'reward': reward,
760
+ 'prompt': self.current_prompt,
761
+ 'episode': self.current_episode
762
+ }
763
+
764
+ # Move to next episode
765
+ self.current_episode += 1
766
+
767
+ # Return unified dict format per Environment contract.
768
+ # Runner code now normalizes both dict and tuple step results, so
769
+ # this remains fully compatible with rollout collection while aligning
770
+ # with our base interface and other adapters (GymAdapter, VectorizedEnvironment).
771
+ return {
772
+ 'observation': next_observation,
773
+ 'reward': reward,
774
+ 'terminated': terminated,
775
+ 'truncated': truncated,
776
+ 'info': info,
777
+ }
778
+
779
+ @property
780
+ def observation_space(self) -> Any:
781
+ """Observation space is text prompts (variable length)."""
782
+ return "TextPrompt" # Placeholder - MLX doesn't need gym spaces
783
+
784
+ @property
785
+ def action_space(self) -> Any:
786
+ """Action space is generated text responses (variable length)."""
787
+ return "GeneratedText" # Placeholder - MLX doesn't need gym spaces
788
+
789
+ def clone(self) -> 'TextGenerationEnv':
790
+ """Create a clone for multiprocessing."""
791
+ return TextGenerationEnv(
792
+ prompts=self.prompts.copy(),
793
+ reward_fn=self.reward_fn,
794
+ max_tokens=self.max_tokens,
795
+ tokenizer=self.tokenizer, # Tokenizer is required for MLX compatibility
796
+ seed=random.randint(0, 10000) # New seed for variety
797
+ )