@elizaos/training 2.0.0-alpha.13 → 2.0.0-alpha.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +2 -2
- package/research-output/training-runs/training-run-1773726941205.json +38 -0
- package/scripts/rank_trajectories.ts +0 -1
- package/scripts/run_task_benchmark.ts +4 -11
- package/src/adapter.ts +96 -49
- package/src/archetypes/ArchetypeConfigService.ts +188 -185
- package/src/archetypes/derive-archetype.ts +47 -47
- package/src/archetypes/index.ts +2 -2
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +70 -70
- package/src/benchmark/BenchmarkChartGenerator.ts +70 -69
- package/src/benchmark/BenchmarkDataGenerator.ts +136 -136
- package/src/benchmark/BenchmarkDataViewer.ts +32 -30
- package/src/benchmark/BenchmarkHistoryService.ts +13 -12
- package/src/benchmark/BenchmarkRunner.ts +87 -83
- package/src/benchmark/BenchmarkValidator.ts +48 -46
- package/src/benchmark/FastEvalRunner.ts +17 -16
- package/src/benchmark/MetricsValidator.ts +20 -21
- package/src/benchmark/MetricsVisualizer.ts +92 -85
- package/src/benchmark/ModelBenchmarkService.ts +90 -82
- package/src/benchmark/ModelRegistry.ts +44 -44
- package/src/benchmark/RulerBenchmarkIntegration.ts +24 -24
- package/src/benchmark/SimulationA2AInterface.ts +118 -118
- package/src/benchmark/SimulationEngine.ts +51 -51
- package/src/benchmark/TaskRunner.ts +87 -79
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +80 -80
- package/src/benchmark/__tests__/HeadToHead.test.ts +26 -26
- package/src/benchmark/index.ts +27 -27
- package/src/benchmark/parseSimulationMetrics.ts +32 -32
- package/src/benchmark/simulation-types.ts +10 -10
- package/src/dependencies.ts +34 -34
- package/src/generation/TrajectoryGenerator.ts +39 -37
- package/src/generation/index.ts +1 -1
- package/src/huggingface/HuggingFaceDatasetUploader.ts +72 -72
- package/src/huggingface/HuggingFaceIntegrationService.ts +59 -53
- package/src/huggingface/HuggingFaceModelUploader.ts +60 -59
- package/src/huggingface/index.ts +6 -6
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +32 -32
- package/src/index.ts +27 -27
- package/src/init-training.ts +6 -6
- package/src/metrics/TrajectoryMetricsExtractor.ts +70 -71
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +182 -182
- package/src/metrics/index.ts +2 -2
- package/src/rubrics/__tests__/index.test.ts +73 -73
- package/src/rubrics/ass-kisser.ts +6 -6
- package/src/rubrics/degen.ts +6 -6
- package/src/rubrics/goody-twoshoes.ts +6 -6
- package/src/rubrics/index.ts +50 -50
- package/src/rubrics/information-trader.ts +6 -6
- package/src/rubrics/infosec.ts +6 -6
- package/src/rubrics/liar.ts +6 -6
- package/src/rubrics/perps-trader.ts +6 -6
- package/src/rubrics/researcher.ts +6 -6
- package/src/rubrics/scammer.ts +6 -6
- package/src/rubrics/social-butterfly.ts +7 -7
- package/src/rubrics/super-predictor.ts +6 -6
- package/src/rubrics/trader.ts +5 -5
- package/src/scoring/ArchetypeScoringService.ts +56 -54
- package/src/scoring/JudgePromptBuilder.ts +96 -96
- package/src/scoring/LLMJudgeCache.ts +26 -23
- package/src/scoring/index.ts +3 -3
- package/src/training/AutomationPipeline.ts +149 -140
- package/src/training/BenchmarkService.ts +49 -45
- package/src/training/ConfigValidator.ts +38 -32
- package/src/training/MarketOutcomesTracker.ts +22 -12
- package/src/training/ModelDeployer.ts +15 -15
- package/src/training/ModelFetcher.ts +7 -7
- package/src/training/ModelSelectionService.ts +32 -32
- package/src/training/ModelUsageVerifier.ts +31 -24
- package/src/training/MultiModelOrchestrator.ts +44 -44
- package/src/training/RLModelConfig.ts +57 -57
- package/src/training/RewardBackpropagationService.ts +18 -17
- package/src/training/RulerScoringService.ts +73 -72
- package/src/training/TrainingMonitor.ts +29 -29
- package/src/training/TrajectoryRecorder.ts +25 -27
- package/src/training/__tests__/TrajectoryRecorder.test.ts +105 -105
- package/src/training/index.ts +36 -36
- package/src/training/logRLConfig.ts +7 -7
- package/src/training/pipeline.ts +13 -16
- package/src/training/storage/ModelStorageService.ts +32 -32
- package/src/training/storage/TrainingDataArchiver.ts +21 -21
- package/src/training/storage/index.ts +2 -2
- package/src/training/types.ts +6 -6
- package/src/training/window-utils.ts +14 -14
- package/src/utils/index.ts +7 -7
- package/src/utils/logger.ts +5 -5
- package/src/utils/snowflake.ts +1 -1
- package/src/utils/synthetic-detector.ts +7 -7
|
@@ -13,13 +13,13 @@
|
|
|
13
13
|
/**
|
|
14
14
|
* Quantization modes for model loading
|
|
15
15
|
*/
|
|
16
|
-
export type QuantizationMode =
|
|
16
|
+
export type QuantizationMode = "none" | "4bit" | "8bit";
|
|
17
17
|
|
|
18
18
|
/**
|
|
19
19
|
* Model tiers for scaling based on available resources
|
|
20
20
|
* Supports automatic selection based on GPU memory
|
|
21
21
|
*/
|
|
22
|
-
export type ModelTier =
|
|
22
|
+
export type ModelTier = "small" | "medium" | "large" | "xlarge";
|
|
23
23
|
|
|
24
24
|
export interface ModelTierConfig {
|
|
25
25
|
name: string;
|
|
@@ -40,44 +40,44 @@ export interface ModelTierConfig {
|
|
|
40
40
|
*/
|
|
41
41
|
export const MODEL_TIERS: Record<ModelTier, ModelTierConfig> = {
|
|
42
42
|
small: {
|
|
43
|
-
name:
|
|
44
|
-
model:
|
|
45
|
-
quantizedModel4bit:
|
|
46
|
-
quantizedModel8bit:
|
|
47
|
-
params:
|
|
43
|
+
name: "Small (4B)",
|
|
44
|
+
model: "unsloth/Qwen3-4B-128K",
|
|
45
|
+
quantizedModel4bit: "unsloth/Qwen3-4B-128K-bnb-4bit",
|
|
46
|
+
quantizedModel8bit: "unsloth/Qwen3-4B-128K-GGUF",
|
|
47
|
+
params: "4B",
|
|
48
48
|
context: 131072, // 128K context
|
|
49
49
|
minVramGb: 8,
|
|
50
50
|
minVramGb4bit: 3,
|
|
51
51
|
minVramGb8bit: 5,
|
|
52
52
|
},
|
|
53
53
|
medium: {
|
|
54
|
-
name:
|
|
55
|
-
model:
|
|
56
|
-
quantizedModel4bit:
|
|
57
|
-
quantizedModel8bit:
|
|
58
|
-
params:
|
|
54
|
+
name: "Medium (8B)",
|
|
55
|
+
model: "unsloth/Qwen3-8B-128K",
|
|
56
|
+
quantizedModel4bit: "unsloth/Qwen3-8B-128K-bnb-4bit",
|
|
57
|
+
quantizedModel8bit: "unsloth/Qwen3-8B-128K-GGUF",
|
|
58
|
+
params: "8B",
|
|
59
59
|
context: 131072, // 128K context
|
|
60
60
|
minVramGb: 16,
|
|
61
61
|
minVramGb4bit: 5,
|
|
62
62
|
minVramGb8bit: 9,
|
|
63
63
|
},
|
|
64
64
|
large: {
|
|
65
|
-
name:
|
|
66
|
-
model:
|
|
67
|
-
quantizedModel4bit:
|
|
68
|
-
quantizedModel8bit:
|
|
69
|
-
params:
|
|
65
|
+
name: "Large (14B)",
|
|
66
|
+
model: "unsloth/Qwen3-14B-128K",
|
|
67
|
+
quantizedModel4bit: "unsloth/Qwen3-14B-128K-bnb-4bit",
|
|
68
|
+
quantizedModel8bit: "unsloth/Qwen3-14B-128K-GGUF",
|
|
69
|
+
params: "14B",
|
|
70
70
|
context: 131072, // 128K context
|
|
71
71
|
minVramGb: 24,
|
|
72
72
|
minVramGb4bit: 8,
|
|
73
73
|
minVramGb8bit: 14,
|
|
74
74
|
},
|
|
75
75
|
xlarge: {
|
|
76
|
-
name:
|
|
77
|
-
model:
|
|
78
|
-
quantizedModel4bit:
|
|
79
|
-
quantizedModel8bit:
|
|
80
|
-
params:
|
|
76
|
+
name: "XLarge (32B)",
|
|
77
|
+
model: "unsloth/Qwen3-32B-128K",
|
|
78
|
+
quantizedModel4bit: "unsloth/Qwen3-32B-128K-bnb-4bit",
|
|
79
|
+
quantizedModel8bit: "unsloth/Qwen3-32B-128K-GGUF",
|
|
80
|
+
params: "32B",
|
|
81
81
|
context: 131072, // 128K context
|
|
82
82
|
minVramGb: 48,
|
|
83
83
|
minVramGb4bit: 16,
|
|
@@ -111,32 +111,32 @@ export function getMultiModelConfig(vramGb: number): MultiModelConfig {
|
|
|
111
111
|
return {
|
|
112
112
|
totalVramGb: vramGb,
|
|
113
113
|
maxConcurrentModels: 4,
|
|
114
|
-
quantization:
|
|
115
|
-
modelTier:
|
|
114
|
+
quantization: "4bit",
|
|
115
|
+
modelTier: "small",
|
|
116
116
|
};
|
|
117
117
|
} else if (vramGb >= 12) {
|
|
118
118
|
// 12GB: Can run 3x 4B models (4-bit)
|
|
119
119
|
return {
|
|
120
120
|
totalVramGb: vramGb,
|
|
121
121
|
maxConcurrentModels: 3,
|
|
122
|
-
quantization:
|
|
123
|
-
modelTier:
|
|
122
|
+
quantization: "4bit",
|
|
123
|
+
modelTier: "small",
|
|
124
124
|
};
|
|
125
125
|
} else if (vramGb >= 8) {
|
|
126
126
|
// 8GB: Can run 2x 4B models (4-bit)
|
|
127
127
|
return {
|
|
128
128
|
totalVramGb: vramGb,
|
|
129
129
|
maxConcurrentModels: 2,
|
|
130
|
-
quantization:
|
|
131
|
-
modelTier:
|
|
130
|
+
quantization: "4bit",
|
|
131
|
+
modelTier: "small",
|
|
132
132
|
};
|
|
133
133
|
}
|
|
134
134
|
// Less than 8GB: Single model only
|
|
135
135
|
return {
|
|
136
136
|
totalVramGb: vramGb,
|
|
137
137
|
maxConcurrentModels: 1,
|
|
138
|
-
quantization:
|
|
139
|
-
modelTier:
|
|
138
|
+
quantization: "4bit",
|
|
139
|
+
modelTier: "small",
|
|
140
140
|
};
|
|
141
141
|
}
|
|
142
142
|
|
|
@@ -145,14 +145,14 @@ export function getMultiModelConfig(vramGb: number): MultiModelConfig {
|
|
|
145
145
|
*/
|
|
146
146
|
export function getQuantizedModelName(
|
|
147
147
|
tier: ModelTier,
|
|
148
|
-
quantization: QuantizationMode
|
|
148
|
+
quantization: QuantizationMode,
|
|
149
149
|
): string {
|
|
150
150
|
const tierConfig = MODEL_TIERS[tier];
|
|
151
151
|
|
|
152
152
|
switch (quantization) {
|
|
153
|
-
case
|
|
153
|
+
case "4bit":
|
|
154
154
|
return tierConfig.quantizedModel4bit || tierConfig.model;
|
|
155
|
-
case
|
|
155
|
+
case "8bit":
|
|
156
156
|
return tierConfig.quantizedModel8bit || tierConfig.model;
|
|
157
157
|
default:
|
|
158
158
|
return tierConfig.model;
|
|
@@ -164,14 +164,14 @@ export function getQuantizedModelName(
|
|
|
164
164
|
*/
|
|
165
165
|
export function getVramRequirement(
|
|
166
166
|
tier: ModelTier,
|
|
167
|
-
quantization: QuantizationMode
|
|
167
|
+
quantization: QuantizationMode,
|
|
168
168
|
): number {
|
|
169
169
|
const tierConfig = MODEL_TIERS[tier];
|
|
170
170
|
|
|
171
171
|
switch (quantization) {
|
|
172
|
-
case
|
|
172
|
+
case "4bit":
|
|
173
173
|
return tierConfig.minVramGb4bit;
|
|
174
|
-
case
|
|
174
|
+
case "8bit":
|
|
175
175
|
return tierConfig.minVramGb8bit;
|
|
176
176
|
default:
|
|
177
177
|
return tierConfig.minVramGb;
|
|
@@ -229,7 +229,7 @@ export function registerArchetypeModel(config: ArchetypeModelConfig): void {
|
|
|
229
229
|
) {
|
|
230
230
|
archetypeModelRegistry.set(config.archetype, config);
|
|
231
231
|
console.log(
|
|
232
|
-
`📦 Registered model for archetype '${config.archetype}': ${config.modelId}
|
|
232
|
+
`📦 Registered model for archetype '${config.archetype}': ${config.modelId}`,
|
|
233
233
|
);
|
|
234
234
|
}
|
|
235
235
|
}
|
|
@@ -239,9 +239,9 @@ export function registerArchetypeModel(config: ArchetypeModelConfig): void {
|
|
|
239
239
|
* Falls back to base model if no archetype-specific model exists
|
|
240
240
|
*/
|
|
241
241
|
export function getModelForArchetype(
|
|
242
|
-
archetype: string
|
|
242
|
+
archetype: string,
|
|
243
243
|
): ArchetypeModelConfig | null {
|
|
244
|
-
const normalized = archetype.toLowerCase().trim().replace(/_/g,
|
|
244
|
+
const normalized = archetype.toLowerCase().trim().replace(/_/g, "-");
|
|
245
245
|
return archetypeModelRegistry.get(normalized) || null;
|
|
246
246
|
}
|
|
247
247
|
|
|
@@ -256,7 +256,7 @@ export function getAllArchetypeModels(): ArchetypeModelConfig[] {
|
|
|
256
256
|
* Check if an archetype has a trained model
|
|
257
257
|
*/
|
|
258
258
|
export function hasArchetypeModel(archetype: string): boolean {
|
|
259
|
-
const normalized = archetype.toLowerCase().trim().replace(/_/g,
|
|
259
|
+
const normalized = archetype.toLowerCase().trim().replace(/_/g, "-");
|
|
260
260
|
return archetypeModelRegistry.has(normalized);
|
|
261
261
|
}
|
|
262
262
|
|
|
@@ -271,10 +271,10 @@ export function clearArchetypeModels(): void {
|
|
|
271
271
|
* Get the appropriate model tier based on available VRAM
|
|
272
272
|
*/
|
|
273
273
|
export function getModelTierForVram(vramGb: number): ModelTier {
|
|
274
|
-
if (vramGb >= MODEL_TIERS.xlarge.minVramGb) return
|
|
275
|
-
if (vramGb >= MODEL_TIERS.large.minVramGb) return
|
|
276
|
-
if (vramGb >= MODEL_TIERS.medium.minVramGb) return
|
|
277
|
-
return
|
|
274
|
+
if (vramGb >= MODEL_TIERS.xlarge.minVramGb) return "xlarge";
|
|
275
|
+
if (vramGb >= MODEL_TIERS.large.minVramGb) return "large";
|
|
276
|
+
if (vramGb >= MODEL_TIERS.medium.minVramGb) return "medium";
|
|
277
|
+
return "small";
|
|
278
278
|
}
|
|
279
279
|
|
|
280
280
|
/**
|
|
@@ -288,8 +288,8 @@ export function getModelForTier(tier: ModelTier): string {
|
|
|
288
288
|
* Get RL model configuration from environment
|
|
289
289
|
*/
|
|
290
290
|
export function getRLModelConfig(): RLModelConfig {
|
|
291
|
-
const isProduction = process.env.NODE_ENV ===
|
|
292
|
-
const isLocal = process.env.NODE_ENV ===
|
|
291
|
+
const isProduction = process.env.NODE_ENV === "production";
|
|
292
|
+
const isLocal = process.env.NODE_ENV === "development" || !isProduction;
|
|
293
293
|
|
|
294
294
|
// Explicit enable/disable flag
|
|
295
295
|
const explicitFlag = process.env.USE_RL_MODEL;
|
|
@@ -297,7 +297,7 @@ export function getRLModelConfig(): RLModelConfig {
|
|
|
297
297
|
// Determine if enabled:
|
|
298
298
|
// - If USE_RL_MODEL is explicitly set, use that value
|
|
299
299
|
// - Otherwise, enabled in local, disabled in production
|
|
300
|
-
const enabled = explicitFlag ? explicitFlag ===
|
|
300
|
+
const enabled = explicitFlag ? explicitFlag === "true" : isLocal;
|
|
301
301
|
|
|
302
302
|
// Check for explicit tier or VRAM override
|
|
303
303
|
const explicitTier = process.env.MODEL_TIER as ModelTier | undefined;
|
|
@@ -309,13 +309,13 @@ export function getRLModelConfig(): RLModelConfig {
|
|
|
309
309
|
const explicitQuant = process.env.MODEL_QUANTIZATION as
|
|
310
310
|
| QuantizationMode
|
|
311
311
|
| undefined;
|
|
312
|
-
const quantization: QuantizationMode = explicitQuant ||
|
|
312
|
+
const quantization: QuantizationMode = explicitQuant || "4bit"; // Default to 4-bit for efficiency
|
|
313
313
|
|
|
314
314
|
// Get multi-model config based on available VRAM
|
|
315
315
|
const multiModelConfig = getMultiModelConfig(explicitVram);
|
|
316
316
|
|
|
317
317
|
// Determine tier: explicit tier > tier from multi-model config > default small
|
|
318
|
-
let modelTier: ModelTier =
|
|
318
|
+
let modelTier: ModelTier = "small";
|
|
319
319
|
if (explicitTier && MODEL_TIERS[explicitTier]) {
|
|
320
320
|
modelTier = explicitTier;
|
|
321
321
|
} else {
|
|
@@ -328,10 +328,10 @@ export function getRLModelConfig(): RLModelConfig {
|
|
|
328
328
|
|
|
329
329
|
return {
|
|
330
330
|
enabled,
|
|
331
|
-
atroposApiUrl: process.env.ATROPOS_API_URL ||
|
|
332
|
-
vllmPort: parseInt(process.env.VLLM_PORT ||
|
|
331
|
+
atroposApiUrl: process.env.ATROPOS_API_URL || "http://localhost:8000",
|
|
332
|
+
vllmPort: parseInt(process.env.VLLM_PORT || "9001", 10),
|
|
333
333
|
modelVersion: process.env.RL_MODEL_VERSION, // Optional: pin to specific version
|
|
334
|
-
fallbackToBase: process.env.RL_FALLBACK_TO_BASE !==
|
|
334
|
+
fallbackToBase: process.env.RL_FALLBACK_TO_BASE !== "false", // Default: true
|
|
335
335
|
baseModel,
|
|
336
336
|
modelTier,
|
|
337
337
|
availableVramGb: explicitVram,
|
|
@@ -353,7 +353,7 @@ export function isRLModelAvailable(): boolean {
|
|
|
353
353
|
// Need Atropos API URL to fetch RL models
|
|
354
354
|
if (!config.atroposApiUrl) {
|
|
355
355
|
console.warn(
|
|
356
|
-
|
|
356
|
+
"RL models enabled but Atropos API URL missing. Set ATROPOS_API_URL.",
|
|
357
357
|
);
|
|
358
358
|
return false;
|
|
359
359
|
}
|
|
@@ -370,22 +370,22 @@ export function logRLModelConfig(): void {
|
|
|
370
370
|
const tierConfig = MODEL_TIERS[config.modelTier];
|
|
371
371
|
const vramPerModel = getVramRequirement(
|
|
372
372
|
config.modelTier,
|
|
373
|
-
config.quantization
|
|
373
|
+
config.quantization,
|
|
374
374
|
);
|
|
375
375
|
|
|
376
|
-
console.log(
|
|
376
|
+
console.log("🤖 RL Model Configuration:", {
|
|
377
377
|
enabled: config.enabled,
|
|
378
378
|
available,
|
|
379
379
|
atroposConfigured: !!config.atroposApiUrl,
|
|
380
380
|
vllmPort: config.vllmPort,
|
|
381
|
-
pinnedVersion: config.modelVersion ||
|
|
381
|
+
pinnedVersion: config.modelVersion || "latest",
|
|
382
382
|
fallbackEnabled: config.fallbackToBase,
|
|
383
383
|
baseModel: config.baseModel,
|
|
384
384
|
modelTier: config.modelTier,
|
|
385
385
|
tierName: tierConfig.name,
|
|
386
386
|
tierParams: tierConfig.params,
|
|
387
387
|
contextWindow: tierConfig.context,
|
|
388
|
-
availableVramGb: config.availableVramGb ||
|
|
388
|
+
availableVramGb: config.availableVramGb || "auto",
|
|
389
389
|
quantization: config.quantization,
|
|
390
390
|
vramPerModel: `${vramPerModel}GB`,
|
|
391
391
|
maxConcurrentModels: config.multiModelConfig.maxConcurrentModels,
|
|
@@ -5,10 +5,10 @@
|
|
|
5
5
|
* This allows the RL model to learn from actual results, not just immediate actions.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
import {
|
|
9
|
-
import { logger } from
|
|
10
|
-
import { MarketOutcomesTracker } from
|
|
11
|
-
import type { TrajectoryStep } from
|
|
8
|
+
import { getMarketDataAdapter, getTrainingDataAdapter } from "../adapter";
|
|
9
|
+
import { logger } from "../utils/logger";
|
|
10
|
+
import { MarketOutcomesTracker } from "./MarketOutcomesTracker";
|
|
11
|
+
import type { TrajectoryStep } from "./types";
|
|
12
12
|
|
|
13
13
|
export class RewardBackpropagationService {
|
|
14
14
|
private outcomesTracker: MarketOutcomesTracker;
|
|
@@ -21,17 +21,18 @@ export class RewardBackpropagationService {
|
|
|
21
21
|
* Update rewards for trajectories in a window when outcomes become known
|
|
22
22
|
*/
|
|
23
23
|
async updateRewardsForWindow(windowId: string): Promise<number> {
|
|
24
|
-
logger.info(
|
|
24
|
+
logger.info("Updating rewards for window", { windowId });
|
|
25
25
|
|
|
26
26
|
// Get outcomes for this window
|
|
27
27
|
const outcomes = await this.outcomesTracker.getWindowOutcomes(windowId);
|
|
28
28
|
if (!outcomes) {
|
|
29
|
-
logger.info(
|
|
29
|
+
logger.info("No outcomes found for window", { windowId });
|
|
30
30
|
return 0;
|
|
31
31
|
}
|
|
32
32
|
|
|
33
33
|
// Get all trajectories for this window (filter to training data)
|
|
34
|
-
const allTrajectories =
|
|
34
|
+
const allTrajectories =
|
|
35
|
+
await getTrainingDataAdapter().getTrajectoriesByWindow(windowId);
|
|
35
36
|
const trajectoriesResult = allTrajectories.filter((t) => t.isTrainingData);
|
|
36
37
|
|
|
37
38
|
let updated = 0;
|
|
@@ -50,9 +51,9 @@ export class RewardBackpropagationService {
|
|
|
50
51
|
|
|
51
52
|
// Check if this step involved trading
|
|
52
53
|
if (
|
|
53
|
-
step.action.actionType.includes(
|
|
54
|
-
step.action.actionType.includes(
|
|
55
|
-
step.action.actionType.includes(
|
|
54
|
+
step.action.actionType.includes("TRADING") ||
|
|
55
|
+
step.action.actionType.includes("BUY") ||
|
|
56
|
+
step.action.actionType.includes("SELL")
|
|
56
57
|
) {
|
|
57
58
|
// Extract market ID from action parameters
|
|
58
59
|
const marketId = step.action.parameters?.marketId as
|
|
@@ -63,14 +64,14 @@ export class RewardBackpropagationService {
|
|
|
63
64
|
if (marketId) {
|
|
64
65
|
// Check prediction market outcome
|
|
65
66
|
const prediction = outcomes.predictions.find(
|
|
66
|
-
(p) => p.marketId === marketId
|
|
67
|
+
(p) => p.marketId === marketId,
|
|
67
68
|
);
|
|
68
69
|
if (prediction) {
|
|
69
70
|
// Calculate reward based on whether trade was correct
|
|
70
71
|
const side = step.action.parameters?.side as string | undefined;
|
|
71
72
|
const isCorrect =
|
|
72
|
-
(side ===
|
|
73
|
-
(side ===
|
|
73
|
+
(side === "YES" && prediction.outcome === "YES") ||
|
|
74
|
+
(side === "NO" && prediction.outcome === "NO");
|
|
74
75
|
|
|
75
76
|
// Reward: +1 for correct, -1 for incorrect (normalized)
|
|
76
77
|
updatedReward = isCorrect ? 1.0 : -1.0;
|
|
@@ -86,9 +87,9 @@ export class RewardBackpropagationService {
|
|
|
86
87
|
// Reward based on whether position direction matched price movement
|
|
87
88
|
// Long position: positive reward if price went up
|
|
88
89
|
// Short position: positive reward if price went down
|
|
89
|
-
if (side ===
|
|
90
|
+
if (side === "long") {
|
|
90
91
|
updatedReward = Math.max(-1, Math.min(1, priceChange / 10)); // Normalize to -1 to 1
|
|
91
|
-
} else if (side ===
|
|
92
|
+
} else if (side === "short") {
|
|
92
93
|
updatedReward = Math.max(-1, Math.min(1, -priceChange / 10)); // Inverted for short
|
|
93
94
|
}
|
|
94
95
|
}
|
|
@@ -107,13 +108,13 @@ export class RewardBackpropagationService {
|
|
|
107
108
|
await getTrainingDataAdapter().updateTrajectoryRewards(
|
|
108
109
|
traj.id,
|
|
109
110
|
JSON.stringify(steps),
|
|
110
|
-
totalReward
|
|
111
|
+
totalReward,
|
|
111
112
|
);
|
|
112
113
|
updated++;
|
|
113
114
|
}
|
|
114
115
|
}
|
|
115
116
|
|
|
116
|
-
logger.info(
|
|
117
|
+
logger.info("Updated rewards for trajectories", {
|
|
117
118
|
windowId,
|
|
118
119
|
updated,
|
|
119
120
|
total: trajectoriesResult.length,
|