@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,647 @@
1
+ """
2
+ Error Recovery and Graceful Degradation for Training Pipeline
3
+
4
+ Provides utilities for handling failures gracefully:
5
+ - Database connection recovery
6
+ - Malformed data handling
7
+ - Service health monitoring
8
+ - Graceful shutdown
9
+ - Retry logic with backoff
10
+
11
+ Philosophy:
12
+ - Fail fast for programmer errors
13
+ - Recover gracefully from transient errors
14
+ - Always log what went wrong
15
+ - Never lose training progress silently
16
+ """
17
+
18
+ import asyncio
19
+ import functools
20
+ import json
21
+ import logging
22
+ import os
23
+ import signal
24
+ import sys
25
+ import time
26
+ from dataclasses import dataclass, field
27
+ from enum import Enum
28
+ from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ T = TypeVar("T")
33
+
34
+
35
+ # ============================================================================
36
+ # Error Categories
37
+ # ============================================================================
38
+
39
+ class ErrorCategory(Enum):
40
+ """Categories of errors for handling decisions"""
41
+
42
+ # Transient - retry makes sense
43
+ TRANSIENT = "transient"
44
+
45
+ # Configuration - fix config and restart
46
+ CONFIGURATION = "configuration"
47
+
48
+ # Data - skip this item and continue
49
+ DATA_VALIDATION = "data_validation"
50
+
51
+ # Infrastructure - external service down
52
+ INFRASTRUCTURE = "infrastructure"
53
+
54
+ # Fatal - cannot continue
55
+ FATAL = "fatal"
56
+
57
+
58
+ class TrainingError(Exception):
59
+ """
60
+ Structured training error for handling decisions.
61
+
62
+ Inherits from Exception so it can be raised and caught properly.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ category: ErrorCategory,
68
+ message: str,
69
+ component: str,
70
+ recoverable: bool,
71
+ details: Optional[Dict[str, Any]] = None,
72
+ original_exception: Optional[Exception] = None,
73
+ ):
74
+ super().__init__(message)
75
+ self.category = category
76
+ self.message = message
77
+ self.component = component
78
+ self.recoverable = recoverable
79
+ self.details = details or {}
80
+ self.original_exception = original_exception
81
+
82
+ def __str__(self) -> str:
83
+ return f"[{self.category.value}] {self.component}: {self.message}"
84
+
85
+
86
+ # ============================================================================
87
+ # Error Classification
88
+ # ============================================================================
89
+
90
+ def classify_error(exception: Exception) -> ErrorCategory:
91
+ """
92
+ Classify an exception into an error category for handling decisions.
93
+
94
+ This helps decide whether to retry, skip, or abort.
95
+ """
96
+ error_str = str(exception).lower()
97
+ exception_type = type(exception).__name__
98
+
99
+ # Connection errors - transient
100
+ if any(x in exception_type for x in ["Connection", "Timeout", "Network"]):
101
+ return ErrorCategory.TRANSIENT
102
+
103
+ if any(x in error_str for x in [
104
+ "connection refused",
105
+ "connection reset",
106
+ "timeout",
107
+ "temporary failure",
108
+ "service unavailable",
109
+ ]):
110
+ return ErrorCategory.TRANSIENT
111
+
112
+ # Configuration errors
113
+ if any(x in error_str for x in [
114
+ "not set",
115
+ "not configured",
116
+ "invalid config",
117
+ "missing required",
118
+ ]):
119
+ return ErrorCategory.CONFIGURATION
120
+
121
+ # Data validation errors
122
+ if any(x in error_str for x in [
123
+ "json",
124
+ "parse",
125
+ "decode",
126
+ "invalid data",
127
+ "schema",
128
+ "validation",
129
+ ]):
130
+ return ErrorCategory.DATA_VALIDATION
131
+
132
+ # Infrastructure errors
133
+ if any(x in error_str for x in [
134
+ "database",
135
+ "redis",
136
+ "cuda",
137
+ "gpu",
138
+ "out of memory",
139
+ ]):
140
+ return ErrorCategory.INFRASTRUCTURE
141
+
142
+ # Default to fatal for unknown errors
143
+ return ErrorCategory.FATAL
144
+
145
+
146
+ def is_recoverable(exception: Exception) -> bool:
147
+ """Check if an error is recoverable (worth retrying)"""
148
+ category = classify_error(exception)
149
+ return category in (ErrorCategory.TRANSIENT, ErrorCategory.DATA_VALIDATION)
150
+
151
+
152
+ # ============================================================================
153
+ # Retry Logic
154
+ # ============================================================================
155
+
156
+ def with_retry(
157
+ max_attempts: int = 3,
158
+ initial_delay: float = 1.0,
159
+ max_delay: float = 30.0,
160
+ backoff_factor: float = 2.0,
161
+ retryable_exceptions: tuple = (Exception,),
162
+ ) -> Callable:
163
+ """
164
+ Decorator for retry with exponential backoff.
165
+
166
+ Args:
167
+ max_attempts: Maximum number of attempts
168
+ initial_delay: Initial delay between retries (seconds)
169
+ max_delay: Maximum delay between retries (seconds)
170
+ backoff_factor: Multiplier for delay after each attempt
171
+ retryable_exceptions: Tuple of exceptions that trigger retry
172
+ """
173
+ def decorator(func: Callable[..., T]) -> Callable[..., T]:
174
+ @functools.wraps(func)
175
+ def wrapper(*args, **kwargs) -> T:
176
+ last_exception = None
177
+ delay = initial_delay
178
+
179
+ for attempt in range(1, max_attempts + 1):
180
+ try:
181
+ return func(*args, **kwargs)
182
+ except retryable_exceptions as e:
183
+ last_exception = e
184
+
185
+ if not is_recoverable(e):
186
+ logger.error(f"{func.__name__} failed with non-recoverable error: {e}")
187
+ raise
188
+
189
+ if attempt < max_attempts:
190
+ logger.warning(
191
+ f"{func.__name__} failed (attempt {attempt}/{max_attempts}), "
192
+ f"retrying in {delay:.1f}s: {e}"
193
+ )
194
+ time.sleep(delay)
195
+ delay = min(delay * backoff_factor, max_delay)
196
+ else:
197
+ logger.error(
198
+ f"{func.__name__} failed after {max_attempts} attempts: {e}"
199
+ )
200
+
201
+ if last_exception:
202
+ raise last_exception
203
+ raise RuntimeError(f"{func.__name__} failed with no exception captured")
204
+
205
+ return wrapper
206
+ return decorator
207
+
208
+
209
+ def with_retry_async(
210
+ max_attempts: int = 3,
211
+ initial_delay: float = 1.0,
212
+ max_delay: float = 30.0,
213
+ backoff_factor: float = 2.0,
214
+ retryable_exceptions: tuple = (Exception,),
215
+ ) -> Callable:
216
+ """Async version of retry decorator"""
217
+ def decorator(func: Callable[..., T]) -> Callable[..., T]:
218
+ @functools.wraps(func)
219
+ async def wrapper(*args, **kwargs) -> T:
220
+ last_exception = None
221
+ delay = initial_delay
222
+
223
+ for attempt in range(1, max_attempts + 1):
224
+ try:
225
+ return await func(*args, **kwargs)
226
+ except retryable_exceptions as e:
227
+ last_exception = e
228
+
229
+ if not is_recoverable(e):
230
+ logger.error(f"{func.__name__} failed with non-recoverable error: {e}")
231
+ raise
232
+
233
+ if attempt < max_attempts:
234
+ logger.warning(
235
+ f"{func.__name__} failed (attempt {attempt}/{max_attempts}), "
236
+ f"retrying in {delay:.1f}s: {e}"
237
+ )
238
+ await asyncio.sleep(delay)
239
+ delay = min(delay * backoff_factor, max_delay)
240
+ else:
241
+ logger.error(
242
+ f"{func.__name__} failed after {max_attempts} attempts: {e}"
243
+ )
244
+
245
+ if last_exception:
246
+ raise last_exception
247
+ raise RuntimeError(f"{func.__name__} failed with no exception captured")
248
+
249
+ return wrapper
250
+ return decorator
251
+
252
+
253
+ # ============================================================================
254
+ # Data Recovery
255
+ # ============================================================================
256
+
257
+ @dataclass
258
+ class RecoveryResult:
259
+ """Result of data recovery attempt"""
260
+ success: bool
261
+ data: Any = None
262
+ fallback_used: bool = False
263
+ errors: List[str] = field(default_factory=list)
264
+
265
+
266
+ def recover_json_parse(
267
+ json_str: str,
268
+ fallback: Any = None,
269
+ ) -> RecoveryResult:
270
+ """
271
+ Attempt to parse JSON with fallback on failure.
272
+
273
+ Args:
274
+ json_str: JSON string to parse
275
+ fallback: Value to return on parse failure
276
+
277
+ Returns:
278
+ RecoveryResult with parsed data or fallback
279
+ """
280
+ if not json_str:
281
+ return RecoveryResult(
282
+ success=True,
283
+ data=fallback if fallback is not None else [],
284
+ fallback_used=True,
285
+ )
286
+
287
+ try:
288
+ data = json.loads(json_str)
289
+ return RecoveryResult(success=True, data=data)
290
+ except json.JSONDecodeError as e:
291
+ return RecoveryResult(
292
+ success=False,
293
+ data=fallback if fallback is not None else [],
294
+ fallback_used=True,
295
+ errors=[f"JSON parse error: {e}"],
296
+ )
297
+
298
+
299
+ def recover_trajectory_archetype(
300
+ trajectory: Dict[str, Any],
301
+ default: str = "default",
302
+ ) -> str:
303
+ """
304
+ Extract archetype from trajectory with fallback logic.
305
+
306
+ Tries:
307
+ 1. trajectory.archetype
308
+ 2. First step's action.parameters.archetype
309
+ 3. First step's action.result.archetype
310
+ 4. default
311
+
312
+ Returns:
313
+ Extracted or default archetype
314
+ """
315
+ # Try trajectory level
316
+ archetype = trajectory.get("archetype")
317
+ if archetype:
318
+ return archetype
319
+
320
+ # Try steps
321
+ steps_json = trajectory.get("stepsJson", trajectory.get("steps_json", "[]"))
322
+ result = recover_json_parse(steps_json, [])
323
+
324
+ if result.success and result.data:
325
+ for step in result.data:
326
+ action = step.get("action", {})
327
+
328
+ # Try parameters
329
+ params_arch = action.get("parameters", {}).get("archetype")
330
+ if params_arch:
331
+ return params_arch
332
+
333
+ # Try result
334
+ result_arch = action.get("result", {}).get("archetype")
335
+ if result_arch:
336
+ return result_arch
337
+
338
+ return default
339
+
340
+
341
+ def filter_valid_trajectories(
342
+ trajectories: List[Dict[str, Any]],
343
+ min_steps: int = 1,
344
+ require_pnl: bool = False,
345
+ ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
346
+ """
347
+ Filter trajectories, separating valid from invalid.
348
+
349
+ Returns:
350
+ Tuple of (valid_trajectories, invalid_trajectories)
351
+ """
352
+ valid = []
353
+ invalid = []
354
+
355
+ for traj in trajectories:
356
+ errors = []
357
+
358
+ # Check required fields
359
+ if not traj.get("trajectoryId") and not traj.get("trajectory_id"):
360
+ errors.append("missing trajectoryId")
361
+
362
+ # Check steps
363
+ steps_json = traj.get("stepsJson", traj.get("steps_json", "[]"))
364
+ steps_result = recover_json_parse(steps_json, [])
365
+
366
+ if not steps_result.success:
367
+ errors.append(f"invalid stepsJson: {steps_result.errors}")
368
+ elif len(steps_result.data) < min_steps:
369
+ errors.append(f"insufficient steps: {len(steps_result.data)} < {min_steps}")
370
+
371
+ # Check PnL if required
372
+ if require_pnl:
373
+ pnl = traj.get("finalPnL", traj.get("final_pnl"))
374
+ if pnl is None:
375
+ errors.append("missing finalPnL")
376
+
377
+ if errors:
378
+ traj["_validation_errors"] = errors
379
+ invalid.append(traj)
380
+ else:
381
+ valid.append(traj)
382
+
383
+ return valid, invalid
384
+
385
+
386
+ # ============================================================================
387
+ # Database Recovery
388
+ # ============================================================================
389
+
390
+ class DatabaseConnectionManager:
391
+ """
392
+ Manages database connection with automatic recovery.
393
+
394
+ Handles:
395
+ - Connection creation with retry
396
+ - Health checking
397
+ - Automatic reconnection on failure
398
+ - Connection pooling
399
+ """
400
+
401
+ def __init__(
402
+ self,
403
+ database_url: str,
404
+ pool_size: int = 5,
405
+ max_retries: int = 3,
406
+ ):
407
+ self.database_url = database_url
408
+ self.pool_size = pool_size
409
+ self.max_retries = max_retries
410
+ self._pool = None
411
+ self._last_health_check = 0.0
412
+ self._health_check_interval = 30.0 # seconds
413
+
414
+ async def get_pool(self):
415
+ """Get database pool, creating if necessary"""
416
+ if self._pool is None:
417
+ await self._create_pool()
418
+ return self._pool
419
+
420
+ @with_retry_async(max_attempts=3, initial_delay=2.0)
421
+ async def _create_pool(self):
422
+ """Create database connection pool with retry"""
423
+ import asyncpg
424
+
425
+ logger.info(f"Creating database connection pool (size={self.pool_size})...")
426
+ self._pool = await asyncpg.create_pool(
427
+ self.database_url,
428
+ min_size=2,
429
+ max_size=self.pool_size,
430
+ command_timeout=60,
431
+ )
432
+ logger.info("Database connection pool created")
433
+
434
+ async def health_check(self) -> bool:
435
+ """Check if database connection is healthy"""
436
+ now = time.time()
437
+
438
+ # Skip if checked recently
439
+ if now - self._last_health_check < self._health_check_interval:
440
+ return True
441
+
442
+ self._last_health_check = now
443
+
444
+ if self._pool is None:
445
+ return False
446
+
447
+ try:
448
+ async with self._pool.acquire() as conn:
449
+ await conn.fetchval("SELECT 1")
450
+ return True
451
+ except Exception as e:
452
+ logger.warning(f"Database health check failed: {e}")
453
+ return False
454
+
455
+ async def close(self):
456
+ """Close database pool"""
457
+ if self._pool:
458
+ await self._pool.close()
459
+ self._pool = None
460
+ logger.info("Database connection pool closed")
461
+
462
+
463
+ # ============================================================================
464
+ # Graceful Shutdown
465
+ # ============================================================================
466
+
467
+ class GracefulShutdown:
468
+ """
469
+ Manages graceful shutdown of training pipeline.
470
+
471
+ Features:
472
+ - Signal handling (SIGINT, SIGTERM)
473
+ - Checkpoint saving before exit
474
+ - Resource cleanup
475
+ - Timeout for forced exit
476
+ """
477
+
478
+ def __init__(
479
+ self,
480
+ shutdown_timeout: float = 30.0,
481
+ checkpoint_callback: Optional[Callable] = None,
482
+ ):
483
+ self.shutdown_timeout = shutdown_timeout
484
+ self.checkpoint_callback = checkpoint_callback
485
+ self._shutdown_requested = False
486
+ self._original_handlers: Dict[int, Any] = {}
487
+
488
+ @property
489
+ def shutdown_requested(self) -> bool:
490
+ """Check if shutdown has been requested"""
491
+ return self._shutdown_requested
492
+
493
+ def install_handlers(self):
494
+ """Install signal handlers for graceful shutdown"""
495
+ for sig in (signal.SIGINT, signal.SIGTERM):
496
+ self._original_handlers[sig] = signal.getsignal(sig)
497
+ signal.signal(sig, self._handle_signal)
498
+
499
+ logger.debug("Graceful shutdown handlers installed")
500
+
501
+ def restore_handlers(self):
502
+ """Restore original signal handlers"""
503
+ for sig, handler in self._original_handlers.items():
504
+ signal.signal(sig, handler)
505
+
506
+ self._original_handlers.clear()
507
+ logger.debug("Original signal handlers restored")
508
+
509
+ def _handle_signal(self, signum: int, frame):
510
+ """Handle shutdown signal"""
511
+ if self._shutdown_requested:
512
+ logger.warning("Forced shutdown requested - exiting immediately")
513
+ sys.exit(1)
514
+
515
+ sig_name = signal.Signals(signum).name
516
+ logger.info(f"Received {sig_name} - initiating graceful shutdown...")
517
+ self._shutdown_requested = True
518
+
519
+ # Save checkpoint if callback provided
520
+ if self.checkpoint_callback:
521
+ logger.info("Saving checkpoint before shutdown...")
522
+ try:
523
+ self.checkpoint_callback()
524
+ logger.info("Checkpoint saved successfully")
525
+ except Exception as e:
526
+ logger.error(f"Failed to save checkpoint: {e}")
527
+
528
+ def __enter__(self):
529
+ self.install_handlers()
530
+ return self
531
+
532
+ def __exit__(self, exc_type, exc_val, exc_tb):
533
+ self.restore_handlers()
534
+
535
+
536
+ # ============================================================================
537
+ # Progress Tracking
538
+ # ============================================================================
539
+
540
+ @dataclass
541
+ class TrainingProgress:
542
+ """Tracks training progress for recovery purposes"""
543
+
544
+ # Maximum errors to keep in memory (prevents unbounded growth during long runs)
545
+ MAX_ERRORS_IN_MEMORY: int = 200
546
+
547
+ current_step: int = 0
548
+ total_steps: int = 0
549
+ trajectories_processed: int = 0
550
+ trajectories_skipped: int = 0
551
+ last_checkpoint_step: int = 0
552
+ errors_encountered: List[str] = field(default_factory=list)
553
+ start_time: float = field(default_factory=time.time)
554
+ total_errors_count: int = 0 # Track total even when list is truncated
555
+
556
+ def add_error(self, error: str) -> None:
557
+ """Add an error, truncating old errors if list grows too large"""
558
+ self.errors_encountered.append(error)
559
+ self.total_errors_count += 1
560
+ # Keep only the most recent errors
561
+ if len(self.errors_encountered) > self.MAX_ERRORS_IN_MEMORY:
562
+ self.errors_encountered = self.errors_encountered[-self.MAX_ERRORS_IN_MEMORY:]
563
+
564
+ @property
565
+ def elapsed_time(self) -> float:
566
+ """Seconds since training started"""
567
+ return time.time() - self.start_time
568
+
569
+ @property
570
+ def progress_pct(self) -> float:
571
+ """Progress percentage (0-100)"""
572
+ if self.total_steps == 0:
573
+ return 0.0
574
+ return (self.current_step / self.total_steps) * 100
575
+
576
+ def to_checkpoint(self) -> Dict[str, Any]:
577
+ """Convert to checkpoint-compatible dict"""
578
+ return {
579
+ "current_step": self.current_step,
580
+ "total_steps": self.total_steps,
581
+ "trajectories_processed": self.trajectories_processed,
582
+ "trajectories_skipped": self.trajectories_skipped,
583
+ "last_checkpoint_step": self.last_checkpoint_step,
584
+ "errors_encountered": self.errors_encountered[-100:], # Keep last 100 in checkpoint
585
+ "total_errors_count": self.total_errors_count,
586
+ "elapsed_time": self.elapsed_time,
587
+ }
588
+
589
+ @classmethod
590
+ def from_checkpoint(cls, data: Dict[str, Any]) -> "TrainingProgress":
591
+ """Restore from checkpoint"""
592
+ progress = cls(
593
+ current_step=data.get("current_step", 0),
594
+ total_steps=data.get("total_steps", 0),
595
+ trajectories_processed=data.get("trajectories_processed", 0),
596
+ trajectories_skipped=data.get("trajectories_skipped", 0),
597
+ last_checkpoint_step=data.get("last_checkpoint_step", 0),
598
+ errors_encountered=data.get("errors_encountered", []),
599
+ total_errors_count=data.get("total_errors_count", len(data.get("errors_encountered", []))),
600
+ )
601
+ return progress
602
+
603
+ def log_status(self):
604
+ """Log current training status"""
605
+ logger.info(
606
+ f"Training Progress: Step {self.current_step}/{self.total_steps} "
607
+ f"({self.progress_pct:.1f}%) | "
608
+ f"Trajectories: {self.trajectories_processed} processed, "
609
+ f"{self.trajectories_skipped} skipped | "
610
+ f"Errors: {self.total_errors_count} | "
611
+ f"Elapsed: {self.elapsed_time:.0f}s"
612
+ )
613
+
614
+
615
+ # ============================================================================
616
+ # Utility Functions
617
+ # ============================================================================
618
+
619
+ def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float:
620
+ """Safe division that returns default on zero denominator"""
621
+ if denominator == 0:
622
+ return default
623
+ return numerator / denominator
624
+
625
+
626
+ def clamp(value: float, min_val: float, max_val: float) -> float:
627
+ """Clamp value to range [min_val, max_val]"""
628
+ return max(min_val, min(max_val, value))
629
+
630
+
631
+ def require_env(name: str) -> str:
632
+ """Get required environment variable or raise clear error"""
633
+ value = os.getenv(name)
634
+ if not value:
635
+ raise TrainingError(
636
+ category=ErrorCategory.CONFIGURATION,
637
+ message=f"Required environment variable '{name}' is not set",
638
+ component="environment",
639
+ recoverable=False,
640
+ )
641
+ return value
642
+
643
+
644
+ def get_env_or_default(name: str, default: str) -> str:
645
+ """Get environment variable with default value"""
646
+ return os.getenv(name, default)
647
+