@elizaos/training 2.0.0-alpha.13 → 2.0.0-alpha.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +2 -2
- package/research-output/training-runs/training-run-1773726941205.json +38 -0
- package/scripts/rank_trajectories.ts +0 -1
- package/scripts/run_task_benchmark.ts +4 -11
- package/src/adapter.ts +96 -49
- package/src/archetypes/ArchetypeConfigService.ts +188 -185
- package/src/archetypes/derive-archetype.ts +47 -47
- package/src/archetypes/index.ts +2 -2
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +70 -70
- package/src/benchmark/BenchmarkChartGenerator.ts +70 -69
- package/src/benchmark/BenchmarkDataGenerator.ts +136 -136
- package/src/benchmark/BenchmarkDataViewer.ts +32 -30
- package/src/benchmark/BenchmarkHistoryService.ts +13 -12
- package/src/benchmark/BenchmarkRunner.ts +87 -83
- package/src/benchmark/BenchmarkValidator.ts +48 -46
- package/src/benchmark/FastEvalRunner.ts +17 -16
- package/src/benchmark/MetricsValidator.ts +20 -21
- package/src/benchmark/MetricsVisualizer.ts +92 -85
- package/src/benchmark/ModelBenchmarkService.ts +90 -82
- package/src/benchmark/ModelRegistry.ts +44 -44
- package/src/benchmark/RulerBenchmarkIntegration.ts +24 -24
- package/src/benchmark/SimulationA2AInterface.ts +118 -118
- package/src/benchmark/SimulationEngine.ts +51 -51
- package/src/benchmark/TaskRunner.ts +87 -79
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +80 -80
- package/src/benchmark/__tests__/HeadToHead.test.ts +26 -26
- package/src/benchmark/index.ts +27 -27
- package/src/benchmark/parseSimulationMetrics.ts +32 -32
- package/src/benchmark/simulation-types.ts +10 -10
- package/src/dependencies.ts +34 -34
- package/src/generation/TrajectoryGenerator.ts +39 -37
- package/src/generation/index.ts +1 -1
- package/src/huggingface/HuggingFaceDatasetUploader.ts +72 -72
- package/src/huggingface/HuggingFaceIntegrationService.ts +59 -53
- package/src/huggingface/HuggingFaceModelUploader.ts +60 -59
- package/src/huggingface/index.ts +6 -6
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +32 -32
- package/src/index.ts +27 -27
- package/src/init-training.ts +6 -6
- package/src/metrics/TrajectoryMetricsExtractor.ts +70 -71
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +182 -182
- package/src/metrics/index.ts +2 -2
- package/src/rubrics/__tests__/index.test.ts +73 -73
- package/src/rubrics/ass-kisser.ts +6 -6
- package/src/rubrics/degen.ts +6 -6
- package/src/rubrics/goody-twoshoes.ts +6 -6
- package/src/rubrics/index.ts +50 -50
- package/src/rubrics/information-trader.ts +6 -6
- package/src/rubrics/infosec.ts +6 -6
- package/src/rubrics/liar.ts +6 -6
- package/src/rubrics/perps-trader.ts +6 -6
- package/src/rubrics/researcher.ts +6 -6
- package/src/rubrics/scammer.ts +6 -6
- package/src/rubrics/social-butterfly.ts +7 -7
- package/src/rubrics/super-predictor.ts +6 -6
- package/src/rubrics/trader.ts +5 -5
- package/src/scoring/ArchetypeScoringService.ts +56 -54
- package/src/scoring/JudgePromptBuilder.ts +96 -96
- package/src/scoring/LLMJudgeCache.ts +26 -23
- package/src/scoring/index.ts +3 -3
- package/src/training/AutomationPipeline.ts +149 -140
- package/src/training/BenchmarkService.ts +49 -45
- package/src/training/ConfigValidator.ts +38 -32
- package/src/training/MarketOutcomesTracker.ts +22 -12
- package/src/training/ModelDeployer.ts +15 -15
- package/src/training/ModelFetcher.ts +7 -7
- package/src/training/ModelSelectionService.ts +32 -32
- package/src/training/ModelUsageVerifier.ts +31 -24
- package/src/training/MultiModelOrchestrator.ts +44 -44
- package/src/training/RLModelConfig.ts +57 -57
- package/src/training/RewardBackpropagationService.ts +18 -17
- package/src/training/RulerScoringService.ts +73 -72
- package/src/training/TrainingMonitor.ts +29 -29
- package/src/training/TrajectoryRecorder.ts +25 -27
- package/src/training/__tests__/TrajectoryRecorder.test.ts +105 -105
- package/src/training/index.ts +36 -36
- package/src/training/logRLConfig.ts +7 -7
- package/src/training/pipeline.ts +13 -16
- package/src/training/storage/ModelStorageService.ts +32 -32
- package/src/training/storage/TrainingDataArchiver.ts +21 -21
- package/src/training/storage/index.ts +2 -2
- package/src/training/types.ts +6 -6
- package/src/training/window-utils.ts +14 -14
- package/src/utils/index.ts +7 -7
- package/src/utils/logger.ts +5 -5
- package/src/utils/snowflake.ts +1 -1
- package/src/utils/synthetic-detector.ts +7 -7
|
@@ -4,8 +4,8 @@
|
|
|
4
4
|
* Fetches trained RL models from the database for inference.
|
|
5
5
|
*/
|
|
6
6
|
|
|
7
|
-
import { getTrainingDataAdapter } from
|
|
8
|
-
import { logger } from
|
|
7
|
+
import { getTrainingDataAdapter } from "../adapter";
|
|
8
|
+
import { logger } from "../utils/logger";
|
|
9
9
|
|
|
10
10
|
export interface ModelArtifact {
|
|
11
11
|
version: string;
|
|
@@ -33,7 +33,7 @@ export async function getLatestRLModel(): Promise<ModelArtifact | null> {
|
|
|
33
33
|
}
|
|
34
34
|
|
|
35
35
|
// Skip models that aren't ready or deployed
|
|
36
|
-
if (model.status !==
|
|
36
|
+
if (model.status !== "ready" && model.status !== "deployed") {
|
|
37
37
|
return null;
|
|
38
38
|
}
|
|
39
39
|
|
|
@@ -41,23 +41,23 @@ export async function getLatestRLModel(): Promise<ModelArtifact | null> {
|
|
|
41
41
|
|
|
42
42
|
if (!rlModelId || rlModelId.trim().length === 0) {
|
|
43
43
|
logger.error(
|
|
44
|
-
|
|
44
|
+
"Model has no storagePath or modelId",
|
|
45
45
|
{
|
|
46
46
|
modelId: model.modelId,
|
|
47
47
|
storagePath: model.storagePath,
|
|
48
48
|
},
|
|
49
|
-
|
|
49
|
+
"ModelFetcher",
|
|
50
50
|
);
|
|
51
51
|
return null;
|
|
52
52
|
}
|
|
53
53
|
|
|
54
54
|
if (!model.baseModel || model.baseModel.trim().length === 0) {
|
|
55
55
|
logger.error(
|
|
56
|
-
|
|
56
|
+
"Model has no baseModel",
|
|
57
57
|
{
|
|
58
58
|
modelId: model.modelId,
|
|
59
59
|
},
|
|
60
|
-
|
|
60
|
+
"ModelFetcher",
|
|
61
61
|
);
|
|
62
62
|
return null;
|
|
63
63
|
}
|
|
@@ -7,13 +7,13 @@
|
|
|
7
7
|
* 3. Performance of previous models
|
|
8
8
|
*/
|
|
9
9
|
|
|
10
|
-
import { getTrainingDataAdapter } from
|
|
11
|
-
import { logger } from
|
|
10
|
+
import { getTrainingDataAdapter } from "../adapter";
|
|
11
|
+
import { logger } from "../utils/logger";
|
|
12
12
|
|
|
13
13
|
export interface ModelSelectionResult {
|
|
14
14
|
modelId: string;
|
|
15
15
|
modelPath: string;
|
|
16
|
-
strategy:
|
|
16
|
+
strategy: "base" | "continue" | "force_first";
|
|
17
17
|
reason: string;
|
|
18
18
|
metadata?: {
|
|
19
19
|
bundleCount?: number;
|
|
@@ -32,7 +32,7 @@ export interface TrainingBundle {
|
|
|
32
32
|
export class ModelSelectionService {
|
|
33
33
|
/** Default base model - uses Qwen3-4B-128K (4B params, 128K context). Scale up via MODEL_TIER or AVAILABLE_VRAM_GB env vars */
|
|
34
34
|
private readonly BASE_MODEL =
|
|
35
|
-
process.env.BASE_MODEL ||
|
|
35
|
+
process.env.BASE_MODEL || "unsloth/Qwen3-4B-128K";
|
|
36
36
|
private readonly BUNDLE_THRESHOLD = 1000;
|
|
37
37
|
private readonly MIN_BUNDLES_FOR_TRAINING = 100;
|
|
38
38
|
private readonly MAX_TRAINING_EXAMPLES = 2000;
|
|
@@ -61,9 +61,9 @@ export class ModelSelectionService {
|
|
|
61
61
|
*/
|
|
62
62
|
async selectBaseModel(): Promise<ModelSelectionResult> {
|
|
63
63
|
logger.info(
|
|
64
|
-
|
|
64
|
+
"Selecting base model for training...",
|
|
65
65
|
undefined,
|
|
66
|
-
|
|
66
|
+
"ModelSelectionService",
|
|
67
67
|
);
|
|
68
68
|
|
|
69
69
|
// Count available training bundles (always fetch for accurate metrics)
|
|
@@ -74,15 +74,15 @@ export class ModelSelectionService {
|
|
|
74
74
|
|
|
75
75
|
if (forceFirst) {
|
|
76
76
|
logger.info(
|
|
77
|
-
|
|
77
|
+
"No models exist - forcing first model creation",
|
|
78
78
|
undefined,
|
|
79
|
-
|
|
79
|
+
"ModelSelectionService",
|
|
80
80
|
);
|
|
81
81
|
return {
|
|
82
82
|
modelId: this.BASE_MODEL,
|
|
83
83
|
modelPath: this.BASE_MODEL,
|
|
84
|
-
strategy:
|
|
85
|
-
reason:
|
|
84
|
+
strategy: "force_first",
|
|
85
|
+
reason: "No trained models exist - creating first model from base",
|
|
86
86
|
metadata: {
|
|
87
87
|
baseModel: this.BASE_MODEL,
|
|
88
88
|
bundleCount, // Use actual count, not 0
|
|
@@ -92,14 +92,14 @@ export class ModelSelectionService {
|
|
|
92
92
|
logger.info(
|
|
93
93
|
`Found ${bundleCount} training bundles`,
|
|
94
94
|
undefined,
|
|
95
|
-
|
|
95
|
+
"ModelSelectionService",
|
|
96
96
|
);
|
|
97
97
|
|
|
98
98
|
// Not enough data yet
|
|
99
99
|
if (bundleCount < this.MIN_BUNDLES_FOR_TRAINING) {
|
|
100
100
|
throw new Error(
|
|
101
101
|
`Insufficient training data: ${bundleCount} bundles ` +
|
|
102
|
-
`(need ${this.MIN_BUNDLES_FOR_TRAINING} minimum)
|
|
102
|
+
`(need ${this.MIN_BUNDLES_FOR_TRAINING} minimum)`,
|
|
103
103
|
);
|
|
104
104
|
}
|
|
105
105
|
|
|
@@ -108,12 +108,12 @@ export class ModelSelectionService {
|
|
|
108
108
|
logger.info(
|
|
109
109
|
`Bundle count ${bundleCount} < ${this.BUNDLE_THRESHOLD} - using base model`,
|
|
110
110
|
undefined,
|
|
111
|
-
|
|
111
|
+
"ModelSelectionService",
|
|
112
112
|
);
|
|
113
113
|
return {
|
|
114
114
|
modelId: this.BASE_MODEL,
|
|
115
115
|
modelPath: this.BASE_MODEL,
|
|
116
|
-
strategy:
|
|
116
|
+
strategy: "base",
|
|
117
117
|
reason: `Training from base model (${bundleCount} bundles < ${this.BUNDLE_THRESHOLD} threshold)`,
|
|
118
118
|
metadata: {
|
|
119
119
|
bundleCount,
|
|
@@ -127,15 +127,15 @@ export class ModelSelectionService {
|
|
|
127
127
|
|
|
128
128
|
if (!bestModel) {
|
|
129
129
|
logger.warn(
|
|
130
|
-
|
|
130
|
+
"No best model found despite bundle threshold - using base model",
|
|
131
131
|
undefined,
|
|
132
|
-
|
|
132
|
+
"ModelSelectionService",
|
|
133
133
|
);
|
|
134
134
|
return {
|
|
135
135
|
modelId: this.BASE_MODEL,
|
|
136
136
|
modelPath: this.BASE_MODEL,
|
|
137
|
-
strategy:
|
|
138
|
-
reason:
|
|
137
|
+
strategy: "base",
|
|
138
|
+
reason: "No previous models available - using base model",
|
|
139
139
|
metadata: {
|
|
140
140
|
bundleCount,
|
|
141
141
|
baseModel: this.BASE_MODEL,
|
|
@@ -149,7 +149,7 @@ export class ModelSelectionService {
|
|
|
149
149
|
bestModelId: bestModel.modelId,
|
|
150
150
|
bestScore: bestModel.benchmarkScore,
|
|
151
151
|
},
|
|
152
|
-
|
|
152
|
+
"ModelSelectionService",
|
|
153
153
|
);
|
|
154
154
|
|
|
155
155
|
// Use storagePath for model path (e.g., HuggingFace URL)
|
|
@@ -158,8 +158,8 @@ export class ModelSelectionService {
|
|
|
158
158
|
return {
|
|
159
159
|
modelId: bestModel.modelId,
|
|
160
160
|
modelPath: modelStoragePath,
|
|
161
|
-
strategy:
|
|
162
|
-
reason: `Continuing from best model (score: ${bestModel.benchmarkScore?.toFixed(3) ||
|
|
161
|
+
strategy: "continue",
|
|
162
|
+
reason: `Continuing from best model (score: ${bestModel.benchmarkScore?.toFixed(3) || "N/A"})`,
|
|
163
163
|
metadata: {
|
|
164
164
|
bundleCount,
|
|
165
165
|
bestModelScore: bestModel.benchmarkScore || undefined,
|
|
@@ -185,22 +185,22 @@ export class ModelSelectionService {
|
|
|
185
185
|
|
|
186
186
|
if (!model) {
|
|
187
187
|
logger.warn(
|
|
188
|
-
|
|
188
|
+
"No benchmarked models found",
|
|
189
189
|
undefined,
|
|
190
|
-
|
|
190
|
+
"ModelSelectionService",
|
|
191
191
|
);
|
|
192
192
|
return null;
|
|
193
193
|
}
|
|
194
194
|
|
|
195
195
|
logger.info(
|
|
196
|
-
|
|
196
|
+
"Found best performing model",
|
|
197
197
|
{
|
|
198
198
|
modelId: model.modelId,
|
|
199
199
|
version: model.version,
|
|
200
200
|
benchmarkScore: model.benchmarkScore,
|
|
201
201
|
avgReward: model.avgReward,
|
|
202
202
|
},
|
|
203
|
-
|
|
203
|
+
"ModelSelectionService",
|
|
204
204
|
);
|
|
205
205
|
|
|
206
206
|
return model;
|
|
@@ -278,13 +278,13 @@ export class ModelSelectionService {
|
|
|
278
278
|
*/
|
|
279
279
|
async getTrainingTrajectories(limit?: number | null) {
|
|
280
280
|
const result = await getTrainingDataAdapter().getTrainingTrajectories(
|
|
281
|
-
limit ?? undefined
|
|
281
|
+
limit ?? undefined,
|
|
282
282
|
);
|
|
283
283
|
|
|
284
284
|
logger.info(
|
|
285
285
|
`Retrieved ${result.length} trajectories for training`,
|
|
286
286
|
{ limit, available: result.length },
|
|
287
|
-
|
|
287
|
+
"ModelSelectionService",
|
|
288
288
|
);
|
|
289
289
|
|
|
290
290
|
return result;
|
|
@@ -316,15 +316,15 @@ export class ModelSelectionService {
|
|
|
316
316
|
const trainedModelCount = await this.countTrainedModels();
|
|
317
317
|
const bestModel = await this.getBestPerformingModel();
|
|
318
318
|
|
|
319
|
-
let recommendation =
|
|
319
|
+
let recommendation = "";
|
|
320
320
|
if (trainedModelCount === 0) {
|
|
321
|
-
recommendation =
|
|
321
|
+
recommendation = "Force first model creation";
|
|
322
322
|
} else if (bundleCount < this.MIN_BUNDLES_FOR_TRAINING) {
|
|
323
|
-
recommendation =
|
|
323
|
+
recommendation = "Not ready - need more data";
|
|
324
324
|
} else if (bundleCount < this.BUNDLE_THRESHOLD) {
|
|
325
|
-
recommendation =
|
|
325
|
+
recommendation = "Train from base model";
|
|
326
326
|
} else {
|
|
327
|
-
recommendation =
|
|
327
|
+
recommendation = "Train from best performing model";
|
|
328
328
|
}
|
|
329
329
|
|
|
330
330
|
return {
|
|
@@ -5,14 +5,14 @@
|
|
|
5
5
|
* Provides assertions and logging for model usage verification.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
import {
|
|
9
|
-
import type { IAgentRuntimeLike } from
|
|
10
|
-
import { logger } from
|
|
8
|
+
import { getLlmLogAdapter, getTrainingDataAdapter } from "../adapter";
|
|
9
|
+
import type { IAgentRuntimeLike } from "../dependencies";
|
|
10
|
+
import { logger } from "../utils/logger";
|
|
11
11
|
|
|
12
12
|
export interface ModelUsageStats {
|
|
13
13
|
agentId: string;
|
|
14
14
|
modelUsed: string;
|
|
15
|
-
modelSource:
|
|
15
|
+
modelSource: "groq" | "claude" | "openai" | "unknown";
|
|
16
16
|
inferenceCount: number;
|
|
17
17
|
}
|
|
18
18
|
|
|
@@ -36,7 +36,7 @@ export class ModelUsageVerifier {
|
|
|
36
36
|
*/
|
|
37
37
|
static async verifyAgentModelUsage(
|
|
38
38
|
agentUserId: string,
|
|
39
|
-
runtime: IAgentRuntimeLike
|
|
39
|
+
runtime: IAgentRuntimeLike,
|
|
40
40
|
): Promise<ModelUsageStats> {
|
|
41
41
|
const character = (runtime as Record<string, unknown>).character as
|
|
42
42
|
| { settings?: Record<string, unknown> }
|
|
@@ -45,30 +45,31 @@ export class ModelUsageVerifier {
|
|
|
45
45
|
|
|
46
46
|
// Check for different model providers
|
|
47
47
|
const groqModel = String(
|
|
48
|
-
settings?.GROQ_LARGE_MODEL || settings?.GROQ_SMALL_MODEL ||
|
|
48
|
+
settings?.GROQ_LARGE_MODEL || settings?.GROQ_SMALL_MODEL || "",
|
|
49
49
|
);
|
|
50
|
-
const claudeModel = String(settings?.CLAUDE_MODEL ||
|
|
51
|
-
const openaiModel = String(settings?.OPENAI_MODEL ||
|
|
50
|
+
const claudeModel = String(settings?.CLAUDE_MODEL || "");
|
|
51
|
+
const openaiModel = String(settings?.OPENAI_MODEL || "");
|
|
52
52
|
|
|
53
53
|
let modelUsed: string;
|
|
54
|
-
let modelSource:
|
|
54
|
+
let modelSource: "groq" | "claude" | "openai" | "unknown";
|
|
55
55
|
|
|
56
56
|
if (claudeModel) {
|
|
57
57
|
modelUsed = claudeModel;
|
|
58
|
-
modelSource =
|
|
58
|
+
modelSource = "claude";
|
|
59
59
|
} else if (openaiModel) {
|
|
60
60
|
modelUsed = openaiModel;
|
|
61
|
-
modelSource =
|
|
61
|
+
modelSource = "openai";
|
|
62
62
|
} else if (groqModel) {
|
|
63
63
|
modelUsed = groqModel;
|
|
64
|
-
modelSource =
|
|
64
|
+
modelSource = "groq";
|
|
65
65
|
} else {
|
|
66
|
-
modelUsed =
|
|
67
|
-
modelSource =
|
|
66
|
+
modelUsed = "unknown";
|
|
67
|
+
modelSource = "unknown";
|
|
68
68
|
}
|
|
69
69
|
|
|
70
70
|
// Count inferences from logs (using trajectoryId)
|
|
71
|
-
const trajectoryIds =
|
|
71
|
+
const trajectoryIds =
|
|
72
|
+
await getTrainingDataAdapter().getTrajectoryIdsByAgent(agentUserId);
|
|
72
73
|
|
|
73
74
|
const twentyFourHoursAgo = new Date(Date.now() - 24 * 60 * 60 * 1000);
|
|
74
75
|
|
|
@@ -77,7 +78,7 @@ export class ModelUsageVerifier {
|
|
|
77
78
|
if (llmAdapter && trajectoryIds.length > 0) {
|
|
78
79
|
inferenceCount = await llmAdapter.countRecentLLMCalls(
|
|
79
80
|
trajectoryIds,
|
|
80
|
-
twentyFourHoursAgo
|
|
81
|
+
twentyFourHoursAgo,
|
|
81
82
|
);
|
|
82
83
|
}
|
|
83
84
|
|
|
@@ -94,7 +95,7 @@ export class ModelUsageVerifier {
|
|
|
94
95
|
*/
|
|
95
96
|
static async verifyMultipleAgents(
|
|
96
97
|
agentUserIds: string[],
|
|
97
|
-
runtimes: Map<string, IAgentRuntimeLike
|
|
98
|
+
runtimes: Map<string, IAgentRuntimeLike>,
|
|
98
99
|
): Promise<VerificationResult> {
|
|
99
100
|
const details: ModelUsageStats[] = [];
|
|
100
101
|
const errors: string[] = [];
|
|
@@ -106,7 +107,10 @@ export class ModelUsageVerifier {
|
|
|
106
107
|
continue;
|
|
107
108
|
}
|
|
108
109
|
|
|
109
|
-
const stats = await
|
|
110
|
+
const stats = await ModelUsageVerifier.verifyAgentModelUsage(
|
|
111
|
+
agentId,
|
|
112
|
+
runtime,
|
|
113
|
+
);
|
|
110
114
|
details.push(stats);
|
|
111
115
|
}
|
|
112
116
|
|
|
@@ -123,25 +127,28 @@ export class ModelUsageVerifier {
|
|
|
123
127
|
*/
|
|
124
128
|
static async assertModelUsage(
|
|
125
129
|
agentUserId: string,
|
|
126
|
-
runtime: IAgentRuntimeLike
|
|
130
|
+
runtime: IAgentRuntimeLike,
|
|
127
131
|
): Promise<void> {
|
|
128
|
-
const stats = await
|
|
132
|
+
const stats = await ModelUsageVerifier.verifyAgentModelUsage(
|
|
133
|
+
agentUserId,
|
|
134
|
+
runtime,
|
|
135
|
+
);
|
|
129
136
|
|
|
130
|
-
if (stats.modelSource ===
|
|
137
|
+
if (stats.modelSource === "unknown") {
|
|
131
138
|
throw new Error(
|
|
132
139
|
`Agent ${agentUserId} has no configured model. ` +
|
|
133
|
-
`Using: ${stats.modelUsed}
|
|
140
|
+
`Using: ${stats.modelUsed}`,
|
|
134
141
|
);
|
|
135
142
|
}
|
|
136
143
|
|
|
137
144
|
logger.info(
|
|
138
|
-
|
|
145
|
+
"Model usage verified",
|
|
139
146
|
{
|
|
140
147
|
agentId: agentUserId,
|
|
141
148
|
model: stats.modelUsed,
|
|
142
149
|
source: stats.modelSource,
|
|
143
150
|
},
|
|
144
|
-
|
|
151
|
+
"ModelUsageVerifier",
|
|
145
152
|
);
|
|
146
153
|
}
|
|
147
154
|
|
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
* - Real vLLM/OpenAI-compatible API integration
|
|
12
12
|
*/
|
|
13
13
|
|
|
14
|
-
import { logger } from
|
|
14
|
+
import { logger } from "../utils/logger";
|
|
15
15
|
import {
|
|
16
16
|
getModelForArchetype as getArchetypeModel,
|
|
17
17
|
getMultiModelConfig,
|
|
@@ -20,7 +20,7 @@ import {
|
|
|
20
20
|
type ModelTier,
|
|
21
21
|
type MultiModelConfig,
|
|
22
22
|
type QuantizationMode,
|
|
23
|
-
} from
|
|
23
|
+
} from "./RLModelConfig";
|
|
24
24
|
|
|
25
25
|
/**
|
|
26
26
|
* Loaded model state
|
|
@@ -107,9 +107,9 @@ export class MultiModelOrchestrator {
|
|
|
107
107
|
|
|
108
108
|
constructor(config: OrchestratorConfig) {
|
|
109
109
|
this.config = {
|
|
110
|
-
vllmBaseUrl: process.env.VLLM_BASE_URL ||
|
|
110
|
+
vllmBaseUrl: process.env.VLLM_BASE_URL || "http://localhost:9001",
|
|
111
111
|
fallbackApiUrl:
|
|
112
|
-
process.env.GROQ_API_URL ||
|
|
112
|
+
process.env.GROQ_API_URL || "https://api.groq.com/openai/v1",
|
|
113
113
|
fallbackApiKey: process.env.GROQ_API_KEY,
|
|
114
114
|
inferenceTimeoutMs: 30000,
|
|
115
115
|
...config,
|
|
@@ -117,7 +117,7 @@ export class MultiModelOrchestrator {
|
|
|
117
117
|
this.multiModelConfig = getMultiModelConfig(config.availableVramGb);
|
|
118
118
|
|
|
119
119
|
logger.info(
|
|
120
|
-
|
|
120
|
+
"MultiModelOrchestrator initialized",
|
|
121
121
|
{
|
|
122
122
|
availableVram: `${config.availableVramGb}GB`,
|
|
123
123
|
maxConcurrentModels: this.multiModelConfig.maxConcurrentModels,
|
|
@@ -126,7 +126,7 @@ export class MultiModelOrchestrator {
|
|
|
126
126
|
vllmUrl: this.config.vllmBaseUrl,
|
|
127
127
|
hasFallback: !!this.config.fallbackApiKey,
|
|
128
128
|
},
|
|
129
|
-
|
|
129
|
+
"MultiModelOrchestrator",
|
|
130
130
|
);
|
|
131
131
|
}
|
|
132
132
|
|
|
@@ -150,9 +150,9 @@ export class MultiModelOrchestrator {
|
|
|
150
150
|
|
|
151
151
|
if (this.vllmAvailable) {
|
|
152
152
|
logger.info(
|
|
153
|
-
|
|
153
|
+
"vLLM server is available",
|
|
154
154
|
{ url: this.config.vllmBaseUrl },
|
|
155
|
-
|
|
155
|
+
"MultiModelOrchestrator",
|
|
156
156
|
);
|
|
157
157
|
}
|
|
158
158
|
|
|
@@ -161,9 +161,9 @@ export class MultiModelOrchestrator {
|
|
|
161
161
|
clearTimeout(timeout);
|
|
162
162
|
this.vllmAvailable = false;
|
|
163
163
|
logger.warn(
|
|
164
|
-
|
|
164
|
+
"vLLM server not available, will use fallback",
|
|
165
165
|
{ url: this.config.vllmBaseUrl },
|
|
166
|
-
|
|
166
|
+
"MultiModelOrchestrator",
|
|
167
167
|
);
|
|
168
168
|
return false;
|
|
169
169
|
}
|
|
@@ -187,7 +187,7 @@ export class MultiModelOrchestrator {
|
|
|
187
187
|
quantization: this.config.defaultQuantization,
|
|
188
188
|
vramGb: getVramRequirement(
|
|
189
189
|
this.config.defaultTier,
|
|
190
|
-
this.config.defaultQuantization
|
|
190
|
+
this.config.defaultQuantization,
|
|
191
191
|
),
|
|
192
192
|
};
|
|
193
193
|
}
|
|
@@ -237,7 +237,7 @@ export class MultiModelOrchestrator {
|
|
|
237
237
|
freedVram: `${model.vramUsageGb}GB`,
|
|
238
238
|
currentUsage: `${this.currentVramUsageGb}GB`,
|
|
239
239
|
},
|
|
240
|
-
|
|
240
|
+
"MultiModelOrchestrator",
|
|
241
241
|
);
|
|
242
242
|
}
|
|
243
243
|
}
|
|
@@ -263,7 +263,7 @@ export class MultiModelOrchestrator {
|
|
|
263
263
|
if (!this.canLoadModel(modelInfo.vramGb)) {
|
|
264
264
|
throw new Error(
|
|
265
265
|
`Cannot load model for ${archetype}: insufficient VRAM. ` +
|
|
266
|
-
`Required: ${modelInfo.vramGb}GB, Available: ${this.config.availableVramGb - this.currentVramUsageGb}GB
|
|
266
|
+
`Required: ${modelInfo.vramGb}GB, Available: ${this.config.availableVramGb - this.currentVramUsageGb}GB`,
|
|
267
267
|
);
|
|
268
268
|
}
|
|
269
269
|
|
|
@@ -288,7 +288,7 @@ export class MultiModelOrchestrator {
|
|
|
288
288
|
totalVramUsed: `${this.currentVramUsageGb}GB`,
|
|
289
289
|
modelsLoaded: this.loadedModels.size,
|
|
290
290
|
},
|
|
291
|
-
|
|
291
|
+
"MultiModelOrchestrator",
|
|
292
292
|
);
|
|
293
293
|
|
|
294
294
|
return loadedModel;
|
|
@@ -302,32 +302,32 @@ export class MultiModelOrchestrator {
|
|
|
302
302
|
prompt: string,
|
|
303
303
|
systemPrompt: string,
|
|
304
304
|
maxTokens: number,
|
|
305
|
-
temperature: number
|
|
305
|
+
temperature: number,
|
|
306
306
|
): Promise<CompletionResponse> {
|
|
307
307
|
const controller = new AbortController();
|
|
308
308
|
const timeout = setTimeout(
|
|
309
309
|
() => controller.abort(),
|
|
310
|
-
this.config.inferenceTimeoutMs
|
|
310
|
+
this.config.inferenceTimeoutMs,
|
|
311
311
|
);
|
|
312
312
|
|
|
313
313
|
const response = await fetch(
|
|
314
314
|
`${this.config.vllmBaseUrl}/v1/chat/completions`,
|
|
315
315
|
{
|
|
316
|
-
method:
|
|
316
|
+
method: "POST",
|
|
317
317
|
headers: {
|
|
318
|
-
|
|
318
|
+
"Content-Type": "application/json",
|
|
319
319
|
},
|
|
320
320
|
body: JSON.stringify({
|
|
321
321
|
model: modelId,
|
|
322
322
|
messages: [
|
|
323
|
-
{ role:
|
|
324
|
-
{ role:
|
|
323
|
+
{ role: "system", content: systemPrompt },
|
|
324
|
+
{ role: "user", content: prompt },
|
|
325
325
|
],
|
|
326
326
|
max_tokens: maxTokens,
|
|
327
327
|
temperature,
|
|
328
328
|
}),
|
|
329
329
|
signal: controller.signal,
|
|
330
|
-
}
|
|
330
|
+
},
|
|
331
331
|
);
|
|
332
332
|
|
|
333
333
|
clearTimeout(timeout);
|
|
@@ -347,42 +347,42 @@ export class MultiModelOrchestrator {
|
|
|
347
347
|
prompt: string,
|
|
348
348
|
systemPrompt: string,
|
|
349
349
|
maxTokens: number,
|
|
350
|
-
temperature: number
|
|
350
|
+
temperature: number,
|
|
351
351
|
): Promise<CompletionResponse> {
|
|
352
352
|
if (!this.config.fallbackApiKey) {
|
|
353
353
|
throw new Error(
|
|
354
|
-
|
|
354
|
+
"No fallback API key configured. Set GROQ_API_KEY environment variable.",
|
|
355
355
|
);
|
|
356
356
|
}
|
|
357
357
|
|
|
358
358
|
const controller = new AbortController();
|
|
359
359
|
const timeout = setTimeout(
|
|
360
360
|
() => controller.abort(),
|
|
361
|
-
this.config.inferenceTimeoutMs
|
|
361
|
+
this.config.inferenceTimeoutMs,
|
|
362
362
|
);
|
|
363
363
|
|
|
364
364
|
// Use a fast model for fallback
|
|
365
|
-
const fallbackModel =
|
|
365
|
+
const fallbackModel = "llama-3.1-8b-instant";
|
|
366
366
|
|
|
367
367
|
const response = await fetch(
|
|
368
368
|
`${this.config.fallbackApiUrl}/chat/completions`,
|
|
369
369
|
{
|
|
370
|
-
method:
|
|
370
|
+
method: "POST",
|
|
371
371
|
headers: {
|
|
372
|
-
|
|
372
|
+
"Content-Type": "application/json",
|
|
373
373
|
Authorization: `Bearer ${this.config.fallbackApiKey}`,
|
|
374
374
|
},
|
|
375
375
|
body: JSON.stringify({
|
|
376
376
|
model: fallbackModel,
|
|
377
377
|
messages: [
|
|
378
|
-
{ role:
|
|
379
|
-
{ role:
|
|
378
|
+
{ role: "system", content: systemPrompt },
|
|
379
|
+
{ role: "user", content: prompt },
|
|
380
380
|
],
|
|
381
381
|
max_tokens: maxTokens,
|
|
382
382
|
temperature,
|
|
383
383
|
}),
|
|
384
384
|
signal: controller.signal,
|
|
385
|
-
}
|
|
385
|
+
},
|
|
386
386
|
);
|
|
387
387
|
|
|
388
388
|
clearTimeout(timeout);
|
|
@@ -390,7 +390,7 @@ export class MultiModelOrchestrator {
|
|
|
390
390
|
if (!response.ok) {
|
|
391
391
|
const error = await response.text();
|
|
392
392
|
throw new Error(
|
|
393
|
-
`Fallback API request failed: ${response.status} - ${error}
|
|
393
|
+
`Fallback API request failed: ${response.status} - ${error}`,
|
|
394
394
|
);
|
|
395
395
|
}
|
|
396
396
|
|
|
@@ -401,7 +401,7 @@ export class MultiModelOrchestrator {
|
|
|
401
401
|
* Run inference for an archetype
|
|
402
402
|
*/
|
|
403
403
|
async inference(
|
|
404
|
-
request: ModelInferenceRequest
|
|
404
|
+
request: ModelInferenceRequest,
|
|
405
405
|
): Promise<ModelInferenceResult> {
|
|
406
406
|
const startTime = Date.now();
|
|
407
407
|
|
|
@@ -427,7 +427,7 @@ export class MultiModelOrchestrator {
|
|
|
427
427
|
request.prompt,
|
|
428
428
|
systemPrompt,
|
|
429
429
|
maxTokens,
|
|
430
|
-
temperature
|
|
430
|
+
temperature,
|
|
431
431
|
);
|
|
432
432
|
} else {
|
|
433
433
|
// Fall back to Groq/OpenAI
|
|
@@ -435,12 +435,12 @@ export class MultiModelOrchestrator {
|
|
|
435
435
|
request.prompt,
|
|
436
436
|
systemPrompt,
|
|
437
437
|
maxTokens,
|
|
438
|
-
temperature
|
|
438
|
+
temperature,
|
|
439
439
|
);
|
|
440
440
|
}
|
|
441
441
|
|
|
442
442
|
const latencyMs = Date.now() - startTime;
|
|
443
|
-
const response = completion.choices[0]?.message.content ||
|
|
443
|
+
const response = completion.choices[0]?.message.content || "";
|
|
444
444
|
const tokensGenerated = completion.usage?.completion_tokens || 0;
|
|
445
445
|
|
|
446
446
|
logger.debug(
|
|
@@ -451,7 +451,7 @@ export class MultiModelOrchestrator {
|
|
|
451
451
|
tokensGenerated,
|
|
452
452
|
usedVllm: vllmAvailable,
|
|
453
453
|
},
|
|
454
|
-
|
|
454
|
+
"MultiModelOrchestrator",
|
|
455
455
|
);
|
|
456
456
|
|
|
457
457
|
return {
|
|
@@ -469,12 +469,12 @@ export class MultiModelOrchestrator {
|
|
|
469
469
|
logger.error(
|
|
470
470
|
`Inference failed for ${request.archetype}`,
|
|
471
471
|
{ error: errorMessage, latencyMs },
|
|
472
|
-
|
|
472
|
+
"MultiModelOrchestrator",
|
|
473
473
|
);
|
|
474
474
|
|
|
475
475
|
return {
|
|
476
476
|
archetype: request.archetype,
|
|
477
|
-
response:
|
|
477
|
+
response: "",
|
|
478
478
|
modelId: model.modelId,
|
|
479
479
|
latencyMs,
|
|
480
480
|
tokensGenerated: 0,
|
|
@@ -487,7 +487,7 @@ export class MultiModelOrchestrator {
|
|
|
487
487
|
* Batch inference for multiple archetypes
|
|
488
488
|
*/
|
|
489
489
|
async batchInference(
|
|
490
|
-
requests: ModelInferenceRequest[]
|
|
490
|
+
requests: ModelInferenceRequest[],
|
|
491
491
|
): Promise<ModelInferenceResult[]> {
|
|
492
492
|
// Group requests by archetype for efficient batching
|
|
493
493
|
const byArchetype = new Map<string, ModelInferenceRequest[]>();
|
|
@@ -509,7 +509,7 @@ export class MultiModelOrchestrator {
|
|
|
509
509
|
for (let i = 0; i < archetypeRequests.length; i += batchSize) {
|
|
510
510
|
const batch = archetypeRequests.slice(i, i + batchSize);
|
|
511
511
|
const batchResults = await Promise.all(
|
|
512
|
-
batch.map((req) => this.inference(req))
|
|
512
|
+
batch.map((req) => this.inference(req)),
|
|
513
513
|
);
|
|
514
514
|
results.push(...batchResults);
|
|
515
515
|
}
|
|
@@ -555,7 +555,7 @@ export class MultiModelOrchestrator {
|
|
|
555
555
|
unloadAll(): void {
|
|
556
556
|
this.loadedModels.clear();
|
|
557
557
|
this.currentVramUsageGb = 0;
|
|
558
|
-
logger.info(
|
|
558
|
+
logger.info("Unloaded all models", {}, "MultiModelOrchestrator");
|
|
559
559
|
}
|
|
560
560
|
|
|
561
561
|
/**
|
|
@@ -570,11 +570,11 @@ export class MultiModelOrchestrator {
|
|
|
570
570
|
* Create a multi-model orchestrator with sensible defaults for RTX 5090 (16GB)
|
|
571
571
|
*/
|
|
572
572
|
export function createMultiModelOrchestrator(
|
|
573
|
-
vramGb = 16
|
|
573
|
+
vramGb = 16,
|
|
574
574
|
): MultiModelOrchestrator {
|
|
575
575
|
return new MultiModelOrchestrator({
|
|
576
576
|
availableVramGb: vramGb,
|
|
577
|
-
defaultTier:
|
|
578
|
-
defaultQuantization:
|
|
577
|
+
defaultTier: "small",
|
|
578
|
+
defaultQuantization: "4bit",
|
|
579
579
|
});
|
|
580
580
|
}
|