@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,647 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Error Recovery and Graceful Degradation for Training Pipeline
|
|
3
|
+
|
|
4
|
+
Provides utilities for handling failures gracefully:
|
|
5
|
+
- Database connection recovery
|
|
6
|
+
- Malformed data handling
|
|
7
|
+
- Service health monitoring
|
|
8
|
+
- Graceful shutdown
|
|
9
|
+
- Retry logic with backoff
|
|
10
|
+
|
|
11
|
+
Philosophy:
|
|
12
|
+
- Fail fast for programmer errors
|
|
13
|
+
- Recover gracefully from transient errors
|
|
14
|
+
- Always log what went wrong
|
|
15
|
+
- Never lose training progress silently
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import asyncio
|
|
19
|
+
import functools
|
|
20
|
+
import json
|
|
21
|
+
import logging
|
|
22
|
+
import os
|
|
23
|
+
import signal
|
|
24
|
+
import sys
|
|
25
|
+
import time
|
|
26
|
+
from dataclasses import dataclass, field
|
|
27
|
+
from enum import Enum
|
|
28
|
+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
T = TypeVar("T")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ============================================================================
|
|
36
|
+
# Error Categories
|
|
37
|
+
# ============================================================================
|
|
38
|
+
|
|
39
|
+
class ErrorCategory(Enum):
|
|
40
|
+
"""Categories of errors for handling decisions"""
|
|
41
|
+
|
|
42
|
+
# Transient - retry makes sense
|
|
43
|
+
TRANSIENT = "transient"
|
|
44
|
+
|
|
45
|
+
# Configuration - fix config and restart
|
|
46
|
+
CONFIGURATION = "configuration"
|
|
47
|
+
|
|
48
|
+
# Data - skip this item and continue
|
|
49
|
+
DATA_VALIDATION = "data_validation"
|
|
50
|
+
|
|
51
|
+
# Infrastructure - external service down
|
|
52
|
+
INFRASTRUCTURE = "infrastructure"
|
|
53
|
+
|
|
54
|
+
# Fatal - cannot continue
|
|
55
|
+
FATAL = "fatal"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TrainingError(Exception):
|
|
59
|
+
"""
|
|
60
|
+
Structured training error for handling decisions.
|
|
61
|
+
|
|
62
|
+
Inherits from Exception so it can be raised and caught properly.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
category: ErrorCategory,
|
|
68
|
+
message: str,
|
|
69
|
+
component: str,
|
|
70
|
+
recoverable: bool,
|
|
71
|
+
details: Optional[Dict[str, Any]] = None,
|
|
72
|
+
original_exception: Optional[Exception] = None,
|
|
73
|
+
):
|
|
74
|
+
super().__init__(message)
|
|
75
|
+
self.category = category
|
|
76
|
+
self.message = message
|
|
77
|
+
self.component = component
|
|
78
|
+
self.recoverable = recoverable
|
|
79
|
+
self.details = details or {}
|
|
80
|
+
self.original_exception = original_exception
|
|
81
|
+
|
|
82
|
+
def __str__(self) -> str:
|
|
83
|
+
return f"[{self.category.value}] {self.component}: {self.message}"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# ============================================================================
|
|
87
|
+
# Error Classification
|
|
88
|
+
# ============================================================================
|
|
89
|
+
|
|
90
|
+
def classify_error(exception: Exception) -> ErrorCategory:
|
|
91
|
+
"""
|
|
92
|
+
Classify an exception into an error category for handling decisions.
|
|
93
|
+
|
|
94
|
+
This helps decide whether to retry, skip, or abort.
|
|
95
|
+
"""
|
|
96
|
+
error_str = str(exception).lower()
|
|
97
|
+
exception_type = type(exception).__name__
|
|
98
|
+
|
|
99
|
+
# Connection errors - transient
|
|
100
|
+
if any(x in exception_type for x in ["Connection", "Timeout", "Network"]):
|
|
101
|
+
return ErrorCategory.TRANSIENT
|
|
102
|
+
|
|
103
|
+
if any(x in error_str for x in [
|
|
104
|
+
"connection refused",
|
|
105
|
+
"connection reset",
|
|
106
|
+
"timeout",
|
|
107
|
+
"temporary failure",
|
|
108
|
+
"service unavailable",
|
|
109
|
+
]):
|
|
110
|
+
return ErrorCategory.TRANSIENT
|
|
111
|
+
|
|
112
|
+
# Configuration errors
|
|
113
|
+
if any(x in error_str for x in [
|
|
114
|
+
"not set",
|
|
115
|
+
"not configured",
|
|
116
|
+
"invalid config",
|
|
117
|
+
"missing required",
|
|
118
|
+
]):
|
|
119
|
+
return ErrorCategory.CONFIGURATION
|
|
120
|
+
|
|
121
|
+
# Data validation errors
|
|
122
|
+
if any(x in error_str for x in [
|
|
123
|
+
"json",
|
|
124
|
+
"parse",
|
|
125
|
+
"decode",
|
|
126
|
+
"invalid data",
|
|
127
|
+
"schema",
|
|
128
|
+
"validation",
|
|
129
|
+
]):
|
|
130
|
+
return ErrorCategory.DATA_VALIDATION
|
|
131
|
+
|
|
132
|
+
# Infrastructure errors
|
|
133
|
+
if any(x in error_str for x in [
|
|
134
|
+
"database",
|
|
135
|
+
"redis",
|
|
136
|
+
"cuda",
|
|
137
|
+
"gpu",
|
|
138
|
+
"out of memory",
|
|
139
|
+
]):
|
|
140
|
+
return ErrorCategory.INFRASTRUCTURE
|
|
141
|
+
|
|
142
|
+
# Default to fatal for unknown errors
|
|
143
|
+
return ErrorCategory.FATAL
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def is_recoverable(exception: Exception) -> bool:
|
|
147
|
+
"""Check if an error is recoverable (worth retrying)"""
|
|
148
|
+
category = classify_error(exception)
|
|
149
|
+
return category in (ErrorCategory.TRANSIENT, ErrorCategory.DATA_VALIDATION)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# ============================================================================
|
|
153
|
+
# Retry Logic
|
|
154
|
+
# ============================================================================
|
|
155
|
+
|
|
156
|
+
def with_retry(
|
|
157
|
+
max_attempts: int = 3,
|
|
158
|
+
initial_delay: float = 1.0,
|
|
159
|
+
max_delay: float = 30.0,
|
|
160
|
+
backoff_factor: float = 2.0,
|
|
161
|
+
retryable_exceptions: tuple = (Exception,),
|
|
162
|
+
) -> Callable:
|
|
163
|
+
"""
|
|
164
|
+
Decorator for retry with exponential backoff.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
max_attempts: Maximum number of attempts
|
|
168
|
+
initial_delay: Initial delay between retries (seconds)
|
|
169
|
+
max_delay: Maximum delay between retries (seconds)
|
|
170
|
+
backoff_factor: Multiplier for delay after each attempt
|
|
171
|
+
retryable_exceptions: Tuple of exceptions that trigger retry
|
|
172
|
+
"""
|
|
173
|
+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
174
|
+
@functools.wraps(func)
|
|
175
|
+
def wrapper(*args, **kwargs) -> T:
|
|
176
|
+
last_exception = None
|
|
177
|
+
delay = initial_delay
|
|
178
|
+
|
|
179
|
+
for attempt in range(1, max_attempts + 1):
|
|
180
|
+
try:
|
|
181
|
+
return func(*args, **kwargs)
|
|
182
|
+
except retryable_exceptions as e:
|
|
183
|
+
last_exception = e
|
|
184
|
+
|
|
185
|
+
if not is_recoverable(e):
|
|
186
|
+
logger.error(f"{func.__name__} failed with non-recoverable error: {e}")
|
|
187
|
+
raise
|
|
188
|
+
|
|
189
|
+
if attempt < max_attempts:
|
|
190
|
+
logger.warning(
|
|
191
|
+
f"{func.__name__} failed (attempt {attempt}/{max_attempts}), "
|
|
192
|
+
f"retrying in {delay:.1f}s: {e}"
|
|
193
|
+
)
|
|
194
|
+
time.sleep(delay)
|
|
195
|
+
delay = min(delay * backoff_factor, max_delay)
|
|
196
|
+
else:
|
|
197
|
+
logger.error(
|
|
198
|
+
f"{func.__name__} failed after {max_attempts} attempts: {e}"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if last_exception:
|
|
202
|
+
raise last_exception
|
|
203
|
+
raise RuntimeError(f"{func.__name__} failed with no exception captured")
|
|
204
|
+
|
|
205
|
+
return wrapper
|
|
206
|
+
return decorator
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def with_retry_async(
|
|
210
|
+
max_attempts: int = 3,
|
|
211
|
+
initial_delay: float = 1.0,
|
|
212
|
+
max_delay: float = 30.0,
|
|
213
|
+
backoff_factor: float = 2.0,
|
|
214
|
+
retryable_exceptions: tuple = (Exception,),
|
|
215
|
+
) -> Callable:
|
|
216
|
+
"""Async version of retry decorator"""
|
|
217
|
+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
218
|
+
@functools.wraps(func)
|
|
219
|
+
async def wrapper(*args, **kwargs) -> T:
|
|
220
|
+
last_exception = None
|
|
221
|
+
delay = initial_delay
|
|
222
|
+
|
|
223
|
+
for attempt in range(1, max_attempts + 1):
|
|
224
|
+
try:
|
|
225
|
+
return await func(*args, **kwargs)
|
|
226
|
+
except retryable_exceptions as e:
|
|
227
|
+
last_exception = e
|
|
228
|
+
|
|
229
|
+
if not is_recoverable(e):
|
|
230
|
+
logger.error(f"{func.__name__} failed with non-recoverable error: {e}")
|
|
231
|
+
raise
|
|
232
|
+
|
|
233
|
+
if attempt < max_attempts:
|
|
234
|
+
logger.warning(
|
|
235
|
+
f"{func.__name__} failed (attempt {attempt}/{max_attempts}), "
|
|
236
|
+
f"retrying in {delay:.1f}s: {e}"
|
|
237
|
+
)
|
|
238
|
+
await asyncio.sleep(delay)
|
|
239
|
+
delay = min(delay * backoff_factor, max_delay)
|
|
240
|
+
else:
|
|
241
|
+
logger.error(
|
|
242
|
+
f"{func.__name__} failed after {max_attempts} attempts: {e}"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
if last_exception:
|
|
246
|
+
raise last_exception
|
|
247
|
+
raise RuntimeError(f"{func.__name__} failed with no exception captured")
|
|
248
|
+
|
|
249
|
+
return wrapper
|
|
250
|
+
return decorator
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# ============================================================================
|
|
254
|
+
# Data Recovery
|
|
255
|
+
# ============================================================================
|
|
256
|
+
|
|
257
|
+
@dataclass
|
|
258
|
+
class RecoveryResult:
|
|
259
|
+
"""Result of data recovery attempt"""
|
|
260
|
+
success: bool
|
|
261
|
+
data: Any = None
|
|
262
|
+
fallback_used: bool = False
|
|
263
|
+
errors: List[str] = field(default_factory=list)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def recover_json_parse(
|
|
267
|
+
json_str: str,
|
|
268
|
+
fallback: Any = None,
|
|
269
|
+
) -> RecoveryResult:
|
|
270
|
+
"""
|
|
271
|
+
Attempt to parse JSON with fallback on failure.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
json_str: JSON string to parse
|
|
275
|
+
fallback: Value to return on parse failure
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
RecoveryResult with parsed data or fallback
|
|
279
|
+
"""
|
|
280
|
+
if not json_str:
|
|
281
|
+
return RecoveryResult(
|
|
282
|
+
success=True,
|
|
283
|
+
data=fallback if fallback is not None else [],
|
|
284
|
+
fallback_used=True,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
try:
|
|
288
|
+
data = json.loads(json_str)
|
|
289
|
+
return RecoveryResult(success=True, data=data)
|
|
290
|
+
except json.JSONDecodeError as e:
|
|
291
|
+
return RecoveryResult(
|
|
292
|
+
success=False,
|
|
293
|
+
data=fallback if fallback is not None else [],
|
|
294
|
+
fallback_used=True,
|
|
295
|
+
errors=[f"JSON parse error: {e}"],
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def recover_trajectory_archetype(
|
|
300
|
+
trajectory: Dict[str, Any],
|
|
301
|
+
default: str = "default",
|
|
302
|
+
) -> str:
|
|
303
|
+
"""
|
|
304
|
+
Extract archetype from trajectory with fallback logic.
|
|
305
|
+
|
|
306
|
+
Tries:
|
|
307
|
+
1. trajectory.archetype
|
|
308
|
+
2. First step's action.parameters.archetype
|
|
309
|
+
3. First step's action.result.archetype
|
|
310
|
+
4. default
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Extracted or default archetype
|
|
314
|
+
"""
|
|
315
|
+
# Try trajectory level
|
|
316
|
+
archetype = trajectory.get("archetype")
|
|
317
|
+
if archetype:
|
|
318
|
+
return archetype
|
|
319
|
+
|
|
320
|
+
# Try steps
|
|
321
|
+
steps_json = trajectory.get("stepsJson", trajectory.get("steps_json", "[]"))
|
|
322
|
+
result = recover_json_parse(steps_json, [])
|
|
323
|
+
|
|
324
|
+
if result.success and result.data:
|
|
325
|
+
for step in result.data:
|
|
326
|
+
action = step.get("action", {})
|
|
327
|
+
|
|
328
|
+
# Try parameters
|
|
329
|
+
params_arch = action.get("parameters", {}).get("archetype")
|
|
330
|
+
if params_arch:
|
|
331
|
+
return params_arch
|
|
332
|
+
|
|
333
|
+
# Try result
|
|
334
|
+
result_arch = action.get("result", {}).get("archetype")
|
|
335
|
+
if result_arch:
|
|
336
|
+
return result_arch
|
|
337
|
+
|
|
338
|
+
return default
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def filter_valid_trajectories(
|
|
342
|
+
trajectories: List[Dict[str, Any]],
|
|
343
|
+
min_steps: int = 1,
|
|
344
|
+
require_pnl: bool = False,
|
|
345
|
+
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
346
|
+
"""
|
|
347
|
+
Filter trajectories, separating valid from invalid.
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
Tuple of (valid_trajectories, invalid_trajectories)
|
|
351
|
+
"""
|
|
352
|
+
valid = []
|
|
353
|
+
invalid = []
|
|
354
|
+
|
|
355
|
+
for traj in trajectories:
|
|
356
|
+
errors = []
|
|
357
|
+
|
|
358
|
+
# Check required fields
|
|
359
|
+
if not traj.get("trajectoryId") and not traj.get("trajectory_id"):
|
|
360
|
+
errors.append("missing trajectoryId")
|
|
361
|
+
|
|
362
|
+
# Check steps
|
|
363
|
+
steps_json = traj.get("stepsJson", traj.get("steps_json", "[]"))
|
|
364
|
+
steps_result = recover_json_parse(steps_json, [])
|
|
365
|
+
|
|
366
|
+
if not steps_result.success:
|
|
367
|
+
errors.append(f"invalid stepsJson: {steps_result.errors}")
|
|
368
|
+
elif len(steps_result.data) < min_steps:
|
|
369
|
+
errors.append(f"insufficient steps: {len(steps_result.data)} < {min_steps}")
|
|
370
|
+
|
|
371
|
+
# Check PnL if required
|
|
372
|
+
if require_pnl:
|
|
373
|
+
pnl = traj.get("finalPnL", traj.get("final_pnl"))
|
|
374
|
+
if pnl is None:
|
|
375
|
+
errors.append("missing finalPnL")
|
|
376
|
+
|
|
377
|
+
if errors:
|
|
378
|
+
traj["_validation_errors"] = errors
|
|
379
|
+
invalid.append(traj)
|
|
380
|
+
else:
|
|
381
|
+
valid.append(traj)
|
|
382
|
+
|
|
383
|
+
return valid, invalid
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
# ============================================================================
|
|
387
|
+
# Database Recovery
|
|
388
|
+
# ============================================================================
|
|
389
|
+
|
|
390
|
+
class DatabaseConnectionManager:
|
|
391
|
+
"""
|
|
392
|
+
Manages database connection with automatic recovery.
|
|
393
|
+
|
|
394
|
+
Handles:
|
|
395
|
+
- Connection creation with retry
|
|
396
|
+
- Health checking
|
|
397
|
+
- Automatic reconnection on failure
|
|
398
|
+
- Connection pooling
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
def __init__(
|
|
402
|
+
self,
|
|
403
|
+
database_url: str,
|
|
404
|
+
pool_size: int = 5,
|
|
405
|
+
max_retries: int = 3,
|
|
406
|
+
):
|
|
407
|
+
self.database_url = database_url
|
|
408
|
+
self.pool_size = pool_size
|
|
409
|
+
self.max_retries = max_retries
|
|
410
|
+
self._pool = None
|
|
411
|
+
self._last_health_check = 0.0
|
|
412
|
+
self._health_check_interval = 30.0 # seconds
|
|
413
|
+
|
|
414
|
+
async def get_pool(self):
|
|
415
|
+
"""Get database pool, creating if necessary"""
|
|
416
|
+
if self._pool is None:
|
|
417
|
+
await self._create_pool()
|
|
418
|
+
return self._pool
|
|
419
|
+
|
|
420
|
+
@with_retry_async(max_attempts=3, initial_delay=2.0)
|
|
421
|
+
async def _create_pool(self):
|
|
422
|
+
"""Create database connection pool with retry"""
|
|
423
|
+
import asyncpg
|
|
424
|
+
|
|
425
|
+
logger.info(f"Creating database connection pool (size={self.pool_size})...")
|
|
426
|
+
self._pool = await asyncpg.create_pool(
|
|
427
|
+
self.database_url,
|
|
428
|
+
min_size=2,
|
|
429
|
+
max_size=self.pool_size,
|
|
430
|
+
command_timeout=60,
|
|
431
|
+
)
|
|
432
|
+
logger.info("Database connection pool created")
|
|
433
|
+
|
|
434
|
+
async def health_check(self) -> bool:
|
|
435
|
+
"""Check if database connection is healthy"""
|
|
436
|
+
now = time.time()
|
|
437
|
+
|
|
438
|
+
# Skip if checked recently
|
|
439
|
+
if now - self._last_health_check < self._health_check_interval:
|
|
440
|
+
return True
|
|
441
|
+
|
|
442
|
+
self._last_health_check = now
|
|
443
|
+
|
|
444
|
+
if self._pool is None:
|
|
445
|
+
return False
|
|
446
|
+
|
|
447
|
+
try:
|
|
448
|
+
async with self._pool.acquire() as conn:
|
|
449
|
+
await conn.fetchval("SELECT 1")
|
|
450
|
+
return True
|
|
451
|
+
except Exception as e:
|
|
452
|
+
logger.warning(f"Database health check failed: {e}")
|
|
453
|
+
return False
|
|
454
|
+
|
|
455
|
+
async def close(self):
|
|
456
|
+
"""Close database pool"""
|
|
457
|
+
if self._pool:
|
|
458
|
+
await self._pool.close()
|
|
459
|
+
self._pool = None
|
|
460
|
+
logger.info("Database connection pool closed")
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
# ============================================================================
|
|
464
|
+
# Graceful Shutdown
|
|
465
|
+
# ============================================================================
|
|
466
|
+
|
|
467
|
+
class GracefulShutdown:
|
|
468
|
+
"""
|
|
469
|
+
Manages graceful shutdown of training pipeline.
|
|
470
|
+
|
|
471
|
+
Features:
|
|
472
|
+
- Signal handling (SIGINT, SIGTERM)
|
|
473
|
+
- Checkpoint saving before exit
|
|
474
|
+
- Resource cleanup
|
|
475
|
+
- Timeout for forced exit
|
|
476
|
+
"""
|
|
477
|
+
|
|
478
|
+
def __init__(
|
|
479
|
+
self,
|
|
480
|
+
shutdown_timeout: float = 30.0,
|
|
481
|
+
checkpoint_callback: Optional[Callable] = None,
|
|
482
|
+
):
|
|
483
|
+
self.shutdown_timeout = shutdown_timeout
|
|
484
|
+
self.checkpoint_callback = checkpoint_callback
|
|
485
|
+
self._shutdown_requested = False
|
|
486
|
+
self._original_handlers: Dict[int, Any] = {}
|
|
487
|
+
|
|
488
|
+
@property
|
|
489
|
+
def shutdown_requested(self) -> bool:
|
|
490
|
+
"""Check if shutdown has been requested"""
|
|
491
|
+
return self._shutdown_requested
|
|
492
|
+
|
|
493
|
+
def install_handlers(self):
|
|
494
|
+
"""Install signal handlers for graceful shutdown"""
|
|
495
|
+
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
496
|
+
self._original_handlers[sig] = signal.getsignal(sig)
|
|
497
|
+
signal.signal(sig, self._handle_signal)
|
|
498
|
+
|
|
499
|
+
logger.debug("Graceful shutdown handlers installed")
|
|
500
|
+
|
|
501
|
+
def restore_handlers(self):
|
|
502
|
+
"""Restore original signal handlers"""
|
|
503
|
+
for sig, handler in self._original_handlers.items():
|
|
504
|
+
signal.signal(sig, handler)
|
|
505
|
+
|
|
506
|
+
self._original_handlers.clear()
|
|
507
|
+
logger.debug("Original signal handlers restored")
|
|
508
|
+
|
|
509
|
+
def _handle_signal(self, signum: int, frame):
|
|
510
|
+
"""Handle shutdown signal"""
|
|
511
|
+
if self._shutdown_requested:
|
|
512
|
+
logger.warning("Forced shutdown requested - exiting immediately")
|
|
513
|
+
sys.exit(1)
|
|
514
|
+
|
|
515
|
+
sig_name = signal.Signals(signum).name
|
|
516
|
+
logger.info(f"Received {sig_name} - initiating graceful shutdown...")
|
|
517
|
+
self._shutdown_requested = True
|
|
518
|
+
|
|
519
|
+
# Save checkpoint if callback provided
|
|
520
|
+
if self.checkpoint_callback:
|
|
521
|
+
logger.info("Saving checkpoint before shutdown...")
|
|
522
|
+
try:
|
|
523
|
+
self.checkpoint_callback()
|
|
524
|
+
logger.info("Checkpoint saved successfully")
|
|
525
|
+
except Exception as e:
|
|
526
|
+
logger.error(f"Failed to save checkpoint: {e}")
|
|
527
|
+
|
|
528
|
+
def __enter__(self):
|
|
529
|
+
self.install_handlers()
|
|
530
|
+
return self
|
|
531
|
+
|
|
532
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
533
|
+
self.restore_handlers()
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
# ============================================================================
|
|
537
|
+
# Progress Tracking
|
|
538
|
+
# ============================================================================
|
|
539
|
+
|
|
540
|
+
@dataclass
|
|
541
|
+
class TrainingProgress:
|
|
542
|
+
"""Tracks training progress for recovery purposes"""
|
|
543
|
+
|
|
544
|
+
# Maximum errors to keep in memory (prevents unbounded growth during long runs)
|
|
545
|
+
MAX_ERRORS_IN_MEMORY: int = 200
|
|
546
|
+
|
|
547
|
+
current_step: int = 0
|
|
548
|
+
total_steps: int = 0
|
|
549
|
+
trajectories_processed: int = 0
|
|
550
|
+
trajectories_skipped: int = 0
|
|
551
|
+
last_checkpoint_step: int = 0
|
|
552
|
+
errors_encountered: List[str] = field(default_factory=list)
|
|
553
|
+
start_time: float = field(default_factory=time.time)
|
|
554
|
+
total_errors_count: int = 0 # Track total even when list is truncated
|
|
555
|
+
|
|
556
|
+
def add_error(self, error: str) -> None:
|
|
557
|
+
"""Add an error, truncating old errors if list grows too large"""
|
|
558
|
+
self.errors_encountered.append(error)
|
|
559
|
+
self.total_errors_count += 1
|
|
560
|
+
# Keep only the most recent errors
|
|
561
|
+
if len(self.errors_encountered) > self.MAX_ERRORS_IN_MEMORY:
|
|
562
|
+
self.errors_encountered = self.errors_encountered[-self.MAX_ERRORS_IN_MEMORY:]
|
|
563
|
+
|
|
564
|
+
@property
|
|
565
|
+
def elapsed_time(self) -> float:
|
|
566
|
+
"""Seconds since training started"""
|
|
567
|
+
return time.time() - self.start_time
|
|
568
|
+
|
|
569
|
+
@property
|
|
570
|
+
def progress_pct(self) -> float:
|
|
571
|
+
"""Progress percentage (0-100)"""
|
|
572
|
+
if self.total_steps == 0:
|
|
573
|
+
return 0.0
|
|
574
|
+
return (self.current_step / self.total_steps) * 100
|
|
575
|
+
|
|
576
|
+
def to_checkpoint(self) -> Dict[str, Any]:
|
|
577
|
+
"""Convert to checkpoint-compatible dict"""
|
|
578
|
+
return {
|
|
579
|
+
"current_step": self.current_step,
|
|
580
|
+
"total_steps": self.total_steps,
|
|
581
|
+
"trajectories_processed": self.trajectories_processed,
|
|
582
|
+
"trajectories_skipped": self.trajectories_skipped,
|
|
583
|
+
"last_checkpoint_step": self.last_checkpoint_step,
|
|
584
|
+
"errors_encountered": self.errors_encountered[-100:], # Keep last 100 in checkpoint
|
|
585
|
+
"total_errors_count": self.total_errors_count,
|
|
586
|
+
"elapsed_time": self.elapsed_time,
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
@classmethod
|
|
590
|
+
def from_checkpoint(cls, data: Dict[str, Any]) -> "TrainingProgress":
|
|
591
|
+
"""Restore from checkpoint"""
|
|
592
|
+
progress = cls(
|
|
593
|
+
current_step=data.get("current_step", 0),
|
|
594
|
+
total_steps=data.get("total_steps", 0),
|
|
595
|
+
trajectories_processed=data.get("trajectories_processed", 0),
|
|
596
|
+
trajectories_skipped=data.get("trajectories_skipped", 0),
|
|
597
|
+
last_checkpoint_step=data.get("last_checkpoint_step", 0),
|
|
598
|
+
errors_encountered=data.get("errors_encountered", []),
|
|
599
|
+
total_errors_count=data.get("total_errors_count", len(data.get("errors_encountered", []))),
|
|
600
|
+
)
|
|
601
|
+
return progress
|
|
602
|
+
|
|
603
|
+
def log_status(self):
|
|
604
|
+
"""Log current training status"""
|
|
605
|
+
logger.info(
|
|
606
|
+
f"Training Progress: Step {self.current_step}/{self.total_steps} "
|
|
607
|
+
f"({self.progress_pct:.1f}%) | "
|
|
608
|
+
f"Trajectories: {self.trajectories_processed} processed, "
|
|
609
|
+
f"{self.trajectories_skipped} skipped | "
|
|
610
|
+
f"Errors: {self.total_errors_count} | "
|
|
611
|
+
f"Elapsed: {self.elapsed_time:.0f}s"
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
# ============================================================================
|
|
616
|
+
# Utility Functions
|
|
617
|
+
# ============================================================================
|
|
618
|
+
|
|
619
|
+
def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float:
|
|
620
|
+
"""Safe division that returns default on zero denominator"""
|
|
621
|
+
if denominator == 0:
|
|
622
|
+
return default
|
|
623
|
+
return numerator / denominator
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def clamp(value: float, min_val: float, max_val: float) -> float:
|
|
627
|
+
"""Clamp value to range [min_val, max_val]"""
|
|
628
|
+
return max(min_val, min(max_val, value))
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
def require_env(name: str) -> str:
|
|
632
|
+
"""Get required environment variable or raise clear error"""
|
|
633
|
+
value = os.getenv(name)
|
|
634
|
+
if not value:
|
|
635
|
+
raise TrainingError(
|
|
636
|
+
category=ErrorCategory.CONFIGURATION,
|
|
637
|
+
message=f"Required environment variable '{name}' is not set",
|
|
638
|
+
component="environment",
|
|
639
|
+
recoverable=False,
|
|
640
|
+
)
|
|
641
|
+
return value
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def get_env_or_default(name: str, default: str) -> str:
|
|
645
|
+
"""Get environment variable with default value"""
|
|
646
|
+
return os.getenv(name, default)
|
|
647
|
+
|