@elizaos/training 2.0.0-alpha.77 → 2.0.0-alpha.78
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +2 -2
- package/.turbo/turbo-lint.log +0 -3
- package/.turbo/turbo-typecheck.log +0 -1
- package/dist/.tsbuildinfo +0 -1
- package/dist/adapter.js +0 -59
- package/dist/archetypes/ArchetypeConfigService.js +0 -510
- package/dist/archetypes/derive-archetype.js +0 -196
- package/dist/archetypes/index.js +0 -7
- package/dist/benchmark/ArchetypeMatchupBenchmark.js +0 -547
- package/dist/benchmark/BenchmarkChartGenerator.js +0 -632
- package/dist/benchmark/BenchmarkDataGenerator.js +0 -825
- package/dist/benchmark/BenchmarkDataViewer.js +0 -197
- package/dist/benchmark/BenchmarkHistoryService.js +0 -135
- package/dist/benchmark/BenchmarkRunner.js +0 -483
- package/dist/benchmark/BenchmarkValidator.js +0 -158
- package/dist/benchmark/FastEvalRunner.js +0 -133
- package/dist/benchmark/MetricsValidator.js +0 -104
- package/dist/benchmark/MetricsVisualizer.js +0 -775
- package/dist/benchmark/ModelBenchmarkService.js +0 -433
- package/dist/benchmark/ModelRegistry.js +0 -122
- package/dist/benchmark/RulerBenchmarkIntegration.js +0 -168
- package/dist/benchmark/SimulationA2AInterface.js +0 -683
- package/dist/benchmark/SimulationEngine.js +0 -522
- package/dist/benchmark/TaskRunner.js +0 -60
- package/dist/benchmark/__tests__/BenchmarkRunner.test.js +0 -409
- package/dist/benchmark/__tests__/HeadToHead.test.js +0 -105
- package/dist/benchmark/index.js +0 -23
- package/dist/benchmark/parseSimulationMetrics.js +0 -86
- package/dist/benchmark/simulation-types.js +0 -1
- package/dist/dependencies.js +0 -197
- package/dist/generation/TrajectoryGenerator.js +0 -244
- package/dist/generation/index.js +0 -6
- package/dist/huggingface/HuggingFaceDatasetUploader.js +0 -463
- package/dist/huggingface/HuggingFaceIntegrationService.js +0 -272
- package/dist/huggingface/HuggingFaceModelUploader.js +0 -385
- package/dist/huggingface/index.js +0 -9
- package/dist/huggingface/shared/HuggingFaceUploadUtil.js +0 -144
- package/dist/index.js +0 -41
- package/dist/init-training.js +0 -43
- package/dist/metrics/TrajectoryMetricsExtractor.js +0 -523
- package/dist/metrics/__tests__/TrajectoryMetricsExtractor.test.js +0 -628
- package/dist/metrics/index.js +0 -7
- package/dist/metrics/types.js +0 -21
- package/dist/rubrics/__tests__/index.test.js +0 -150
- package/dist/rubrics/ass-kisser.js +0 -83
- package/dist/rubrics/degen.js +0 -78
- package/dist/rubrics/goody-twoshoes.js +0 -82
- package/dist/rubrics/index.js +0 -184
- package/dist/rubrics/information-trader.js +0 -82
- package/dist/rubrics/infosec.js +0 -99
- package/dist/rubrics/liar.js +0 -102
- package/dist/rubrics/perps-trader.js +0 -85
- package/dist/rubrics/researcher.js +0 -79
- package/dist/rubrics/scammer.js +0 -80
- package/dist/rubrics/social-butterfly.js +0 -71
- package/dist/rubrics/super-predictor.js +0 -95
- package/dist/rubrics/trader.js +0 -65
- package/dist/scoring/ArchetypeScoringService.js +0 -301
- package/dist/scoring/JudgePromptBuilder.js +0 -401
- package/dist/scoring/LLMJudgeCache.js +0 -263
- package/dist/scoring/index.js +0 -8
- package/dist/training/AutomationPipeline.js +0 -714
- package/dist/training/BenchmarkService.js +0 -370
- package/dist/training/ConfigValidator.js +0 -153
- package/dist/training/MarketOutcomesTracker.js +0 -142
- package/dist/training/ModelDeployer.js +0 -128
- package/dist/training/ModelFetcher.js +0 -48
- package/dist/training/ModelSelectionService.js +0 -248
- package/dist/training/ModelUsageVerifier.js +0 -106
- package/dist/training/MultiModelOrchestrator.js +0 -349
- package/dist/training/RLModelConfig.js +0 -295
- package/dist/training/RewardBackpropagationService.js +0 -117
- package/dist/training/RulerScoringService.js +0 -450
- package/dist/training/TrainingMonitor.js +0 -108
- package/dist/training/TrajectoryRecorder.js +0 -281
- package/dist/training/__tests__/TrajectoryRecorder.test.js +0 -363
- package/dist/training/index.js +0 -30
- package/dist/training/logRLConfig.js +0 -29
- package/dist/training/pipeline.js +0 -80
- package/dist/training/storage/ModelStorageService.js +0 -190
- package/dist/training/storage/TrainingDataArchiver.js +0 -136
- package/dist/training/storage/index.js +0 -7
- package/dist/training/types.js +0 -6
- package/dist/training/window-utils.js +0 -100
- package/dist/utils/index.js +0 -73
- package/dist/utils/logger.js +0 -55
- package/dist/utils/snowflake.js +0 -15
- package/dist/utils/synthetic-detector.js +0 -67
- package/vitest.config.ts +0 -8
|
@@ -1,349 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Multi-Model Orchestrator
|
|
3
|
-
*
|
|
4
|
-
* Manages loading and inference for multiple archetype-specific models
|
|
5
|
-
* within VRAM constraints. Optimized for 16GB GPUs (RTX 5090).
|
|
6
|
-
*
|
|
7
|
-
* Strategy:
|
|
8
|
-
* - Use 4-bit quantization to fit 4+ models in 16GB VRAM
|
|
9
|
-
* - LRU cache for model loading/unloading
|
|
10
|
-
* - Batch inference per archetype for efficiency
|
|
11
|
-
* - Real vLLM/OpenAI-compatible API integration
|
|
12
|
-
*/
|
|
13
|
-
import { logger } from "../utils/logger";
|
|
14
|
-
import { getModelForArchetype as getArchetypeModel, getMultiModelConfig, getQuantizedModelName, getVramRequirement, } from "./RLModelConfig";
|
|
15
|
-
/**
|
|
16
|
-
* Multi-Model Orchestrator
|
|
17
|
-
*
|
|
18
|
-
* Manages multiple quantized models for archetype-specific inference.
|
|
19
|
-
* Uses LRU eviction when VRAM is constrained.
|
|
20
|
-
*/
|
|
21
|
-
export class MultiModelOrchestrator {
|
|
22
|
-
config;
|
|
23
|
-
multiModelConfig;
|
|
24
|
-
loadedModels = new Map();
|
|
25
|
-
currentVramUsageGb = 0;
|
|
26
|
-
vllmAvailable = null;
|
|
27
|
-
constructor(config) {
|
|
28
|
-
this.config = {
|
|
29
|
-
vllmBaseUrl: process.env.VLLM_BASE_URL || "http://localhost:9001",
|
|
30
|
-
fallbackApiUrl: process.env.GROQ_API_URL || "https://api.groq.com/openai/v1",
|
|
31
|
-
fallbackApiKey: process.env.GROQ_API_KEY,
|
|
32
|
-
inferenceTimeoutMs: 30000,
|
|
33
|
-
...config,
|
|
34
|
-
};
|
|
35
|
-
this.multiModelConfig = getMultiModelConfig(config.availableVramGb);
|
|
36
|
-
logger.info("MultiModelOrchestrator initialized", {
|
|
37
|
-
availableVram: `${config.availableVramGb}GB`,
|
|
38
|
-
maxConcurrentModels: this.multiModelConfig.maxConcurrentModels,
|
|
39
|
-
quantization: this.multiModelConfig.quantization,
|
|
40
|
-
tier: this.multiModelConfig.modelTier,
|
|
41
|
-
vllmUrl: this.config.vllmBaseUrl,
|
|
42
|
-
hasFallback: !!this.config.fallbackApiKey,
|
|
43
|
-
}, "MultiModelOrchestrator");
|
|
44
|
-
}
|
|
45
|
-
/**
|
|
46
|
-
* Check if vLLM server is available
|
|
47
|
-
*/
|
|
48
|
-
async checkVllmAvailability() {
|
|
49
|
-
if (this.vllmAvailable !== null) {
|
|
50
|
-
return this.vllmAvailable;
|
|
51
|
-
}
|
|
52
|
-
const controller = new AbortController();
|
|
53
|
-
const timeout = setTimeout(() => controller.abort(), 5000);
|
|
54
|
-
try {
|
|
55
|
-
const response = await fetch(`${this.config.vllmBaseUrl}/v1/models`, {
|
|
56
|
-
signal: controller.signal,
|
|
57
|
-
});
|
|
58
|
-
clearTimeout(timeout);
|
|
59
|
-
this.vllmAvailable = response.ok;
|
|
60
|
-
if (this.vllmAvailable) {
|
|
61
|
-
logger.info("vLLM server is available", { url: this.config.vllmBaseUrl }, "MultiModelOrchestrator");
|
|
62
|
-
}
|
|
63
|
-
return this.vllmAvailable;
|
|
64
|
-
}
|
|
65
|
-
catch {
|
|
66
|
-
clearTimeout(timeout);
|
|
67
|
-
this.vllmAvailable = false;
|
|
68
|
-
logger.warn("vLLM server not available, will use fallback", { url: this.config.vllmBaseUrl }, "MultiModelOrchestrator");
|
|
69
|
-
return false;
|
|
70
|
-
}
|
|
71
|
-
}
|
|
72
|
-
/**
|
|
73
|
-
* Get model info for an archetype
|
|
74
|
-
*/
|
|
75
|
-
getModelForArchetype(archetype) {
|
|
76
|
-
// Check if there's a trained archetype-specific model
|
|
77
|
-
const archetypeModel = getArchetypeModel(archetype);
|
|
78
|
-
if (archetypeModel) {
|
|
79
|
-
return {
|
|
80
|
-
modelId: archetypeModel.modelPath || archetypeModel.modelId,
|
|
81
|
-
tier: this.config.defaultTier,
|
|
82
|
-
quantization: this.config.defaultQuantization,
|
|
83
|
-
vramGb: getVramRequirement(this.config.defaultTier, this.config.defaultQuantization),
|
|
84
|
-
};
|
|
85
|
-
}
|
|
86
|
-
// Use default tier and quantization
|
|
87
|
-
const tier = this.config.defaultTier;
|
|
88
|
-
const quantization = this.config.defaultQuantization;
|
|
89
|
-
const modelId = getQuantizedModelName(tier, quantization);
|
|
90
|
-
const vramGb = getVramRequirement(tier, quantization);
|
|
91
|
-
return { modelId, tier, quantization, vramGb };
|
|
92
|
-
}
|
|
93
|
-
/**
|
|
94
|
-
* Check if we can load a model given current VRAM usage
|
|
95
|
-
*/
|
|
96
|
-
canLoadModel(vramRequired) {
|
|
97
|
-
const availableVram = this.config.availableVramGb - this.currentVramUsageGb;
|
|
98
|
-
return availableVram >= vramRequired;
|
|
99
|
-
}
|
|
100
|
-
/**
|
|
101
|
-
* Evict least recently used model to free VRAM
|
|
102
|
-
*/
|
|
103
|
-
evictLRUModel() {
|
|
104
|
-
if (this.loadedModels.size === 0)
|
|
105
|
-
return;
|
|
106
|
-
let lruArchetype = null;
|
|
107
|
-
let lruTime = Infinity;
|
|
108
|
-
for (const [archetype, model] of this.loadedModels) {
|
|
109
|
-
if (model.lastUsed < lruTime) {
|
|
110
|
-
lruTime = model.lastUsed;
|
|
111
|
-
lruArchetype = archetype;
|
|
112
|
-
}
|
|
113
|
-
}
|
|
114
|
-
if (lruArchetype) {
|
|
115
|
-
const model = this.loadedModels.get(lruArchetype);
|
|
116
|
-
if (model) {
|
|
117
|
-
this.currentVramUsageGb -= model.vramUsageGb;
|
|
118
|
-
this.loadedModels.delete(lruArchetype);
|
|
119
|
-
logger.info(`Evicted model for archetype: ${lruArchetype}`, {
|
|
120
|
-
freedVram: `${model.vramUsageGb}GB`,
|
|
121
|
-
currentUsage: `${this.currentVramUsageGb}GB`,
|
|
122
|
-
}, "MultiModelOrchestrator");
|
|
123
|
-
}
|
|
124
|
-
}
|
|
125
|
-
}
|
|
126
|
-
/**
|
|
127
|
-
* Load a model for an archetype (tracks VRAM usage)
|
|
128
|
-
*/
|
|
129
|
-
async loadModelForArchetype(archetype) {
|
|
130
|
-
const existing = this.loadedModels.get(archetype);
|
|
131
|
-
if (existing) {
|
|
132
|
-
existing.lastUsed = Date.now();
|
|
133
|
-
return existing;
|
|
134
|
-
}
|
|
135
|
-
const modelInfo = this.getModelForArchetype(archetype);
|
|
136
|
-
// Evict models if necessary to make room
|
|
137
|
-
while (!this.canLoadModel(modelInfo.vramGb) && this.loadedModels.size > 0) {
|
|
138
|
-
this.evictLRUModel();
|
|
139
|
-
}
|
|
140
|
-
if (!this.canLoadModel(modelInfo.vramGb)) {
|
|
141
|
-
throw new Error(`Cannot load model for ${archetype}: insufficient VRAM. ` +
|
|
142
|
-
`Required: ${modelInfo.vramGb}GB, Available: ${this.config.availableVramGb - this.currentVramUsageGb}GB`);
|
|
143
|
-
}
|
|
144
|
-
const loadedModel = {
|
|
145
|
-
archetype,
|
|
146
|
-
modelId: modelInfo.modelId,
|
|
147
|
-
tier: modelInfo.tier,
|
|
148
|
-
quantization: modelInfo.quantization,
|
|
149
|
-
vramUsageGb: modelInfo.vramGb,
|
|
150
|
-
lastUsed: Date.now(),
|
|
151
|
-
inferenceCount: 0,
|
|
152
|
-
};
|
|
153
|
-
this.loadedModels.set(archetype, loadedModel);
|
|
154
|
-
this.currentVramUsageGb += modelInfo.vramGb;
|
|
155
|
-
logger.info(`Loaded model for archetype: ${archetype}`, {
|
|
156
|
-
modelId: modelInfo.modelId,
|
|
157
|
-
vramUsed: `${modelInfo.vramGb}GB`,
|
|
158
|
-
totalVramUsed: `${this.currentVramUsageGb}GB`,
|
|
159
|
-
modelsLoaded: this.loadedModels.size,
|
|
160
|
-
}, "MultiModelOrchestrator");
|
|
161
|
-
return loadedModel;
|
|
162
|
-
}
|
|
163
|
-
/**
|
|
164
|
-
* Call vLLM server for inference
|
|
165
|
-
*/
|
|
166
|
-
async callVllm(modelId, prompt, systemPrompt, maxTokens, temperature) {
|
|
167
|
-
const controller = new AbortController();
|
|
168
|
-
const timeout = setTimeout(() => controller.abort(), this.config.inferenceTimeoutMs);
|
|
169
|
-
const response = await fetch(`${this.config.vllmBaseUrl}/v1/chat/completions`, {
|
|
170
|
-
method: "POST",
|
|
171
|
-
headers: {
|
|
172
|
-
"Content-Type": "application/json",
|
|
173
|
-
},
|
|
174
|
-
body: JSON.stringify({
|
|
175
|
-
model: modelId,
|
|
176
|
-
messages: [
|
|
177
|
-
{ role: "system", content: systemPrompt },
|
|
178
|
-
{ role: "user", content: prompt },
|
|
179
|
-
],
|
|
180
|
-
max_tokens: maxTokens,
|
|
181
|
-
temperature,
|
|
182
|
-
}),
|
|
183
|
-
signal: controller.signal,
|
|
184
|
-
});
|
|
185
|
-
clearTimeout(timeout);
|
|
186
|
-
if (!response.ok) {
|
|
187
|
-
const error = await response.text();
|
|
188
|
-
throw new Error(`vLLM request failed: ${response.status} - ${error}`);
|
|
189
|
-
}
|
|
190
|
-
return response.json();
|
|
191
|
-
}
|
|
192
|
-
/**
|
|
193
|
-
* Call fallback API (Groq/OpenAI) for inference
|
|
194
|
-
*/
|
|
195
|
-
async callFallbackApi(prompt, systemPrompt, maxTokens, temperature) {
|
|
196
|
-
if (!this.config.fallbackApiKey) {
|
|
197
|
-
throw new Error("No fallback API key configured. Set GROQ_API_KEY environment variable.");
|
|
198
|
-
}
|
|
199
|
-
const controller = new AbortController();
|
|
200
|
-
const timeout = setTimeout(() => controller.abort(), this.config.inferenceTimeoutMs);
|
|
201
|
-
// Use a fast model for fallback
|
|
202
|
-
const fallbackModel = "llama-3.1-8b-instant";
|
|
203
|
-
const response = await fetch(`${this.config.fallbackApiUrl}/chat/completions`, {
|
|
204
|
-
method: "POST",
|
|
205
|
-
headers: {
|
|
206
|
-
"Content-Type": "application/json",
|
|
207
|
-
Authorization: `Bearer ${this.config.fallbackApiKey}`,
|
|
208
|
-
},
|
|
209
|
-
body: JSON.stringify({
|
|
210
|
-
model: fallbackModel,
|
|
211
|
-
messages: [
|
|
212
|
-
{ role: "system", content: systemPrompt },
|
|
213
|
-
{ role: "user", content: prompt },
|
|
214
|
-
],
|
|
215
|
-
max_tokens: maxTokens,
|
|
216
|
-
temperature,
|
|
217
|
-
}),
|
|
218
|
-
signal: controller.signal,
|
|
219
|
-
});
|
|
220
|
-
clearTimeout(timeout);
|
|
221
|
-
if (!response.ok) {
|
|
222
|
-
const error = await response.text();
|
|
223
|
-
throw new Error(`Fallback API request failed: ${response.status} - ${error}`);
|
|
224
|
-
}
|
|
225
|
-
return response.json();
|
|
226
|
-
}
|
|
227
|
-
/**
|
|
228
|
-
* Run inference for an archetype
|
|
229
|
-
*/
|
|
230
|
-
async inference(request) {
|
|
231
|
-
const startTime = Date.now();
|
|
232
|
-
// Ensure model is loaded (for VRAM tracking)
|
|
233
|
-
const model = await this.loadModelForArchetype(request.archetype);
|
|
234
|
-
model.inferenceCount++;
|
|
235
|
-
const systemPrompt = request.systemPrompt ||
|
|
236
|
-
`You are an AI agent with the ${request.archetype} archetype. Respond appropriately to the given situation.`;
|
|
237
|
-
const maxTokens = request.maxTokens || 512;
|
|
238
|
-
const temperature = request.temperature ?? 0.7;
|
|
239
|
-
try {
|
|
240
|
-
// Try vLLM first
|
|
241
|
-
const vllmAvailable = await this.checkVllmAvailability();
|
|
242
|
-
let completion;
|
|
243
|
-
if (vllmAvailable) {
|
|
244
|
-
completion = await this.callVllm(model.modelId, request.prompt, systemPrompt, maxTokens, temperature);
|
|
245
|
-
}
|
|
246
|
-
else {
|
|
247
|
-
// Fall back to Groq/OpenAI
|
|
248
|
-
completion = await this.callFallbackApi(request.prompt, systemPrompt, maxTokens, temperature);
|
|
249
|
-
}
|
|
250
|
-
const latencyMs = Date.now() - startTime;
|
|
251
|
-
const response = completion.choices[0]?.message.content || "";
|
|
252
|
-
const tokensGenerated = completion.usage?.completion_tokens || 0;
|
|
253
|
-
logger.debug(`Inference completed for ${request.archetype}`, {
|
|
254
|
-
modelId: model.modelId,
|
|
255
|
-
latencyMs,
|
|
256
|
-
tokensGenerated,
|
|
257
|
-
usedVllm: vllmAvailable,
|
|
258
|
-
}, "MultiModelOrchestrator");
|
|
259
|
-
return {
|
|
260
|
-
archetype: request.archetype,
|
|
261
|
-
response,
|
|
262
|
-
modelId: model.modelId,
|
|
263
|
-
latencyMs,
|
|
264
|
-
tokensGenerated,
|
|
265
|
-
};
|
|
266
|
-
}
|
|
267
|
-
catch (error) {
|
|
268
|
-
const latencyMs = Date.now() - startTime;
|
|
269
|
-
const errorMessage = error instanceof Error ? error.message : String(error);
|
|
270
|
-
logger.error(`Inference failed for ${request.archetype}`, { error: errorMessage, latencyMs }, "MultiModelOrchestrator");
|
|
271
|
-
return {
|
|
272
|
-
archetype: request.archetype,
|
|
273
|
-
response: "",
|
|
274
|
-
modelId: model.modelId,
|
|
275
|
-
latencyMs,
|
|
276
|
-
tokensGenerated: 0,
|
|
277
|
-
error: errorMessage,
|
|
278
|
-
};
|
|
279
|
-
}
|
|
280
|
-
}
|
|
281
|
-
/**
|
|
282
|
-
* Batch inference for multiple archetypes
|
|
283
|
-
*/
|
|
284
|
-
async batchInference(requests) {
|
|
285
|
-
// Group requests by archetype for efficient batching
|
|
286
|
-
const byArchetype = new Map();
|
|
287
|
-
for (const req of requests) {
|
|
288
|
-
const existing = byArchetype.get(req.archetype) || [];
|
|
289
|
-
existing.push(req);
|
|
290
|
-
byArchetype.set(req.archetype, existing);
|
|
291
|
-
}
|
|
292
|
-
const results = [];
|
|
293
|
-
// Process each archetype's requests
|
|
294
|
-
for (const [archetype, archetypeRequests] of byArchetype) {
|
|
295
|
-
// Load model once per archetype
|
|
296
|
-
await this.loadModelForArchetype(archetype);
|
|
297
|
-
// Process all requests for this archetype in parallel (up to 5 concurrent)
|
|
298
|
-
const batchSize = 5;
|
|
299
|
-
for (let i = 0; i < archetypeRequests.length; i += batchSize) {
|
|
300
|
-
const batch = archetypeRequests.slice(i, i + batchSize);
|
|
301
|
-
const batchResults = await Promise.all(batch.map((req) => this.inference(req)));
|
|
302
|
-
results.push(...batchResults);
|
|
303
|
-
}
|
|
304
|
-
}
|
|
305
|
-
return results;
|
|
306
|
-
}
|
|
307
|
-
/**
|
|
308
|
-
* Get current orchestrator status
|
|
309
|
-
*/
|
|
310
|
-
getStatus() {
|
|
311
|
-
const loadedModels = Array.from(this.loadedModels.values()).map((m) => ({
|
|
312
|
-
archetype: m.archetype,
|
|
313
|
-
modelId: m.modelId,
|
|
314
|
-
vramGb: m.vramUsageGb,
|
|
315
|
-
inferenceCount: m.inferenceCount,
|
|
316
|
-
}));
|
|
317
|
-
return {
|
|
318
|
-
loadedModels,
|
|
319
|
-
totalVramUsed: this.currentVramUsageGb,
|
|
320
|
-
availableVram: this.config.availableVramGb - this.currentVramUsageGb,
|
|
321
|
-
maxConcurrentModels: this.multiModelConfig.maxConcurrentModels,
|
|
322
|
-
vllmAvailable: this.vllmAvailable,
|
|
323
|
-
};
|
|
324
|
-
}
|
|
325
|
-
/**
|
|
326
|
-
* Unload all models
|
|
327
|
-
*/
|
|
328
|
-
unloadAll() {
|
|
329
|
-
this.loadedModels.clear();
|
|
330
|
-
this.currentVramUsageGb = 0;
|
|
331
|
-
logger.info("Unloaded all models", {}, "MultiModelOrchestrator");
|
|
332
|
-
}
|
|
333
|
-
/**
|
|
334
|
-
* Reset vLLM availability check (force re-check on next inference)
|
|
335
|
-
*/
|
|
336
|
-
resetAvailabilityCheck() {
|
|
337
|
-
this.vllmAvailable = null;
|
|
338
|
-
}
|
|
339
|
-
}
|
|
340
|
-
/**
|
|
341
|
-
* Create a multi-model orchestrator with sensible defaults for RTX 5090 (16GB)
|
|
342
|
-
*/
|
|
343
|
-
export function createMultiModelOrchestrator(vramGb = 16) {
|
|
344
|
-
return new MultiModelOrchestrator({
|
|
345
|
-
availableVramGb: vramGb,
|
|
346
|
-
defaultTier: "small",
|
|
347
|
-
defaultQuantization: "4bit",
|
|
348
|
-
});
|
|
349
|
-
}
|
|
@@ -1,295 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* RL Model Configuration
|
|
3
|
-
*
|
|
4
|
-
* Controls when and how RL-trained models are used for inference.
|
|
5
|
-
* Designed to be:
|
|
6
|
-
* - Enabled by default in local development
|
|
7
|
-
* - Disabled by default in production
|
|
8
|
-
* - Easy to toggle via environment variables
|
|
9
|
-
* - Scalable to larger models when more memory is available
|
|
10
|
-
* - Support quantized models for efficient multi-model loading
|
|
11
|
-
*/
|
|
12
|
-
/**
|
|
13
|
-
* Available model tiers - scale up when resources allow
|
|
14
|
-
* All models have 128K context (critical requirement)
|
|
15
|
-
* Quantized models reduce VRAM by ~4x (4-bit) or ~2x (8-bit)
|
|
16
|
-
*/
|
|
17
|
-
export const MODEL_TIERS = {
|
|
18
|
-
small: {
|
|
19
|
-
name: "Small (4B)",
|
|
20
|
-
model: "unsloth/Qwen3-4B-128K",
|
|
21
|
-
quantizedModel4bit: "unsloth/Qwen3-4B-128K-bnb-4bit",
|
|
22
|
-
quantizedModel8bit: "unsloth/Qwen3-4B-128K-GGUF",
|
|
23
|
-
params: "4B",
|
|
24
|
-
context: 131072, // 128K context
|
|
25
|
-
minVramGb: 8,
|
|
26
|
-
minVramGb4bit: 3,
|
|
27
|
-
minVramGb8bit: 5,
|
|
28
|
-
},
|
|
29
|
-
medium: {
|
|
30
|
-
name: "Medium (8B)",
|
|
31
|
-
model: "unsloth/Qwen3-8B-128K",
|
|
32
|
-
quantizedModel4bit: "unsloth/Qwen3-8B-128K-bnb-4bit",
|
|
33
|
-
quantizedModel8bit: "unsloth/Qwen3-8B-128K-GGUF",
|
|
34
|
-
params: "8B",
|
|
35
|
-
context: 131072, // 128K context
|
|
36
|
-
minVramGb: 16,
|
|
37
|
-
minVramGb4bit: 5,
|
|
38
|
-
minVramGb8bit: 9,
|
|
39
|
-
},
|
|
40
|
-
large: {
|
|
41
|
-
name: "Large (14B)",
|
|
42
|
-
model: "unsloth/Qwen3-14B-128K",
|
|
43
|
-
quantizedModel4bit: "unsloth/Qwen3-14B-128K-bnb-4bit",
|
|
44
|
-
quantizedModel8bit: "unsloth/Qwen3-14B-128K-GGUF",
|
|
45
|
-
params: "14B",
|
|
46
|
-
context: 131072, // 128K context
|
|
47
|
-
minVramGb: 24,
|
|
48
|
-
minVramGb4bit: 8,
|
|
49
|
-
minVramGb8bit: 14,
|
|
50
|
-
},
|
|
51
|
-
xlarge: {
|
|
52
|
-
name: "XLarge (32B)",
|
|
53
|
-
model: "unsloth/Qwen3-32B-128K",
|
|
54
|
-
quantizedModel4bit: "unsloth/Qwen3-32B-128K-bnb-4bit",
|
|
55
|
-
quantizedModel8bit: "unsloth/Qwen3-32B-128K-GGUF",
|
|
56
|
-
params: "32B",
|
|
57
|
-
context: 131072, // 128K context
|
|
58
|
-
minVramGb: 48,
|
|
59
|
-
minVramGb4bit: 16,
|
|
60
|
-
minVramGb8bit: 28,
|
|
61
|
-
},
|
|
62
|
-
};
|
|
63
|
-
/**
|
|
64
|
-
* Calculate optimal multi-model configuration for available VRAM
|
|
65
|
-
* Optimizes for running multiple archetype models simultaneously
|
|
66
|
-
*/
|
|
67
|
-
export function getMultiModelConfig(vramGb) {
|
|
68
|
-
// For 16GB VRAM, we want to run 4+ models using 4-bit quantization
|
|
69
|
-
// Each 4B model at 4-bit uses ~3GB VRAM
|
|
70
|
-
// Each 8B model at 4-bit uses ~5GB VRAM
|
|
71
|
-
if (vramGb >= 16) {
|
|
72
|
-
// 16GB: Can run 4x 4B models (4-bit) or 3x 8B models (4-bit)
|
|
73
|
-
// Prefer 4B for more archetype coverage
|
|
74
|
-
return {
|
|
75
|
-
totalVramGb: vramGb,
|
|
76
|
-
maxConcurrentModels: 4,
|
|
77
|
-
quantization: "4bit",
|
|
78
|
-
modelTier: "small",
|
|
79
|
-
};
|
|
80
|
-
}
|
|
81
|
-
else if (vramGb >= 12) {
|
|
82
|
-
// 12GB: Can run 3x 4B models (4-bit)
|
|
83
|
-
return {
|
|
84
|
-
totalVramGb: vramGb,
|
|
85
|
-
maxConcurrentModels: 3,
|
|
86
|
-
quantization: "4bit",
|
|
87
|
-
modelTier: "small",
|
|
88
|
-
};
|
|
89
|
-
}
|
|
90
|
-
else if (vramGb >= 8) {
|
|
91
|
-
// 8GB: Can run 2x 4B models (4-bit)
|
|
92
|
-
return {
|
|
93
|
-
totalVramGb: vramGb,
|
|
94
|
-
maxConcurrentModels: 2,
|
|
95
|
-
quantization: "4bit",
|
|
96
|
-
modelTier: "small",
|
|
97
|
-
};
|
|
98
|
-
}
|
|
99
|
-
// Less than 8GB: Single model only
|
|
100
|
-
return {
|
|
101
|
-
totalVramGb: vramGb,
|
|
102
|
-
maxConcurrentModels: 1,
|
|
103
|
-
quantization: "4bit",
|
|
104
|
-
modelTier: "small",
|
|
105
|
-
};
|
|
106
|
-
}
|
|
107
|
-
/**
|
|
108
|
-
* Get the model name based on quantization mode
|
|
109
|
-
*/
|
|
110
|
-
export function getQuantizedModelName(tier, quantization) {
|
|
111
|
-
const tierConfig = MODEL_TIERS[tier];
|
|
112
|
-
switch (quantization) {
|
|
113
|
-
case "4bit":
|
|
114
|
-
return tierConfig.quantizedModel4bit || tierConfig.model;
|
|
115
|
-
case "8bit":
|
|
116
|
-
return tierConfig.quantizedModel8bit || tierConfig.model;
|
|
117
|
-
default:
|
|
118
|
-
return tierConfig.model;
|
|
119
|
-
}
|
|
120
|
-
}
|
|
121
|
-
/**
|
|
122
|
-
* Get VRAM requirement based on tier and quantization
|
|
123
|
-
*/
|
|
124
|
-
export function getVramRequirement(tier, quantization) {
|
|
125
|
-
const tierConfig = MODEL_TIERS[tier];
|
|
126
|
-
switch (quantization) {
|
|
127
|
-
case "4bit":
|
|
128
|
-
return tierConfig.minVramGb4bit;
|
|
129
|
-
case "8bit":
|
|
130
|
-
return tierConfig.minVramGb8bit;
|
|
131
|
-
default:
|
|
132
|
-
return tierConfig.minVramGb;
|
|
133
|
-
}
|
|
134
|
-
}
|
|
135
|
-
/**
|
|
136
|
-
* Registry of trained models per archetype
|
|
137
|
-
* Maps archetype -> best available model
|
|
138
|
-
*/
|
|
139
|
-
const archetypeModelRegistry = new Map();
|
|
140
|
-
/**
|
|
141
|
-
* Register a trained model for an archetype
|
|
142
|
-
*/
|
|
143
|
-
export function registerArchetypeModel(config) {
|
|
144
|
-
const existing = archetypeModelRegistry.get(config.archetype);
|
|
145
|
-
if (!existing ||
|
|
146
|
-
(config.benchmarkScore &&
|
|
147
|
-
(!existing.benchmarkScore ||
|
|
148
|
-
config.benchmarkScore > existing.benchmarkScore))) {
|
|
149
|
-
archetypeModelRegistry.set(config.archetype, config);
|
|
150
|
-
console.log(`📦 Registered model for archetype '${config.archetype}': ${config.modelId}`);
|
|
151
|
-
}
|
|
152
|
-
}
|
|
153
|
-
/**
|
|
154
|
-
* Get the best model for a specific archetype
|
|
155
|
-
* Falls back to base model if no archetype-specific model exists
|
|
156
|
-
*/
|
|
157
|
-
export function getModelForArchetype(archetype) {
|
|
158
|
-
const normalized = archetype.toLowerCase().trim().replace(/_/g, "-");
|
|
159
|
-
return archetypeModelRegistry.get(normalized) || null;
|
|
160
|
-
}
|
|
161
|
-
/**
|
|
162
|
-
* Get all registered archetype models
|
|
163
|
-
*/
|
|
164
|
-
export function getAllArchetypeModels() {
|
|
165
|
-
return Array.from(archetypeModelRegistry.values());
|
|
166
|
-
}
|
|
167
|
-
/**
|
|
168
|
-
* Check if an archetype has a trained model
|
|
169
|
-
*/
|
|
170
|
-
export function hasArchetypeModel(archetype) {
|
|
171
|
-
const normalized = archetype.toLowerCase().trim().replace(/_/g, "-");
|
|
172
|
-
return archetypeModelRegistry.has(normalized);
|
|
173
|
-
}
|
|
174
|
-
/**
|
|
175
|
-
* Clear all registered models
|
|
176
|
-
*/
|
|
177
|
-
export function clearArchetypeModels() {
|
|
178
|
-
archetypeModelRegistry.clear();
|
|
179
|
-
}
|
|
180
|
-
/**
|
|
181
|
-
* Get the appropriate model tier based on available VRAM
|
|
182
|
-
*/
|
|
183
|
-
export function getModelTierForVram(vramGb) {
|
|
184
|
-
if (vramGb >= MODEL_TIERS.xlarge.minVramGb)
|
|
185
|
-
return "xlarge";
|
|
186
|
-
if (vramGb >= MODEL_TIERS.large.minVramGb)
|
|
187
|
-
return "large";
|
|
188
|
-
if (vramGb >= MODEL_TIERS.medium.minVramGb)
|
|
189
|
-
return "medium";
|
|
190
|
-
return "small";
|
|
191
|
-
}
|
|
192
|
-
/**
|
|
193
|
-
* Get model for a specific tier
|
|
194
|
-
*/
|
|
195
|
-
export function getModelForTier(tier) {
|
|
196
|
-
return MODEL_TIERS[tier].model;
|
|
197
|
-
}
|
|
198
|
-
/**
|
|
199
|
-
* Get RL model configuration from environment
|
|
200
|
-
*/
|
|
201
|
-
export function getRLModelConfig() {
|
|
202
|
-
const isProduction = process.env.NODE_ENV === "production";
|
|
203
|
-
const isLocal = process.env.NODE_ENV === "development" || !isProduction;
|
|
204
|
-
// Explicit enable/disable flag
|
|
205
|
-
const explicitFlag = process.env.USE_RL_MODEL;
|
|
206
|
-
// Determine if enabled:
|
|
207
|
-
// - If USE_RL_MODEL is explicitly set, use that value
|
|
208
|
-
// - Otherwise, enabled in local, disabled in production
|
|
209
|
-
const enabled = explicitFlag ? explicitFlag === "true" : isLocal;
|
|
210
|
-
// Check for explicit tier or VRAM override
|
|
211
|
-
const explicitTier = process.env.MODEL_TIER;
|
|
212
|
-
const explicitVram = process.env.AVAILABLE_VRAM_GB
|
|
213
|
-
? parseInt(process.env.AVAILABLE_VRAM_GB, 10)
|
|
214
|
-
: 16; // Default to 16GB (RTX 5090)
|
|
215
|
-
// Determine quantization mode: explicit or auto-detect based on VRAM
|
|
216
|
-
const explicitQuant = process.env.MODEL_QUANTIZATION;
|
|
217
|
-
const quantization = explicitQuant || "4bit"; // Default to 4-bit for efficiency
|
|
218
|
-
// Get multi-model config based on available VRAM
|
|
219
|
-
const multiModelConfig = getMultiModelConfig(explicitVram);
|
|
220
|
-
// Determine tier: explicit tier > tier from multi-model config > default small
|
|
221
|
-
let modelTier = "small";
|
|
222
|
-
if (explicitTier && MODEL_TIERS[explicitTier]) {
|
|
223
|
-
modelTier = explicitTier;
|
|
224
|
-
}
|
|
225
|
-
else {
|
|
226
|
-
modelTier = multiModelConfig.modelTier;
|
|
227
|
-
}
|
|
228
|
-
// Use explicit BASE_MODEL if set, otherwise use quantized tier-based model
|
|
229
|
-
const baseModel = process.env.BASE_MODEL || getQuantizedModelName(modelTier, quantization);
|
|
230
|
-
return {
|
|
231
|
-
enabled,
|
|
232
|
-
atroposApiUrl: process.env.ATROPOS_API_URL || "http://localhost:8000",
|
|
233
|
-
vllmPort: parseInt(process.env.VLLM_PORT || "9001", 10),
|
|
234
|
-
modelVersion: process.env.RL_MODEL_VERSION, // Optional: pin to specific version
|
|
235
|
-
fallbackToBase: process.env.RL_FALLBACK_TO_BASE !== "false", // Default: true
|
|
236
|
-
baseModel,
|
|
237
|
-
modelTier,
|
|
238
|
-
availableVramGb: explicitVram,
|
|
239
|
-
quantization,
|
|
240
|
-
multiModelConfig,
|
|
241
|
-
};
|
|
242
|
-
}
|
|
243
|
-
/**
|
|
244
|
-
* Check if RL models are available and configured
|
|
245
|
-
*/
|
|
246
|
-
export function isRLModelAvailable() {
|
|
247
|
-
const config = getRLModelConfig();
|
|
248
|
-
if (!config.enabled) {
|
|
249
|
-
return false;
|
|
250
|
-
}
|
|
251
|
-
// Need Atropos API URL to fetch RL models
|
|
252
|
-
if (!config.atroposApiUrl) {
|
|
253
|
-
console.warn("RL models enabled but Atropos API URL missing. Set ATROPOS_API_URL.");
|
|
254
|
-
return false;
|
|
255
|
-
}
|
|
256
|
-
return true;
|
|
257
|
-
}
|
|
258
|
-
/**
|
|
259
|
-
* Log configuration on startup
|
|
260
|
-
*/
|
|
261
|
-
export function logRLModelConfig() {
|
|
262
|
-
const config = getRLModelConfig();
|
|
263
|
-
const available = isRLModelAvailable();
|
|
264
|
-
const tierConfig = MODEL_TIERS[config.modelTier];
|
|
265
|
-
const vramPerModel = getVramRequirement(config.modelTier, config.quantization);
|
|
266
|
-
console.log("🤖 RL Model Configuration:", {
|
|
267
|
-
enabled: config.enabled,
|
|
268
|
-
available,
|
|
269
|
-
atroposConfigured: !!config.atroposApiUrl,
|
|
270
|
-
vllmPort: config.vllmPort,
|
|
271
|
-
pinnedVersion: config.modelVersion || "latest",
|
|
272
|
-
fallbackEnabled: config.fallbackToBase,
|
|
273
|
-
baseModel: config.baseModel,
|
|
274
|
-
modelTier: config.modelTier,
|
|
275
|
-
tierName: tierConfig.name,
|
|
276
|
-
tierParams: tierConfig.params,
|
|
277
|
-
contextWindow: tierConfig.context,
|
|
278
|
-
availableVramGb: config.availableVramGb || "auto",
|
|
279
|
-
quantization: config.quantization,
|
|
280
|
-
vramPerModel: `${vramPerModel}GB`,
|
|
281
|
-
maxConcurrentModels: config.multiModelConfig.maxConcurrentModels,
|
|
282
|
-
});
|
|
283
|
-
}
|
|
284
|
-
/**
|
|
285
|
-
* Get all available model tiers with their configurations
|
|
286
|
-
*/
|
|
287
|
-
export function getAvailableModelTiers() {
|
|
288
|
-
return Object.values(MODEL_TIERS);
|
|
289
|
-
}
|
|
290
|
-
/**
|
|
291
|
-
* Check if a specific model tier is available based on VRAM
|
|
292
|
-
*/
|
|
293
|
-
export function isTierAvailable(tier, vramGb) {
|
|
294
|
-
return vramGb >= MODEL_TIERS[tier].minVramGb;
|
|
295
|
-
}
|