textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,338 @@
1
+ # textpolicy/rewards/integrated_system.py
2
+ """
3
+ Integrated rollout reward system combining rewards and verifiers.
4
+
5
+ This system provides a unified interface for:
6
+ 1. Computing rewards at the rollout level
7
+ 2. Verifying text quality
8
+ 3. Filtering episodes based on quality thresholds
9
+ 4. MLX-optimized batch processing
10
+ """
11
+
12
+ from typing import Dict, List, Any, Tuple
13
+ import mlx.core as mx
14
+ from dataclasses import dataclass
15
+ import numpy as np
16
+
17
+ from .rollout_rewards import RolloutRewardProcessor, RewardConfig
18
+ from .verifiers import create_default_verifier_pipeline
19
+
20
+
21
+ @dataclass
22
+ class IntegratedRewardConfig:
23
+ """Configuration for integrated reward and verification system."""
24
+ # Reward configuration
25
+ reward_config: RewardConfig
26
+
27
+ # Verification configuration
28
+ enable_verification: bool = True
29
+ verification_threshold: float = 0.7 # Minimum verification score
30
+ strict_filtering: bool = False # If True, reject episodes below threshold
31
+
32
+ # Quality control
33
+ min_reward_threshold: float = 0.3
34
+ max_reward_threshold: float = 1.0
35
+
36
+ # Batch processing
37
+ batch_size: int = 32
38
+ enable_mlx_compilation: bool = True
39
+
40
+
41
+ class IntegratedRolloutRewardSystem:
42
+ """
43
+ Integrated system for rollout-level reward computation and verification.
44
+
45
+ Combines reward computation with quality verification for comprehensive
46
+ episode evaluation and filtering.
47
+ """
48
+
49
+ def __init__(self, config: IntegratedRewardConfig):
50
+ """
51
+ Initialize integrated reward system.
52
+
53
+ Args:
54
+ config: Configuration for the integrated system
55
+ """
56
+ self.config = config
57
+
58
+ # Initialize reward processor
59
+ self.reward_processor = RolloutRewardProcessor(config.reward_config)
60
+
61
+ # Initialize verification pipeline
62
+ if config.enable_verification:
63
+ self.verifier = create_default_verifier_pipeline()
64
+ else:
65
+ self.verifier = None
66
+
67
+ def process_episodes(
68
+ self,
69
+ episodes: List[Dict[str, Any]]
70
+ ) -> Tuple[mx.array, List[Dict[str, Any]], List[Dict[str, Any]]]:
71
+ """
72
+ Process episodes with rewards and verification.
73
+
74
+ Args:
75
+ episodes: List of episode dictionaries
76
+
77
+ Returns:
78
+ (rewards, accepted_episodes, rejected_episodes)
79
+ """
80
+ if not episodes:
81
+ return mx.array([]), [], []
82
+
83
+ # Compute rewards
84
+ rewards = self.reward_processor.process_episode_rewards(episodes)
85
+
86
+ # Verify quality if enabled
87
+ if self.verifier:
88
+ verification_results = self._verify_episodes(episodes)
89
+ accepted, rejected = self._filter_episodes(
90
+ episodes, rewards, verification_results
91
+ )
92
+ else:
93
+ # No verification - accept all episodes above reward threshold
94
+ accepted, rejected = self._filter_by_rewards_only(episodes, rewards)
95
+
96
+ return rewards, accepted, rejected
97
+
98
+ def _verify_episodes(
99
+ self,
100
+ episodes: List[Dict[str, Any]]
101
+ ) -> List[Any]:
102
+ """Verify quality of episodes."""
103
+ prompts = [ep.get('prompt', '') for ep in episodes]
104
+ responses = [ep.get('response', '') for ep in episodes]
105
+
106
+ return self.verifier.verify_batch(prompts, responses)
107
+
108
+ def _filter_episodes(
109
+ self,
110
+ episodes: List[Dict[str, Any]],
111
+ rewards: mx.array,
112
+ verification_results: List[Any]
113
+ ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
114
+ """Filter episodes based on rewards and verification."""
115
+ accepted = []
116
+ rejected = []
117
+
118
+ for i, (episode, reward, verification) in enumerate(
119
+ zip(episodes, rewards, verification_results)
120
+ ):
121
+ # Check reward threshold
122
+ reward_ok = (
123
+ self.config.min_reward_threshold <= reward <= self.config.max_reward_threshold
124
+ )
125
+
126
+ # Check verification threshold
127
+ verification_ok = verification.score >= self.config.verification_threshold
128
+
129
+ # Determine acceptance
130
+ if self.config.strict_filtering:
131
+ # Both must pass
132
+ is_accepted = reward_ok and verification_ok
133
+ else:
134
+ # Either can pass (more lenient)
135
+ is_accepted = reward_ok or verification_ok
136
+
137
+ # Add metadata
138
+ episode_with_metadata = episode.copy()
139
+ episode_with_metadata.update({
140
+ 'reward': float(reward),
141
+ 'verification_score': verification.score,
142
+ 'verification_result': verification.result.value,
143
+ 'verification_message': verification.message,
144
+ 'verification_details': verification.details
145
+ })
146
+
147
+ if is_accepted:
148
+ accepted.append(episode_with_metadata)
149
+ else:
150
+ rejected.append(episode_with_metadata)
151
+
152
+ return accepted, rejected
153
+
154
+ def _filter_by_rewards_only(
155
+ self,
156
+ episodes: List[Dict[str, Any]],
157
+ rewards: mx.array
158
+ ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
159
+ """Filter episodes based only on rewards."""
160
+ accepted = []
161
+ rejected = []
162
+
163
+ for episode, reward in zip(episodes, rewards):
164
+ episode_with_metadata = episode.copy()
165
+ episode_with_metadata['reward'] = float(reward)
166
+
167
+ if self.config.min_reward_threshold <= reward <= self.config.max_reward_threshold:
168
+ accepted.append(episode_with_metadata)
169
+ else:
170
+ rejected.append(episode_with_metadata)
171
+
172
+ return accepted, rejected
173
+
174
+ def process_buffer(
175
+ self,
176
+ buffer
177
+ ) -> Tuple[mx.array, List[Dict[str, Any]], List[Dict[str, Any]]]:
178
+ """
179
+ Process all episodes in a buffer.
180
+
181
+ Args:
182
+ buffer: Buffer instance containing episodes
183
+
184
+ Returns:
185
+ (rewards, accepted_episodes, rejected_episodes)
186
+ """
187
+ # Extract episodes from buffer
188
+ episodes = self._extract_episodes_from_buffer(buffer)
189
+
190
+ return self.process_episodes(episodes)
191
+
192
+ def _extract_episodes_from_buffer(self, buffer) -> List[Dict[str, Any]]:
193
+ """Extract episodes from buffer in the expected format."""
194
+ episodes = []
195
+
196
+ for episode in buffer.storage.episodes:
197
+ # Convert episode to dict format
198
+ episode_dict = {
199
+ 'prompt': episode.obs[0] if episode.obs else '',
200
+ 'response': episode.act[-1] if episode.act else '',
201
+ 'metadata': {
202
+ 'length': len(episode.obs),
203
+ 'logprobs': episode.logprob,
204
+ 'values': episode.value,
205
+ 'episode_id': id(episode)
206
+ }
207
+ }
208
+ episodes.append(episode_dict)
209
+
210
+ return episodes
211
+
212
+ def get_quality_metrics(
213
+ self,
214
+ episodes: List[Dict[str, Any]]
215
+ ) -> Dict[str, Any]:
216
+ """
217
+ Compute quality metrics for a batch of episodes.
218
+
219
+ Args:
220
+ episodes: List of processed episodes with metadata
221
+
222
+ Returns:
223
+ Dictionary of quality metrics
224
+ """
225
+ if not episodes:
226
+ return {}
227
+
228
+ # Extract metrics
229
+ rewards = [ep.get('reward', 0.0) for ep in episodes]
230
+ verification_scores = [ep.get('verification_score', 1.0) for ep in episodes]
231
+
232
+ # Compute statistics
233
+ metrics = {
234
+ 'num_episodes': len(episodes),
235
+ 'reward_stats': {
236
+ 'mean': float(np.mean(rewards)),
237
+ 'std': float(np.std(rewards)),
238
+ 'min': float(np.min(rewards)),
239
+ 'max': float(np.max(rewards)),
240
+ 'median': float(np.median(rewards))
241
+ }
242
+ }
243
+
244
+ if self.verifier:
245
+ metrics['verification_stats'] = {
246
+ 'mean': float(np.mean(verification_scores)),
247
+ 'std': float(np.std(verification_scores)),
248
+ 'min': float(np.min(verification_scores)),
249
+ 'max': float(np.max(verification_scores)),
250
+ 'median': float(np.median(verification_scores))
251
+ }
252
+
253
+ # Count verification results
254
+ verification_results = [ep.get('verification_result', 'pass') for ep in episodes]
255
+ metrics['verification_counts'] = {
256
+ 'pass': verification_results.count('pass'),
257
+ 'warning': verification_results.count('warning'),
258
+ 'fail': verification_results.count('fail')
259
+ }
260
+
261
+ return metrics
262
+
263
+ def close(self):
264
+ """Cleanup resources."""
265
+ if self.reward_processor:
266
+ self.reward_processor.close()
267
+
268
+
269
+ # Pure function interface for easy integration
270
+ def create_integrated_reward_system(config: IntegratedRewardConfig) -> IntegratedRolloutRewardSystem:
271
+ """
272
+ Factory function for creating integrated reward systems.
273
+
274
+ Args:
275
+ config: Configuration for the integrated system
276
+
277
+ Returns:
278
+ IntegratedRolloutRewardSystem instance
279
+ """
280
+ return IntegratedRolloutRewardSystem(config)
281
+
282
+
283
+ def process_episodes_with_quality_control(
284
+ episodes: List[Dict[str, Any]],
285
+ config: IntegratedRewardConfig
286
+ ) -> Tuple[mx.array, List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, Any]]:
287
+ """
288
+ Pure function for processing episodes with quality control.
289
+
290
+ Args:
291
+ episodes: List of episode dictionaries
292
+ config: Configuration for the integrated system
293
+
294
+ Returns:
295
+ (rewards, accepted_episodes, rejected_episodes, quality_metrics)
296
+ """
297
+ system = IntegratedRolloutRewardSystem(config)
298
+ try:
299
+ rewards, accepted, rejected = system.process_episodes(episodes)
300
+ quality_metrics = system.get_quality_metrics(accepted + rejected)
301
+ return rewards, accepted, rejected, quality_metrics
302
+ finally:
303
+ system.close()
304
+
305
+
306
+ # MLX-optimized batch processing
307
+ @mx.compile
308
+ def compute_integrated_rewards(
309
+ base_rewards: mx.array,
310
+ verification_scores: mx.array,
311
+ reward_weight: float = 0.7,
312
+ verification_weight: float = 0.3
313
+ ) -> mx.array:
314
+ """
315
+ MLX-compiled function for combining rewards and verification scores.
316
+
317
+ Args:
318
+ base_rewards: Base reward scores [batch_size]
319
+ verification_scores: Verification scores [batch_size]
320
+ reward_weight: Weight for base rewards
321
+ verification_weight: Weight for verification scores
322
+
323
+ Returns:
324
+ Combined scores [batch_size]
325
+ """
326
+ # Normalize weights
327
+ total_weight = reward_weight + verification_weight
328
+ reward_weight = reward_weight / total_weight
329
+ verification_weight = verification_weight / total_weight
330
+
331
+ # Weighted combination
332
+ combined_scores = (
333
+ reward_weight * base_rewards +
334
+ verification_weight * verification_scores
335
+ )
336
+
337
+ return combined_scores
338
+