@mastra/longmemeval 0.0.0-add-libsql-changeset-20250910154739
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +919 -0
- package/DATA_DOWNLOAD_GUIDE.md +117 -0
- package/LICENSE.md +15 -0
- package/README.md +173 -0
- package/USAGE.md +105 -0
- package/package.json +67 -0
- package/scripts/download.ts +180 -0
- package/scripts/find-failed.ts +176 -0
- package/scripts/generate-embeddings.ts +56 -0
- package/scripts/generate-wm-templates.ts +296 -0
- package/scripts/setup.ts +60 -0
- package/src/__fixtures__/embeddings.json +2319 -0
- package/src/__fixtures__/test-dataset.json +82 -0
- package/src/cli.ts +690 -0
- package/src/commands/__tests__/prepare.test.ts +230 -0
- package/src/commands/__tests__/run.test.ts +403 -0
- package/src/commands/prepare.ts +793 -0
- package/src/commands/run.ts +553 -0
- package/src/config.ts +83 -0
- package/src/data/loader.ts +163 -0
- package/src/data/types.ts +61 -0
- package/src/embeddings/cached-openai-embedding-model.ts +227 -0
- package/src/embeddings/cached-openai-provider.ts +40 -0
- package/src/embeddings/index.ts +2 -0
- package/src/evaluation/__tests__/longmemeval-metric.test.ts +169 -0
- package/src/evaluation/longmemeval-metric.ts +173 -0
- package/src/retry-model.ts +60 -0
- package/src/storage/__tests__/benchmark-store.test.ts +280 -0
- package/src/storage/__tests__/benchmark-vector.test.ts +214 -0
- package/src/storage/benchmark-store.ts +540 -0
- package/src/storage/benchmark-vector.ts +234 -0
- package/src/storage/index.ts +2 -0
- package/src/test-utils/mock-embeddings.ts +54 -0
- package/src/test-utils/mock-model.ts +49 -0
- package/tests/data-loader.test.ts +96 -0
- package/tsconfig.json +18 -0
- package/vitest.config.ts +9 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import { readdir, readFile, unlink, rmdir } from 'fs/promises';
|
|
2
|
+
import { existsSync } from 'fs';
|
|
3
|
+
import { join } from 'path';
|
|
4
|
+
import chalk from 'chalk';
|
|
5
|
+
|
|
6
|
+
interface FailedQuestion {
|
|
7
|
+
questionId: string;
|
|
8
|
+
dataset: string;
|
|
9
|
+
memoryConfig: string;
|
|
10
|
+
error: string;
|
|
11
|
+
failedAt: string;
|
|
12
|
+
path: string;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
async function findFailedQuestions(baseDir: string = './prepared-data'): Promise<FailedQuestion[]> {
|
|
16
|
+
const failed: FailedQuestion[] = [];
|
|
17
|
+
|
|
18
|
+
console.log(chalk.gray(`Scanning directory: ${baseDir}`));
|
|
19
|
+
|
|
20
|
+
if (!existsSync(baseDir)) {
|
|
21
|
+
console.error(chalk.red(`Base directory not found: ${baseDir}`));
|
|
22
|
+
return failed;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
try {
|
|
26
|
+
// Iterate through datasets
|
|
27
|
+
const datasets = await readdir(baseDir);
|
|
28
|
+
console.log(chalk.gray(`Found datasets: ${datasets.join(', ')}`));
|
|
29
|
+
|
|
30
|
+
for (const dataset of datasets) {
|
|
31
|
+
const datasetPath = join(baseDir, dataset);
|
|
32
|
+
const stat = await readdir(datasetPath).catch(() => null);
|
|
33
|
+
if (!stat) continue;
|
|
34
|
+
|
|
35
|
+
// Iterate through memory configs
|
|
36
|
+
const configs = await readdir(datasetPath);
|
|
37
|
+
console.log(chalk.gray(` ${dataset} configs: ${configs.join(', ')}`));
|
|
38
|
+
|
|
39
|
+
for (const config of configs) {
|
|
40
|
+
const configPath = join(datasetPath, config);
|
|
41
|
+
const configStat = await readdir(configPath).catch(() => null);
|
|
42
|
+
if (!configStat) continue;
|
|
43
|
+
|
|
44
|
+
// Iterate through questions
|
|
45
|
+
const questions = await readdir(configPath);
|
|
46
|
+
console.log(chalk.gray(` ${config}: ${questions.length} questions`));
|
|
47
|
+
|
|
48
|
+
let progressFound = 0;
|
|
49
|
+
let failedFound = 0;
|
|
50
|
+
|
|
51
|
+
for (const questionId of questions) {
|
|
52
|
+
const questionPath = join(configPath, questionId);
|
|
53
|
+
const progressPath = join(questionPath, 'progress.json');
|
|
54
|
+
|
|
55
|
+
// Check if progress.json exists and has failed status
|
|
56
|
+
if (existsSync(progressPath)) {
|
|
57
|
+
progressFound++;
|
|
58
|
+
try {
|
|
59
|
+
const progress = JSON.parse(await readFile(progressPath, 'utf-8'));
|
|
60
|
+
|
|
61
|
+
if (progress.failed === true) {
|
|
62
|
+
failedFound++;
|
|
63
|
+
failed.push({
|
|
64
|
+
questionId,
|
|
65
|
+
dataset,
|
|
66
|
+
memoryConfig: config,
|
|
67
|
+
error: progress.error || 'Unknown error',
|
|
68
|
+
failedAt: progress.failedAt || 'Unknown time',
|
|
69
|
+
path: questionPath,
|
|
70
|
+
});
|
|
71
|
+
}
|
|
72
|
+
} catch (e) {
|
|
73
|
+
console.error(chalk.red(`Error reading progress for ${questionId}:`, e));
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
if (progressFound > 0) {
|
|
79
|
+
console.log(chalk.gray(` Progress files found: ${progressFound}, Failed: ${failedFound}`));
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
} catch (error) {
|
|
84
|
+
console.error(chalk.red('Error scanning directories:'), error);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
return failed;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
async function deleteQuestionDir(path: string): Promise<void> {
|
|
91
|
+
// Recursively delete directory
|
|
92
|
+
const entries = await readdir(path, { withFileTypes: true });
|
|
93
|
+
|
|
94
|
+
for (const entry of entries) {
|
|
95
|
+
const fullPath = join(path, entry.name);
|
|
96
|
+
if (entry.isDirectory()) {
|
|
97
|
+
await deleteQuestionDir(fullPath);
|
|
98
|
+
} else {
|
|
99
|
+
await unlink(fullPath);
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
await rmdir(path);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
async function main() {
|
|
107
|
+
const args = process.argv.slice(2);
|
|
108
|
+
const shouldDelete = args.includes('--delete');
|
|
109
|
+
const dataset = args.find(arg => arg.startsWith('--dataset='))?.split('=')[1];
|
|
110
|
+
const config = args.find(arg => arg.startsWith('--config='))?.split('=')[1];
|
|
111
|
+
|
|
112
|
+
console.log(chalk.blue('\n🔍 Finding failed questions...\n'));
|
|
113
|
+
|
|
114
|
+
const failed = await findFailedQuestions();
|
|
115
|
+
|
|
116
|
+
// Filter by dataset/config if specified
|
|
117
|
+
let filtered = failed;
|
|
118
|
+
if (dataset) {
|
|
119
|
+
filtered = filtered.filter(f => f.dataset === dataset);
|
|
120
|
+
}
|
|
121
|
+
if (config) {
|
|
122
|
+
filtered = filtered.filter(f => f.memoryConfig === config);
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
if (filtered.length === 0) {
|
|
126
|
+
console.log(chalk.green('✅ No failed questions found!\n'));
|
|
127
|
+
return;
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
// Group by dataset and config
|
|
131
|
+
const grouped = filtered.reduce(
|
|
132
|
+
(acc, f) => {
|
|
133
|
+
const key = `${f.dataset}/${f.memoryConfig}`;
|
|
134
|
+
if (!acc[key]) acc[key] = [];
|
|
135
|
+
acc[key].push(f);
|
|
136
|
+
return acc;
|
|
137
|
+
},
|
|
138
|
+
{} as Record<string, FailedQuestion[]>,
|
|
139
|
+
);
|
|
140
|
+
|
|
141
|
+
// Display results
|
|
142
|
+
console.log(chalk.red(`Found ${filtered.length} failed questions:\n`));
|
|
143
|
+
|
|
144
|
+
for (const [group, questions] of Object.entries(grouped)) {
|
|
145
|
+
console.log(chalk.yellow(`\n${group}:`));
|
|
146
|
+
|
|
147
|
+
for (const q of questions) {
|
|
148
|
+
console.log(chalk.gray(` - ${q.questionId}`));
|
|
149
|
+
console.log(chalk.gray(` Error: ${q.error.substring(0, 100)}${q.error.length > 100 ? '...' : ''}`));
|
|
150
|
+
console.log(chalk.gray(` Failed at: ${q.failedAt}`));
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
if (shouldDelete) {
|
|
155
|
+
console.log(chalk.yellow('\n⚠️ Deleting failed question directories...\n'));
|
|
156
|
+
|
|
157
|
+
for (const q of filtered) {
|
|
158
|
+
try {
|
|
159
|
+
await deleteQuestionDir(q.path);
|
|
160
|
+
console.log(chalk.green(`✓ Deleted ${q.questionId}`));
|
|
161
|
+
} catch (error) {
|
|
162
|
+
console.error(chalk.red(`✗ Failed to delete ${q.questionId}:`, error));
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
console.log(chalk.green(`\n✅ Deleted ${filtered.length} failed question directories\n`));
|
|
167
|
+
} else {
|
|
168
|
+
console.log(chalk.gray('\n💡 Tip: Use --delete to remove these directories and retry preparation'));
|
|
169
|
+
console.log(chalk.gray(' Example: pnpm tsx scripts/find-failed.ts --delete'));
|
|
170
|
+
console.log(
|
|
171
|
+
chalk.gray(' Filter: pnpm tsx scripts/find-failed.ts --dataset=longmemeval_s --config=working-memory'),
|
|
172
|
+
);
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
main().catch(console.error);
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
#!/usr/bin/env tsx
|
|
2
|
+
|
|
3
|
+
import { openai } from '@ai-sdk/openai';
|
|
4
|
+
import { writeFile, mkdir } from 'fs/promises';
|
|
5
|
+
import { join, dirname } from 'path';
|
|
6
|
+
import { fileURLToPath } from 'url';
|
|
7
|
+
|
|
8
|
+
const __dirname = dirname(fileURLToPath(import.meta.url));
|
|
9
|
+
|
|
10
|
+
// Sample texts to generate embeddings for
|
|
11
|
+
const SAMPLE_TEXTS = [
|
|
12
|
+
'My favorite color is blue',
|
|
13
|
+
'I understand your favorite color is blue.',
|
|
14
|
+
'I have a pet',
|
|
15
|
+
'What kind of pet do you have?',
|
|
16
|
+
'It is a cat named Fluffy',
|
|
17
|
+
'Fluffy is a lovely name for a cat!',
|
|
18
|
+
'Hello',
|
|
19
|
+
'Hi there!',
|
|
20
|
+
'What is my favorite color?',
|
|
21
|
+
'What did I say about my pet?',
|
|
22
|
+
'You have a cat named Fluffy',
|
|
23
|
+
'Blue',
|
|
24
|
+
];
|
|
25
|
+
|
|
26
|
+
async function generateEmbeddings() {
|
|
27
|
+
console.log('🔧 Generating fixture embeddings...\n');
|
|
28
|
+
|
|
29
|
+
const embedder = openai.embedding('text-embedding-3-small');
|
|
30
|
+
const embeddings: Record<string, number[]> = {};
|
|
31
|
+
|
|
32
|
+
for (const text of SAMPLE_TEXTS) {
|
|
33
|
+
console.log(`Generating embedding for: "${text}"`);
|
|
34
|
+
const result = await embedder.doEmbed({
|
|
35
|
+
values: [text],
|
|
36
|
+
});
|
|
37
|
+
embeddings[text] = result.embeddings[0];
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
// Save embeddings to fixtures directory
|
|
41
|
+
const fixturesDir = join(__dirname, '..', 'src', '__fixtures__');
|
|
42
|
+
await mkdir(fixturesDir, { recursive: true });
|
|
43
|
+
|
|
44
|
+
const outputPath = join(fixturesDir, 'embeddings.json');
|
|
45
|
+
await writeFile(outputPath, JSON.stringify(embeddings, null, 2));
|
|
46
|
+
|
|
47
|
+
console.log(`\n✅ Embeddings saved to: ${outputPath}`);
|
|
48
|
+
console.log(`Generated ${Object.keys(embeddings).length} embeddings`);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
// Run if called directly
|
|
52
|
+
if (import.meta.url === `file://${process.argv[1]}`) {
|
|
53
|
+
generateEmbeddings().catch(console.error);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
export { generateEmbeddings };
|
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
import { Agent } from '@mastra/core/agent';
|
|
2
|
+
import { google } from '@ai-sdk/google';
|
|
3
|
+
import chalk from 'chalk';
|
|
4
|
+
import ora, { Ora } from 'ora';
|
|
5
|
+
import { existsSync } from 'fs';
|
|
6
|
+
import { readFile, writeFile, mkdir } from 'fs/promises';
|
|
7
|
+
import { join } from 'path';
|
|
8
|
+
|
|
9
|
+
import { DatasetLoader } from '../src/data/loader';
|
|
10
|
+
import type { LongMemEvalQuestion } from '../src/data/types';
|
|
11
|
+
|
|
12
|
+
interface WorkingMemoryTemplate {
|
|
13
|
+
template: string;
|
|
14
|
+
generated_at: string;
|
|
15
|
+
question_type: string;
|
|
16
|
+
question: string;
|
|
17
|
+
answer: string;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
interface TemplateDatabase {
|
|
21
|
+
[questionId: string]: WorkingMemoryTemplate;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
async function generateTemplate(question: LongMemEvalQuestion): Promise<string> {
|
|
25
|
+
// Create a simple agent for template generation
|
|
26
|
+
const agent = new Agent({
|
|
27
|
+
name: 'template-generator',
|
|
28
|
+
instructions: `You are an expert at designing working memory templates for AI assistants.
|
|
29
|
+
|
|
30
|
+
Given a question and answer from a conversation history benchmark, generate a working memory instruction that would help an AI assistant extract and save the specific information needed to answer the question correctly.
|
|
31
|
+
|
|
32
|
+
The instruction should:
|
|
33
|
+
1. Be specific about what information to track
|
|
34
|
+
2. Use bullet points to organize different categories
|
|
35
|
+
3. Focus ONLY on information directly relevant to answering this specific question
|
|
36
|
+
4. Be concise but comprehensive
|
|
37
|
+
5. Do not be overly specific, the template should be generic enough to apply generally to the topic at hand, without revealing too much about the answer directly. Overly specific templates will invalidate the usefulness of the recorded information.
|
|
38
|
+
${!isNaN(Number(question.answer)) ? '6. A number should be stored counting the relevant data' : '6. If the question involves keeping track of the count or number of something, make that clear in the template'}
|
|
39
|
+
|
|
40
|
+
Format your response as a clear instruction starting with "Pay close attention to the following information (current and past):"
|
|
41
|
+
|
|
42
|
+
Then list the specific categories and details to track using bullet points.`,
|
|
43
|
+
model: google('gemini-2.5-flash-preview-04-17'),
|
|
44
|
+
});
|
|
45
|
+
|
|
46
|
+
const prompt = `Question Type: ${question.question_type}
|
|
47
|
+
Question: "${question.question}"
|
|
48
|
+
|
|
49
|
+
Generate a working memory instruction specifically tailored for capturing the information needed to answer this question.
|
|
50
|
+
If the question involves remembering a specific date or a specific location, make sure that's captured in the template.`;
|
|
51
|
+
|
|
52
|
+
const result = await agent.generate(prompt, {
|
|
53
|
+
temperature: 0,
|
|
54
|
+
});
|
|
55
|
+
|
|
56
|
+
const template = result.text.trim();
|
|
57
|
+
|
|
58
|
+
// Validate that we got a non-empty response
|
|
59
|
+
if (!template || template.length < 50) {
|
|
60
|
+
throw new Error(`Generated template is too short or empty for question ${question.question_id}`);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
return template;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
async function main() {
|
|
67
|
+
const args = process.argv.slice(2);
|
|
68
|
+
const dataset = args[0] || 'longmemeval_s';
|
|
69
|
+
const concurrency = parseInt(args[1]) || 100; // Default to 5 concurrent generations
|
|
70
|
+
const outputPath = join(process.cwd(), 'prepared-data', 'wm-templates', `${dataset}.json`);
|
|
71
|
+
|
|
72
|
+
console.log(chalk.blue('\n🧠 Generating Working Memory Templates\n'));
|
|
73
|
+
console.log(chalk.gray(`Dataset: ${dataset}`));
|
|
74
|
+
console.log(chalk.gray(`Concurrency: ${concurrency}`));
|
|
75
|
+
console.log(chalk.gray(`Output: ${outputPath}`));
|
|
76
|
+
|
|
77
|
+
// Set up signal handlers for graceful shutdown
|
|
78
|
+
let interrupted = false;
|
|
79
|
+
let currentSpinner: any = null;
|
|
80
|
+
let cleanupHandler: () => void;
|
|
81
|
+
|
|
82
|
+
const baseCleanup = () => {
|
|
83
|
+
interrupted = true;
|
|
84
|
+
if (currentSpinner) {
|
|
85
|
+
currentSpinner.stop();
|
|
86
|
+
}
|
|
87
|
+
console.log(chalk.yellow('\n\n⚠️ Interrupted! Progress has been saved.'));
|
|
88
|
+
console.log(chalk.gray(`Templates saved to: ${outputPath}`));
|
|
89
|
+
process.exit(0);
|
|
90
|
+
};
|
|
91
|
+
|
|
92
|
+
cleanupHandler = baseCleanup;
|
|
93
|
+
|
|
94
|
+
process.on('SIGINT', () => cleanupHandler());
|
|
95
|
+
process.on('SIGTERM', () => cleanupHandler());
|
|
96
|
+
|
|
97
|
+
// Check for OpenAI API key
|
|
98
|
+
if (!process.env.OPENAI_API_KEY) {
|
|
99
|
+
console.error(chalk.red('Error: OPENAI_API_KEY environment variable is required'));
|
|
100
|
+
process.exit(1);
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
// Load dataset
|
|
104
|
+
const loader = new DatasetLoader();
|
|
105
|
+
const spinner = ora('Loading dataset...').start();
|
|
106
|
+
currentSpinner = spinner;
|
|
107
|
+
const questions = await loader.loadDataset(dataset as any);
|
|
108
|
+
spinner.succeed(`Loaded ${questions.length} questions`);
|
|
109
|
+
currentSpinner = null;
|
|
110
|
+
|
|
111
|
+
// Load existing templates if they exist
|
|
112
|
+
let templates: TemplateDatabase = {};
|
|
113
|
+
if (existsSync(outputPath)) {
|
|
114
|
+
const loadSpinner = ora('Loading existing templates...').start();
|
|
115
|
+
currentSpinner = loadSpinner;
|
|
116
|
+
try {
|
|
117
|
+
templates = JSON.parse(await readFile(outputPath, 'utf-8'));
|
|
118
|
+
loadSpinner.succeed(`Loaded ${Object.keys(templates).length} existing templates`);
|
|
119
|
+
currentSpinner = null;
|
|
120
|
+
|
|
121
|
+
// Count empty templates
|
|
122
|
+
const emptyTemplates = Object.entries(templates).filter(([_, t]) => !t.template || t.template.length === 0);
|
|
123
|
+
if (emptyTemplates.length > 0) {
|
|
124
|
+
console.log(chalk.yellow(`⚠️ Found ${emptyTemplates.length} empty templates that will be regenerated`));
|
|
125
|
+
// Remove empty templates so they get regenerated
|
|
126
|
+
emptyTemplates.forEach(([id]) => delete templates[id]);
|
|
127
|
+
}
|
|
128
|
+
} catch (e) {
|
|
129
|
+
loadSpinner.warn('Could not load existing templates, starting fresh');
|
|
130
|
+
currentSpinner = null;
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
// Process questions
|
|
135
|
+
const questionsToProcess = questions.filter(q => !templates[q.question_id] || !templates[q.question_id].template);
|
|
136
|
+
|
|
137
|
+
if (questionsToProcess.length === 0) {
|
|
138
|
+
console.log(chalk.green('\n✅ All questions already have templates!'));
|
|
139
|
+
return;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
console.log(chalk.yellow(`\nGenerating templates for ${questionsToProcess.length} questions...\n`));
|
|
143
|
+
|
|
144
|
+
let processed = 0;
|
|
145
|
+
let errors = 0;
|
|
146
|
+
let inProgress = 0;
|
|
147
|
+
const questionQueue = [...questionsToProcess];
|
|
148
|
+
const activeGenerations = new Map<string, Ora>();
|
|
149
|
+
|
|
150
|
+
// Update cleanup to have access to activeGenerations
|
|
151
|
+
cleanupHandler = () => {
|
|
152
|
+
interrupted = true;
|
|
153
|
+
if (currentSpinner) {
|
|
154
|
+
currentSpinner.stop();
|
|
155
|
+
}
|
|
156
|
+
activeGenerations.forEach(spinner => spinner.stop());
|
|
157
|
+
console.log(chalk.yellow('\n\n⚠️ Interrupted! Progress has been saved.'));
|
|
158
|
+
console.log(chalk.gray(`Templates saved to: ${outputPath}`));
|
|
159
|
+
process.exit(0);
|
|
160
|
+
};
|
|
161
|
+
|
|
162
|
+
// Create directory once
|
|
163
|
+
await mkdir(join(process.cwd(), 'prepared-data', 'wm-templates'), { recursive: true });
|
|
164
|
+
|
|
165
|
+
// Main progress spinner
|
|
166
|
+
const mainSpinner = ora({
|
|
167
|
+
text: `Processing: 0/${questionsToProcess.length} (0 in progress)`,
|
|
168
|
+
spinner: 'dots',
|
|
169
|
+
}).start();
|
|
170
|
+
currentSpinner = mainSpinner;
|
|
171
|
+
|
|
172
|
+
const updateProgress = () => {
|
|
173
|
+
mainSpinner.text = `Processing: ${processed}/${questionsToProcess.length} (${inProgress} in progress, ${errors} failed)`;
|
|
174
|
+
};
|
|
175
|
+
|
|
176
|
+
// Worker function to process a single question
|
|
177
|
+
const processQuestion = async (question: LongMemEvalQuestion): Promise<void> => {
|
|
178
|
+
if (interrupted) return;
|
|
179
|
+
|
|
180
|
+
const questionSpinner = ora({
|
|
181
|
+
text: `${question.question_id}: Starting...`,
|
|
182
|
+
prefixText: ' ',
|
|
183
|
+
spinner: 'dots',
|
|
184
|
+
}).start();
|
|
185
|
+
|
|
186
|
+
activeGenerations.set(question.question_id, questionSpinner);
|
|
187
|
+
inProgress++;
|
|
188
|
+
updateProgress();
|
|
189
|
+
|
|
190
|
+
let attempts = 0;
|
|
191
|
+
const maxAttempts = 3;
|
|
192
|
+
let lastError = null;
|
|
193
|
+
|
|
194
|
+
while (attempts < maxAttempts && !interrupted) {
|
|
195
|
+
try {
|
|
196
|
+
attempts++;
|
|
197
|
+
if (attempts > 1) {
|
|
198
|
+
questionSpinner.text = `${question.question_id}: Retry ${attempts}/${maxAttempts}...`;
|
|
199
|
+
} else {
|
|
200
|
+
questionSpinner.text = `${question.question_id}: Generating...`;
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
const template = await generateTemplate(question);
|
|
204
|
+
|
|
205
|
+
templates[question.question_id] = {
|
|
206
|
+
template,
|
|
207
|
+
generated_at: new Date().toISOString(),
|
|
208
|
+
question_type: question.question_type,
|
|
209
|
+
question: question.question,
|
|
210
|
+
answer: question.answer,
|
|
211
|
+
};
|
|
212
|
+
|
|
213
|
+
// Save after each successful generation
|
|
214
|
+
await writeFile(outputPath, JSON.stringify(templates, null, 2));
|
|
215
|
+
|
|
216
|
+
questionSpinner.succeed(`${question.question_id} (${question.question_type})`);
|
|
217
|
+
activeGenerations.delete(question.question_id);
|
|
218
|
+
|
|
219
|
+
processed++;
|
|
220
|
+
inProgress--;
|
|
221
|
+
updateProgress();
|
|
222
|
+
break; // Success, exit retry loop
|
|
223
|
+
} catch (error) {
|
|
224
|
+
lastError = error;
|
|
225
|
+
if (attempts < maxAttempts && !interrupted) {
|
|
226
|
+
// Add a small delay before retry
|
|
227
|
+
await new Promise(resolve => setTimeout(resolve, 1000));
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
if ((attempts === maxAttempts && lastError) || interrupted) {
|
|
233
|
+
if (!interrupted) {
|
|
234
|
+
errors++;
|
|
235
|
+
questionSpinner.fail(`${question.question_id}: ${lastError}`);
|
|
236
|
+
} else {
|
|
237
|
+
questionSpinner.warn(`${question.question_id}: Interrupted`);
|
|
238
|
+
}
|
|
239
|
+
activeGenerations.delete(question.question_id);
|
|
240
|
+
inProgress--;
|
|
241
|
+
updateProgress();
|
|
242
|
+
}
|
|
243
|
+
};
|
|
244
|
+
|
|
245
|
+
// Process questions concurrently with a worker pool
|
|
246
|
+
const workers: Promise<void>[] = [];
|
|
247
|
+
|
|
248
|
+
while (questionQueue.length > 0 && !interrupted) {
|
|
249
|
+
// Fill up to concurrency limit
|
|
250
|
+
while (workers.length < concurrency && questionQueue.length > 0 && !interrupted) {
|
|
251
|
+
const question = questionQueue.shift()!;
|
|
252
|
+
const workerPromise = processQuestion(question).catch(err => {
|
|
253
|
+
console.error(chalk.red(`Unexpected error processing ${question.question_id}:`), err);
|
|
254
|
+
});
|
|
255
|
+
workers.push(workerPromise);
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
// Wait for at least one to complete
|
|
259
|
+
if (workers.length > 0) {
|
|
260
|
+
await Promise.race(workers);
|
|
261
|
+
// Remove completed workers
|
|
262
|
+
for (let i = workers.length - 1; i >= 0; i--) {
|
|
263
|
+
if ((await Promise.race([workers[i], Promise.resolve('pending')])) !== 'pending') {
|
|
264
|
+
workers.splice(i, 1);
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// Wait for remaining workers
|
|
271
|
+
if (!interrupted) {
|
|
272
|
+
await Promise.all(workers);
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
// Clean up spinners
|
|
276
|
+
activeGenerations.forEach(spinner => spinner.stop());
|
|
277
|
+
mainSpinner.succeed(`Completed: ${processed}/${questionsToProcess.length} (${errors} failed)`);
|
|
278
|
+
currentSpinner = null;
|
|
279
|
+
|
|
280
|
+
// Final summary
|
|
281
|
+
console.log(chalk.blue('\n📊 Summary'));
|
|
282
|
+
console.log(chalk.green(`✓ Successfully generated: ${processed} templates`));
|
|
283
|
+
if (errors > 0) {
|
|
284
|
+
console.log(chalk.red(`✗ Failed: ${errors} templates`));
|
|
285
|
+
}
|
|
286
|
+
console.log(chalk.gray(`Total templates: ${Object.keys(templates).length}`));
|
|
287
|
+
console.log(chalk.gray(`Saved to: ${outputPath}`));
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
main().catch(error => {
|
|
291
|
+
console.error(chalk.red('\nError:'), error.message);
|
|
292
|
+
console.log(chalk.gray('\nUsage: pnpm generate-wm-templates [dataset] [concurrency]'));
|
|
293
|
+
console.log(chalk.gray(' dataset: longmemeval_s (default)'));
|
|
294
|
+
console.log(chalk.gray(' concurrency: number of parallel generations (default: 5)'));
|
|
295
|
+
process.exit(1);
|
|
296
|
+
});
|
package/scripts/setup.ts
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
#!/usr/bin/env tsx
|
|
2
|
+
|
|
3
|
+
import { execSync } from 'child_process';
|
|
4
|
+
import { existsSync } from 'fs';
|
|
5
|
+
import { join } from 'path';
|
|
6
|
+
import chalk from 'chalk';
|
|
7
|
+
import ora from 'ora';
|
|
8
|
+
|
|
9
|
+
const DATA_DIR = join(process.cwd(), 'data');
|
|
10
|
+
const EXPECTED_FILES = ['longmemeval_s.json', 'longmemeval_m.json', 'longmemeval_oracle.json'];
|
|
11
|
+
|
|
12
|
+
async function setup() {
|
|
13
|
+
console.log(chalk.blue('\n🚀 LongMemEval Setup\n'));
|
|
14
|
+
|
|
15
|
+
// Check if already set up
|
|
16
|
+
const hasAllFiles = EXPECTED_FILES.every(file => existsSync(join(DATA_DIR, file)));
|
|
17
|
+
|
|
18
|
+
if (hasAllFiles) {
|
|
19
|
+
console.log(chalk.green('✓ All datasets are already downloaded'));
|
|
20
|
+
console.log(chalk.gray('\nYou can run the benchmark with:'));
|
|
21
|
+
console.log(chalk.cyan(' pnpm cli run --dataset longmemeval_s --model gpt-4o\n'));
|
|
22
|
+
return;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
// Install dependencies
|
|
26
|
+
const spinner = ora('Installing dependencies...').start();
|
|
27
|
+
try {
|
|
28
|
+
execSync('pnpm install', { stdio: 'ignore' });
|
|
29
|
+
spinner.succeed('Dependencies installed');
|
|
30
|
+
} catch (error) {
|
|
31
|
+
spinner.fail('Failed to install dependencies');
|
|
32
|
+
throw error;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
// Download datasets
|
|
36
|
+
console.log(chalk.blue('\n📥 Downloading datasets...\n'));
|
|
37
|
+
|
|
38
|
+
try {
|
|
39
|
+
execSync('pnpm download', { stdio: 'inherit' });
|
|
40
|
+
} catch (error) {
|
|
41
|
+
console.log(chalk.yellow('\n⚠️ Automatic download failed.'));
|
|
42
|
+
console.log(chalk.yellow('Please check the DOWNLOAD_GUIDE.md for manual download instructions.\n'));
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
// Verify setup
|
|
46
|
+
const filesAfterDownload = EXPECTED_FILES.filter(file => existsSync(join(DATA_DIR, file)));
|
|
47
|
+
|
|
48
|
+
if (filesAfterDownload.length === EXPECTED_FILES.length) {
|
|
49
|
+
console.log(chalk.green('\n✅ Setup complete!'));
|
|
50
|
+
console.log(chalk.gray('\nYou can now run the benchmark:'));
|
|
51
|
+
console.log(chalk.cyan(' pnpm cli run --dataset longmemeval_s --model gpt-4o'));
|
|
52
|
+
console.log(chalk.gray('\nOr view available commands:'));
|
|
53
|
+
console.log(chalk.cyan(' pnpm cli --help\n'));
|
|
54
|
+
} else {
|
|
55
|
+
console.log(chalk.yellow('\n⚠️ Setup incomplete. Please download the datasets manually.'));
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
// Run setup
|
|
60
|
+
setup().catch(console.error);
|