textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +53 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +797 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.1.dist-info/METADATA +109 -0
- textpolicy-0.1.1.dist-info/RECORD +66 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
- textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
# textpolicy/rewards/integrated_system.py
|
|
2
|
+
"""
|
|
3
|
+
Integrated rollout reward system combining rewards and verifiers.
|
|
4
|
+
|
|
5
|
+
This system provides a unified interface for:
|
|
6
|
+
1. Computing rewards at the rollout level
|
|
7
|
+
2. Verifying text quality
|
|
8
|
+
3. Filtering episodes based on quality thresholds
|
|
9
|
+
4. MLX-optimized batch processing
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Dict, List, Any, Tuple
|
|
13
|
+
import mlx.core as mx
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from .rollout_rewards import RolloutRewardProcessor, RewardConfig
|
|
18
|
+
from .verifiers import create_default_verifier_pipeline
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class IntegratedRewardConfig:
|
|
23
|
+
"""Configuration for integrated reward and verification system."""
|
|
24
|
+
# Reward configuration
|
|
25
|
+
reward_config: RewardConfig
|
|
26
|
+
|
|
27
|
+
# Verification configuration
|
|
28
|
+
enable_verification: bool = True
|
|
29
|
+
verification_threshold: float = 0.7 # Minimum verification score
|
|
30
|
+
strict_filtering: bool = False # If True, reject episodes below threshold
|
|
31
|
+
|
|
32
|
+
# Quality control
|
|
33
|
+
min_reward_threshold: float = 0.3
|
|
34
|
+
max_reward_threshold: float = 1.0
|
|
35
|
+
|
|
36
|
+
# Batch processing
|
|
37
|
+
batch_size: int = 32
|
|
38
|
+
enable_mlx_compilation: bool = True
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class IntegratedRolloutRewardSystem:
|
|
42
|
+
"""
|
|
43
|
+
Integrated system for rollout-level reward computation and verification.
|
|
44
|
+
|
|
45
|
+
Combines reward computation with quality verification for comprehensive
|
|
46
|
+
episode evaluation and filtering.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, config: IntegratedRewardConfig):
|
|
50
|
+
"""
|
|
51
|
+
Initialize integrated reward system.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
config: Configuration for the integrated system
|
|
55
|
+
"""
|
|
56
|
+
self.config = config
|
|
57
|
+
|
|
58
|
+
# Initialize reward processor
|
|
59
|
+
self.reward_processor = RolloutRewardProcessor(config.reward_config)
|
|
60
|
+
|
|
61
|
+
# Initialize verification pipeline
|
|
62
|
+
if config.enable_verification:
|
|
63
|
+
self.verifier = create_default_verifier_pipeline()
|
|
64
|
+
else:
|
|
65
|
+
self.verifier = None
|
|
66
|
+
|
|
67
|
+
def process_episodes(
|
|
68
|
+
self,
|
|
69
|
+
episodes: List[Dict[str, Any]]
|
|
70
|
+
) -> Tuple[mx.array, List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
71
|
+
"""
|
|
72
|
+
Process episodes with rewards and verification.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
episodes: List of episode dictionaries
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
(rewards, accepted_episodes, rejected_episodes)
|
|
79
|
+
"""
|
|
80
|
+
if not episodes:
|
|
81
|
+
return mx.array([]), [], []
|
|
82
|
+
|
|
83
|
+
# Compute rewards
|
|
84
|
+
rewards = self.reward_processor.process_episode_rewards(episodes)
|
|
85
|
+
|
|
86
|
+
# Verify quality if enabled
|
|
87
|
+
if self.verifier:
|
|
88
|
+
verification_results = self._verify_episodes(episodes)
|
|
89
|
+
accepted, rejected = self._filter_episodes(
|
|
90
|
+
episodes, rewards, verification_results
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
# No verification - accept all episodes above reward threshold
|
|
94
|
+
accepted, rejected = self._filter_by_rewards_only(episodes, rewards)
|
|
95
|
+
|
|
96
|
+
return rewards, accepted, rejected
|
|
97
|
+
|
|
98
|
+
def _verify_episodes(
|
|
99
|
+
self,
|
|
100
|
+
episodes: List[Dict[str, Any]]
|
|
101
|
+
) -> List[Any]:
|
|
102
|
+
"""Verify quality of episodes."""
|
|
103
|
+
prompts = [ep.get('prompt', '') for ep in episodes]
|
|
104
|
+
responses = [ep.get('response', '') for ep in episodes]
|
|
105
|
+
|
|
106
|
+
return self.verifier.verify_batch(prompts, responses)
|
|
107
|
+
|
|
108
|
+
def _filter_episodes(
|
|
109
|
+
self,
|
|
110
|
+
episodes: List[Dict[str, Any]],
|
|
111
|
+
rewards: mx.array,
|
|
112
|
+
verification_results: List[Any]
|
|
113
|
+
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
114
|
+
"""Filter episodes based on rewards and verification."""
|
|
115
|
+
accepted = []
|
|
116
|
+
rejected = []
|
|
117
|
+
|
|
118
|
+
for i, (episode, reward, verification) in enumerate(
|
|
119
|
+
zip(episodes, rewards, verification_results)
|
|
120
|
+
):
|
|
121
|
+
# Check reward threshold
|
|
122
|
+
reward_ok = (
|
|
123
|
+
self.config.min_reward_threshold <= reward <= self.config.max_reward_threshold
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Check verification threshold
|
|
127
|
+
verification_ok = verification.score >= self.config.verification_threshold
|
|
128
|
+
|
|
129
|
+
# Determine acceptance
|
|
130
|
+
if self.config.strict_filtering:
|
|
131
|
+
# Both must pass
|
|
132
|
+
is_accepted = reward_ok and verification_ok
|
|
133
|
+
else:
|
|
134
|
+
# Either can pass (more lenient)
|
|
135
|
+
is_accepted = reward_ok or verification_ok
|
|
136
|
+
|
|
137
|
+
# Add metadata
|
|
138
|
+
episode_with_metadata = episode.copy()
|
|
139
|
+
episode_with_metadata.update({
|
|
140
|
+
'reward': float(reward),
|
|
141
|
+
'verification_score': verification.score,
|
|
142
|
+
'verification_result': verification.result.value,
|
|
143
|
+
'verification_message': verification.message,
|
|
144
|
+
'verification_details': verification.details
|
|
145
|
+
})
|
|
146
|
+
|
|
147
|
+
if is_accepted:
|
|
148
|
+
accepted.append(episode_with_metadata)
|
|
149
|
+
else:
|
|
150
|
+
rejected.append(episode_with_metadata)
|
|
151
|
+
|
|
152
|
+
return accepted, rejected
|
|
153
|
+
|
|
154
|
+
def _filter_by_rewards_only(
|
|
155
|
+
self,
|
|
156
|
+
episodes: List[Dict[str, Any]],
|
|
157
|
+
rewards: mx.array
|
|
158
|
+
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
159
|
+
"""Filter episodes based only on rewards."""
|
|
160
|
+
accepted = []
|
|
161
|
+
rejected = []
|
|
162
|
+
|
|
163
|
+
for episode, reward in zip(episodes, rewards):
|
|
164
|
+
episode_with_metadata = episode.copy()
|
|
165
|
+
episode_with_metadata['reward'] = float(reward)
|
|
166
|
+
|
|
167
|
+
if self.config.min_reward_threshold <= reward <= self.config.max_reward_threshold:
|
|
168
|
+
accepted.append(episode_with_metadata)
|
|
169
|
+
else:
|
|
170
|
+
rejected.append(episode_with_metadata)
|
|
171
|
+
|
|
172
|
+
return accepted, rejected
|
|
173
|
+
|
|
174
|
+
def process_buffer(
|
|
175
|
+
self,
|
|
176
|
+
buffer
|
|
177
|
+
) -> Tuple[mx.array, List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
178
|
+
"""
|
|
179
|
+
Process all episodes in a buffer.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
buffer: Buffer instance containing episodes
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
(rewards, accepted_episodes, rejected_episodes)
|
|
186
|
+
"""
|
|
187
|
+
# Extract episodes from buffer
|
|
188
|
+
episodes = self._extract_episodes_from_buffer(buffer)
|
|
189
|
+
|
|
190
|
+
return self.process_episodes(episodes)
|
|
191
|
+
|
|
192
|
+
def _extract_episodes_from_buffer(self, buffer) -> List[Dict[str, Any]]:
|
|
193
|
+
"""Extract episodes from buffer in the expected format."""
|
|
194
|
+
episodes = []
|
|
195
|
+
|
|
196
|
+
for episode in buffer.storage.episodes:
|
|
197
|
+
# Convert episode to dict format
|
|
198
|
+
episode_dict = {
|
|
199
|
+
'prompt': episode.obs[0] if episode.obs else '',
|
|
200
|
+
'response': episode.act[-1] if episode.act else '',
|
|
201
|
+
'metadata': {
|
|
202
|
+
'length': len(episode.obs),
|
|
203
|
+
'logprobs': episode.logprob,
|
|
204
|
+
'values': episode.value,
|
|
205
|
+
'episode_id': id(episode)
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
episodes.append(episode_dict)
|
|
209
|
+
|
|
210
|
+
return episodes
|
|
211
|
+
|
|
212
|
+
def get_quality_metrics(
|
|
213
|
+
self,
|
|
214
|
+
episodes: List[Dict[str, Any]]
|
|
215
|
+
) -> Dict[str, Any]:
|
|
216
|
+
"""
|
|
217
|
+
Compute quality metrics for a batch of episodes.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
episodes: List of processed episodes with metadata
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Dictionary of quality metrics
|
|
224
|
+
"""
|
|
225
|
+
if not episodes:
|
|
226
|
+
return {}
|
|
227
|
+
|
|
228
|
+
# Extract metrics
|
|
229
|
+
rewards = [ep.get('reward', 0.0) for ep in episodes]
|
|
230
|
+
verification_scores = [ep.get('verification_score', 1.0) for ep in episodes]
|
|
231
|
+
|
|
232
|
+
# Compute statistics
|
|
233
|
+
metrics = {
|
|
234
|
+
'num_episodes': len(episodes),
|
|
235
|
+
'reward_stats': {
|
|
236
|
+
'mean': float(np.mean(rewards)),
|
|
237
|
+
'std': float(np.std(rewards)),
|
|
238
|
+
'min': float(np.min(rewards)),
|
|
239
|
+
'max': float(np.max(rewards)),
|
|
240
|
+
'median': float(np.median(rewards))
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
if self.verifier:
|
|
245
|
+
metrics['verification_stats'] = {
|
|
246
|
+
'mean': float(np.mean(verification_scores)),
|
|
247
|
+
'std': float(np.std(verification_scores)),
|
|
248
|
+
'min': float(np.min(verification_scores)),
|
|
249
|
+
'max': float(np.max(verification_scores)),
|
|
250
|
+
'median': float(np.median(verification_scores))
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
# Count verification results
|
|
254
|
+
verification_results = [ep.get('verification_result', 'pass') for ep in episodes]
|
|
255
|
+
metrics['verification_counts'] = {
|
|
256
|
+
'pass': verification_results.count('pass'),
|
|
257
|
+
'warning': verification_results.count('warning'),
|
|
258
|
+
'fail': verification_results.count('fail')
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
return metrics
|
|
262
|
+
|
|
263
|
+
def close(self):
|
|
264
|
+
"""Cleanup resources."""
|
|
265
|
+
if self.reward_processor:
|
|
266
|
+
self.reward_processor.close()
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
# Pure function interface for easy integration
|
|
270
|
+
def create_integrated_reward_system(config: IntegratedRewardConfig) -> IntegratedRolloutRewardSystem:
|
|
271
|
+
"""
|
|
272
|
+
Factory function for creating integrated reward systems.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
config: Configuration for the integrated system
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
IntegratedRolloutRewardSystem instance
|
|
279
|
+
"""
|
|
280
|
+
return IntegratedRolloutRewardSystem(config)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def process_episodes_with_quality_control(
|
|
284
|
+
episodes: List[Dict[str, Any]],
|
|
285
|
+
config: IntegratedRewardConfig
|
|
286
|
+
) -> Tuple[mx.array, List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, Any]]:
|
|
287
|
+
"""
|
|
288
|
+
Pure function for processing episodes with quality control.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
episodes: List of episode dictionaries
|
|
292
|
+
config: Configuration for the integrated system
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
(rewards, accepted_episodes, rejected_episodes, quality_metrics)
|
|
296
|
+
"""
|
|
297
|
+
system = IntegratedRolloutRewardSystem(config)
|
|
298
|
+
try:
|
|
299
|
+
rewards, accepted, rejected = system.process_episodes(episodes)
|
|
300
|
+
quality_metrics = system.get_quality_metrics(accepted + rejected)
|
|
301
|
+
return rewards, accepted, rejected, quality_metrics
|
|
302
|
+
finally:
|
|
303
|
+
system.close()
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
# MLX-optimized batch processing
|
|
307
|
+
@mx.compile
|
|
308
|
+
def compute_integrated_rewards(
|
|
309
|
+
base_rewards: mx.array,
|
|
310
|
+
verification_scores: mx.array,
|
|
311
|
+
reward_weight: float = 0.7,
|
|
312
|
+
verification_weight: float = 0.3
|
|
313
|
+
) -> mx.array:
|
|
314
|
+
"""
|
|
315
|
+
MLX-compiled function for combining rewards and verification scores.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
base_rewards: Base reward scores [batch_size]
|
|
319
|
+
verification_scores: Verification scores [batch_size]
|
|
320
|
+
reward_weight: Weight for base rewards
|
|
321
|
+
verification_weight: Weight for verification scores
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Combined scores [batch_size]
|
|
325
|
+
"""
|
|
326
|
+
# Normalize weights
|
|
327
|
+
total_weight = reward_weight + verification_weight
|
|
328
|
+
reward_weight = reward_weight / total_weight
|
|
329
|
+
verification_weight = verification_weight / total_weight
|
|
330
|
+
|
|
331
|
+
# Weighted combination
|
|
332
|
+
combined_scores = (
|
|
333
|
+
reward_weight * base_rewards +
|
|
334
|
+
verification_weight * verification_scores
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
return combined_scores
|
|
338
|
+
|