@elizaos/training 2.0.0-alpha.21 → 2.0.0-alpha.22
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-lint.log +2 -0
- package/.turbo/turbo-typecheck.log +1 -0
- package/dist/.tsbuildinfo +1 -0
- package/dist/adapter.js +59 -0
- package/dist/archetypes/ArchetypeConfigService.js +510 -0
- package/dist/archetypes/derive-archetype.js +196 -0
- package/dist/archetypes/index.js +7 -0
- package/dist/benchmark/ArchetypeMatchupBenchmark.js +547 -0
- package/dist/benchmark/BenchmarkChartGenerator.js +632 -0
- package/dist/benchmark/BenchmarkDataGenerator.js +825 -0
- package/dist/benchmark/BenchmarkDataViewer.js +197 -0
- package/dist/benchmark/BenchmarkHistoryService.js +135 -0
- package/dist/benchmark/BenchmarkRunner.js +483 -0
- package/dist/benchmark/BenchmarkValidator.js +158 -0
- package/dist/benchmark/FastEvalRunner.js +133 -0
- package/dist/benchmark/MetricsValidator.js +104 -0
- package/dist/benchmark/MetricsVisualizer.js +775 -0
- package/dist/benchmark/ModelBenchmarkService.js +433 -0
- package/dist/benchmark/ModelRegistry.js +122 -0
- package/dist/benchmark/RulerBenchmarkIntegration.js +168 -0
- package/dist/benchmark/SimulationA2AInterface.js +683 -0
- package/dist/benchmark/SimulationEngine.js +522 -0
- package/dist/benchmark/TaskRunner.js +60 -0
- package/dist/benchmark/__tests__/BenchmarkRunner.test.js +409 -0
- package/dist/benchmark/__tests__/HeadToHead.test.js +105 -0
- package/dist/benchmark/index.js +23 -0
- package/dist/benchmark/parseSimulationMetrics.js +86 -0
- package/dist/benchmark/simulation-types.js +1 -0
- package/dist/dependencies.js +197 -0
- package/dist/generation/TrajectoryGenerator.js +244 -0
- package/dist/generation/index.js +6 -0
- package/dist/huggingface/HuggingFaceDatasetUploader.js +463 -0
- package/dist/huggingface/HuggingFaceIntegrationService.js +272 -0
- package/dist/huggingface/HuggingFaceModelUploader.js +385 -0
- package/dist/huggingface/index.js +9 -0
- package/dist/huggingface/shared/HuggingFaceUploadUtil.js +144 -0
- package/dist/index.js +41 -0
- package/dist/init-training.js +43 -0
- package/dist/metrics/TrajectoryMetricsExtractor.js +523 -0
- package/dist/metrics/__tests__/TrajectoryMetricsExtractor.test.js +628 -0
- package/dist/metrics/index.js +7 -0
- package/dist/metrics/types.js +21 -0
- package/dist/rubrics/__tests__/index.test.js +150 -0
- package/dist/rubrics/ass-kisser.js +83 -0
- package/dist/rubrics/degen.js +78 -0
- package/dist/rubrics/goody-twoshoes.js +82 -0
- package/dist/rubrics/index.js +184 -0
- package/dist/rubrics/information-trader.js +82 -0
- package/dist/rubrics/infosec.js +99 -0
- package/dist/rubrics/liar.js +102 -0
- package/dist/rubrics/perps-trader.js +85 -0
- package/dist/rubrics/researcher.js +79 -0
- package/dist/rubrics/scammer.js +80 -0
- package/dist/rubrics/social-butterfly.js +71 -0
- package/dist/rubrics/super-predictor.js +95 -0
- package/dist/rubrics/trader.js +65 -0
- package/dist/scoring/ArchetypeScoringService.js +301 -0
- package/dist/scoring/JudgePromptBuilder.js +401 -0
- package/dist/scoring/LLMJudgeCache.js +263 -0
- package/dist/scoring/index.js +8 -0
- package/dist/training/AutomationPipeline.js +714 -0
- package/dist/training/BenchmarkService.js +370 -0
- package/dist/training/ConfigValidator.js +153 -0
- package/dist/training/MarketOutcomesTracker.js +142 -0
- package/dist/training/ModelDeployer.js +128 -0
- package/dist/training/ModelFetcher.js +48 -0
- package/dist/training/ModelSelectionService.js +248 -0
- package/dist/training/ModelUsageVerifier.js +106 -0
- package/dist/training/MultiModelOrchestrator.js +349 -0
- package/dist/training/RLModelConfig.js +295 -0
- package/dist/training/RewardBackpropagationService.js +117 -0
- package/dist/training/RulerScoringService.js +450 -0
- package/dist/training/TrainingMonitor.js +108 -0
- package/dist/training/TrajectoryRecorder.js +281 -0
- package/dist/training/__tests__/TrajectoryRecorder.test.js +363 -0
- package/dist/training/index.js +30 -0
- package/dist/training/logRLConfig.js +29 -0
- package/dist/training/pipeline.js +80 -0
- package/dist/training/storage/ModelStorageService.js +190 -0
- package/dist/training/storage/TrainingDataArchiver.js +136 -0
- package/dist/training/storage/index.js +7 -0
- package/dist/training/types.js +6 -0
- package/dist/training/window-utils.js +100 -0
- package/dist/utils/index.js +73 -0
- package/dist/utils/logger.js +55 -0
- package/dist/utils/snowflake.js +15 -0
- package/dist/utils/synthetic-detector.js +67 -0
- package/package.json +2 -2
- package/research-output/training-runs/training-run-1773742857616.json +38 -0
- package/research-output/training-runs/training-run-1773742946977.json +38 -0
- package/research-output/training-runs/training-run-1773743278891.json +38 -0
- package/research-output/training-runs/training-run-1773743409754.json +38 -0
- package/research-output/training-runs/training-run-1773743651086.json +38 -0
- package/research-output/training-runs/training-run-1773743782883.json +38 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Model Deployer Service
|
|
3
|
+
*
|
|
4
|
+
* Automatically deploys trained models to agents.
|
|
5
|
+
* Handles gradual rollout and rollback if needed.
|
|
6
|
+
*/
|
|
7
|
+
import { getTrainingDataAdapter } from "../adapter";
|
|
8
|
+
import { getAgentRuntimeManager } from "../dependencies";
|
|
9
|
+
import { logger } from "../utils/logger";
|
|
10
|
+
export class ModelDeployer {
|
|
11
|
+
deploymentStatus = new Map();
|
|
12
|
+
/**
|
|
13
|
+
* Deploy model to agents
|
|
14
|
+
*/
|
|
15
|
+
async deploy(options) {
|
|
16
|
+
const da = getTrainingDataAdapter();
|
|
17
|
+
logger.info("Starting model deployment", {
|
|
18
|
+
version: options.modelVersion,
|
|
19
|
+
strategy: options.strategy,
|
|
20
|
+
});
|
|
21
|
+
const model = await da.getModelByVersion(options.modelVersion);
|
|
22
|
+
if (!model) {
|
|
23
|
+
throw new Error(`Model ${options.modelVersion} not found`);
|
|
24
|
+
}
|
|
25
|
+
const strategy = options.strategy === "immediate" ? "all" : options.strategy;
|
|
26
|
+
const targetAgents = await da.getAgentUsers({
|
|
27
|
+
strategy,
|
|
28
|
+
rolloutPercentage: options.rolloutPercentage,
|
|
29
|
+
testAgentIds: options.testAgentIds,
|
|
30
|
+
});
|
|
31
|
+
logger.info(`Deploying to ${targetAgents.length} agents`);
|
|
32
|
+
const deploymentId = `deploy-${Date.now()}`;
|
|
33
|
+
this.deploymentStatus.set(deploymentId, {
|
|
34
|
+
deploymentId,
|
|
35
|
+
modelVersion: options.modelVersion,
|
|
36
|
+
status: "in_progress",
|
|
37
|
+
agentsUpdated: 0,
|
|
38
|
+
agentsFailed: 0,
|
|
39
|
+
performance: {
|
|
40
|
+
rolloutSuccessRate: 0,
|
|
41
|
+
runtimeResetFailures: 0,
|
|
42
|
+
},
|
|
43
|
+
startedAt: new Date(),
|
|
44
|
+
completedAt: null,
|
|
45
|
+
});
|
|
46
|
+
await da.updateModelStatus(model.modelId, "deployed", {
|
|
47
|
+
deployedAt: new Date(),
|
|
48
|
+
agentsUsing: targetAgents.length,
|
|
49
|
+
});
|
|
50
|
+
// Clear agent runtimes so they pick up the new model.
|
|
51
|
+
const runtimeManager = getAgentRuntimeManager();
|
|
52
|
+
let runtimesReset = 0;
|
|
53
|
+
let runtimeResetFailures = 0;
|
|
54
|
+
for (const agent of targetAgents) {
|
|
55
|
+
try {
|
|
56
|
+
await runtimeManager.resetRuntime(agent.id);
|
|
57
|
+
runtimesReset++;
|
|
58
|
+
}
|
|
59
|
+
catch (err) {
|
|
60
|
+
runtimeResetFailures++;
|
|
61
|
+
logger.warn("Failed to reset runtime for agent", {
|
|
62
|
+
agentId: agent.id,
|
|
63
|
+
error: err instanceof Error ? err.message : String(err),
|
|
64
|
+
});
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
logger.info("Model deployed successfully", {
|
|
68
|
+
version: options.modelVersion,
|
|
69
|
+
agentsUpdated: targetAgents.length,
|
|
70
|
+
deploymentId,
|
|
71
|
+
runtimesReset,
|
|
72
|
+
});
|
|
73
|
+
const successRate = targetAgents.length > 0 ? runtimesReset / targetAgents.length : 0;
|
|
74
|
+
this.deploymentStatus.set(deploymentId, {
|
|
75
|
+
deploymentId,
|
|
76
|
+
modelVersion: options.modelVersion,
|
|
77
|
+
status: runtimeResetFailures > 0 ? "degraded" : "deployed",
|
|
78
|
+
agentsUpdated: runtimesReset,
|
|
79
|
+
agentsFailed: runtimeResetFailures,
|
|
80
|
+
performance: {
|
|
81
|
+
rolloutSuccessRate: successRate,
|
|
82
|
+
runtimeResetFailures,
|
|
83
|
+
},
|
|
84
|
+
startedAt: this.deploymentStatus.get(deploymentId)?.startedAt ?? new Date(),
|
|
85
|
+
completedAt: new Date(),
|
|
86
|
+
});
|
|
87
|
+
return {
|
|
88
|
+
success: runtimeResetFailures === 0,
|
|
89
|
+
agentsUpdated: runtimesReset,
|
|
90
|
+
deploymentId,
|
|
91
|
+
error: runtimeResetFailures > 0
|
|
92
|
+
? `${runtimeResetFailures} agent runtimes failed to reset`
|
|
93
|
+
: undefined,
|
|
94
|
+
};
|
|
95
|
+
}
|
|
96
|
+
/**
|
|
97
|
+
* Rollback to previous model version
|
|
98
|
+
*/
|
|
99
|
+
async rollback(currentVersion, targetVersion) {
|
|
100
|
+
logger.info("Rolling back model", {
|
|
101
|
+
from: currentVersion,
|
|
102
|
+
to: targetVersion,
|
|
103
|
+
});
|
|
104
|
+
return await this.deploy({
|
|
105
|
+
modelVersion: targetVersion,
|
|
106
|
+
strategy: "immediate",
|
|
107
|
+
});
|
|
108
|
+
}
|
|
109
|
+
/**
|
|
110
|
+
* Get deployment status
|
|
111
|
+
*/
|
|
112
|
+
async getDeploymentStatus(deploymentId) {
|
|
113
|
+
const status = this.deploymentStatus.get(deploymentId);
|
|
114
|
+
if (!status)
|
|
115
|
+
return null;
|
|
116
|
+
return {
|
|
117
|
+
status: status.status,
|
|
118
|
+
agentsUpdated: status.agentsUpdated,
|
|
119
|
+
agentsFailed: status.agentsFailed,
|
|
120
|
+
performance: {
|
|
121
|
+
rolloutSuccessRate: status.performance.rolloutSuccessRate,
|
|
122
|
+
runtimeResetFailures: status.performance.runtimeResetFailures,
|
|
123
|
+
},
|
|
124
|
+
};
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
// Singleton
|
|
128
|
+
export const modelDeployer = new ModelDeployer();
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Model Fetcher
|
|
3
|
+
*
|
|
4
|
+
* Fetches trained RL models from the database for inference.
|
|
5
|
+
*/
|
|
6
|
+
import { getTrainingDataAdapter } from "../adapter";
|
|
7
|
+
import { logger } from "../utils/logger";
|
|
8
|
+
/**
|
|
9
|
+
* Get the latest RL model from database
|
|
10
|
+
*/
|
|
11
|
+
export async function getLatestRLModel() {
|
|
12
|
+
// Adapter returns the most recently created model.
|
|
13
|
+
// Original query filtered to status IN ('ready', 'deployed').
|
|
14
|
+
const adapter = getTrainingDataAdapter();
|
|
15
|
+
const model = await adapter.getLatestModel();
|
|
16
|
+
if (!model) {
|
|
17
|
+
return null;
|
|
18
|
+
}
|
|
19
|
+
// Skip models that aren't ready or deployed
|
|
20
|
+
if (model.status !== "ready" && model.status !== "deployed") {
|
|
21
|
+
return null;
|
|
22
|
+
}
|
|
23
|
+
const rlModelId = model.storagePath || model.modelId;
|
|
24
|
+
if (!rlModelId || rlModelId.trim().length === 0) {
|
|
25
|
+
logger.error("Model has no storagePath or modelId", {
|
|
26
|
+
modelId: model.modelId,
|
|
27
|
+
storagePath: model.storagePath,
|
|
28
|
+
}, "ModelFetcher");
|
|
29
|
+
return null;
|
|
30
|
+
}
|
|
31
|
+
if (!model.baseModel || model.baseModel.trim().length === 0) {
|
|
32
|
+
logger.error("Model has no baseModel", {
|
|
33
|
+
modelId: model.modelId,
|
|
34
|
+
}, "ModelFetcher");
|
|
35
|
+
return null;
|
|
36
|
+
}
|
|
37
|
+
return {
|
|
38
|
+
version: model.version,
|
|
39
|
+
modelId: rlModelId,
|
|
40
|
+
modelPath: rlModelId,
|
|
41
|
+
metadata: {
|
|
42
|
+
avgReward: model.avgReward ?? undefined,
|
|
43
|
+
benchmarkScore: model.benchmarkScore ?? undefined,
|
|
44
|
+
baseModel: model.baseModel,
|
|
45
|
+
trainedAt: model.createdAt,
|
|
46
|
+
},
|
|
47
|
+
};
|
|
48
|
+
}
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Model Selection Service
|
|
3
|
+
*
|
|
4
|
+
* Determines which base model to use for training based on:
|
|
5
|
+
* 1. Number of available training bundles
|
|
6
|
+
* 2. Existence of trained models
|
|
7
|
+
* 3. Performance of previous models
|
|
8
|
+
*/
|
|
9
|
+
import { getTrainingDataAdapter } from "../adapter";
|
|
10
|
+
import { logger } from "../utils/logger";
|
|
11
|
+
export class ModelSelectionService {
|
|
12
|
+
/** Default base model - uses Qwen3-4B-128K (4B params, 128K context). Scale up via MODEL_TIER or AVAILABLE_VRAM_GB env vars */
|
|
13
|
+
BASE_MODEL = process.env.BASE_MODEL || "unsloth/Qwen3-4B-128K";
|
|
14
|
+
BUNDLE_THRESHOLD = 1000;
|
|
15
|
+
MIN_BUNDLES_FOR_TRAINING = 100;
|
|
16
|
+
MAX_TRAINING_EXAMPLES = 2000;
|
|
17
|
+
/**
|
|
18
|
+
* Select base model for training
|
|
19
|
+
*
|
|
20
|
+
* Determines which model to use as the base for training based on available
|
|
21
|
+
* training data and existing model performance.
|
|
22
|
+
*
|
|
23
|
+
* Decision tree:
|
|
24
|
+
* 1. No models exist? → Force first model from base
|
|
25
|
+
* 2. < 100 bundles? → Wait (not ready) - throws error
|
|
26
|
+
* 3. < 1000 bundles? → Train from base model
|
|
27
|
+
* 4. ≥ 1000 bundles? → Train from best performing model
|
|
28
|
+
*
|
|
29
|
+
* @returns ModelSelectionResult with selected model and strategy
|
|
30
|
+
* @throws Error if insufficient training data (< 100 bundles)
|
|
31
|
+
*
|
|
32
|
+
* @example
|
|
33
|
+
* ```typescript
|
|
34
|
+
* const result = await modelSelectionService.selectBaseModel();
|
|
35
|
+
* console.log(`Strategy: ${result.strategy}`);
|
|
36
|
+
* console.log(`Model: ${result.modelPath}`);
|
|
37
|
+
* ```
|
|
38
|
+
*/
|
|
39
|
+
async selectBaseModel() {
|
|
40
|
+
logger.info("Selecting base model for training...", undefined, "ModelSelectionService");
|
|
41
|
+
// Count available training bundles (always fetch for accurate metrics)
|
|
42
|
+
const bundleCount = await this.countTrainingBundles();
|
|
43
|
+
// Check if any models exist
|
|
44
|
+
const forceFirst = await this.shouldForceFirstModel();
|
|
45
|
+
if (forceFirst) {
|
|
46
|
+
logger.info("No models exist - forcing first model creation", undefined, "ModelSelectionService");
|
|
47
|
+
return {
|
|
48
|
+
modelId: this.BASE_MODEL,
|
|
49
|
+
modelPath: this.BASE_MODEL,
|
|
50
|
+
strategy: "force_first",
|
|
51
|
+
reason: "No trained models exist - creating first model from base",
|
|
52
|
+
metadata: {
|
|
53
|
+
baseModel: this.BASE_MODEL,
|
|
54
|
+
bundleCount, // Use actual count, not 0
|
|
55
|
+
},
|
|
56
|
+
};
|
|
57
|
+
}
|
|
58
|
+
logger.info(`Found ${bundleCount} training bundles`, undefined, "ModelSelectionService");
|
|
59
|
+
// Not enough data yet
|
|
60
|
+
if (bundleCount < this.MIN_BUNDLES_FOR_TRAINING) {
|
|
61
|
+
throw new Error(`Insufficient training data: ${bundleCount} bundles ` +
|
|
62
|
+
`(need ${this.MIN_BUNDLES_FOR_TRAINING} minimum)`);
|
|
63
|
+
}
|
|
64
|
+
// Less than threshold: train from base model
|
|
65
|
+
if (bundleCount < this.BUNDLE_THRESHOLD) {
|
|
66
|
+
logger.info(`Bundle count ${bundleCount} < ${this.BUNDLE_THRESHOLD} - using base model`, undefined, "ModelSelectionService");
|
|
67
|
+
return {
|
|
68
|
+
modelId: this.BASE_MODEL,
|
|
69
|
+
modelPath: this.BASE_MODEL,
|
|
70
|
+
strategy: "base",
|
|
71
|
+
reason: `Training from base model (${bundleCount} bundles < ${this.BUNDLE_THRESHOLD} threshold)`,
|
|
72
|
+
metadata: {
|
|
73
|
+
bundleCount,
|
|
74
|
+
baseModel: this.BASE_MODEL,
|
|
75
|
+
},
|
|
76
|
+
};
|
|
77
|
+
}
|
|
78
|
+
// Above threshold: train from best performing model
|
|
79
|
+
const bestModel = await this.getBestPerformingModel();
|
|
80
|
+
if (!bestModel) {
|
|
81
|
+
logger.warn("No best model found despite bundle threshold - using base model", undefined, "ModelSelectionService");
|
|
82
|
+
return {
|
|
83
|
+
modelId: this.BASE_MODEL,
|
|
84
|
+
modelPath: this.BASE_MODEL,
|
|
85
|
+
strategy: "base",
|
|
86
|
+
reason: "No previous models available - using base model",
|
|
87
|
+
metadata: {
|
|
88
|
+
bundleCount,
|
|
89
|
+
baseModel: this.BASE_MODEL,
|
|
90
|
+
},
|
|
91
|
+
};
|
|
92
|
+
}
|
|
93
|
+
logger.info(`Bundle count ${bundleCount} ≥ ${this.BUNDLE_THRESHOLD} - continuing from best model`, {
|
|
94
|
+
bestModelId: bestModel.modelId,
|
|
95
|
+
bestScore: bestModel.benchmarkScore,
|
|
96
|
+
}, "ModelSelectionService");
|
|
97
|
+
// Use storagePath for model path (e.g., HuggingFace URL)
|
|
98
|
+
const modelStoragePath = bestModel.storagePath || bestModel.modelId;
|
|
99
|
+
return {
|
|
100
|
+
modelId: bestModel.modelId,
|
|
101
|
+
modelPath: modelStoragePath,
|
|
102
|
+
strategy: "continue",
|
|
103
|
+
reason: `Continuing from best model (score: ${bestModel.benchmarkScore?.toFixed(3) || "N/A"})`,
|
|
104
|
+
metadata: {
|
|
105
|
+
bundleCount,
|
|
106
|
+
bestModelScore: bestModel.benchmarkScore || undefined,
|
|
107
|
+
baseModel: bestModel.baseModel,
|
|
108
|
+
},
|
|
109
|
+
};
|
|
110
|
+
}
|
|
111
|
+
/**
|
|
112
|
+
* Get best performing model based on benchmark scores
|
|
113
|
+
*
|
|
114
|
+
* Finds the trained model with the highest benchmark score that is
|
|
115
|
+
* ready or deployed. Used for continuing training from a strong baseline.
|
|
116
|
+
*
|
|
117
|
+
* @returns Best performing model record, or null if none found
|
|
118
|
+
*
|
|
119
|
+
* @remarks
|
|
120
|
+
* Only considers models with status 'ready' or 'deployed' and
|
|
121
|
+
* non-null benchmark scores.
|
|
122
|
+
*/
|
|
123
|
+
async getBestPerformingModel() {
|
|
124
|
+
const model = await getTrainingDataAdapter().getBestBenchmarkedModel();
|
|
125
|
+
if (!model) {
|
|
126
|
+
logger.warn("No benchmarked models found", undefined, "ModelSelectionService");
|
|
127
|
+
return null;
|
|
128
|
+
}
|
|
129
|
+
logger.info("Found best performing model", {
|
|
130
|
+
modelId: model.modelId,
|
|
131
|
+
version: model.version,
|
|
132
|
+
benchmarkScore: model.benchmarkScore,
|
|
133
|
+
avgReward: model.avgReward,
|
|
134
|
+
}, "ModelSelectionService");
|
|
135
|
+
return model;
|
|
136
|
+
}
|
|
137
|
+
/**
|
|
138
|
+
* Count available training bundles
|
|
139
|
+
*
|
|
140
|
+
* A "bundle" is a trajectory that:
|
|
141
|
+
* - Is marked as training data
|
|
142
|
+
* - Has been scored (aiJudgeReward IS NOT NULL)
|
|
143
|
+
* - Has not been used in training yet
|
|
144
|
+
* - Has valid steps data (not 'null' or '[]')
|
|
145
|
+
*
|
|
146
|
+
* @returns Number of available training bundles
|
|
147
|
+
*/
|
|
148
|
+
async countTrainingBundles() {
|
|
149
|
+
return await getTrainingDataAdapter().countScoredTrajectoriesReady();
|
|
150
|
+
}
|
|
151
|
+
/**
|
|
152
|
+
* Check if we should force first model creation
|
|
153
|
+
*
|
|
154
|
+
* Returns true if no trained models exist yet, indicating we should
|
|
155
|
+
* create the first model from the base model.
|
|
156
|
+
*
|
|
157
|
+
* @returns True if no models exist, false otherwise
|
|
158
|
+
*/
|
|
159
|
+
async shouldForceFirstModel() {
|
|
160
|
+
const modelCount = await this.countTrainedModels();
|
|
161
|
+
return modelCount === 0;
|
|
162
|
+
}
|
|
163
|
+
/**
|
|
164
|
+
* Count existing trained models
|
|
165
|
+
*/
|
|
166
|
+
async countTrainedModels() {
|
|
167
|
+
return await getTrainingDataAdapter().countActiveModels();
|
|
168
|
+
}
|
|
169
|
+
/**
|
|
170
|
+
* Get training data limit based on bundle count
|
|
171
|
+
*
|
|
172
|
+
* Determines how many trajectories to use for training:
|
|
173
|
+
* - < 1000 bundles: Use all available (returns null)
|
|
174
|
+
* - ≥ 1000 bundles: Cap at 2000 most recent
|
|
175
|
+
*
|
|
176
|
+
* @returns Limit number (2000) or null to use all available
|
|
177
|
+
*/
|
|
178
|
+
async getTrainingDataLimit() {
|
|
179
|
+
const bundleCount = await this.countTrainingBundles();
|
|
180
|
+
if (bundleCount < this.BUNDLE_THRESHOLD) {
|
|
181
|
+
return null; // Use all available
|
|
182
|
+
}
|
|
183
|
+
return this.MAX_TRAINING_EXAMPLES; // Cap at 2000
|
|
184
|
+
}
|
|
185
|
+
/**
|
|
186
|
+
* Get trajectories for training (with optional limit)
|
|
187
|
+
*
|
|
188
|
+
* Retrieves scored trajectories that haven't been used in training yet.
|
|
189
|
+
* Orders by most recent first to prioritize fresh data.
|
|
190
|
+
*
|
|
191
|
+
* @param limit - Optional limit on number of trajectories to return
|
|
192
|
+
* @returns Array of training trajectories
|
|
193
|
+
*
|
|
194
|
+
* @remarks
|
|
195
|
+
* Filters to only include:
|
|
196
|
+
* - isTrainingData: true
|
|
197
|
+
* - usedInTraining: false
|
|
198
|
+
* - aiJudgeReward: not null
|
|
199
|
+
* - Valid stepsJson (not 'null' or '[]')
|
|
200
|
+
*/
|
|
201
|
+
async getTrainingTrajectories(limit) {
|
|
202
|
+
const result = await getTrainingDataAdapter().getTrainingTrajectories(limit ?? undefined);
|
|
203
|
+
logger.info(`Retrieved ${result.length} trajectories for training`, { limit, available: result.length }, "ModelSelectionService");
|
|
204
|
+
return result;
|
|
205
|
+
}
|
|
206
|
+
/**
|
|
207
|
+
* Get model selection summary for logging/monitoring
|
|
208
|
+
*
|
|
209
|
+
* Provides a comprehensive summary of the current model selection state,
|
|
210
|
+
* including bundle counts, model availability, and recommendations.
|
|
211
|
+
*
|
|
212
|
+
* @returns Summary object with counts, best model info, and recommendation
|
|
213
|
+
*
|
|
214
|
+
* @example
|
|
215
|
+
* ```typescript
|
|
216
|
+
* const summary = await modelSelectionService.getSelectionSummary();
|
|
217
|
+
* console.log(`Bundles: ${summary.bundleCount}`);
|
|
218
|
+
* console.log(`Recommendation: ${summary.recommendation}`);
|
|
219
|
+
* ```
|
|
220
|
+
*/
|
|
221
|
+
async getSelectionSummary() {
|
|
222
|
+
const bundleCount = await this.countTrainingBundles();
|
|
223
|
+
const trainedModelCount = await this.countTrainedModels();
|
|
224
|
+
const bestModel = await this.getBestPerformingModel();
|
|
225
|
+
let recommendation = "";
|
|
226
|
+
if (trainedModelCount === 0) {
|
|
227
|
+
recommendation = "Force first model creation";
|
|
228
|
+
}
|
|
229
|
+
else if (bundleCount < this.MIN_BUNDLES_FOR_TRAINING) {
|
|
230
|
+
recommendation = "Not ready - need more data";
|
|
231
|
+
}
|
|
232
|
+
else if (bundleCount < this.BUNDLE_THRESHOLD) {
|
|
233
|
+
recommendation = "Train from base model";
|
|
234
|
+
}
|
|
235
|
+
else {
|
|
236
|
+
recommendation = "Train from best performing model";
|
|
237
|
+
}
|
|
238
|
+
return {
|
|
239
|
+
bundleCount,
|
|
240
|
+
trainedModelCount,
|
|
241
|
+
bestModel: bestModel?.modelId || null,
|
|
242
|
+
bestScore: bestModel?.benchmarkScore || null,
|
|
243
|
+
recommendation,
|
|
244
|
+
};
|
|
245
|
+
}
|
|
246
|
+
}
|
|
247
|
+
// Export singleton instance
|
|
248
|
+
export const modelSelectionService = new ModelSelectionService();
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Model Usage Verifier
|
|
3
|
+
*
|
|
4
|
+
* Verifies that agents are using the correct models.
|
|
5
|
+
* Provides assertions and logging for model usage verification.
|
|
6
|
+
*/
|
|
7
|
+
import { getLlmLogAdapter, getTrainingDataAdapter } from "../adapter";
|
|
8
|
+
import { logger } from "../utils/logger";
|
|
9
|
+
export class ModelUsageVerifier {
|
|
10
|
+
/**
|
|
11
|
+
* Verify an agent's model usage
|
|
12
|
+
*
|
|
13
|
+
* Checks the agent's runtime configuration to determine which model
|
|
14
|
+
* is being used.
|
|
15
|
+
*
|
|
16
|
+
* @param agentUserId - Unique identifier for the agent
|
|
17
|
+
* @param runtime - Agent runtime to verify
|
|
18
|
+
* @returns ModelUsageStats with model information and inference count
|
|
19
|
+
*/
|
|
20
|
+
static async verifyAgentModelUsage(agentUserId, runtime) {
|
|
21
|
+
const character = runtime.character;
|
|
22
|
+
const settings = character?.settings;
|
|
23
|
+
// Check for different model providers
|
|
24
|
+
const groqModel = String(settings?.GROQ_LARGE_MODEL || settings?.GROQ_SMALL_MODEL || "");
|
|
25
|
+
const claudeModel = String(settings?.CLAUDE_MODEL || "");
|
|
26
|
+
const openaiModel = String(settings?.OPENAI_MODEL || "");
|
|
27
|
+
let modelUsed;
|
|
28
|
+
let modelSource;
|
|
29
|
+
if (claudeModel) {
|
|
30
|
+
modelUsed = claudeModel;
|
|
31
|
+
modelSource = "claude";
|
|
32
|
+
}
|
|
33
|
+
else if (openaiModel) {
|
|
34
|
+
modelUsed = openaiModel;
|
|
35
|
+
modelSource = "openai";
|
|
36
|
+
}
|
|
37
|
+
else if (groqModel) {
|
|
38
|
+
modelUsed = groqModel;
|
|
39
|
+
modelSource = "groq";
|
|
40
|
+
}
|
|
41
|
+
else {
|
|
42
|
+
modelUsed = "unknown";
|
|
43
|
+
modelSource = "unknown";
|
|
44
|
+
}
|
|
45
|
+
// Count inferences from logs (using trajectoryId)
|
|
46
|
+
const trajectoryIds = await getTrainingDataAdapter().getTrajectoryIdsByAgent(agentUserId);
|
|
47
|
+
const twentyFourHoursAgo = new Date(Date.now() - 24 * 60 * 60 * 1000);
|
|
48
|
+
let inferenceCount = 0;
|
|
49
|
+
const llmAdapter = getLlmLogAdapter();
|
|
50
|
+
if (llmAdapter && trajectoryIds.length > 0) {
|
|
51
|
+
inferenceCount = await llmAdapter.countRecentLLMCalls(trajectoryIds, twentyFourHoursAgo);
|
|
52
|
+
}
|
|
53
|
+
return {
|
|
54
|
+
agentId: agentUserId,
|
|
55
|
+
modelUsed,
|
|
56
|
+
modelSource,
|
|
57
|
+
inferenceCount,
|
|
58
|
+
};
|
|
59
|
+
}
|
|
60
|
+
/**
|
|
61
|
+
* Verify multiple agents
|
|
62
|
+
*/
|
|
63
|
+
static async verifyMultipleAgents(agentUserIds, runtimes) {
|
|
64
|
+
const details = [];
|
|
65
|
+
const errors = [];
|
|
66
|
+
for (const agentId of agentUserIds) {
|
|
67
|
+
const runtime = runtimes.get(agentId);
|
|
68
|
+
if (!runtime) {
|
|
69
|
+
errors.push(`Runtime not found for agent ${agentId}`);
|
|
70
|
+
continue;
|
|
71
|
+
}
|
|
72
|
+
const stats = await ModelUsageVerifier.verifyAgentModelUsage(agentId, runtime);
|
|
73
|
+
details.push(stats);
|
|
74
|
+
}
|
|
75
|
+
return {
|
|
76
|
+
success: details.length > 0,
|
|
77
|
+
agentsChecked: details.length,
|
|
78
|
+
details,
|
|
79
|
+
errors,
|
|
80
|
+
};
|
|
81
|
+
}
|
|
82
|
+
/**
|
|
83
|
+
* Assert that an agent is using a model
|
|
84
|
+
*/
|
|
85
|
+
static async assertModelUsage(agentUserId, runtime) {
|
|
86
|
+
const stats = await ModelUsageVerifier.verifyAgentModelUsage(agentUserId, runtime);
|
|
87
|
+
if (stats.modelSource === "unknown") {
|
|
88
|
+
throw new Error(`Agent ${agentUserId} has no configured model. ` +
|
|
89
|
+
`Using: ${stats.modelUsed}`);
|
|
90
|
+
}
|
|
91
|
+
logger.info("Model usage verified", {
|
|
92
|
+
agentId: agentUserId,
|
|
93
|
+
model: stats.modelUsed,
|
|
94
|
+
source: stats.modelSource,
|
|
95
|
+
}, "ModelUsageVerifier");
|
|
96
|
+
}
|
|
97
|
+
/**
|
|
98
|
+
* Get model usage summary
|
|
99
|
+
*/
|
|
100
|
+
static async getModelUsageSummary() {
|
|
101
|
+
const agents = await getTrainingDataAdapter().getAgentUsers();
|
|
102
|
+
return {
|
|
103
|
+
totalAgents: agents.length,
|
|
104
|
+
};
|
|
105
|
+
}
|
|
106
|
+
}
|