textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
# textpolicy/rewards/verifiers.py
|
|
2
|
+
"""
|
|
3
|
+
Text quality verifiers following retrain's decorator-based pattern.
|
|
4
|
+
|
|
5
|
+
Verifiers provide boolean pre-filtering for reward functions,
|
|
6
|
+
following retrain's philosophy of efficient quality control.
|
|
7
|
+
|
|
8
|
+
All verifiers follow the signature: (prompt: str, completion: str, example: Dict[str, Any]) -> bool
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import warnings
|
|
12
|
+
from typing import Dict, List, Optional, Any
|
|
13
|
+
import re
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from enum import Enum
|
|
16
|
+
from .registry import verifier
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
warnings.warn(
|
|
20
|
+
"VerificationResult and VerificationReport are deprecated; use boolean verifiers instead.",
|
|
21
|
+
DeprecationWarning,
|
|
22
|
+
stacklevel=2,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# Legacy types for backward compatibility (deprecated)
|
|
26
|
+
class VerificationResult(Enum):
|
|
27
|
+
"""Result of verification check (deprecated - use boolean verifiers now)."""
|
|
28
|
+
PASS = "pass"
|
|
29
|
+
FAIL = "fail"
|
|
30
|
+
WARNING = "warning"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class VerificationReport:
|
|
35
|
+
"""Report from verification check (deprecated - use boolean verifiers now)."""
|
|
36
|
+
result: VerificationResult
|
|
37
|
+
score: float # 0.0 to 1.0
|
|
38
|
+
details: Dict[str, Any]
|
|
39
|
+
message: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@verifier
|
|
43
|
+
def length_verifier(
|
|
44
|
+
prompt: str,
|
|
45
|
+
completion: str,
|
|
46
|
+
example: Dict[str, Any],
|
|
47
|
+
min_length: int = 10,
|
|
48
|
+
max_length: int = 500,
|
|
49
|
+
**kwargs
|
|
50
|
+
) -> bool:
|
|
51
|
+
"""
|
|
52
|
+
Verifies response length appropriateness following retrain's pattern.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
prompt: Input prompt
|
|
56
|
+
completion: Generated response text
|
|
57
|
+
example: Example data context (may contain length constraints)
|
|
58
|
+
min_length: Minimum required word count
|
|
59
|
+
max_length: Maximum allowed word count
|
|
60
|
+
**kwargs: Additional parameters
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
True if length is appropriate, False otherwise
|
|
64
|
+
"""
|
|
65
|
+
# Get constraints from example if not provided
|
|
66
|
+
if 'min_length' in example:
|
|
67
|
+
min_length = example['min_length']
|
|
68
|
+
if 'max_length' in example:
|
|
69
|
+
max_length = example['max_length']
|
|
70
|
+
|
|
71
|
+
word_count = len(completion.split())
|
|
72
|
+
return min_length <= word_count <= max_length
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@verifier
|
|
76
|
+
def toxicity_verifier(
|
|
77
|
+
prompt: str,
|
|
78
|
+
completion: str,
|
|
79
|
+
example: Dict[str, Any],
|
|
80
|
+
**kwargs
|
|
81
|
+
) -> bool:
|
|
82
|
+
"""
|
|
83
|
+
Verifies response is not toxic or inappropriate following retrain's pattern.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
prompt: Input prompt
|
|
87
|
+
completion: Generated response text
|
|
88
|
+
example: Example data context (may contain custom toxic patterns)
|
|
89
|
+
**kwargs: Additional parameters
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
True if content is non-toxic, False if toxic content detected
|
|
93
|
+
"""
|
|
94
|
+
# Default toxic patterns - simple keyword-based detection
|
|
95
|
+
# In practice, use a proper toxicity classifier
|
|
96
|
+
toxic_patterns = [
|
|
97
|
+
r'\b(hate|kill|die|stupid|idiot|racist|sexist)\b',
|
|
98
|
+
r'\b(fuck|shit|damn|hell)\b',
|
|
99
|
+
r'\b(violence|abuse|harassment)\b'
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
# Allow custom patterns from example
|
|
103
|
+
if 'toxic_patterns' in example:
|
|
104
|
+
toxic_patterns.extend(example['toxic_patterns'])
|
|
105
|
+
|
|
106
|
+
toxic_regex = re.compile('|'.join(toxic_patterns), re.IGNORECASE)
|
|
107
|
+
matches = toxic_regex.findall(completion.lower())
|
|
108
|
+
|
|
109
|
+
return len(matches) == 0
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@verifier
|
|
113
|
+
def coherence_verifier(
|
|
114
|
+
prompt: str,
|
|
115
|
+
completion: str,
|
|
116
|
+
example: Dict[str, Any],
|
|
117
|
+
min_coherence_score: float = 0.5,
|
|
118
|
+
**kwargs
|
|
119
|
+
) -> bool:
|
|
120
|
+
"""
|
|
121
|
+
Verifies response coherence and logical flow following retrain's pattern.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
prompt: Input prompt
|
|
125
|
+
completion: Generated response text
|
|
126
|
+
example: Example data context (may contain coherence requirements)
|
|
127
|
+
min_coherence_score: Minimum coherence score threshold
|
|
128
|
+
**kwargs: Additional parameters
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
True if coherent enough, False otherwise
|
|
132
|
+
"""
|
|
133
|
+
# Get threshold from example if specified
|
|
134
|
+
if 'min_coherence_score' in example:
|
|
135
|
+
min_coherence_score = example['min_coherence_score']
|
|
136
|
+
|
|
137
|
+
# Simple coherence heuristics
|
|
138
|
+
coherence_indicators = [
|
|
139
|
+
r'\b(therefore|thus|however|moreover|furthermore)\b',
|
|
140
|
+
r'\b(first|second|third|finally|in conclusion)\b',
|
|
141
|
+
r'\b(because|since|as a result|consequently)\b'
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
coherence_regex = re.compile('|'.join(coherence_indicators), re.IGNORECASE)
|
|
145
|
+
|
|
146
|
+
# Check for logical connectors
|
|
147
|
+
connectors = len(coherence_regex.findall(completion))
|
|
148
|
+
|
|
149
|
+
# Check sentence structure
|
|
150
|
+
sentences = re.split(r'[.!?]+', completion)
|
|
151
|
+
valid_sentences = [s for s in sentences if s.strip()]
|
|
152
|
+
if not valid_sentences:
|
|
153
|
+
return False
|
|
154
|
+
|
|
155
|
+
avg_sentence_length = sum(len(s.split()) for s in valid_sentences) / len(valid_sentences)
|
|
156
|
+
|
|
157
|
+
# Simple coherence score
|
|
158
|
+
connector_score = min(1.0, connectors / 2.0) # 2+ connectors is good
|
|
159
|
+
sentence_score = 1.0 if 5 <= avg_sentence_length <= 25 else 0.5
|
|
160
|
+
|
|
161
|
+
coherence_score = (connector_score + sentence_score) / 2.0
|
|
162
|
+
|
|
163
|
+
return coherence_score >= min_coherence_score
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@verifier
|
|
167
|
+
def factual_verifier(
|
|
168
|
+
prompt: str,
|
|
169
|
+
completion: str,
|
|
170
|
+
example: Dict[str, Any],
|
|
171
|
+
min_factual_score: float = 0.6,
|
|
172
|
+
max_uncertainty_count: int = 2,
|
|
173
|
+
**kwargs
|
|
174
|
+
) -> bool:
|
|
175
|
+
"""
|
|
176
|
+
Verifies factual accuracy and consistency following retrain's pattern.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
prompt: Input prompt
|
|
180
|
+
completion: Generated response text
|
|
181
|
+
example: Example data context (may contain factual requirements)
|
|
182
|
+
min_factual_score: Minimum factual confidence score
|
|
183
|
+
max_uncertainty_count: Maximum allowed uncertainty phrases
|
|
184
|
+
**kwargs: Additional parameters
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
True if factually sound, False otherwise
|
|
188
|
+
"""
|
|
189
|
+
# Get thresholds from example if specified
|
|
190
|
+
if 'min_factual_score' in example:
|
|
191
|
+
min_factual_score = example['min_factual_score']
|
|
192
|
+
if 'max_uncertainty_count' in example:
|
|
193
|
+
max_uncertainty_count = example['max_uncertainty_count']
|
|
194
|
+
|
|
195
|
+
# Uncertainty indicators that suggest factual issues
|
|
196
|
+
uncertainty_phrases = [
|
|
197
|
+
r'\b(i think|maybe|perhaps|possibly|not sure)\b',
|
|
198
|
+
r'\b(might be|could be|i believe|seems like)\b',
|
|
199
|
+
r'\b(probably|likely|unclear|unknown)\b'
|
|
200
|
+
]
|
|
201
|
+
|
|
202
|
+
uncertainty_regex = re.compile('|'.join(uncertainty_phrases), re.IGNORECASE)
|
|
203
|
+
|
|
204
|
+
# Count uncertainty phrases
|
|
205
|
+
uncertainty_count = len(uncertainty_regex.findall(completion.lower()))
|
|
206
|
+
|
|
207
|
+
# Check for contradictory statements (simple heuristic)
|
|
208
|
+
contradictions = 0
|
|
209
|
+
if 'yes' in completion.lower() and 'no' in completion.lower():
|
|
210
|
+
contradictions += 1
|
|
211
|
+
if 'true' in completion.lower() and 'false' in completion.lower():
|
|
212
|
+
contradictions += 1
|
|
213
|
+
|
|
214
|
+
# Calculate factual confidence
|
|
215
|
+
uncertainty_penalty = min(1.0, uncertainty_count * 0.2)
|
|
216
|
+
contradiction_penalty = contradictions * 0.3
|
|
217
|
+
|
|
218
|
+
factual_score = max(0.0, 1.0 - uncertainty_penalty - contradiction_penalty)
|
|
219
|
+
|
|
220
|
+
return factual_score >= min_factual_score and uncertainty_count <= max_uncertainty_count
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
# Additional verifiers following retrain's patterns
|
|
224
|
+
|
|
225
|
+
@verifier
|
|
226
|
+
def has_greeting(
|
|
227
|
+
prompt: str,
|
|
228
|
+
completion: str,
|
|
229
|
+
example: Dict[str, Any],
|
|
230
|
+
required_greeting: str = "hello",
|
|
231
|
+
**kwargs
|
|
232
|
+
) -> bool:
|
|
233
|
+
"""
|
|
234
|
+
Verifies that the completion contains a greeting.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
prompt: Input prompt
|
|
238
|
+
completion: Generated response text
|
|
239
|
+
example: Example data context (may specify required greeting)
|
|
240
|
+
required_greeting: The greeting phrase to look for
|
|
241
|
+
**kwargs: Additional parameters
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
True if greeting is present, False otherwise
|
|
245
|
+
"""
|
|
246
|
+
if 'required_greeting' in example:
|
|
247
|
+
required_greeting = example['required_greeting']
|
|
248
|
+
|
|
249
|
+
return required_greeting.lower() in completion.lower()
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@verifier
|
|
253
|
+
def no_empty_response(
|
|
254
|
+
prompt: str,
|
|
255
|
+
completion: str,
|
|
256
|
+
example: Dict[str, Any],
|
|
257
|
+
**kwargs
|
|
258
|
+
) -> bool:
|
|
259
|
+
"""
|
|
260
|
+
Verifies that the completion is not empty or whitespace-only.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
prompt: Input prompt
|
|
264
|
+
completion: Generated response text
|
|
265
|
+
example: Example data context
|
|
266
|
+
**kwargs: Additional parameters
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
True if completion has content, False if empty
|
|
270
|
+
"""
|
|
271
|
+
return bool(completion.strip())
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@verifier
|
|
275
|
+
def contains_keywords(
|
|
276
|
+
prompt: str,
|
|
277
|
+
completion: str,
|
|
278
|
+
example: Dict[str, Any],
|
|
279
|
+
required_keywords: Optional[List[str]] = None,
|
|
280
|
+
**kwargs
|
|
281
|
+
) -> bool:
|
|
282
|
+
"""
|
|
283
|
+
Verifies that the completion contains required keywords.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
prompt: Input prompt
|
|
287
|
+
completion: Generated response text
|
|
288
|
+
example: Example data context (may contain required_keywords)
|
|
289
|
+
required_keywords: List of keywords that must be present
|
|
290
|
+
**kwargs: Additional parameters
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
True if all required keywords are present, False otherwise
|
|
294
|
+
"""
|
|
295
|
+
if required_keywords is None:
|
|
296
|
+
required_keywords = example.get('required_keywords', [])
|
|
297
|
+
|
|
298
|
+
if not required_keywords:
|
|
299
|
+
return True # No requirements, always pass
|
|
300
|
+
|
|
301
|
+
completion_lower = completion.lower()
|
|
302
|
+
return all(keyword.lower() in completion_lower for keyword in required_keywords)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# Legacy compatibility functions (deprecated)
|
|
306
|
+
def create_default_verifier_pipeline():
|
|
307
|
+
"""Create default verification pipeline (deprecated - use registry-based verifiers)."""
|
|
308
|
+
import logging
|
|
309
|
+
logger = logging.getLogger(__name__)
|
|
310
|
+
logger.warning("create_default_verifier_pipeline is deprecated. Use registry-based verifiers with apply_verifiers_to_reward instead.")
|
|
311
|
+
return None
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def create_custom_verifier_pipeline(verifier_configs: List[Dict[str, Any]]):
|
|
315
|
+
"""Create custom verification pipeline (deprecated - use registry-based verifiers)."""
|
|
316
|
+
import logging
|
|
317
|
+
logger = logging.getLogger(__name__)
|
|
318
|
+
logger.warning("create_custom_verifier_pipeline is deprecated. Use registry-based verifiers with apply_verifiers_to_reward instead.")
|
|
319
|
+
return None
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
# Legacy class aliases for backward compatibility
|
|
323
|
+
class TextVerifier:
|
|
324
|
+
"""Base class for text quality verifiers (deprecated)."""
|
|
325
|
+
def verify(self, prompt: str, response: str):
|
|
326
|
+
# Legacy method: deprecated in favour of registry-based boolean verifiers
|
|
327
|
+
warnings.warn(
|
|
328
|
+
"TextVerifier.verify is deprecated; use registry-based @verifier functions instead.",
|
|
329
|
+
DeprecationWarning,
|
|
330
|
+
stacklevel=2,
|
|
331
|
+
)
|
|
332
|
+
return False
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class LengthVerifier(TextVerifier):
|
|
336
|
+
"""Legacy length verifier (deprecated - use length_verifier function)."""
|
|
337
|
+
pass
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class ToxicityVerifier(TextVerifier):
|
|
341
|
+
"""Legacy toxicity verifier (deprecated - use toxicity_verifier function)."""
|
|
342
|
+
pass
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class CoherenceVerifier(TextVerifier):
|
|
346
|
+
"""Legacy coherence verifier (deprecated - use coherence_verifier function)."""
|
|
347
|
+
pass
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class FactualVerifier(TextVerifier):
|
|
351
|
+
"""Legacy factual verifier (deprecated - use factual_verifier function)."""
|
|
352
|
+
pass
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class VerificationPipeline:
|
|
356
|
+
"""Legacy verification pipeline (deprecated - use registry-based approach)."""
|
|
357
|
+
def __init__(self, verifiers):
|
|
358
|
+
import logging
|
|
359
|
+
logger = logging.getLogger(__name__)
|
|
360
|
+
logger.warning("VerificationPipeline is deprecated. Use registry-based verifiers with apply_verifiers_to_reward instead.")
|
|
361
|
+
|
|
362
|
+
def verify_batch(self, prompts, responses):
|
|
363
|
+
# Legacy method: deprecated in favour of registry-based verification pipeline
|
|
364
|
+
warnings.warn(
|
|
365
|
+
"VerificationPipeline.verify_batch is deprecated; use registry-based verifiers instead.",
|
|
366
|
+
DeprecationWarning,
|
|
367
|
+
stacklevel=2,
|
|
368
|
+
)
|
|
369
|
+
return []
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# textpolicy/rollout/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Modular rollout system for TextPolicy.
|
|
4
|
+
|
|
5
|
+
Main components:
|
|
6
|
+
- RolloutCoordinator: Interface for rollout collection
|
|
7
|
+
- RolloutStrategy: Algorithm-specific rollout behavior
|
|
8
|
+
- RolloutWorker: Multi-process worker management
|
|
9
|
+
- BufferAggregator: Multi-worker data coordination
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .rollout import RolloutCoordinator, create_rollout_coordinator
|
|
13
|
+
from .base import RolloutStrategy
|
|
14
|
+
from .worker import RolloutWorker
|
|
15
|
+
from .runner import RolloutRunner
|
|
16
|
+
from .aggregator import BufferAggregator
|
|
17
|
+
from .strategy import PPOStrategy, GRPOStrategy, create_strategy
|
|
18
|
+
|
|
19
|
+
# Backwards compatibility exports
|
|
20
|
+
from .runner import RolloutRunner as RolloutRunner_Legacy
|
|
21
|
+
from .aggregator import BufferAggregator as BufferAggregator_Legacy
|
|
22
|
+
from .worker import RolloutWorker as RolloutWorker_Legacy
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
# Main public interface
|
|
26
|
+
'RolloutCoordinator',
|
|
27
|
+
'create_rollout_coordinator',
|
|
28
|
+
|
|
29
|
+
# Core components
|
|
30
|
+
'RolloutStrategy',
|
|
31
|
+
'RolloutWorker',
|
|
32
|
+
'RolloutRunner',
|
|
33
|
+
'BufferAggregator',
|
|
34
|
+
|
|
35
|
+
# Strategies
|
|
36
|
+
'PPOStrategy',
|
|
37
|
+
'GRPOStrategy',
|
|
38
|
+
'create_strategy',
|
|
39
|
+
|
|
40
|
+
# Legacy compatibility
|
|
41
|
+
'RolloutRunner_Legacy',
|
|
42
|
+
'BufferAggregator_Legacy',
|
|
43
|
+
'RolloutWorker_Legacy',
|
|
44
|
+
]
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# textpolicy/rollout/aggregator.py
|
|
2
|
+
"""
|
|
3
|
+
Multi-worker buffer aggregation and coordination.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import List, Optional
|
|
7
|
+
import multiprocessing as mp
|
|
8
|
+
from textpolicy.buffer import Buffer
|
|
9
|
+
from .worker import RolloutWorker
|
|
10
|
+
from .base import DEFAULT_MAX_EPISODES
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BufferAggregator:
|
|
14
|
+
"""
|
|
15
|
+
Aggregates episodes from multiple RolloutWorkers.
|
|
16
|
+
|
|
17
|
+
Coordinates data collection from multiple processes:
|
|
18
|
+
- Maintains bounded buffer of complete episodes
|
|
19
|
+
- Provides sampling interface for trainer
|
|
20
|
+
- Handles queue management and data consumption
|
|
21
|
+
- Thread-safe operation for async collection
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, num_workers: int, max_episodes: int = DEFAULT_MAX_EPISODES):
|
|
25
|
+
"""
|
|
26
|
+
Initialize aggregator for multi-worker coordination.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
num_workers: Number of RolloutWorker processes to coordinate
|
|
30
|
+
max_episodes: Maximum episodes to keep (oldest are dropped)
|
|
31
|
+
"""
|
|
32
|
+
self.num_workers = num_workers
|
|
33
|
+
self.buffer = Buffer(max_episodes=max_episodes)
|
|
34
|
+
self._worker_queues: List[Optional[mp.Queue]] = [None] * num_workers
|
|
35
|
+
|
|
36
|
+
def add_worker(self, worker: RolloutWorker, worker_id: int):
|
|
37
|
+
"""
|
|
38
|
+
Register a worker's send queue for data consumption.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
worker: RolloutWorker instance
|
|
42
|
+
worker_id: Unique worker ID (0 <= id < num_workers)
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ValueError: If worker_id is out of bounds
|
|
46
|
+
"""
|
|
47
|
+
if not (0 <= worker_id < self.num_workers):
|
|
48
|
+
raise ValueError(f"worker_id must be in [0, {self.num_workers - 1}]")
|
|
49
|
+
self._worker_queues[worker_id] = worker.send_queue
|
|
50
|
+
|
|
51
|
+
def consume_from_worker(self, worker_id: int) -> bool:
|
|
52
|
+
"""
|
|
53
|
+
Try to consume new episodes from specific worker.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
worker_id: ID of the worker to consume from
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
True if data was consumed, False if no data available
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValueError: If worker_id has no registered queue
|
|
63
|
+
"""
|
|
64
|
+
queue = self._worker_queues[worker_id]
|
|
65
|
+
if queue is None:
|
|
66
|
+
raise ValueError(f"No queue registered for worker {worker_id}")
|
|
67
|
+
|
|
68
|
+
if queue.empty():
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
# Get serialized episodes from worker
|
|
72
|
+
episodes_data = queue.get()
|
|
73
|
+
|
|
74
|
+
# Add episodes to aggregated buffer
|
|
75
|
+
for ep_dict in episodes_data:
|
|
76
|
+
self.buffer.add_episode_from_dict(ep_dict)
|
|
77
|
+
|
|
78
|
+
return True
|
|
79
|
+
|
|
80
|
+
def consume_all(self) -> int:
|
|
81
|
+
"""
|
|
82
|
+
Consume available data from all workers.
|
|
83
|
+
|
|
84
|
+
Non-blocking operation that checks all worker queues.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Number of workers that had data ready
|
|
88
|
+
"""
|
|
89
|
+
count = 0
|
|
90
|
+
for wid in range(self.num_workers):
|
|
91
|
+
queue = self._worker_queues[wid]
|
|
92
|
+
if queue is not None and not queue.empty():
|
|
93
|
+
if self.consume_from_worker(wid):
|
|
94
|
+
count += 1
|
|
95
|
+
return count
|
|
96
|
+
|
|
97
|
+
def ready(self, min_episodes: int = 1) -> bool:
|
|
98
|
+
"""
|
|
99
|
+
Check if enough episodes have been collected for training.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
min_episodes: Minimum number of episodes required
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
True if buffer has at least min_episodes
|
|
106
|
+
"""
|
|
107
|
+
return self.buffer.ready(min_episodes)
|
|
108
|
+
|
|
109
|
+
def sample_latest_steps(self, n: int) -> dict:
|
|
110
|
+
"""
|
|
111
|
+
Sample the N most recent steps for training.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
n: Number of steps to sample
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Dict of MLX arrays (obs, act, rew, done, etc.)
|
|
118
|
+
"""
|
|
119
|
+
return self.buffer.sample_latest_steps(n)
|
|
120
|
+
|
|
121
|
+
def sample_episodes(self, k: int, order: str = 'desc') -> dict:
|
|
122
|
+
"""
|
|
123
|
+
Sample up to k episodes for training.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
k: Number of episodes to sample
|
|
127
|
+
order: 'asc' (oldest first) or 'desc' (newest first)
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Dict of MLX arrays containing episode data
|
|
131
|
+
"""
|
|
132
|
+
return self.buffer.sample_episodes(k, order)
|
|
133
|
+
|
|
134
|
+
def clear(self):
|
|
135
|
+
"""Clear all collected episodes from the buffer."""
|
|
136
|
+
self.buffer.clear()
|
|
137
|
+
|
|
138
|
+
def __len__(self) -> int:
|
|
139
|
+
"""Total number of steps across all episodes."""
|
|
140
|
+
return len(self.buffer)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def episode_count(self) -> int:
|
|
144
|
+
"""Number of complete episodes currently stored."""
|
|
145
|
+
return len(self.buffer.episodes)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# textpolicy/rollout/base.py
|
|
2
|
+
"""
|
|
3
|
+
Base classes and protocols for the rollout system.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Callable, Dict, Any, Tuple, Protocol
|
|
7
|
+
import mlx.core as mx # type: ignore
|
|
8
|
+
from textpolicy.buffer import Buffer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RolloutStrategy(Protocol):
|
|
12
|
+
"""
|
|
13
|
+
Protocol for rollout strategies.
|
|
14
|
+
|
|
15
|
+
Defines the interface used by RolloutRunner for algorithm-specific behavior.
|
|
16
|
+
Each strategy encapsulates how to:
|
|
17
|
+
- Select actions from policy outputs
|
|
18
|
+
- Store transition data in buffers
|
|
19
|
+
- Handle algorithm-specific requirements
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def select_action(self, policy: Callable, obs: mx.array) -> Tuple[mx.array, Dict[str, Any]]:
|
|
23
|
+
"""
|
|
24
|
+
Select an action using the policy.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
policy: Function that takes obs and returns (action, extras)
|
|
28
|
+
obs: MLX array observation
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
action: mx.array (scalar or tensor)
|
|
32
|
+
extras: Dict of additional data (e.g. logprob, value)
|
|
33
|
+
"""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
def store_transition(self, buffer: Buffer, **data) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Store a transition in the buffer.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
buffer: Buffer instance
|
|
42
|
+
**data: Transition data (obs, act, rew, next_obs, done, logprob, value, etc.)
|
|
43
|
+
"""
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# Common constants and configurations
|
|
48
|
+
DEFAULT_MAX_STEPS = 1000
|
|
49
|
+
DEFAULT_MAX_EPISODES = 100
|
|
50
|
+
DEFAULT_WORKER_TIMEOUT = 1.0
|
|
51
|
+
|
|
52
|
+
# Supported transition data keys for validation
|
|
53
|
+
REQUIRED_TRANSITION_KEYS = {'obs', 'act', 'rew', 'next_obs', 'done'}
|
|
54
|
+
OPTIONAL_TRANSITION_KEYS = {'timeout', 'logprob', 'value', 'entropy'}
|
|
55
|
+
ALL_TRANSITION_KEYS = REQUIRED_TRANSITION_KEYS | OPTIONAL_TRANSITION_KEYS
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def validate_transition_data(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
59
|
+
"""
|
|
60
|
+
Validate and filter transition data.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
data: Dictionary containing transition data
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Filtered dictionary with only valid keys
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If required keys are missing
|
|
70
|
+
"""
|
|
71
|
+
# Check required keys
|
|
72
|
+
missing_keys = REQUIRED_TRANSITION_KEYS - set(data.keys())
|
|
73
|
+
if missing_keys:
|
|
74
|
+
raise ValueError(f"Missing required transition keys: {missing_keys}")
|
|
75
|
+
|
|
76
|
+
# Filter to only valid keys
|
|
77
|
+
valid_data = {k: v for k, v in data.items() if k in ALL_TRANSITION_KEYS}
|
|
78
|
+
|
|
79
|
+
return valid_data
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def serialize_mx_array(arr: mx.array) -> Any:
|
|
83
|
+
"""
|
|
84
|
+
Serialize MLX array for multiprocessing communication.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
arr: MLX array to serialize
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Python scalar or list suitable for queue transmission
|
|
91
|
+
"""
|
|
92
|
+
if arr.ndim == 0:
|
|
93
|
+
return arr.item() # Scalar
|
|
94
|
+
else:
|
|
95
|
+
return arr.tolist() # List
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def deserialize_to_mx_array(data: Any) -> mx.array:
|
|
99
|
+
"""
|
|
100
|
+
Deserialize data back to MLX array.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
data: Python scalar or list
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
MLX array
|
|
107
|
+
"""
|
|
108
|
+
return mx.array(data)
|