@mastra/longmemeval 0.0.0-add-libsql-changeset-20250910154739
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +919 -0
- package/DATA_DOWNLOAD_GUIDE.md +117 -0
- package/LICENSE.md +15 -0
- package/README.md +173 -0
- package/USAGE.md +105 -0
- package/package.json +67 -0
- package/scripts/download.ts +180 -0
- package/scripts/find-failed.ts +176 -0
- package/scripts/generate-embeddings.ts +56 -0
- package/scripts/generate-wm-templates.ts +296 -0
- package/scripts/setup.ts +60 -0
- package/src/__fixtures__/embeddings.json +2319 -0
- package/src/__fixtures__/test-dataset.json +82 -0
- package/src/cli.ts +690 -0
- package/src/commands/__tests__/prepare.test.ts +230 -0
- package/src/commands/__tests__/run.test.ts +403 -0
- package/src/commands/prepare.ts +793 -0
- package/src/commands/run.ts +553 -0
- package/src/config.ts +83 -0
- package/src/data/loader.ts +163 -0
- package/src/data/types.ts +61 -0
- package/src/embeddings/cached-openai-embedding-model.ts +227 -0
- package/src/embeddings/cached-openai-provider.ts +40 -0
- package/src/embeddings/index.ts +2 -0
- package/src/evaluation/__tests__/longmemeval-metric.test.ts +169 -0
- package/src/evaluation/longmemeval-metric.ts +173 -0
- package/src/retry-model.ts +60 -0
- package/src/storage/__tests__/benchmark-store.test.ts +280 -0
- package/src/storage/__tests__/benchmark-vector.test.ts +214 -0
- package/src/storage/benchmark-store.ts +540 -0
- package/src/storage/benchmark-vector.ts +234 -0
- package/src/storage/index.ts +2 -0
- package/src/test-utils/mock-embeddings.ts +54 -0
- package/src/test-utils/mock-model.ts +49 -0
- package/tests/data-loader.test.ts +96 -0
- package/tsconfig.json +18 -0
- package/vitest.config.ts +9 -0
|
@@ -0,0 +1,553 @@
|
|
|
1
|
+
import { Agent } from '@mastra/core/agent';
|
|
2
|
+
import { Memory } from '@mastra/memory';
|
|
3
|
+
import { openai } from '@ai-sdk/openai';
|
|
4
|
+
import { cachedOpenAI } from '../embeddings/cached-openai-provider';
|
|
5
|
+
import chalk from 'chalk';
|
|
6
|
+
import ora, { Ora } from 'ora';
|
|
7
|
+
import { join } from 'path';
|
|
8
|
+
import { readdir, readFile, mkdir, writeFile } from 'fs/promises';
|
|
9
|
+
import { existsSync } from 'fs';
|
|
10
|
+
|
|
11
|
+
import { BenchmarkStore, BenchmarkVectorStore } from '../storage';
|
|
12
|
+
import { LongMemEvalMetric } from '../evaluation/longmemeval-metric';
|
|
13
|
+
import type { EvaluationResult, BenchmarkMetrics, QuestionType, MemoryConfigType, DatasetType } from '../data/types';
|
|
14
|
+
import { getMemoryOptions } from '../config';
|
|
15
|
+
import { makeRetryModel } from '../retry-model';
|
|
16
|
+
|
|
17
|
+
export interface RunOptions {
|
|
18
|
+
dataset: DatasetType;
|
|
19
|
+
memoryConfig: MemoryConfigType;
|
|
20
|
+
model: string;
|
|
21
|
+
preparedDataDir?: string;
|
|
22
|
+
outputDir?: string;
|
|
23
|
+
subset?: number;
|
|
24
|
+
concurrency?: number;
|
|
25
|
+
questionId?: string;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
interface PreparedQuestionMeta {
|
|
29
|
+
questionId: string;
|
|
30
|
+
questionType: string;
|
|
31
|
+
resourceId: string;
|
|
32
|
+
threadIds: string[];
|
|
33
|
+
memoryConfig: string;
|
|
34
|
+
question: string;
|
|
35
|
+
answer: string;
|
|
36
|
+
evidenceSessionIds?: string[];
|
|
37
|
+
questionDate?: string;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
const retry4o = makeRetryModel(openai('gpt-4o'));
|
|
41
|
+
|
|
42
|
+
export class RunCommand {
|
|
43
|
+
private preparedDataDir: string;
|
|
44
|
+
private outputDir: string;
|
|
45
|
+
|
|
46
|
+
constructor() {
|
|
47
|
+
this.preparedDataDir = './prepared-data';
|
|
48
|
+
this.outputDir = './results';
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
async run(options: RunOptions): Promise<BenchmarkMetrics> {
|
|
52
|
+
const runId = `run_${Date.now()}`;
|
|
53
|
+
const runDir = join(options.outputDir || this.outputDir, options.memoryConfig, runId);
|
|
54
|
+
await mkdir(runDir, { recursive: true });
|
|
55
|
+
|
|
56
|
+
console.log(chalk.blue(`\nš Starting LongMemEval benchmark run: ${runId}\n`));
|
|
57
|
+
console.log(chalk.gray(`Dataset: ${options.dataset}`));
|
|
58
|
+
console.log(chalk.gray(`Model: ${options.model}`));
|
|
59
|
+
console.log(chalk.gray(`Memory Config: ${options.memoryConfig}`));
|
|
60
|
+
if (options.subset) {
|
|
61
|
+
console.log(chalk.gray(`Subset: ${options.subset} questions`));
|
|
62
|
+
}
|
|
63
|
+
console.log();
|
|
64
|
+
|
|
65
|
+
const preparedDir = join(options.preparedDataDir || this.preparedDataDir, options.dataset, options.memoryConfig);
|
|
66
|
+
|
|
67
|
+
if (!existsSync(preparedDir)) {
|
|
68
|
+
throw new Error(`Prepared data not found at: ${preparedDir}\nPlease run 'longmemeval prepare' first.`);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
// Load prepared questions
|
|
72
|
+
const spinner = ora('Loading prepared data...').start();
|
|
73
|
+
const questionDirs = await readdir(preparedDir);
|
|
74
|
+
const preparedQuestions: PreparedQuestionMeta[] = [];
|
|
75
|
+
|
|
76
|
+
let skippedCount = 0;
|
|
77
|
+
let failedCount = 0;
|
|
78
|
+
for (const questionDir of questionDirs) {
|
|
79
|
+
const questionPath = join(preparedDir, questionDir);
|
|
80
|
+
const metaPath = join(questionPath, 'meta.json');
|
|
81
|
+
const progressPath = join(questionPath, 'progress.json');
|
|
82
|
+
|
|
83
|
+
// Check if question has been prepared
|
|
84
|
+
if (existsSync(metaPath)) {
|
|
85
|
+
// Check if there's an incomplete or failed preparation
|
|
86
|
+
if (existsSync(progressPath)) {
|
|
87
|
+
const progress = JSON.parse(await readFile(progressPath, 'utf-8'));
|
|
88
|
+
if (!progress.completed) {
|
|
89
|
+
skippedCount++;
|
|
90
|
+
continue; // Skip this question as it's still being prepared
|
|
91
|
+
}
|
|
92
|
+
if (progress.failed) {
|
|
93
|
+
failedCount++;
|
|
94
|
+
continue; // Skip this question as it failed to prepare
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
const meta = JSON.parse(await readFile(metaPath, 'utf-8'));
|
|
99
|
+
preparedQuestions.push(meta);
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
spinner.succeed(
|
|
104
|
+
`Loaded ${preparedQuestions.length} prepared questions${skippedCount > 0 || failedCount > 0 ? ` (${skippedCount} incomplete, ${failedCount} failed)` : ''}`,
|
|
105
|
+
);
|
|
106
|
+
|
|
107
|
+
if (skippedCount > 0) {
|
|
108
|
+
console.log(
|
|
109
|
+
chalk.yellow(
|
|
110
|
+
`\nā ļø ${skippedCount} question${skippedCount > 1 ? 's' : ''} skipped due to incomplete preparation.`,
|
|
111
|
+
),
|
|
112
|
+
);
|
|
113
|
+
console.log(chalk.gray(` Run 'prepare' command to complete preparation.\n`));
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
if (failedCount > 0) {
|
|
117
|
+
console.log(
|
|
118
|
+
chalk.red(`\nā ļø ${failedCount} question${failedCount > 1 ? 's' : ''} skipped due to failed preparation.`),
|
|
119
|
+
);
|
|
120
|
+
console.log(chalk.gray(` Check error logs and re-run 'prepare' command.\n`));
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
// Filter by questionId if specified
|
|
124
|
+
let questionsToProcess = preparedQuestions;
|
|
125
|
+
if (options.questionId) {
|
|
126
|
+
questionsToProcess = preparedQuestions.filter(q => q.questionId === options.questionId);
|
|
127
|
+
if (questionsToProcess.length === 0) {
|
|
128
|
+
throw new Error(`Question with ID "${options.questionId}" not found in prepared data`);
|
|
129
|
+
}
|
|
130
|
+
console.log(chalk.yellow(`\nFocusing on question: ${options.questionId}\n`));
|
|
131
|
+
} else if (options.subset) {
|
|
132
|
+
// Apply subset if requested
|
|
133
|
+
questionsToProcess = preparedQuestions.slice(0, options.subset);
|
|
134
|
+
console.log(
|
|
135
|
+
chalk.gray(`\nApplying subset: ${options.subset} questions from ${preparedQuestions.length} total\n`),
|
|
136
|
+
);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
console.log(
|
|
140
|
+
chalk.yellow(`\nEvaluating ${questionsToProcess.length} question${questionsToProcess.length !== 1 ? 's' : ''}\n`),
|
|
141
|
+
);
|
|
142
|
+
|
|
143
|
+
// Process questions with concurrency control
|
|
144
|
+
const results: EvaluationResult[] = [];
|
|
145
|
+
const concurrency = options.concurrency || 5;
|
|
146
|
+
const questionSpinner = ora('Evaluating questions...').start();
|
|
147
|
+
|
|
148
|
+
let completedCount = 0;
|
|
149
|
+
let inProgressCount = 0;
|
|
150
|
+
const startTime = Date.now();
|
|
151
|
+
|
|
152
|
+
// Track active evaluations
|
|
153
|
+
const activeEvaluations = new Map<number, { questionId: string; status: string }>();
|
|
154
|
+
|
|
155
|
+
// Function to update progress display
|
|
156
|
+
let lastText = '';
|
|
157
|
+
const updateProgress = () => {
|
|
158
|
+
const elapsed = Math.round((Date.now() - startTime) / 1000);
|
|
159
|
+
const rate = elapsed > 0 ? completedCount / elapsed : 0;
|
|
160
|
+
const remaining = rate > 0 ? Math.round((questionsToProcess.length - completedCount) / rate) : 0;
|
|
161
|
+
|
|
162
|
+
let progressText = `Overall: ${completedCount}/${questionsToProcess.length} (${inProgressCount} in progress, ${Math.round(rate * 60)} q/min, ~${remaining}s remaining)`;
|
|
163
|
+
|
|
164
|
+
if (activeEvaluations.size > 0 && concurrency > 1) {
|
|
165
|
+
progressText += '\n\nActive evaluations:';
|
|
166
|
+
|
|
167
|
+
// Sort active evaluations by completion status
|
|
168
|
+
const sortedEvaluations = Array.from(activeEvaluations.entries())
|
|
169
|
+
.map(([index, info]) => {
|
|
170
|
+
// Assign progress based on status
|
|
171
|
+
let progress = 0;
|
|
172
|
+
if (info.status.includes('Querying agent')) progress = 0.75;
|
|
173
|
+
else if (info.status.includes('Loading vector')) progress = 0.5;
|
|
174
|
+
else if (info.status.includes('Loading data')) progress = 0.25;
|
|
175
|
+
else if (info.status.includes('Starting')) progress = 0.0;
|
|
176
|
+
|
|
177
|
+
return { index, info, progress };
|
|
178
|
+
})
|
|
179
|
+
.sort((a, b) => b.progress - a.progress); // Sort by most complete first
|
|
180
|
+
|
|
181
|
+
sortedEvaluations.forEach(({ index, info, progress }) => {
|
|
182
|
+
const percentage = (progress * 100).toFixed(0);
|
|
183
|
+
progressText += `\n [${index + 1}] ${info.questionId} - ${info.status} (${percentage}%)`;
|
|
184
|
+
});
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
if (lastText !== progressText) {
|
|
188
|
+
lastText = progressText;
|
|
189
|
+
questionSpinner.text = progressText;
|
|
190
|
+
}
|
|
191
|
+
};
|
|
192
|
+
|
|
193
|
+
// Create a queue of questions to evaluate
|
|
194
|
+
const questionQueue = [...questionsToProcess];
|
|
195
|
+
|
|
196
|
+
// Function to process next question from queue
|
|
197
|
+
const processNextQuestion = async (slotIndex: number): Promise<EvaluationResult[]> => {
|
|
198
|
+
const workerResults: EvaluationResult[] = [];
|
|
199
|
+
|
|
200
|
+
while (questionQueue.length > 0) {
|
|
201
|
+
const meta = questionQueue.shift();
|
|
202
|
+
if (!meta) break;
|
|
203
|
+
|
|
204
|
+
inProgressCount++;
|
|
205
|
+
activeEvaluations.set(slotIndex, { questionId: meta.questionId, status: 'Starting...' });
|
|
206
|
+
// Don't update progress here - let the periodic timer handle it
|
|
207
|
+
|
|
208
|
+
const result = await this.evaluateQuestion(
|
|
209
|
+
meta,
|
|
210
|
+
preparedDir,
|
|
211
|
+
retry4o.model,
|
|
212
|
+
options,
|
|
213
|
+
concurrency > 1
|
|
214
|
+
? {
|
|
215
|
+
updateStatus: (status: string) => {
|
|
216
|
+
activeEvaluations.set(slotIndex, { questionId: meta.questionId, status });
|
|
217
|
+
},
|
|
218
|
+
}
|
|
219
|
+
: questionSpinner,
|
|
220
|
+
);
|
|
221
|
+
|
|
222
|
+
completedCount++;
|
|
223
|
+
inProgressCount--;
|
|
224
|
+
activeEvaluations.delete(slotIndex);
|
|
225
|
+
|
|
226
|
+
// Log result when running concurrently
|
|
227
|
+
if (concurrency > 1) {
|
|
228
|
+
// Temporarily clear the spinner to log cleanly
|
|
229
|
+
questionSpinner.clear();
|
|
230
|
+
|
|
231
|
+
console.log(
|
|
232
|
+
chalk.blue(`ā¶ ${meta.questionId}`),
|
|
233
|
+
chalk.gray(`(${meta.questionType})`),
|
|
234
|
+
chalk[result.is_correct ? 'green' : 'red'](`${result.is_correct ? 'ā' : 'ā'}`),
|
|
235
|
+
chalk.gray(`${((Date.now() - startTime) / 1000).toFixed(1)}s`),
|
|
236
|
+
);
|
|
237
|
+
if (!result.is_correct) {
|
|
238
|
+
console.log(chalk.gray(` Q: "${meta.question}"`));
|
|
239
|
+
console.log(chalk.gray(` A: "${result.hypothesis}"`));
|
|
240
|
+
console.log(chalk.yellow(` Expected: "${meta.answer}"`));
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
// Re-render the spinner
|
|
244
|
+
questionSpinner.render();
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
// Don't update progress here - let the periodic timer handle it
|
|
248
|
+
workerResults.push(result);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
return workerResults;
|
|
252
|
+
};
|
|
253
|
+
|
|
254
|
+
// Set up periodic progress updates
|
|
255
|
+
const progressInterval = setInterval(updateProgress, 500);
|
|
256
|
+
|
|
257
|
+
// Create worker slots
|
|
258
|
+
const workers = Array.from({ length: concurrency }, (_, i) => processNextQuestion(i));
|
|
259
|
+
|
|
260
|
+
// Wait for all workers to complete and collect results
|
|
261
|
+
const workerResults = await Promise.all(workers);
|
|
262
|
+
|
|
263
|
+
// Process results from all workers
|
|
264
|
+
for (const workerResultArray of workerResults) {
|
|
265
|
+
results.push(...workerResultArray);
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
// Clear the interval
|
|
269
|
+
clearInterval(progressInterval);
|
|
270
|
+
|
|
271
|
+
questionSpinner.succeed(`Evaluated ${results.length} questions`);
|
|
272
|
+
|
|
273
|
+
// Calculate metrics
|
|
274
|
+
console.log(chalk.blue('\nš Calculating metrics...\n'));
|
|
275
|
+
const metrics = this.calculateMetrics(results);
|
|
276
|
+
|
|
277
|
+
// Save results
|
|
278
|
+
await this.saveResults(runDir, results, metrics, options);
|
|
279
|
+
|
|
280
|
+
// Display results
|
|
281
|
+
this.displayMetrics(metrics, options);
|
|
282
|
+
|
|
283
|
+
return metrics;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
private async evaluateQuestion(
|
|
287
|
+
meta: PreparedQuestionMeta,
|
|
288
|
+
preparedDir: string,
|
|
289
|
+
modelProvider: any,
|
|
290
|
+
options: RunOptions,
|
|
291
|
+
spinner?: Ora | { updateStatus: (status: string) => void },
|
|
292
|
+
): Promise<EvaluationResult> {
|
|
293
|
+
const questionStart = Date.now();
|
|
294
|
+
|
|
295
|
+
// Update status
|
|
296
|
+
const updateStatus = (status: string) => {
|
|
297
|
+
if (spinner && 'updateStatus' in spinner) {
|
|
298
|
+
spinner.updateStatus(status);
|
|
299
|
+
} else if (spinner && 'text' in spinner) {
|
|
300
|
+
spinner.text = status;
|
|
301
|
+
}
|
|
302
|
+
};
|
|
303
|
+
|
|
304
|
+
updateStatus(`Loading data for ${meta.questionId}...`);
|
|
305
|
+
|
|
306
|
+
// Load the prepared storage and vector store
|
|
307
|
+
const questionDir = join(preparedDir, meta.questionId);
|
|
308
|
+
const benchmarkStore = new BenchmarkStore('read');
|
|
309
|
+
const benchmarkVectorStore = new BenchmarkVectorStore('read');
|
|
310
|
+
|
|
311
|
+
await benchmarkStore.init();
|
|
312
|
+
await benchmarkStore.hydrate(join(questionDir, 'db.json'));
|
|
313
|
+
|
|
314
|
+
// Hydrate vector store if it exists
|
|
315
|
+
const vectorPath = join(questionDir, 'vector.json');
|
|
316
|
+
if (existsSync(vectorPath)) {
|
|
317
|
+
await benchmarkVectorStore.hydrate(vectorPath);
|
|
318
|
+
updateStatus(`Loading vector embeddings for ${meta.questionId}...`);
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
const memoryOptions = getMemoryOptions(options.memoryConfig);
|
|
322
|
+
|
|
323
|
+
// Create memory with the hydrated stores
|
|
324
|
+
const memory = new Memory({
|
|
325
|
+
storage: benchmarkStore,
|
|
326
|
+
vector: benchmarkVectorStore,
|
|
327
|
+
embedder: cachedOpenAI.embedding('text-embedding-3-small'),
|
|
328
|
+
options: memoryOptions.options,
|
|
329
|
+
});
|
|
330
|
+
|
|
331
|
+
// Create agent with the specified model
|
|
332
|
+
const agentInstructions = `You are a helpful assistant with access to extensive conversation history.
|
|
333
|
+
When answering questions, carefully review the conversation history to identify and use any relevant user preferences, interests, or specific details they have mentioned.
|
|
334
|
+
For example, if the user previously mentioned they prefer a specific software, tool, or approach, tailor your recommendations to match their stated preferences.
|
|
335
|
+
Be specific rather than generic when the user has expressed clear preferences in past conversations. If there is a clear preference, focus in on that, and do not add additional irrelevant information.`;
|
|
336
|
+
|
|
337
|
+
const agent = new Agent({
|
|
338
|
+
name: 'longmemeval-agent',
|
|
339
|
+
model: modelProvider,
|
|
340
|
+
instructions: agentInstructions,
|
|
341
|
+
memory,
|
|
342
|
+
});
|
|
343
|
+
|
|
344
|
+
// Create a fresh thread for the evaluation question
|
|
345
|
+
const evalThreadId = `eval_${meta.questionId}_${Date.now()}`;
|
|
346
|
+
|
|
347
|
+
updateStatus(`${meta.threadIds.length} sessions, ${options.memoryConfig}`);
|
|
348
|
+
|
|
349
|
+
const response = await agent.generate(meta.question, {
|
|
350
|
+
threadId: evalThreadId,
|
|
351
|
+
resourceId: meta.resourceId,
|
|
352
|
+
temperature: 0,
|
|
353
|
+
context: meta.questionDate ? [{ role: 'system', content: `Todays date is ${meta.questionDate}` }] : undefined,
|
|
354
|
+
});
|
|
355
|
+
|
|
356
|
+
const evalAgent = new Agent({
|
|
357
|
+
name: 'longmemeval-metric-agent',
|
|
358
|
+
model: retry4o.model,
|
|
359
|
+
instructions: 'You are an evaluation assistant. Answer questions precisely and concisely.',
|
|
360
|
+
});
|
|
361
|
+
|
|
362
|
+
const metric = new LongMemEvalMetric({
|
|
363
|
+
agent: evalAgent,
|
|
364
|
+
questionType: meta.questionType as any,
|
|
365
|
+
isAbstention: meta.questionId.endsWith('_abs'),
|
|
366
|
+
});
|
|
367
|
+
|
|
368
|
+
const input = JSON.stringify({
|
|
369
|
+
question: meta.question,
|
|
370
|
+
answer: meta.answer,
|
|
371
|
+
});
|
|
372
|
+
|
|
373
|
+
const result = await metric.measure(input, response.text);
|
|
374
|
+
const isCorrect = result.score === 1;
|
|
375
|
+
|
|
376
|
+
const elapsed = ((Date.now() - questionStart) / 1000).toFixed(1);
|
|
377
|
+
|
|
378
|
+
const isOraSpinner = spinner && 'clear' in spinner;
|
|
379
|
+
if (isOraSpinner) {
|
|
380
|
+
console.log(
|
|
381
|
+
chalk.blue(`ā¶ ${meta.questionId}`),
|
|
382
|
+
chalk.gray(`(${meta.questionType})`),
|
|
383
|
+
chalk[isCorrect ? 'green' : 'red'](`${isCorrect ? 'ā' : 'ā'}`),
|
|
384
|
+
chalk.gray(`${elapsed}s`),
|
|
385
|
+
);
|
|
386
|
+
if (!isCorrect) {
|
|
387
|
+
console.log(chalk.gray(` Q: "${meta.question}"`));
|
|
388
|
+
console.log(chalk.gray(` A: "${response.text}"`));
|
|
389
|
+
console.log(chalk.yellow(` Expected: "${meta.answer}"`));
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
return {
|
|
394
|
+
question_id: meta.questionId,
|
|
395
|
+
hypothesis: response.text,
|
|
396
|
+
question_type: meta.questionType as QuestionType,
|
|
397
|
+
is_correct: isCorrect,
|
|
398
|
+
};
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
private async saveResults(
|
|
402
|
+
runDir: string,
|
|
403
|
+
results: EvaluationResult[],
|
|
404
|
+
metrics: BenchmarkMetrics,
|
|
405
|
+
options: RunOptions,
|
|
406
|
+
): Promise<void> {
|
|
407
|
+
// Save raw results
|
|
408
|
+
const resultsPath = join(runDir, 'results.jsonl');
|
|
409
|
+
const resultsContent = results.map(r => JSON.stringify(r)).join('\n');
|
|
410
|
+
await writeFile(resultsPath, resultsContent);
|
|
411
|
+
|
|
412
|
+
// Save metrics
|
|
413
|
+
const metricsPath = join(runDir, 'metrics.json');
|
|
414
|
+
const metricsData = {
|
|
415
|
+
...metrics,
|
|
416
|
+
config: {
|
|
417
|
+
dataset: options.dataset,
|
|
418
|
+
model: options.model,
|
|
419
|
+
memoryConfig: options.memoryConfig,
|
|
420
|
+
subset: options.subset,
|
|
421
|
+
},
|
|
422
|
+
timestamp: new Date().toISOString(),
|
|
423
|
+
};
|
|
424
|
+
await writeFile(metricsPath, JSON.stringify(metricsData, null, 2));
|
|
425
|
+
|
|
426
|
+
console.log(chalk.gray(`\nResults saved to: ${runDir}`));
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
private calculateMetrics(results: EvaluationResult[]): BenchmarkMetrics {
|
|
430
|
+
const metrics: BenchmarkMetrics = {
|
|
431
|
+
overall_accuracy: 0,
|
|
432
|
+
accuracy_by_type: {} as Record<QuestionType, { correct: number; total: number; accuracy: number }>,
|
|
433
|
+
abstention_accuracy: 0,
|
|
434
|
+
total_questions: results.length,
|
|
435
|
+
correct_answers: 0,
|
|
436
|
+
abstention_correct: 0,
|
|
437
|
+
abstention_total: 0,
|
|
438
|
+
};
|
|
439
|
+
|
|
440
|
+
// Calculate overall metrics
|
|
441
|
+
for (const result of results) {
|
|
442
|
+
if (result.is_correct) {
|
|
443
|
+
metrics.correct_answers++;
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
// Track by question type
|
|
447
|
+
if (result.question_type) {
|
|
448
|
+
const type = result.question_type;
|
|
449
|
+
if (!metrics.accuracy_by_type[type]) {
|
|
450
|
+
metrics.accuracy_by_type[type] = { correct: 0, total: 0, accuracy: 0 };
|
|
451
|
+
}
|
|
452
|
+
metrics.accuracy_by_type[type].total++;
|
|
453
|
+
if (result.is_correct) {
|
|
454
|
+
metrics.accuracy_by_type[type].correct++;
|
|
455
|
+
}
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
// Track abstention separately
|
|
459
|
+
if (result.question_id.endsWith('_abs')) {
|
|
460
|
+
metrics.abstention_total = (metrics.abstention_total || 0) + 1;
|
|
461
|
+
if (result.is_correct) {
|
|
462
|
+
metrics.abstention_correct = (metrics.abstention_correct || 0) + 1;
|
|
463
|
+
}
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
// Calculate per-type accuracies first
|
|
468
|
+
for (const type in metrics.accuracy_by_type) {
|
|
469
|
+
const typeMetrics = metrics.accuracy_by_type[type as QuestionType];
|
|
470
|
+
if (typeMetrics) {
|
|
471
|
+
typeMetrics.accuracy = typeMetrics.total > 0 ? typeMetrics.correct / typeMetrics.total : 0;
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
if (metrics.abstention_total && metrics.abstention_total > 0) {
|
|
476
|
+
metrics.abstention_accuracy = (metrics.abstention_correct || 0) / metrics.abstention_total;
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
// Calculate overall accuracy as average of all question type accuracies (excluding abstention)
|
|
480
|
+
const allTypeAccuracies = Object.values(metrics.accuracy_by_type).map(t => t.accuracy);
|
|
481
|
+
|
|
482
|
+
// Debug: Log the exact values being averaged
|
|
483
|
+
console.log('\nDebug - Question type accuracies being averaged:');
|
|
484
|
+
Object.entries(metrics.accuracy_by_type).forEach(([type, data]) => {
|
|
485
|
+
console.log(` ${type}: ${data.accuracy} (${data.correct}/${data.total})`);
|
|
486
|
+
});
|
|
487
|
+
console.log(` Sum: ${allTypeAccuracies.reduce((sum, acc) => sum + acc, 0)}`);
|
|
488
|
+
console.log(` Count: ${allTypeAccuracies.length}`);
|
|
489
|
+
|
|
490
|
+
metrics.overall_accuracy =
|
|
491
|
+
allTypeAccuracies.length > 0
|
|
492
|
+
? allTypeAccuracies.reduce((sum, acc) => sum + acc, 0) / allTypeAccuracies.length
|
|
493
|
+
: 0;
|
|
494
|
+
|
|
495
|
+
console.log(` Calculated overall: ${metrics.overall_accuracy}`);
|
|
496
|
+
console.log(` As percentage: ${(metrics.overall_accuracy * 100).toFixed(10)}%\n`);
|
|
497
|
+
|
|
498
|
+
return metrics;
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
private displayMetrics(metrics: BenchmarkMetrics, options?: RunOptions): void {
|
|
502
|
+
console.log(chalk.bold('\nš Benchmark Results\n'));
|
|
503
|
+
|
|
504
|
+
// Display configuration if provided
|
|
505
|
+
if (options) {
|
|
506
|
+
console.log(chalk.bold('Configuration:\n'));
|
|
507
|
+
console.log(chalk.gray('Dataset:'), chalk.cyan(options.dataset));
|
|
508
|
+
console.log(chalk.gray('Model:'), chalk.cyan(options.model));
|
|
509
|
+
console.log(chalk.gray('Memory Config:'), chalk.cyan(options.memoryConfig));
|
|
510
|
+
if (options.subset) {
|
|
511
|
+
console.log(chalk.gray('Subset:'), chalk.cyan(`${options.subset} questions`));
|
|
512
|
+
}
|
|
513
|
+
// Get terminal width
|
|
514
|
+
const terminalWidth = process.stdout.columns || 80;
|
|
515
|
+
const lineWidth = Math.min(terminalWidth - 1, 60);
|
|
516
|
+
console.log(chalk.gray('ā'.repeat(lineWidth)));
|
|
517
|
+
console.log();
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
// Question type breakdown
|
|
521
|
+
console.log(chalk.bold('Accuracy by Question Type:'));
|
|
522
|
+
|
|
523
|
+
// Sort question types alphabetically
|
|
524
|
+
const sortedTypes = Object.entries(metrics.accuracy_by_type).sort(([a], [b]) => a.localeCompare(b));
|
|
525
|
+
|
|
526
|
+
// Display regular question types
|
|
527
|
+
for (const [type, typeMetrics] of sortedTypes) {
|
|
528
|
+
const { correct, total, accuracy } = typeMetrics;
|
|
529
|
+
const typeColor = accuracy >= 0.8 ? 'green' : accuracy >= 0.6 ? 'yellow' : 'red';
|
|
530
|
+
|
|
531
|
+
// Create a simple progress bar
|
|
532
|
+
const barLength = 20;
|
|
533
|
+
const filledLength = Math.round(accuracy * barLength);
|
|
534
|
+
const bar = 'ā'.repeat(filledLength) + 'ā'.repeat(barLength - filledLength);
|
|
535
|
+
|
|
536
|
+
console.log(
|
|
537
|
+
chalk.gray(` ${type.padEnd(25)}:`),
|
|
538
|
+
chalk[typeColor](`${(accuracy * 100).toFixed(1).padStart(5)}%`),
|
|
539
|
+
chalk.gray(`[${bar}]`),
|
|
540
|
+
chalk.gray(`(${correct}/${total})`),
|
|
541
|
+
);
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
console.log();
|
|
545
|
+
const accuracyColor =
|
|
546
|
+
metrics.overall_accuracy >= 0.8 ? 'green' : metrics.overall_accuracy >= 0.6 ? 'yellow' : 'red';
|
|
547
|
+
console.log(
|
|
548
|
+
chalk.bold('Overall Accuracy:'),
|
|
549
|
+
chalk[accuracyColor](`${(metrics.overall_accuracy * 100).toFixed(2)}%`),
|
|
550
|
+
chalk.gray(`(average of ${Object.keys(metrics.accuracy_by_type).length} question types)`),
|
|
551
|
+
);
|
|
552
|
+
}
|
|
553
|
+
}
|
package/src/config.ts
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import { MemoryConfigOptions } from './data/types';
|
|
2
|
+
|
|
3
|
+
const semanticRecall = {
|
|
4
|
+
topK: 10,
|
|
5
|
+
messageRange: 2,
|
|
6
|
+
scope: 'resource',
|
|
7
|
+
} as const;
|
|
8
|
+
|
|
9
|
+
const lastMessages = 10;
|
|
10
|
+
|
|
11
|
+
export function getMemoryOptions(memoryConfig: string): MemoryConfigOptions {
|
|
12
|
+
switch (memoryConfig) {
|
|
13
|
+
case 'semantic-recall':
|
|
14
|
+
return {
|
|
15
|
+
type: 'semantic-recall',
|
|
16
|
+
options: {
|
|
17
|
+
lastMessages,
|
|
18
|
+
semanticRecall,
|
|
19
|
+
workingMemory: { enabled: false },
|
|
20
|
+
},
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
case 'working-memory':
|
|
24
|
+
return {
|
|
25
|
+
type: 'working-memory',
|
|
26
|
+
options: {
|
|
27
|
+
lastMessages,
|
|
28
|
+
semanticRecall: false,
|
|
29
|
+
workingMemory: {
|
|
30
|
+
enabled: true,
|
|
31
|
+
scope: 'resource',
|
|
32
|
+
version: 'vnext',
|
|
33
|
+
},
|
|
34
|
+
},
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
// tailored means a custom working memory template is passed in per-question - to align with how working memory is intended to be used to track specific relevant information.
|
|
38
|
+
case 'working-memory-tailored':
|
|
39
|
+
return {
|
|
40
|
+
type: 'working-memory',
|
|
41
|
+
options: {
|
|
42
|
+
lastMessages,
|
|
43
|
+
semanticRecall: false,
|
|
44
|
+
workingMemory: {
|
|
45
|
+
enabled: true,
|
|
46
|
+
scope: 'resource',
|
|
47
|
+
version: 'vnext',
|
|
48
|
+
},
|
|
49
|
+
},
|
|
50
|
+
};
|
|
51
|
+
|
|
52
|
+
// Combined means semantic recall + working memory
|
|
53
|
+
case 'combined':
|
|
54
|
+
return {
|
|
55
|
+
type: 'combined',
|
|
56
|
+
options: {
|
|
57
|
+
lastMessages,
|
|
58
|
+
semanticRecall,
|
|
59
|
+
workingMemory: {
|
|
60
|
+
enabled: true,
|
|
61
|
+
scope: 'resource',
|
|
62
|
+
},
|
|
63
|
+
},
|
|
64
|
+
};
|
|
65
|
+
|
|
66
|
+
case 'combined-tailored':
|
|
67
|
+
return {
|
|
68
|
+
type: 'combined-tailored',
|
|
69
|
+
options: {
|
|
70
|
+
lastMessages,
|
|
71
|
+
semanticRecall,
|
|
72
|
+
workingMemory: {
|
|
73
|
+
enabled: true,
|
|
74
|
+
scope: 'resource',
|
|
75
|
+
version: 'vnext',
|
|
76
|
+
},
|
|
77
|
+
},
|
|
78
|
+
};
|
|
79
|
+
|
|
80
|
+
default:
|
|
81
|
+
throw new Error(`Unknown memory config: ${memoryConfig}`);
|
|
82
|
+
}
|
|
83
|
+
}
|