textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,789 @@
|
|
|
1
|
+
# textpolicy/environment/text_generation.py
|
|
2
|
+
"""
|
|
3
|
+
Text Generation Environment for Testing MLX RL Training.
|
|
4
|
+
|
|
5
|
+
This environment provides measurable text generation tasks to validate that
|
|
6
|
+
models are actually learning through RL training, not just going through motions.
|
|
7
|
+
|
|
8
|
+
Key features:
|
|
9
|
+
- Consistent, reproducible text generation tasks
|
|
10
|
+
- Before/after learning validation metrics
|
|
11
|
+
- Integration with MLX generation system
|
|
12
|
+
- Support for various text generation benchmarks
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from typing import Dict, List, Optional, Tuple, Any, Callable
|
|
16
|
+
import mlx.core as mx
|
|
17
|
+
import random
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from .base import Environment
|
|
20
|
+
from .task_suites import register_task_suite, get_task_suite
|
|
21
|
+
|
|
22
|
+
# Import our generation functions
|
|
23
|
+
from ..generation.mlx_generation import encode, decode, generate_tokens
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class TextGenerationTask:
|
|
28
|
+
"""A single text generation task with validation criteria."""
|
|
29
|
+
prompt: str
|
|
30
|
+
target_keywords: List[str]
|
|
31
|
+
target_length_range: Tuple[int, int] # (min_words, max_words)
|
|
32
|
+
difficulty: float # 0.0 to 1.0
|
|
33
|
+
category: str
|
|
34
|
+
evaluation_criteria: Dict[str, Any]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Default task suites for registration and internal use.
|
|
38
|
+
def _default_basic_tasks() -> List[TextGenerationTask]:
|
|
39
|
+
return [
|
|
40
|
+
TextGenerationTask(
|
|
41
|
+
prompt="Write a brief explanation of machine learning.",
|
|
42
|
+
target_keywords=["algorithm", "data", "learn"],
|
|
43
|
+
target_length_range=(20, 40),
|
|
44
|
+
difficulty=0.3,
|
|
45
|
+
category="length_control",
|
|
46
|
+
evaluation_criteria={"keyword_weight": 0.4, "length_weight": 0.6}
|
|
47
|
+
),
|
|
48
|
+
TextGenerationTask(
|
|
49
|
+
prompt="Describe the benefits of renewable energy in one paragraph.",
|
|
50
|
+
target_keywords=["environment", "sustainable", "clean"],
|
|
51
|
+
target_length_range=(30, 50),
|
|
52
|
+
difficulty=0.4,
|
|
53
|
+
category="length_control",
|
|
54
|
+
evaluation_criteria={"keyword_weight": 0.5, "length_weight": 0.5}
|
|
55
|
+
),
|
|
56
|
+
TextGenerationTask(
|
|
57
|
+
prompt="Explain how computers work.",
|
|
58
|
+
target_keywords=["processor", "memory", "software", "hardware"],
|
|
59
|
+
target_length_range=(25, 45),
|
|
60
|
+
difficulty=0.5,
|
|
61
|
+
category="keyword_inclusion",
|
|
62
|
+
evaluation_criteria={"keyword_weight": 0.7, "length_weight": 0.3}
|
|
63
|
+
),
|
|
64
|
+
TextGenerationTask(
|
|
65
|
+
prompt="Write about the importance of education.",
|
|
66
|
+
target_keywords=["knowledge", "skills", "future", "learning"],
|
|
67
|
+
target_length_range=(20, 40),
|
|
68
|
+
difficulty=0.4,
|
|
69
|
+
category="keyword_inclusion",
|
|
70
|
+
evaluation_criteria={"keyword_weight": 0.6, "length_weight": 0.4}
|
|
71
|
+
),
|
|
72
|
+
TextGenerationTask(
|
|
73
|
+
prompt="Explain the process of photosynthesis step by step.",
|
|
74
|
+
target_keywords=["sunlight", "carbon", "oxygen", "glucose"],
|
|
75
|
+
target_length_range=(35, 60),
|
|
76
|
+
difficulty=0.6,
|
|
77
|
+
category="coherence",
|
|
78
|
+
evaluation_criteria={"keyword_weight": 0.3, "length_weight": 0.3, "coherence_weight": 0.4}
|
|
79
|
+
),
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _default_challenging_tasks() -> List[TextGenerationTask]:
|
|
84
|
+
return [
|
|
85
|
+
TextGenerationTask(
|
|
86
|
+
prompt="Compare and contrast neural networks and traditional algorithms.",
|
|
87
|
+
target_keywords=["pattern", "weights", "training", "classification", "regression"],
|
|
88
|
+
target_length_range=(50, 80),
|
|
89
|
+
difficulty=0.8,
|
|
90
|
+
category="comparison",
|
|
91
|
+
evaluation_criteria={"keyword_weight": 0.4, "length_weight": 0.3, "coherence_weight": 0.3}
|
|
92
|
+
),
|
|
93
|
+
TextGenerationTask(
|
|
94
|
+
prompt="Analyze the ethical implications of artificial intelligence.",
|
|
95
|
+
target_keywords=["bias", "privacy", "autonomy", "responsibility", "society"],
|
|
96
|
+
target_length_range=(60, 100),
|
|
97
|
+
difficulty=0.9,
|
|
98
|
+
category="analysis",
|
|
99
|
+
evaluation_criteria={"keyword_weight": 0.3, "length_weight": 0.2, "coherence_weight": 0.5}
|
|
100
|
+
),
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# Register defaults at import time to make them discoverable via the registry.
|
|
105
|
+
register_task_suite("basic", _default_basic_tasks)
|
|
106
|
+
register_task_suite("challenging", _default_challenging_tasks)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class TextGenerationEnvironment(Environment):
|
|
110
|
+
"""
|
|
111
|
+
Environment for testing text generation learning with MLX models.
|
|
112
|
+
|
|
113
|
+
This environment provides a suite of text generation tasks that allow
|
|
114
|
+
measuring model improvement through RL training. It integrates directly
|
|
115
|
+
with our MLX generation system and reward functions.
|
|
116
|
+
|
|
117
|
+
Key validation approach:
|
|
118
|
+
1. Pre-training baseline: Measure model performance on task suite
|
|
119
|
+
2. Post-training comparison: Measure same model after RL training
|
|
120
|
+
3. Learning validation: Prove statistically significant improvement
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
model: Any,
|
|
126
|
+
tokenizer: Any,
|
|
127
|
+
task_suite: str = "basic",
|
|
128
|
+
num_episodes: int = 50,
|
|
129
|
+
generation_params: Optional[Dict[str, Any]] = None,
|
|
130
|
+
seed: int = 42
|
|
131
|
+
):
|
|
132
|
+
"""
|
|
133
|
+
Initialize text generation testing environment.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
model: MLX model for text generation
|
|
137
|
+
tokenizer: MLX tokenizer
|
|
138
|
+
task_suite: Which task suite to use ("basic", "challenging", "custom")
|
|
139
|
+
num_episodes: Number of episodes per evaluation
|
|
140
|
+
generation_params: Parameters for text generation
|
|
141
|
+
seed: Random seed for reproducible evaluation
|
|
142
|
+
"""
|
|
143
|
+
super().__init__()
|
|
144
|
+
|
|
145
|
+
# Validate critical dependencies early to produce clear, actionable errors.
|
|
146
|
+
# This environment integrates directly with MLX generation; both model and
|
|
147
|
+
# tokenizer are required. Without these, encode/decode/generate would fail
|
|
148
|
+
# later with obscure attribute errors.
|
|
149
|
+
if model is None:
|
|
150
|
+
raise ValueError("TextGenerationEnvironment requires a valid MLX model (got None)")
|
|
151
|
+
if tokenizer is None:
|
|
152
|
+
raise ValueError("TextGenerationEnvironment requires a valid tokenizer (got None)")
|
|
153
|
+
|
|
154
|
+
self.model = model
|
|
155
|
+
self.tokenizer = tokenizer
|
|
156
|
+
self.generation_params = generation_params or {
|
|
157
|
+
'max_tokens': 50,
|
|
158
|
+
'temperature': 0.8,
|
|
159
|
+
'top_p': 0.95
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
# Create task suite for evaluation
|
|
163
|
+
self.tasks = self._create_task_suite(task_suite)
|
|
164
|
+
self.task_suite = task_suite # remember suite type for cloning
|
|
165
|
+
self.num_episodes = num_episodes
|
|
166
|
+
self.current_episode = 0
|
|
167
|
+
self.current_task = None
|
|
168
|
+
|
|
169
|
+
# Performance tracking for learning validation
|
|
170
|
+
self.baseline_scores = []
|
|
171
|
+
self.current_scores = []
|
|
172
|
+
|
|
173
|
+
# Environment state
|
|
174
|
+
random.seed(seed)
|
|
175
|
+
self._episode_data = []
|
|
176
|
+
|
|
177
|
+
# Initialization complete; environment ready for evaluation
|
|
178
|
+
# (Debug prints removed for production efficiency)
|
|
179
|
+
|
|
180
|
+
def _create_task_suite(self, suite_type: str) -> List[TextGenerationTask]:
|
|
181
|
+
"""
|
|
182
|
+
Create a suite of text generation tasks for evaluation.
|
|
183
|
+
|
|
184
|
+
These tasks are designed to be:
|
|
185
|
+
- Measurable: Clear success criteria
|
|
186
|
+
- Diverse: Cover different generation challenges
|
|
187
|
+
- Reproducible: Same tasks for before/after comparison
|
|
188
|
+
"""
|
|
189
|
+
# Prefer registry-based suites when available; fall back to defaults here.
|
|
190
|
+
# First, try registry-based loader (see environment.task_suites).
|
|
191
|
+
# This enables custom suites without hardcoding here.
|
|
192
|
+
registered = get_task_suite(suite_type)
|
|
193
|
+
if registered is not None:
|
|
194
|
+
return registered
|
|
195
|
+
|
|
196
|
+
if suite_type == "basic":
|
|
197
|
+
return _default_basic_tasks()
|
|
198
|
+
|
|
199
|
+
elif suite_type == "challenging":
|
|
200
|
+
return _default_challenging_tasks()
|
|
201
|
+
|
|
202
|
+
else: # custom or fallback
|
|
203
|
+
return [
|
|
204
|
+
TextGenerationTask(
|
|
205
|
+
prompt="Tell me about your favorite topic.",
|
|
206
|
+
target_keywords=["interesting", "because", "example"],
|
|
207
|
+
target_length_range=(15, 35),
|
|
208
|
+
difficulty=0.2,
|
|
209
|
+
category="open_ended",
|
|
210
|
+
evaluation_criteria={"keyword_weight": 0.5, "length_weight": 0.5}
|
|
211
|
+
)
|
|
212
|
+
]
|
|
213
|
+
|
|
214
|
+
# Register default suites in the registry to enable external access (list/get).
|
|
215
|
+
# Done here to avoid import cycles: the loader closures capture TextGenerationTask.
|
|
216
|
+
def _default_basic_tasks() -> List[TextGenerationTask]:
|
|
217
|
+
return [
|
|
218
|
+
TextGenerationTask(
|
|
219
|
+
prompt="Write a brief explanation of machine learning.",
|
|
220
|
+
target_keywords=["algorithm", "data", "learn"],
|
|
221
|
+
target_length_range=(20, 40),
|
|
222
|
+
difficulty=0.3,
|
|
223
|
+
category="length_control",
|
|
224
|
+
evaluation_criteria={"keyword_weight": 0.4, "length_weight": 0.6}
|
|
225
|
+
),
|
|
226
|
+
TextGenerationTask(
|
|
227
|
+
prompt="Describe the benefits of renewable energy in one paragraph.",
|
|
228
|
+
target_keywords=["environment", "sustainable", "clean"],
|
|
229
|
+
target_length_range=(30, 50),
|
|
230
|
+
difficulty=0.4,
|
|
231
|
+
category="length_control",
|
|
232
|
+
evaluation_criteria={"keyword_weight": 0.5, "length_weight": 0.5}
|
|
233
|
+
),
|
|
234
|
+
TextGenerationTask(
|
|
235
|
+
prompt="Explain how computers work.",
|
|
236
|
+
target_keywords=["processor", "memory", "software", "hardware"],
|
|
237
|
+
target_length_range=(25, 45),
|
|
238
|
+
difficulty=0.5,
|
|
239
|
+
category="keyword_inclusion",
|
|
240
|
+
evaluation_criteria={"keyword_weight": 0.7, "length_weight": 0.3}
|
|
241
|
+
),
|
|
242
|
+
TextGenerationTask(
|
|
243
|
+
prompt="Write about the importance of education.",
|
|
244
|
+
target_keywords=["knowledge", "skills", "future", "learning"],
|
|
245
|
+
target_length_range=(20, 40),
|
|
246
|
+
difficulty=0.4,
|
|
247
|
+
category="keyword_inclusion",
|
|
248
|
+
evaluation_criteria={"keyword_weight": 0.6, "length_weight": 0.4}
|
|
249
|
+
),
|
|
250
|
+
TextGenerationTask(
|
|
251
|
+
prompt="Explain the process of photosynthesis step by step.",
|
|
252
|
+
target_keywords=["sunlight", "carbon", "oxygen", "glucose"],
|
|
253
|
+
target_length_range=(35, 60),
|
|
254
|
+
difficulty=0.6,
|
|
255
|
+
category="coherence",
|
|
256
|
+
evaluation_criteria={"keyword_weight": 0.3, "length_weight": 0.3, "coherence_weight": 0.4}
|
|
257
|
+
),
|
|
258
|
+
]
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def reset(self) -> Tuple[Any, Dict[str, Any]]:
|
|
262
|
+
"""
|
|
263
|
+
Reset environment to start a new episode.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
(observation, info): Initial prompt and episode metadata
|
|
267
|
+
"""
|
|
268
|
+
# Select next task (cycle through tasks)
|
|
269
|
+
task_index = self.current_episode % len(self.tasks)
|
|
270
|
+
self.current_task = self.tasks[task_index]
|
|
271
|
+
|
|
272
|
+
# Reset episode state
|
|
273
|
+
self._episode_data = {
|
|
274
|
+
'prompt': self.current_task.prompt,
|
|
275
|
+
'task': self.current_task,
|
|
276
|
+
'responses': [],
|
|
277
|
+
'scores': []
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
# Return initial observation (the prompt to generate from)
|
|
281
|
+
observation = encode(self.tokenizer, self.current_task.prompt)
|
|
282
|
+
|
|
283
|
+
info = {
|
|
284
|
+
'episode': self.current_episode,
|
|
285
|
+
'task_category': self.current_task.category,
|
|
286
|
+
'difficulty': self.current_task.difficulty,
|
|
287
|
+
'target_keywords': self.current_task.target_keywords,
|
|
288
|
+
'target_length_range': self.current_task.target_length_range
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
return observation, info
|
|
292
|
+
|
|
293
|
+
def step(self, action: Any) -> Dict[str, Any]:
|
|
294
|
+
"""
|
|
295
|
+
Take a step in the environment by generating text response.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
action: Generated response tokens (MLX array)
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Step result with observation, reward, termination status, and info
|
|
302
|
+
"""
|
|
303
|
+
if self.current_task is None:
|
|
304
|
+
raise ValueError("Environment not reset - call reset() first")
|
|
305
|
+
|
|
306
|
+
# Decode response from tokens
|
|
307
|
+
if hasattr(action, 'tolist'):
|
|
308
|
+
# Action is MLX array of tokens
|
|
309
|
+
response_text = decode(self.tokenizer, action)
|
|
310
|
+
else:
|
|
311
|
+
# Action might already be text
|
|
312
|
+
response_text = str(action)
|
|
313
|
+
|
|
314
|
+
# Compute reward using our reward system
|
|
315
|
+
reward_score = self._evaluate_response(
|
|
316
|
+
prompt=self.current_task.prompt,
|
|
317
|
+
response=response_text,
|
|
318
|
+
task=self.current_task
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Store episode data for analysis
|
|
322
|
+
self._episode_data['responses'].append(response_text)
|
|
323
|
+
self._episode_data['scores'].append(reward_score)
|
|
324
|
+
|
|
325
|
+
# Episode terminates after each generation (single-turn tasks)
|
|
326
|
+
terminated = True
|
|
327
|
+
truncated = False
|
|
328
|
+
|
|
329
|
+
# Prepare next observation (empty since episode ended)
|
|
330
|
+
next_observation = mx.array([])
|
|
331
|
+
|
|
332
|
+
info = {
|
|
333
|
+
'response': response_text,
|
|
334
|
+
'reward_score': reward_score,
|
|
335
|
+
'task_category': self.current_task.category,
|
|
336
|
+
'target_keywords_found': [kw for kw in self.current_task.target_keywords
|
|
337
|
+
if kw.lower() in response_text.lower()],
|
|
338
|
+
'response_length': len(response_text.split()),
|
|
339
|
+
'target_length_range': self.current_task.target_length_range
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
# Move to next episode
|
|
343
|
+
self.current_episode += 1
|
|
344
|
+
|
|
345
|
+
return {
|
|
346
|
+
'observation': next_observation,
|
|
347
|
+
'reward': reward_score,
|
|
348
|
+
'terminated': terminated,
|
|
349
|
+
'truncated': truncated,
|
|
350
|
+
'info': info
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
def _evaluate_response(self, prompt: str, response: str, task: TextGenerationTask) -> float:
|
|
354
|
+
"""
|
|
355
|
+
Evaluate response quality using task-specific criteria.
|
|
356
|
+
|
|
357
|
+
This function integrates with our reward system to provide
|
|
358
|
+
consistent, measurable evaluation of text generation quality.
|
|
359
|
+
"""
|
|
360
|
+
criteria = task.evaluation_criteria
|
|
361
|
+
total_score = 0.0
|
|
362
|
+
|
|
363
|
+
# Length-based scoring
|
|
364
|
+
if 'length_weight' in criteria:
|
|
365
|
+
word_count = len(response.split())
|
|
366
|
+
min_len, max_len = task.target_length_range
|
|
367
|
+
target_len = (min_len + max_len) / 2
|
|
368
|
+
|
|
369
|
+
# Score based on proximity to target length
|
|
370
|
+
if min_len <= word_count <= max_len:
|
|
371
|
+
length_score = 1.0
|
|
372
|
+
else:
|
|
373
|
+
# Penalty for being outside range
|
|
374
|
+
distance = min(abs(word_count - min_len), abs(word_count - max_len))
|
|
375
|
+
length_score = max(0.0, 1.0 - distance / target_len)
|
|
376
|
+
|
|
377
|
+
total_score += criteria['length_weight'] * length_score
|
|
378
|
+
|
|
379
|
+
# Keyword inclusion scoring
|
|
380
|
+
if 'keyword_weight' in criteria:
|
|
381
|
+
keywords_found = sum(1 for kw in task.target_keywords
|
|
382
|
+
if kw.lower() in response.lower())
|
|
383
|
+
keyword_score = keywords_found / len(task.target_keywords)
|
|
384
|
+
total_score += criteria['keyword_weight'] * keyword_score
|
|
385
|
+
|
|
386
|
+
# Coherence scoring (simple heuristic)
|
|
387
|
+
if 'coherence_weight' in criteria:
|
|
388
|
+
# Use our existing coherence evaluation
|
|
389
|
+
coherence_score = self._simple_coherence_score(response)
|
|
390
|
+
total_score += criteria['coherence_weight'] * coherence_score
|
|
391
|
+
|
|
392
|
+
return total_score
|
|
393
|
+
|
|
394
|
+
def _simple_coherence_score(self, text: str) -> float:
|
|
395
|
+
"""Simple coherence scoring based on structure indicators."""
|
|
396
|
+
if not text.strip():
|
|
397
|
+
return 0.0
|
|
398
|
+
|
|
399
|
+
# Basic coherence indicators
|
|
400
|
+
sentences = [s.strip() for s in text.split('.') if s.strip()]
|
|
401
|
+
if len(sentences) < 2:
|
|
402
|
+
return 0.5 # Single sentence is moderately coherent
|
|
403
|
+
|
|
404
|
+
# Look for logical connectors
|
|
405
|
+
connectors = ['therefore', 'however', 'moreover', 'furthermore', 'because', 'since']
|
|
406
|
+
connector_count = sum(1 for conn in connectors if conn in text.lower())
|
|
407
|
+
|
|
408
|
+
# Coherence score based on structure
|
|
409
|
+
connector_score = min(1.0, connector_count / 2.0) # 2+ connectors is good
|
|
410
|
+
sentence_score = min(1.0, len(sentences) / 3.0) # 3+ sentences is good
|
|
411
|
+
|
|
412
|
+
return (connector_score + sentence_score) / 2.0
|
|
413
|
+
|
|
414
|
+
def evaluate_model(self, mode: str = "current") -> Dict[str, float]:
|
|
415
|
+
"""
|
|
416
|
+
Evaluate model performance on the full task suite.
|
|
417
|
+
|
|
418
|
+
This function runs through all tasks and computes aggregate
|
|
419
|
+
performance metrics to measure learning progress.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
mode: "baseline" (store baseline) or "current" (compare to baseline)
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
Performance metrics dictionary
|
|
426
|
+
"""
|
|
427
|
+
print(f"Running {mode} evaluation on {self.num_episodes} episodes...")
|
|
428
|
+
|
|
429
|
+
all_scores = []
|
|
430
|
+
category_scores = {}
|
|
431
|
+
|
|
432
|
+
# Reset episode counter for evaluation
|
|
433
|
+
original_episode = self.current_episode
|
|
434
|
+
self.current_episode = 0
|
|
435
|
+
|
|
436
|
+
try:
|
|
437
|
+
for episode in range(self.num_episodes):
|
|
438
|
+
# Reset environment
|
|
439
|
+
observation, info = self.reset()
|
|
440
|
+
|
|
441
|
+
# Generate response using current model
|
|
442
|
+
response_tokens, generation_info = generate_tokens(
|
|
443
|
+
model=self.model,
|
|
444
|
+
tokenizer=self.tokenizer,
|
|
445
|
+
prompt_tokens=observation,
|
|
446
|
+
**self.generation_params
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# Take step to get reward
|
|
450
|
+
step_result = self.step(response_tokens)
|
|
451
|
+
score = step_result['reward']
|
|
452
|
+
category = step_result['info']['task_category']
|
|
453
|
+
|
|
454
|
+
all_scores.append(score)
|
|
455
|
+
if category not in category_scores:
|
|
456
|
+
category_scores[category] = []
|
|
457
|
+
category_scores[category].append(score)
|
|
458
|
+
|
|
459
|
+
finally:
|
|
460
|
+
# Restore episode counter
|
|
461
|
+
self.current_episode = original_episode
|
|
462
|
+
|
|
463
|
+
# Compute aggregate metrics
|
|
464
|
+
mean_score = float(mx.mean(mx.array(all_scores)))
|
|
465
|
+
std_score = float(mx.std(mx.array(all_scores)))
|
|
466
|
+
|
|
467
|
+
metrics = {
|
|
468
|
+
'mean_score': mean_score,
|
|
469
|
+
'std_score': std_score,
|
|
470
|
+
'num_episodes': self.num_episodes,
|
|
471
|
+
'category_breakdown': {
|
|
472
|
+
cat: float(mx.mean(mx.array(scores)))
|
|
473
|
+
for cat, scores in category_scores.items()
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
# Store results based on mode
|
|
478
|
+
if mode == "baseline":
|
|
479
|
+
self.baseline_scores = all_scores
|
|
480
|
+
print(f"✓ Baseline evaluation complete: {mean_score:.3f} ± {std_score:.3f}")
|
|
481
|
+
else:
|
|
482
|
+
self.current_scores = all_scores
|
|
483
|
+
|
|
484
|
+
# Compute learning improvement if we have baseline
|
|
485
|
+
if self.baseline_scores:
|
|
486
|
+
baseline_mean = float(mx.mean(mx.array(self.baseline_scores)))
|
|
487
|
+
improvement = mean_score - baseline_mean
|
|
488
|
+
improvement_pct = (improvement / baseline_mean) * 100 if baseline_mean > 0 else 0
|
|
489
|
+
|
|
490
|
+
metrics['baseline_score'] = baseline_mean
|
|
491
|
+
metrics['improvement'] = improvement
|
|
492
|
+
metrics['improvement_percent'] = improvement_pct
|
|
493
|
+
|
|
494
|
+
print(f"Current evaluation complete: {mean_score:.3f} ± {std_score:.3f}")
|
|
495
|
+
print(f" Improvement: {improvement:+.3f} ({improvement_pct:+.1f}%)")
|
|
496
|
+
|
|
497
|
+
# Statistical significance test (simple)
|
|
498
|
+
if improvement > 2 * std_score: # Rough 2-sigma test
|
|
499
|
+
print(" LEARNING DETECTED: Statistically significant improvement!")
|
|
500
|
+
else:
|
|
501
|
+
print(" Learning uncertain: Improvement not statistically significant")
|
|
502
|
+
|
|
503
|
+
return metrics
|
|
504
|
+
|
|
505
|
+
@property
|
|
506
|
+
def observation_space(self) -> Any:
|
|
507
|
+
"""Observation space is tokenized text (variable length)."""
|
|
508
|
+
return "TokenizedText" # Placeholder - MLX doesn't need gym spaces
|
|
509
|
+
|
|
510
|
+
@property
|
|
511
|
+
def action_space(self) -> Any:
|
|
512
|
+
"""Action space is generated text tokens (variable length)."""
|
|
513
|
+
return "GeneratedTokens" # Placeholder - MLX doesn't need gym spaces
|
|
514
|
+
|
|
515
|
+
def clone(self) -> 'TextGenerationEnvironment':
|
|
516
|
+
"""Create a multiprocessing clone with the same configuration.
|
|
517
|
+
|
|
518
|
+
This returns a new environment instance that references the same
|
|
519
|
+
model/tokenizer objects. On some systems, MLX models are not picklable;
|
|
520
|
+
for process spawning, prefer passing an environment factory (env_fn)
|
|
521
|
+
so model/tokenizer can be constructed in each process. See rollout.coordinator.
|
|
522
|
+
"""
|
|
523
|
+
# Delegate to the same constructor with preserved parameters
|
|
524
|
+
return TextGenerationEnvironment(
|
|
525
|
+
model=self.model,
|
|
526
|
+
tokenizer=self.tokenizer,
|
|
527
|
+
task_suite=self.task_suite,
|
|
528
|
+
num_episodes=self.num_episodes,
|
|
529
|
+
generation_params=self.generation_params,
|
|
530
|
+
seed=random.randint(0, 10000)
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def create_text_generation_test_env(
|
|
535
|
+
model: Any,
|
|
536
|
+
tokenizer: Any,
|
|
537
|
+
task_suite: str = "basic",
|
|
538
|
+
num_episodes: int = 50,
|
|
539
|
+
**kwargs
|
|
540
|
+
) -> TextGenerationEnvironment:
|
|
541
|
+
"""
|
|
542
|
+
Factory function to create a text generation testing environment.
|
|
543
|
+
|
|
544
|
+
This is the main entry point for creating environments to test
|
|
545
|
+
whether RL training actually improves model performance.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
model: MLX model for text generation
|
|
549
|
+
tokenizer: MLX tokenizer
|
|
550
|
+
task_suite: Which task suite to use for evaluation
|
|
551
|
+
num_episodes: Number of episodes per evaluation
|
|
552
|
+
**kwargs: Additional environment parameters
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
Configured TextGenerationEnvironment ready for testing
|
|
556
|
+
"""
|
|
557
|
+
return TextGenerationEnvironment(
|
|
558
|
+
model=model,
|
|
559
|
+
tokenizer=tokenizer,
|
|
560
|
+
task_suite=task_suite,
|
|
561
|
+
num_episodes=num_episodes,
|
|
562
|
+
**kwargs
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def validate_learning_progress(
|
|
567
|
+
env: TextGenerationEnvironment,
|
|
568
|
+
pre_training_metrics: Dict[str, float],
|
|
569
|
+
post_training_metrics: Dict[str, float]
|
|
570
|
+
) -> Dict[str, Any]:
|
|
571
|
+
"""
|
|
572
|
+
Pure function to validate that learning actually occurred.
|
|
573
|
+
|
|
574
|
+
This function provides statistical analysis to prove that
|
|
575
|
+
RL training resulted in measurable improvement.
|
|
576
|
+
|
|
577
|
+
Args:
|
|
578
|
+
env: The environment used for testing
|
|
579
|
+
pre_training_metrics: Metrics before training
|
|
580
|
+
post_training_metrics: Metrics after training
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
Learning validation report
|
|
584
|
+
"""
|
|
585
|
+
improvement = post_training_metrics['mean_score'] - pre_training_metrics['mean_score']
|
|
586
|
+
improvement_pct = (improvement / pre_training_metrics['mean_score']) * 100
|
|
587
|
+
|
|
588
|
+
# Simple statistical significance test
|
|
589
|
+
pre_std = pre_training_metrics['std_score']
|
|
590
|
+
post_std = post_training_metrics['std_score']
|
|
591
|
+
pooled_std = (pre_std + post_std) / 2
|
|
592
|
+
|
|
593
|
+
significance_threshold = 2 * pooled_std # Rough 2-sigma test
|
|
594
|
+
is_significant = abs(improvement) > significance_threshold
|
|
595
|
+
|
|
596
|
+
validation_report = {
|
|
597
|
+
'learning_detected': improvement > 0 and is_significant,
|
|
598
|
+
'improvement_score': improvement,
|
|
599
|
+
'improvement_percent': improvement_pct,
|
|
600
|
+
'statistical_significance': is_significant,
|
|
601
|
+
'significance_threshold': significance_threshold,
|
|
602
|
+
'pre_training_score': pre_training_metrics['mean_score'],
|
|
603
|
+
'post_training_score': post_training_metrics['mean_score'],
|
|
604
|
+
'recommendation': (
|
|
605
|
+
"LEARNING CONFIRMED: Model shows statistically significant improvement"
|
|
606
|
+
if improvement > 0 and is_significant
|
|
607
|
+
else "LEARNING UNCERTAIN: No significant improvement detected"
|
|
608
|
+
)
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
return validation_report
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
class TextGenerationEnv(Environment):
|
|
615
|
+
"""
|
|
616
|
+
Simple text generation environment for RL training.
|
|
617
|
+
|
|
618
|
+
This is a lightweight wrapper around TextGenerationEnvironment that provides
|
|
619
|
+
the simple interface expected by training examples. It's designed for:
|
|
620
|
+
- Simple prompt-based training tasks
|
|
621
|
+
- External reward function integration
|
|
622
|
+
- Basic RL training workflows
|
|
623
|
+
|
|
624
|
+
For comprehensive testing and validation, use TextGenerationEnvironment instead.
|
|
625
|
+
"""
|
|
626
|
+
|
|
627
|
+
def __init__(
|
|
628
|
+
self,
|
|
629
|
+
prompts: List[str],
|
|
630
|
+
reward_fn: Callable[[str, str, dict], float],
|
|
631
|
+
max_tokens: int = 25,
|
|
632
|
+
seed: int = 42,
|
|
633
|
+
tokenizer: Any = None
|
|
634
|
+
):
|
|
635
|
+
"""
|
|
636
|
+
Initialize simple text generation environment.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
prompts: List of prompts to cycle through
|
|
640
|
+
reward_fn: Function that computes reward from (prompt, completion, example)
|
|
641
|
+
max_tokens: Maximum tokens to generate per response
|
|
642
|
+
seed: Random seed for reproducible behavior
|
|
643
|
+
tokenizer: Tokenizer for converting prompts to tokens (required for MLX compatibility)
|
|
644
|
+
"""
|
|
645
|
+
super().__init__()
|
|
646
|
+
|
|
647
|
+
if tokenizer is None:
|
|
648
|
+
raise ValueError("tokenizer is required for TextGenerationEnv to work with MLX rollout system")
|
|
649
|
+
|
|
650
|
+
self.prompts = prompts
|
|
651
|
+
self.reward_fn = reward_fn
|
|
652
|
+
self.max_tokens = max_tokens
|
|
653
|
+
self.tokenizer = tokenizer
|
|
654
|
+
self.current_episode = 0
|
|
655
|
+
self.current_prompt = None
|
|
656
|
+
|
|
657
|
+
# Environment state
|
|
658
|
+
random.seed(seed)
|
|
659
|
+
|
|
660
|
+
# Debug prints removed for production efficiency
|
|
661
|
+
|
|
662
|
+
def reset(self) -> Tuple[Any, Dict[str, Any]]:
|
|
663
|
+
"""
|
|
664
|
+
Reset environment to start a new episode.
|
|
665
|
+
|
|
666
|
+
Returns:
|
|
667
|
+
(observation, info): Current prompt tokens and episode metadata
|
|
668
|
+
"""
|
|
669
|
+
# Cycle through prompts
|
|
670
|
+
prompt_index = self.current_episode % len(self.prompts)
|
|
671
|
+
self.current_prompt = self.prompts[prompt_index]
|
|
672
|
+
|
|
673
|
+
# Tokenize prompt for MLX compatibility
|
|
674
|
+
# Import encode function from mlx_generation to avoid circular imports
|
|
675
|
+
from ..generation.mlx_generation import encode
|
|
676
|
+
observation = encode(self.tokenizer, self.current_prompt)
|
|
677
|
+
|
|
678
|
+
info = {
|
|
679
|
+
'episode': self.current_episode,
|
|
680
|
+
'prompt_index': prompt_index,
|
|
681
|
+
'max_tokens': self.max_tokens,
|
|
682
|
+
'prompt_text': self.current_prompt # Keep original text for reward computation
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
return observation, info
|
|
686
|
+
|
|
687
|
+
def step(self, action: Any) -> Dict[str, Any]:
|
|
688
|
+
"""
|
|
689
|
+
Take a step in the environment by evaluating generated text.
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
action: Generated text response (string or token array)
|
|
693
|
+
|
|
694
|
+
Returns:
|
|
695
|
+
Dictionary with keys: observation, reward, terminated, truncated, info.
|
|
696
|
+
This matches the Environment base class contract. The rollout runner
|
|
697
|
+
normalizes both dict and tuple returns, so returning a dict here keeps
|
|
698
|
+
interfaces consistent and compatible with rollouts.
|
|
699
|
+
"""
|
|
700
|
+
if self.current_prompt is None:
|
|
701
|
+
raise ValueError("Environment not reset - call reset() first")
|
|
702
|
+
|
|
703
|
+
# Handle different action types - properly decode token arrays to text
|
|
704
|
+
if hasattr(action, 'tolist'):
|
|
705
|
+
# Action is MLX array of tokens - decode to text using tokenizer
|
|
706
|
+
try:
|
|
707
|
+
from ..generation.mlx_generation import decode
|
|
708
|
+
response_text = decode(self.tokenizer, action)
|
|
709
|
+
except Exception as e:
|
|
710
|
+
print(f"WARNING: Failed to decode MLX action array: {e}")
|
|
711
|
+
# Fallback: try to handle as raw tokens
|
|
712
|
+
try:
|
|
713
|
+
response_text = self.tokenizer.decode(action.tolist())
|
|
714
|
+
except Exception as e2:
|
|
715
|
+
print(f"WARNING: Fallback decode also failed: {e2}")
|
|
716
|
+
response_text = "Generated response (decode failed)"
|
|
717
|
+
elif isinstance(action, list) and len(action) > 0 and isinstance(action[0], (int, float)):
|
|
718
|
+
# Action is a Python list of token IDs - decode to text
|
|
719
|
+
try:
|
|
720
|
+
response_text = self.tokenizer.decode(action)
|
|
721
|
+
except Exception as e:
|
|
722
|
+
print(f"WARNING: Failed to decode token list: {e}")
|
|
723
|
+
response_text = "Generated response (decode failed)"
|
|
724
|
+
else:
|
|
725
|
+
# Action is already text or something else
|
|
726
|
+
response_text = str(action)
|
|
727
|
+
|
|
728
|
+
# Detect if response was truncated by max_tokens limit
|
|
729
|
+
# This happens when the generation hits the token limit before naturally ending
|
|
730
|
+
response_tokens = len(response_text.split()) if response_text else 0
|
|
731
|
+
truncated = response_tokens >= (self.max_tokens * 0.95) # Consider 95% of limit as likely truncated
|
|
732
|
+
|
|
733
|
+
# Episode terminates after each generation (single-turn tasks)
|
|
734
|
+
terminated = True
|
|
735
|
+
|
|
736
|
+
# Compute reward using provided reward function
|
|
737
|
+
# Pass tokenizer for EOS token detection and truncation detection
|
|
738
|
+
reward = self.reward_fn(
|
|
739
|
+
prompt=self.current_prompt,
|
|
740
|
+
completion=response_text,
|
|
741
|
+
example={},
|
|
742
|
+
tokenizer=self.tokenizer, # Pass tokenizer for EOS detection
|
|
743
|
+
truncated=truncated # Pass truncation flag from environment
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
# Prepare next observation (empty MLX array since episode ended)
|
|
747
|
+
next_observation = mx.array([])
|
|
748
|
+
|
|
749
|
+
info = {
|
|
750
|
+
'response': response_text,
|
|
751
|
+
'reward': reward,
|
|
752
|
+
'prompt': self.current_prompt,
|
|
753
|
+
'episode': self.current_episode
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
# Move to next episode
|
|
757
|
+
self.current_episode += 1
|
|
758
|
+
|
|
759
|
+
# Return unified dict format per Environment contract.
|
|
760
|
+
# Runner code now normalizes both dict and tuple step results, so
|
|
761
|
+
# this remains fully compatible with rollout collection while aligning
|
|
762
|
+
# with our base interface and other adapters (GymAdapter, VectorizedEnvironment).
|
|
763
|
+
return {
|
|
764
|
+
'observation': next_observation,
|
|
765
|
+
'reward': reward,
|
|
766
|
+
'terminated': terminated,
|
|
767
|
+
'truncated': truncated,
|
|
768
|
+
'info': info,
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
@property
|
|
772
|
+
def observation_space(self) -> Any:
|
|
773
|
+
"""Observation space is text prompts (variable length)."""
|
|
774
|
+
return "TextPrompt" # Placeholder - MLX doesn't need gym spaces
|
|
775
|
+
|
|
776
|
+
@property
|
|
777
|
+
def action_space(self) -> Any:
|
|
778
|
+
"""Action space is generated text responses (variable length)."""
|
|
779
|
+
return "GeneratedText" # Placeholder - MLX doesn't need gym spaces
|
|
780
|
+
|
|
781
|
+
def clone(self) -> 'TextGenerationEnv':
|
|
782
|
+
"""Create a clone for multiprocessing."""
|
|
783
|
+
return TextGenerationEnv(
|
|
784
|
+
prompts=self.prompts.copy(),
|
|
785
|
+
reward_fn=self.reward_fn,
|
|
786
|
+
max_tokens=self.max_tokens,
|
|
787
|
+
tokenizer=self.tokenizer, # Tokenizer is required for MLX compatibility
|
|
788
|
+
seed=random.randint(0, 10000) # New seed for variety
|
|
789
|
+
)
|