@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Training Pipeline – public helpers
|
|
3
|
+
*
|
|
4
|
+
* IMPORTANT: All heavy modules (AutomationPipeline, ModelDeployer) are loaded
|
|
5
|
+
* lazily so that importing this file does NOT trigger a database connection.
|
|
6
|
+
* Consumers that only need types or lightweight utilities can import
|
|
7
|
+
* "@elizaos/training" without side-effects.
|
|
8
|
+
*/
|
|
9
|
+
|
|
10
|
+
import type {
|
|
11
|
+
AutomationStatus,
|
|
12
|
+
TrainingMonitoringStatus,
|
|
13
|
+
TrainingReadinessResult,
|
|
14
|
+
TrainingTriggerOptions,
|
|
15
|
+
TrainingTriggerResult,
|
|
16
|
+
} from './types';
|
|
17
|
+
import type {
|
|
18
|
+
DeploymentOptions,
|
|
19
|
+
DeploymentResult,
|
|
20
|
+
} from './ModelDeployer';
|
|
21
|
+
import type { AutomationPipeline } from './AutomationPipeline';
|
|
22
|
+
|
|
23
|
+
export type NextTrainingModelSelection = Awaited<
|
|
24
|
+
ReturnType<AutomationPipeline['getModelSelectionInfo']>
|
|
25
|
+
>;
|
|
26
|
+
|
|
27
|
+
// ---------------------------------------------------------------------------
|
|
28
|
+
// Lazy singletons – only resolved on first call to avoid DB side-effects at
|
|
29
|
+
// module-load time.
|
|
30
|
+
// ---------------------------------------------------------------------------
|
|
31
|
+
|
|
32
|
+
let _pipeline: AutomationPipeline | null = null;
|
|
33
|
+
|
|
34
|
+
async function getPipeline(): Promise<AutomationPipeline> {
|
|
35
|
+
if (!_pipeline) {
|
|
36
|
+
const mod = await import('./AutomationPipeline');
|
|
37
|
+
_pipeline = mod.automationPipeline;
|
|
38
|
+
}
|
|
39
|
+
return _pipeline;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
async function getDeployer() {
|
|
43
|
+
const mod = await import('./ModelDeployer');
|
|
44
|
+
return mod.modelDeployer;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
/**
|
|
48
|
+
* Check whether the current trajectory set is ready for training.
|
|
49
|
+
*/
|
|
50
|
+
export async function checkTrainingReadiness(): Promise<TrainingReadinessResult> {
|
|
51
|
+
const pipeline = await getPipeline();
|
|
52
|
+
return pipeline.checkTrainingReadiness();
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
/**
|
|
56
|
+
* Trigger a new training job.
|
|
57
|
+
*/
|
|
58
|
+
export async function triggerTraining(
|
|
59
|
+
options: TrainingTriggerOptions = {}
|
|
60
|
+
): Promise<TrainingTriggerResult> {
|
|
61
|
+
const pipeline = await getPipeline();
|
|
62
|
+
return pipeline.triggerTraining(options);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/**
|
|
66
|
+
* Monitor a training batch by its batch id.
|
|
67
|
+
*/
|
|
68
|
+
export async function monitorTrainingJob(
|
|
69
|
+
batchId: string
|
|
70
|
+
): Promise<TrainingMonitoringStatus> {
|
|
71
|
+
const pipeline = await getPipeline();
|
|
72
|
+
return pipeline.monitorTraining(batchId);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
/**
|
|
76
|
+
* Get summarized status for automation, jobs, and model health.
|
|
77
|
+
*/
|
|
78
|
+
export async function getAutomationPipelineStatus(): Promise<AutomationStatus> {
|
|
79
|
+
const pipeline = await getPipeline();
|
|
80
|
+
return pipeline.getStatus();
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/**
|
|
84
|
+
* Get model-selection metadata for the next run.
|
|
85
|
+
*/
|
|
86
|
+
export async function getNextTrainingModelSelection(): Promise<{
|
|
87
|
+
success: boolean;
|
|
88
|
+
selection: NextTrainingModelSelection['selection'];
|
|
89
|
+
summary: NextTrainingModelSelection['summary'];
|
|
90
|
+
}> {
|
|
91
|
+
const pipeline = await getPipeline();
|
|
92
|
+
return pipeline.getModelSelectionInfo();
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
/**
|
|
96
|
+
* Run benchmark and deploy the model only if it passes thresholds.
|
|
97
|
+
*/
|
|
98
|
+
export async function benchmarkAndMaybeDeployModel(
|
|
99
|
+
batchId: string,
|
|
100
|
+
autoDeploy = true
|
|
101
|
+
): Promise<{
|
|
102
|
+
benchmarked: boolean;
|
|
103
|
+
deployed: boolean;
|
|
104
|
+
reason?: string;
|
|
105
|
+
}> {
|
|
106
|
+
const pipeline = await getPipeline();
|
|
107
|
+
return pipeline.benchmarkAndDeploy(batchId, autoDeploy);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
/**
|
|
111
|
+
* Deploy a specific model version using the deployment strategy options.
|
|
112
|
+
*/
|
|
113
|
+
export async function deployModelVersion(
|
|
114
|
+
options: DeploymentOptions
|
|
115
|
+
): Promise<DeploymentResult> {
|
|
116
|
+
const deployer = await getDeployer();
|
|
117
|
+
return deployer.deploy(options);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
/**
|
|
121
|
+
* Roll back from one version to another.
|
|
122
|
+
*/
|
|
123
|
+
export async function rollbackModelVersion(
|
|
124
|
+
currentVersion: string,
|
|
125
|
+
targetVersion: string
|
|
126
|
+
): Promise<DeploymentResult> {
|
|
127
|
+
const deployer = await getDeployer();
|
|
128
|
+
return deployer.rollback(currentVersion, targetVersion);
|
|
129
|
+
}
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Model Storage Service (Vercel Blob)
|
|
3
|
+
*
|
|
4
|
+
* Handles model versioning and storage using Vercel Blob.
|
|
5
|
+
* Stores trained models with metadata for easy deployment.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import { getTrainingDataAdapter } from '../../adapter';
|
|
9
|
+
import type { JsonValue } from '../../adapter';
|
|
10
|
+
import { del, list, put } from '@vercel/blob';
|
|
11
|
+
import fs from 'fs/promises';
|
|
12
|
+
import path from 'path';
|
|
13
|
+
import { logger } from '../../utils/logger';
|
|
14
|
+
|
|
15
|
+
export interface ModelMetadata {
|
|
16
|
+
trainingBatch?: string;
|
|
17
|
+
accuracy?: number;
|
|
18
|
+
avgReward?: number;
|
|
19
|
+
baseModel?: string;
|
|
20
|
+
modelIdPrefix?: string;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export interface ModelVersion {
|
|
24
|
+
version: string;
|
|
25
|
+
baseModel: string;
|
|
26
|
+
blobUrl: string;
|
|
27
|
+
size: number;
|
|
28
|
+
uploadedAt: Date;
|
|
29
|
+
metadata: ModelMetadata & Record<string, JsonValue | undefined>;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
export class ModelStorageService {
|
|
33
|
+
private readonly blobPrefix = 'models/';
|
|
34
|
+
|
|
35
|
+
/**
|
|
36
|
+
* Upload trained model to Vercel Blob
|
|
37
|
+
*/
|
|
38
|
+
async uploadModel(options: {
|
|
39
|
+
version: string;
|
|
40
|
+
modelPath: string;
|
|
41
|
+
modelIdPrefix?: string;
|
|
42
|
+
metadata?: ModelVersion['metadata'];
|
|
43
|
+
}): Promise<ModelVersion> {
|
|
44
|
+
logger.info('Uploading model to Vercel Blob', {
|
|
45
|
+
version: options.version,
|
|
46
|
+
path: options.modelPath,
|
|
47
|
+
});
|
|
48
|
+
|
|
49
|
+
// Read model file
|
|
50
|
+
const modelData = await fs.readFile(options.modelPath);
|
|
51
|
+
const fileName = path.basename(options.modelPath);
|
|
52
|
+
|
|
53
|
+
// Upload to Vercel Blob
|
|
54
|
+
const blob = await put(
|
|
55
|
+
`${this.blobPrefix}${options.version}/${fileName}`,
|
|
56
|
+
modelData,
|
|
57
|
+
{
|
|
58
|
+
access: 'public', // Models can be publicly downloaded
|
|
59
|
+
addRandomSuffix: false,
|
|
60
|
+
}
|
|
61
|
+
);
|
|
62
|
+
|
|
63
|
+
// Upload metadata
|
|
64
|
+
await put(
|
|
65
|
+
`${this.blobPrefix}${options.version}/metadata.json`,
|
|
66
|
+
JSON.stringify(options.metadata || {}, null, 2),
|
|
67
|
+
{
|
|
68
|
+
access: 'public',
|
|
69
|
+
addRandomSuffix: false,
|
|
70
|
+
}
|
|
71
|
+
);
|
|
72
|
+
|
|
73
|
+
logger.info('Model uploaded to Vercel Blob', {
|
|
74
|
+
version: options.version,
|
|
75
|
+
url: blob.url,
|
|
76
|
+
size: (blob as { size?: number }).size || 0,
|
|
77
|
+
});
|
|
78
|
+
|
|
79
|
+
const metadataModelIdPrefix =
|
|
80
|
+
typeof options.metadata?.modelIdPrefix === 'string'
|
|
81
|
+
? options.metadata.modelIdPrefix
|
|
82
|
+
: undefined;
|
|
83
|
+
const modelIdPrefix =
|
|
84
|
+
options.modelIdPrefix ||
|
|
85
|
+
metadataModelIdPrefix ||
|
|
86
|
+
process.env.TRAINING_MODEL_ID_PREFIX ||
|
|
87
|
+
'eliza-agent';
|
|
88
|
+
|
|
89
|
+
// Save to database via adapter
|
|
90
|
+
const adapter = getTrainingDataAdapter();
|
|
91
|
+
await adapter.insertModel({
|
|
92
|
+
id: `model-${Date.now()}`,
|
|
93
|
+
modelId: `${modelIdPrefix}-${options.version}`,
|
|
94
|
+
version: options.version,
|
|
95
|
+
baseModel:
|
|
96
|
+
(options.metadata?.baseModel as string) || 'unsloth/Qwen3-4B-128K',
|
|
97
|
+
trainingBatch: (options.metadata?.trainingBatch as string) || null,
|
|
98
|
+
status: 'ready',
|
|
99
|
+
deployedAt: null,
|
|
100
|
+
archivedAt: null,
|
|
101
|
+
storagePath: blob.url,
|
|
102
|
+
benchmarkScore: null,
|
|
103
|
+
accuracy: (options.metadata?.accuracy as number) || null,
|
|
104
|
+
avgReward: (options.metadata?.avgReward as number) || null,
|
|
105
|
+
evalMetrics: null,
|
|
106
|
+
wandbRunId: null,
|
|
107
|
+
wandbArtifactId: null,
|
|
108
|
+
huggingFaceRepo: null,
|
|
109
|
+
agentsUsing: 0,
|
|
110
|
+
totalInferences: 0,
|
|
111
|
+
lastBenchmarked: null,
|
|
112
|
+
benchmarkCount: 0,
|
|
113
|
+
updatedAt: new Date(),
|
|
114
|
+
});
|
|
115
|
+
|
|
116
|
+
return {
|
|
117
|
+
version: options.version,
|
|
118
|
+
baseModel:
|
|
119
|
+
(options.metadata?.baseModel as string) || 'unsloth/Qwen3-4B-128K',
|
|
120
|
+
blobUrl: blob.url,
|
|
121
|
+
size: (blob as { size?: number }).size || 0,
|
|
122
|
+
uploadedAt: new Date(),
|
|
123
|
+
metadata: options.metadata || {},
|
|
124
|
+
};
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
/**
|
|
128
|
+
* Download model from Vercel Blob
|
|
129
|
+
*/
|
|
130
|
+
async downloadModel(version: string): Promise<{
|
|
131
|
+
modelData: Buffer;
|
|
132
|
+
metadata: ModelVersion['metadata'];
|
|
133
|
+
}> {
|
|
134
|
+
const adapter = getTrainingDataAdapter();
|
|
135
|
+
const model = await adapter.getModelByVersion(version);
|
|
136
|
+
|
|
137
|
+
if (!model) {
|
|
138
|
+
throw new Error(`Model version ${version} not found`);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
// Download model file
|
|
142
|
+
const modelResponse = await fetch(model.storagePath);
|
|
143
|
+
const modelData = Buffer.from(await modelResponse.arrayBuffer());
|
|
144
|
+
|
|
145
|
+
// Download metadata
|
|
146
|
+
const metadataUrl = model.storagePath.replace(/\/[^/]+$/, '/metadata.json');
|
|
147
|
+
const metadataResponse = await fetch(metadataUrl);
|
|
148
|
+
const metadata =
|
|
149
|
+
(await metadataResponse.json()) as ModelVersion['metadata'];
|
|
150
|
+
|
|
151
|
+
return {
|
|
152
|
+
modelData,
|
|
153
|
+
metadata,
|
|
154
|
+
};
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
/**
|
|
158
|
+
* List all model versions
|
|
159
|
+
*/
|
|
160
|
+
async listModels(): Promise<ModelVersion[]> {
|
|
161
|
+
const { blobs } = await list({
|
|
162
|
+
prefix: this.blobPrefix,
|
|
163
|
+
});
|
|
164
|
+
|
|
165
|
+
// Group by version
|
|
166
|
+
interface BlobInfo {
|
|
167
|
+
url: string;
|
|
168
|
+
pathname: string;
|
|
169
|
+
size: number;
|
|
170
|
+
uploadedAt: string | Date;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
interface VersionData {
|
|
174
|
+
version: string;
|
|
175
|
+
blobs: BlobInfo[];
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
const versions = new Map<string, VersionData>();
|
|
179
|
+
|
|
180
|
+
for (const blob of blobs) {
|
|
181
|
+
const parts = blob.pathname.split('/');
|
|
182
|
+
const version = parts[1];
|
|
183
|
+
if (!version) continue;
|
|
184
|
+
|
|
185
|
+
if (!versions.has(version)) {
|
|
186
|
+
versions.set(version, {
|
|
187
|
+
version,
|
|
188
|
+
blobs: [],
|
|
189
|
+
});
|
|
190
|
+
}
|
|
191
|
+
// Convert uploadedAt to string if it's a Date
|
|
192
|
+
const blobInfo: BlobInfo = {
|
|
193
|
+
...blob,
|
|
194
|
+
uploadedAt:
|
|
195
|
+
blob.uploadedAt instanceof Date
|
|
196
|
+
? blob.uploadedAt.toISOString()
|
|
197
|
+
: blob.uploadedAt,
|
|
198
|
+
};
|
|
199
|
+
versions.get(version)!.blobs.push(blobInfo);
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
// Get metadata for each version
|
|
203
|
+
const models: ModelVersion[] = [];
|
|
204
|
+
|
|
205
|
+
for (const [version, data] of versions) {
|
|
206
|
+
const modelBlob = data.blobs.find(
|
|
207
|
+
(b: BlobInfo) =>
|
|
208
|
+
b.pathname.endsWith('.safetensors') || b.pathname.endsWith('.bin')
|
|
209
|
+
);
|
|
210
|
+
|
|
211
|
+
if (modelBlob) {
|
|
212
|
+
// Try to get metadata
|
|
213
|
+
let metadata: ModelVersion['metadata'] = {};
|
|
214
|
+
try {
|
|
215
|
+
const metadataBlob = data.blobs.find((b: BlobInfo) =>
|
|
216
|
+
b.pathname.endsWith('metadata.json')
|
|
217
|
+
);
|
|
218
|
+
if (metadataBlob) {
|
|
219
|
+
const response = await fetch(metadataBlob.url);
|
|
220
|
+
metadata = (await response.json()) as ModelVersion['metadata'];
|
|
221
|
+
}
|
|
222
|
+
} catch {
|
|
223
|
+
// No metadata, use defaults
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
models.push({
|
|
227
|
+
version,
|
|
228
|
+
baseModel: metadata.baseModel || 'unknown',
|
|
229
|
+
blobUrl: modelBlob.url,
|
|
230
|
+
size: modelBlob.size,
|
|
231
|
+
uploadedAt:
|
|
232
|
+
modelBlob.uploadedAt instanceof Date
|
|
233
|
+
? modelBlob.uploadedAt
|
|
234
|
+
: new Date(modelBlob.uploadedAt),
|
|
235
|
+
metadata,
|
|
236
|
+
});
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
return models.sort(
|
|
241
|
+
(a, b) => b.uploadedAt.getTime() - a.uploadedAt.getTime()
|
|
242
|
+
);
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
/**
|
|
246
|
+
* Delete model version
|
|
247
|
+
*/
|
|
248
|
+
async deleteModel(version: string): Promise<void> {
|
|
249
|
+
const { blobs } = await list({
|
|
250
|
+
prefix: `${this.blobPrefix}${version}/`,
|
|
251
|
+
});
|
|
252
|
+
|
|
253
|
+
for (const blob of blobs) {
|
|
254
|
+
await del(blob.url);
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
// Update database via adapter
|
|
258
|
+
const adapter = getTrainingDataAdapter();
|
|
259
|
+
const model = await adapter.getModelByVersion(version);
|
|
260
|
+
if (model) {
|
|
261
|
+
await adapter.updateModelStatus(model.modelId, 'archived', {
|
|
262
|
+
archivedAt: new Date(),
|
|
263
|
+
});
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
logger.info('Model deleted from Vercel Blob', { version });
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
/**
|
|
270
|
+
* Get latest model version
|
|
271
|
+
*/
|
|
272
|
+
async getLatestVersion(): Promise<ModelVersion | null> {
|
|
273
|
+
const models = await this.listModels();
|
|
274
|
+
return models[0] || null;
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
// Singleton
|
|
279
|
+
export const modelStorage = new ModelStorageService();
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Training Data Archiver (Vercel Blob)
|
|
3
|
+
*
|
|
4
|
+
* Archives training data (exported trajectories, RULER scores) to Vercel Blob
|
|
5
|
+
* for long-term storage and reproducibility.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import type { JsonValue } from '../../adapter';
|
|
9
|
+
import { del, list, put } from '@vercel/blob';
|
|
10
|
+
import fs from 'fs/promises';
|
|
11
|
+
import path from 'path';
|
|
12
|
+
import { logger } from '../../utils/logger';
|
|
13
|
+
|
|
14
|
+
export interface ArchivedWindow {
|
|
15
|
+
windowId: string;
|
|
16
|
+
trajectoryCount: number;
|
|
17
|
+
blobUrls: {
|
|
18
|
+
trajectories: string;
|
|
19
|
+
groups?: string;
|
|
20
|
+
rulerScores?: string;
|
|
21
|
+
metadata: string;
|
|
22
|
+
};
|
|
23
|
+
archivedAt: Date;
|
|
24
|
+
size: number;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
export class TrainingDataArchiver {
|
|
28
|
+
private readonly blobPrefix = 'training-data/';
|
|
29
|
+
|
|
30
|
+
/**
|
|
31
|
+
* Archive training data for a window
|
|
32
|
+
*/
|
|
33
|
+
async archiveWindow(options: {
|
|
34
|
+
windowId: string;
|
|
35
|
+
trajectoriesPath: string;
|
|
36
|
+
groupsPath?: string;
|
|
37
|
+
rulerScoresPath?: string;
|
|
38
|
+
metadata?: Record<string, unknown>;
|
|
39
|
+
}): Promise<ArchivedWindow> {
|
|
40
|
+
logger.info('Archiving training data', { windowId: options.windowId });
|
|
41
|
+
|
|
42
|
+
const prefix = `${this.blobPrefix}${options.windowId}/`;
|
|
43
|
+
interface BlobUrls {
|
|
44
|
+
trajectories: string;
|
|
45
|
+
groups?: string;
|
|
46
|
+
rulerScores?: string;
|
|
47
|
+
metadata: string;
|
|
48
|
+
}
|
|
49
|
+
const urls: BlobUrls = {
|
|
50
|
+
trajectories: '',
|
|
51
|
+
metadata: '',
|
|
52
|
+
};
|
|
53
|
+
let totalSize = 0;
|
|
54
|
+
|
|
55
|
+
// Upload trajectories
|
|
56
|
+
const trajData = await fs.readFile(options.trajectoriesPath);
|
|
57
|
+
const trajBlob = await put(`${prefix}trajectories.jsonl`, trajData, {
|
|
58
|
+
access: 'public',
|
|
59
|
+
addRandomSuffix: false,
|
|
60
|
+
});
|
|
61
|
+
urls.trajectories = trajBlob.url;
|
|
62
|
+
totalSize += trajData.length;
|
|
63
|
+
|
|
64
|
+
// Upload groups if provided
|
|
65
|
+
if (options.groupsPath) {
|
|
66
|
+
const groupsData = await fs.readFile(options.groupsPath);
|
|
67
|
+
const groupsBlob = await put(`${prefix}groups.jsonl`, groupsData, {
|
|
68
|
+
access: 'public',
|
|
69
|
+
addRandomSuffix: false,
|
|
70
|
+
});
|
|
71
|
+
urls.groups = groupsBlob.url;
|
|
72
|
+
totalSize += groupsData.length;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
// Upload RULER scores if provided
|
|
76
|
+
if (options.rulerScoresPath) {
|
|
77
|
+
const scoresData = await fs.readFile(options.rulerScoresPath);
|
|
78
|
+
const scoresBlob = await put(`${prefix}ruler_scores.json`, scoresData, {
|
|
79
|
+
access: 'public',
|
|
80
|
+
addRandomSuffix: false,
|
|
81
|
+
});
|
|
82
|
+
urls.rulerScores = scoresBlob.url;
|
|
83
|
+
totalSize += scoresData.length;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// Upload metadata
|
|
87
|
+
const metadataJson = JSON.stringify(options.metadata || {}, null, 2);
|
|
88
|
+
const metadataBlob = await put(`${prefix}metadata.json`, metadataJson, {
|
|
89
|
+
access: 'public',
|
|
90
|
+
addRandomSuffix: false,
|
|
91
|
+
});
|
|
92
|
+
urls.metadata = metadataBlob.url;
|
|
93
|
+
totalSize += Buffer.byteLength(metadataJson, 'utf8');
|
|
94
|
+
|
|
95
|
+
logger.info('Training data archived', {
|
|
96
|
+
windowId: options.windowId,
|
|
97
|
+
size: totalSize,
|
|
98
|
+
});
|
|
99
|
+
|
|
100
|
+
return {
|
|
101
|
+
windowId: options.windowId,
|
|
102
|
+
trajectoryCount: (options.metadata?.trajectoryCount as number) || 0,
|
|
103
|
+
blobUrls: urls,
|
|
104
|
+
archivedAt: new Date(),
|
|
105
|
+
size: totalSize,
|
|
106
|
+
};
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
/**
|
|
110
|
+
* Retrieve archived training data
|
|
111
|
+
*/
|
|
112
|
+
async getWindowData(windowId: string): Promise<{
|
|
113
|
+
trajectories: string;
|
|
114
|
+
groups?: string;
|
|
115
|
+
rulerScores?: Record<string, JsonValue>;
|
|
116
|
+
metadata: Record<string, JsonValue>;
|
|
117
|
+
} | null> {
|
|
118
|
+
const prefix = `${this.blobPrefix}${windowId}/`;
|
|
119
|
+
const { blobs } = await list({ prefix });
|
|
120
|
+
|
|
121
|
+
if (blobs.length === 0) {
|
|
122
|
+
return null;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
interface WindowDataResult {
|
|
126
|
+
trajectories?: string;
|
|
127
|
+
groups?: string;
|
|
128
|
+
rulerScores?: Record<string, JsonValue>;
|
|
129
|
+
metadata?: Record<string, JsonValue>;
|
|
130
|
+
}
|
|
131
|
+
const result: WindowDataResult = {};
|
|
132
|
+
|
|
133
|
+
for (const blob of blobs) {
|
|
134
|
+
const response = await fetch(blob.url);
|
|
135
|
+
const filename = path.basename(blob.pathname);
|
|
136
|
+
|
|
137
|
+
if (filename === 'trajectories.jsonl') {
|
|
138
|
+
result.trajectories = await response.text();
|
|
139
|
+
} else if (filename === 'groups.jsonl') {
|
|
140
|
+
result.groups = await response.text();
|
|
141
|
+
} else if (filename === 'ruler_scores.json') {
|
|
142
|
+
result.rulerScores = (await response.json()) as Record<
|
|
143
|
+
string,
|
|
144
|
+
JsonValue
|
|
145
|
+
>;
|
|
146
|
+
} else if (filename === 'metadata.json') {
|
|
147
|
+
result.metadata = (await response.json()) as Record<string, JsonValue>;
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
// Ensure required fields are present
|
|
152
|
+
if (!result.trajectories || !result.metadata) {
|
|
153
|
+
return null;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
return {
|
|
157
|
+
trajectories: result.trajectories,
|
|
158
|
+
groups: result.groups,
|
|
159
|
+
rulerScores: result.rulerScores,
|
|
160
|
+
metadata: result.metadata,
|
|
161
|
+
};
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
/**
|
|
165
|
+
* List all archived windows
|
|
166
|
+
*/
|
|
167
|
+
async listWindows(): Promise<string[]> {
|
|
168
|
+
const { blobs } = await list({ prefix: this.blobPrefix });
|
|
169
|
+
|
|
170
|
+
const windows = new Set<string>();
|
|
171
|
+
for (const blob of blobs) {
|
|
172
|
+
const parts = blob.pathname.split('/');
|
|
173
|
+
if (parts[1]) {
|
|
174
|
+
windows.add(parts[1]);
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
return Array.from(windows).sort().reverse();
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
/**
|
|
182
|
+
* Delete archived window
|
|
183
|
+
*/
|
|
184
|
+
async deleteWindow(windowId: string): Promise<void> {
|
|
185
|
+
const prefix = `${this.blobPrefix}${windowId}/`;
|
|
186
|
+
const { blobs } = await list({ prefix });
|
|
187
|
+
|
|
188
|
+
for (const blob of blobs) {
|
|
189
|
+
await del(blob.url);
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
logger.info('Deleted archived window', { windowId });
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
// Singleton
|
|
197
|
+
export const trainingDataArchiver = new TrainingDataArchiver();
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Storage Module
|
|
3
|
+
*
|
|
4
|
+
* Services for storing models and training data to Vercel Blob.
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
export {
|
|
8
|
+
ModelStorageService,
|
|
9
|
+
type ModelVersion,
|
|
10
|
+
modelStorage,
|
|
11
|
+
} from './ModelStorageService';
|
|
12
|
+
|
|
13
|
+
export {
|
|
14
|
+
type ArchivedWindow,
|
|
15
|
+
TrainingDataArchiver,
|
|
16
|
+
trainingDataArchiver,
|
|
17
|
+
} from './TrainingDataArchiver';
|