@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,73 @@
1
+ #!/usr/bin/env bun
2
+ /**
3
+ * Test scoring directly
4
+ */
5
+
6
+ import { and, db, desc, eq, isNull, not, trajectories } from '@elizaos/db';
7
+ import { archetypeScoringService } from '../src/scoring';
8
+
9
+ async function main() {
10
+ console.log('Testing trajectory scoring...\n');
11
+
12
+ // Get unscored trajectories
13
+ const unscored = await db
14
+ .select({ trajectoryId: trajectories.trajectoryId })
15
+ .from(trajectories)
16
+ .where(
17
+ and(
18
+ isNull(trajectories.aiJudgeReward),
19
+ eq(trajectories.isTrainingData, true),
20
+ not(eq(trajectories.stepsJson, 'null')),
21
+ not(eq(trajectories.stepsJson, '[]'))
22
+ )
23
+ )
24
+ .limit(10);
25
+
26
+ console.log(`Found ${unscored.length} unscored trajectories`);
27
+
28
+ if (unscored.length === 0) {
29
+ console.log('No trajectories to score!');
30
+ process.exit(0);
31
+ }
32
+
33
+ const ids = unscored.map((t) => t.trajectoryId);
34
+ console.log('Trajectory IDs:', ids);
35
+
36
+ console.log('\nAttempting to score...');
37
+
38
+ try {
39
+ const result = await archetypeScoringService.scoreByArchetype(
40
+ 'default',
41
+ ids
42
+ );
43
+ console.log('\nResult:', result);
44
+
45
+ // Check if any were scored
46
+ const scored = await db
47
+ .select({
48
+ trajectoryId: trajectories.trajectoryId,
49
+ aiJudgeReward: trajectories.aiJudgeReward,
50
+ aiJudgeReasoning: trajectories.aiJudgeReasoning,
51
+ })
52
+ .from(trajectories)
53
+ .where(not(isNull(trajectories.aiJudgeReward)))
54
+ .orderBy(desc(trajectories.judgedAt))
55
+ .limit(5);
56
+
57
+ console.log('\nScored trajectories:', scored.length);
58
+ if (scored.length > 0) {
59
+ console.log('Sample scores:');
60
+ for (const s of scored) {
61
+ console.log(
62
+ ` ${s.trajectoryId}: score=${s.aiJudgeReward}, reasoning=${s.aiJudgeReasoning?.substring(0, 50)}...`
63
+ );
64
+ }
65
+ }
66
+ } catch (error) {
67
+ console.error('Scoring error:', error);
68
+ }
69
+
70
+ process.exit(0);
71
+ }
72
+
73
+ main().catch(console.error);
@@ -0,0 +1,209 @@
1
+ #!/usr/bin/env bun
2
+
3
+ /**
4
+ * Test Trained Model - TypeScript/Node
5
+ *
6
+ * Tests a trained model by:
7
+ * 1. Loading model from database or path
8
+ * 2. Running benchmark if available
9
+ * 3. Testing inference
10
+ * 4. Comparing to baseline
11
+ *
12
+ * Usage:
13
+ * bun run packages/training/scripts/test-trained-model.ts --model-id <id>
14
+ * bun run packages/training/scripts/test-trained-model.ts --model-path <path> --benchmark
15
+ */
16
+
17
+ import { db, eq, trainedModels } from '@elizaos/db';
18
+ import { BenchmarkService } from '../src/training/BenchmarkService';
19
+ import { logger } from '../src/utils/logger';
20
+
21
+ interface TestConfig {
22
+ modelId?: string;
23
+ modelPath?: string;
24
+ benchmark?: boolean;
25
+ benchmarkPath?: string;
26
+ compareToBaseline?: boolean;
27
+ }
28
+
29
+ async function testModel(config: TestConfig): Promise<void> {
30
+ logger.info('Testing trained model', config);
31
+
32
+ // Get model from database or path
33
+ let model;
34
+ if (config.modelId) {
35
+ const result = await db
36
+ .select()
37
+ .from(trainedModels)
38
+ .where(eq(trainedModels.modelId, config.modelId))
39
+ .limit(1);
40
+
41
+ model = result[0];
42
+
43
+ if (!model) {
44
+ throw new Error(`Model not found: ${config.modelId}`);
45
+ }
46
+
47
+ logger.info('Found model in database', {
48
+ modelId: model.modelId,
49
+ version: model.version,
50
+ status: model.status,
51
+ storagePath: model.storagePath,
52
+ });
53
+ } else if (config.modelPath) {
54
+ // Create mock model entry for testing
55
+ model = {
56
+ modelId: `test-${Date.now()}`,
57
+ version: 'test',
58
+ status: 'ready' as const,
59
+ storagePath: config.modelPath,
60
+ benchmarkScore: null,
61
+ };
62
+
63
+ logger.info('Using model from path', {
64
+ modelPath: config.modelPath,
65
+ });
66
+ } else {
67
+ throw new Error('Must provide either --model-id or --model-path');
68
+ }
69
+
70
+ // Test 1: Model loading validation
71
+ logger.info('='.repeat(60));
72
+ logger.info('TEST 1: Model Loading');
73
+ logger.info('='.repeat(60));
74
+
75
+ if (!model.storagePath) {
76
+ throw new Error('Model storage path not set');
77
+ }
78
+
79
+ const modelExists = await Bun.file(model.storagePath)
80
+ .exists()
81
+ .catch(() => false);
82
+ if (!modelExists && !config.modelPath) {
83
+ logger.warn('Model file not found at storage path', {
84
+ storagePath: model.storagePath,
85
+ });
86
+ } else {
87
+ logger.info('✅ Model path validated', {
88
+ path: model.storagePath || config.modelPath,
89
+ });
90
+ }
91
+
92
+ // Test 2: Benchmark if requested
93
+ if (config.benchmark) {
94
+ logger.info('='.repeat(60));
95
+ logger.info('TEST 2: Running Benchmark');
96
+ logger.info('='.repeat(60));
97
+
98
+ if (config.modelId) {
99
+ const benchmarkService = new BenchmarkService();
100
+ const results = await benchmarkService.benchmarkModel(
101
+ config.modelId,
102
+ config.benchmarkPath
103
+ );
104
+
105
+ logger.info('Benchmark Results:', {
106
+ score: results.benchmarkScore,
107
+ pnl: results.pnl,
108
+ accuracy: results.accuracy,
109
+ optimality: results.optimality,
110
+ });
111
+
112
+ // Compare to baseline if requested
113
+ if (config.compareToBaseline) {
114
+ const comparison = await benchmarkService.compareModels(config.modelId);
115
+ logger.info('Comparison to Baseline:', {
116
+ newScore: comparison.newScore,
117
+ previousScore: comparison.previousScore,
118
+ improvement: comparison.improvement,
119
+ shouldDeploy: comparison.shouldDeploy,
120
+ reason: comparison.reason,
121
+ });
122
+ }
123
+ } else {
124
+ logger.warn('Benchmark requires model-id (model must be in database)');
125
+ }
126
+ }
127
+
128
+ // Test 3: Inference test (if we can get runtime)
129
+ logger.info('='.repeat(60));
130
+ logger.info('TEST 3: Inference Test');
131
+ logger.info('='.repeat(60));
132
+
133
+ try {
134
+ // Get test agent
135
+ const testAgentResult = await db.select().from(trainedModels).limit(1);
136
+
137
+ if (testAgentResult.length > 0) {
138
+ logger.info('✅ Inference test setup available');
139
+ logger.info('Run full benchmark to test inference with real agent');
140
+ } else {
141
+ logger.warn('No test agent available for inference test');
142
+ }
143
+ } catch (error) {
144
+ logger.warn('Inference test skipped', {
145
+ error: error instanceof Error ? error.message : String(error),
146
+ });
147
+ }
148
+
149
+ // Summary
150
+ logger.info('='.repeat(60));
151
+ logger.info('TESTING COMPLETE');
152
+ logger.info('='.repeat(60));
153
+ logger.info('Model:', {
154
+ id: model.modelId,
155
+ version: model.version,
156
+ status: model.status,
157
+ });
158
+
159
+ if (model.benchmarkScore !== null) {
160
+ logger.info('Benchmark Score:', model.benchmarkScore);
161
+ }
162
+ }
163
+
164
+ async function main() {
165
+ const args = process.argv.slice(2);
166
+
167
+ const config: TestConfig = {};
168
+
169
+ for (let i = 0; i < args.length; i++) {
170
+ const arg = args[i];
171
+
172
+ if (arg === '--model-id' && i + 1 < args.length) {
173
+ config.modelId = args[i + 1];
174
+ i++;
175
+ } else if (arg === '--model-path' && i + 1 < args.length) {
176
+ config.modelPath = args[i + 1];
177
+ i++;
178
+ } else if (arg === '--benchmark') {
179
+ config.benchmark = true;
180
+ } else if (arg === '--benchmark-path' && i + 1 < args.length) {
181
+ config.benchmarkPath = args[i + 1];
182
+ i++;
183
+ } else if (arg === '--compare') {
184
+ config.compareToBaseline = true;
185
+ }
186
+ }
187
+
188
+ if (!config.modelId && !config.modelPath) {
189
+ console.error('Usage:');
190
+ console.error(
191
+ ' bun run test-trained-model.ts --model-id <id> [--benchmark] [--compare]'
192
+ );
193
+ console.error(
194
+ ' bun run test-trained-model.ts --model-path <path> [--benchmark]'
195
+ );
196
+ process.exit(1);
197
+ }
198
+
199
+ try {
200
+ await testModel(config);
201
+ } catch (error) {
202
+ logger.error('Testing failed', {
203
+ error: error instanceof Error ? error.message : String(error),
204
+ });
205
+ process.exit(1);
206
+ }
207
+ }
208
+
209
+ main();