@mastra/longmemeval 0.0.0-add-libsql-changeset-20250910154739
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +919 -0
- package/DATA_DOWNLOAD_GUIDE.md +117 -0
- package/LICENSE.md +15 -0
- package/README.md +173 -0
- package/USAGE.md +105 -0
- package/package.json +67 -0
- package/scripts/download.ts +180 -0
- package/scripts/find-failed.ts +176 -0
- package/scripts/generate-embeddings.ts +56 -0
- package/scripts/generate-wm-templates.ts +296 -0
- package/scripts/setup.ts +60 -0
- package/src/__fixtures__/embeddings.json +2319 -0
- package/src/__fixtures__/test-dataset.json +82 -0
- package/src/cli.ts +690 -0
- package/src/commands/__tests__/prepare.test.ts +230 -0
- package/src/commands/__tests__/run.test.ts +403 -0
- package/src/commands/prepare.ts +793 -0
- package/src/commands/run.ts +553 -0
- package/src/config.ts +83 -0
- package/src/data/loader.ts +163 -0
- package/src/data/types.ts +61 -0
- package/src/embeddings/cached-openai-embedding-model.ts +227 -0
- package/src/embeddings/cached-openai-provider.ts +40 -0
- package/src/embeddings/index.ts +2 -0
- package/src/evaluation/__tests__/longmemeval-metric.test.ts +169 -0
- package/src/evaluation/longmemeval-metric.ts +173 -0
- package/src/retry-model.ts +60 -0
- package/src/storage/__tests__/benchmark-store.test.ts +280 -0
- package/src/storage/__tests__/benchmark-vector.test.ts +214 -0
- package/src/storage/benchmark-store.ts +540 -0
- package/src/storage/benchmark-vector.ts +234 -0
- package/src/storage/index.ts +2 -0
- package/src/test-utils/mock-embeddings.ts +54 -0
- package/src/test-utils/mock-model.ts +49 -0
- package/tests/data-loader.test.ts +96 -0
- package/tsconfig.json +18 -0
- package/vitest.config.ts +9 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import { readFile } from 'fs/promises';
|
|
2
|
+
import { join } from 'path';
|
|
3
|
+
import type { LongMemEvalQuestion } from './types';
|
|
4
|
+
|
|
5
|
+
export class DatasetLoader {
|
|
6
|
+
private dataDir: string;
|
|
7
|
+
|
|
8
|
+
constructor(dataDir?: string) {
|
|
9
|
+
// Default to data directory relative to where the command is run
|
|
10
|
+
this.dataDir = dataDir || join(process.cwd(), 'data');
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* Load a LongMemEval dataset from JSON file
|
|
15
|
+
*/
|
|
16
|
+
async loadDataset(
|
|
17
|
+
dataset: 'longmemeval_s' | 'longmemeval_m' | 'longmemeval_oracle' | 'sample_data',
|
|
18
|
+
): Promise<LongMemEvalQuestion[]> {
|
|
19
|
+
const filePath = join(this.dataDir, `${dataset}.json`);
|
|
20
|
+
|
|
21
|
+
try {
|
|
22
|
+
const fileContent = await readFile(filePath, 'utf-8');
|
|
23
|
+
const data = JSON.parse(fileContent) as LongMemEvalQuestion[];
|
|
24
|
+
|
|
25
|
+
// Validate the data structure
|
|
26
|
+
this.validateDataset(data);
|
|
27
|
+
|
|
28
|
+
return data;
|
|
29
|
+
} catch (error) {
|
|
30
|
+
if ((error as any).code === 'ENOENT') {
|
|
31
|
+
throw new Error(
|
|
32
|
+
`Dataset file not found: ${filePath}\n` +
|
|
33
|
+
`Please download the LongMemEval dataset from https://drive.google.com/file/d/1zJgtYRFhOh5zDQzzatiddfjYhFSnyQ80/view ` +
|
|
34
|
+
`and extract it to ${this.dataDir}`,
|
|
35
|
+
);
|
|
36
|
+
}
|
|
37
|
+
throw error;
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
/**
|
|
42
|
+
* Load a subset of questions for testing
|
|
43
|
+
*/
|
|
44
|
+
async loadSubset(
|
|
45
|
+
dataset: 'longmemeval_s' | 'longmemeval_m' | 'longmemeval_oracle' | 'sample_data',
|
|
46
|
+
limit: number,
|
|
47
|
+
): Promise<LongMemEvalQuestion[]> {
|
|
48
|
+
const fullDataset = await this.loadDataset(dataset);
|
|
49
|
+
return fullDataset.slice(0, limit);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
/**
|
|
53
|
+
* Load questions of a specific type
|
|
54
|
+
*/
|
|
55
|
+
async loadByType(
|
|
56
|
+
dataset: 'longmemeval_s' | 'longmemeval_m' | 'longmemeval_oracle' | 'sample_data',
|
|
57
|
+
questionType: string,
|
|
58
|
+
): Promise<LongMemEvalQuestion[]> {
|
|
59
|
+
const fullDataset = await this.loadDataset(dataset);
|
|
60
|
+
return fullDataset.filter(q => q.question_type === questionType);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
/**
|
|
64
|
+
* Get dataset statistics
|
|
65
|
+
*/
|
|
66
|
+
async getDatasetStats(dataset: 'longmemeval_s' | 'longmemeval_m' | 'longmemeval_oracle' | 'sample_data') {
|
|
67
|
+
const data = await this.loadDataset(dataset);
|
|
68
|
+
|
|
69
|
+
const stats = {
|
|
70
|
+
totalQuestions: data.length,
|
|
71
|
+
questionsByType: {} as Record<string, number>,
|
|
72
|
+
abstentionQuestions: 0,
|
|
73
|
+
avgSessionsPerQuestion: 0,
|
|
74
|
+
avgTurnsPerSession: 0,
|
|
75
|
+
totalTokensEstimate: 0,
|
|
76
|
+
};
|
|
77
|
+
|
|
78
|
+
// Count questions by type
|
|
79
|
+
for (const question of data) {
|
|
80
|
+
const type = question.question_type;
|
|
81
|
+
stats.questionsByType[type] = (stats.questionsByType[type] || 0) + 1;
|
|
82
|
+
|
|
83
|
+
if (question.question_id.endsWith('_abs')) {
|
|
84
|
+
stats.abstentionQuestions++;
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// Calculate average sessions and turns
|
|
89
|
+
const totalSessions = data.reduce((sum, q) => sum + q.haystack_sessions.length, 0);
|
|
90
|
+
stats.avgSessionsPerQuestion = totalSessions / data.length;
|
|
91
|
+
|
|
92
|
+
let totalTurns = 0;
|
|
93
|
+
for (const question of data) {
|
|
94
|
+
for (const session of question.haystack_sessions) {
|
|
95
|
+
totalTurns += session.length;
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
stats.avgTurnsPerSession = totalTurns / totalSessions;
|
|
99
|
+
|
|
100
|
+
// Rough token estimate (assuming ~4 chars per token)
|
|
101
|
+
for (const question of data) {
|
|
102
|
+
for (const session of question.haystack_sessions) {
|
|
103
|
+
for (const turn of session) {
|
|
104
|
+
stats.totalTokensEstimate += Math.ceil(turn.content.length / 4);
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
return stats;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
/**
|
|
113
|
+
* Validate dataset structure
|
|
114
|
+
*/
|
|
115
|
+
private validateDataset(data: any[]): void {
|
|
116
|
+
if (!Array.isArray(data)) {
|
|
117
|
+
throw new Error('Dataset must be an array of questions');
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
if (data.length === 0) {
|
|
121
|
+
throw new Error('Dataset is empty');
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
// Validate first question structure as sample
|
|
125
|
+
const sample = data[0];
|
|
126
|
+
const requiredFields = [
|
|
127
|
+
'question_id',
|
|
128
|
+
'question_type',
|
|
129
|
+
'question',
|
|
130
|
+
'answer',
|
|
131
|
+
'question_date',
|
|
132
|
+
'haystack_session_ids',
|
|
133
|
+
'haystack_dates',
|
|
134
|
+
'haystack_sessions',
|
|
135
|
+
'answer_session_ids',
|
|
136
|
+
];
|
|
137
|
+
|
|
138
|
+
for (const field of requiredFields) {
|
|
139
|
+
if (!(field in sample)) {
|
|
140
|
+
throw new Error(`Missing required field: ${field}`);
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
// Validate haystack_sessions structure
|
|
145
|
+
if (!Array.isArray(sample.haystack_sessions)) {
|
|
146
|
+
throw new Error('haystack_sessions must be an array');
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
if (sample.haystack_sessions.length > 0) {
|
|
150
|
+
const firstSession = sample.haystack_sessions[0];
|
|
151
|
+
if (!Array.isArray(firstSession)) {
|
|
152
|
+
throw new Error('Each session must be an array of turns');
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
if (firstSession.length > 0) {
|
|
156
|
+
const firstTurn = firstSession[0];
|
|
157
|
+
if (!firstTurn.role || !firstTurn.content) {
|
|
158
|
+
throw new Error('Each turn must have role and content fields');
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
}
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import { MemoryConfig } from '@mastra/core/memory';
|
|
2
|
+
|
|
3
|
+
export type QuestionType =
|
|
4
|
+
| 'single-session-user'
|
|
5
|
+
| 'single-session-assistant'
|
|
6
|
+
| 'single-session-preference'
|
|
7
|
+
| 'temporal-reasoning'
|
|
8
|
+
| 'knowledge-update'
|
|
9
|
+
| 'multi-session';
|
|
10
|
+
|
|
11
|
+
export interface Turn {
|
|
12
|
+
role: 'user' | 'assistant';
|
|
13
|
+
content: string;
|
|
14
|
+
has_answer?: boolean;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
export interface LongMemEvalQuestion {
|
|
18
|
+
question_id: string;
|
|
19
|
+
question_type: QuestionType;
|
|
20
|
+
question: string;
|
|
21
|
+
answer: string;
|
|
22
|
+
question_date: string;
|
|
23
|
+
haystack_session_ids: string[];
|
|
24
|
+
haystack_dates: string[];
|
|
25
|
+
haystack_sessions: Turn[][];
|
|
26
|
+
answer_session_ids: string[];
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
export interface EvaluationResult {
|
|
30
|
+
question_id: string;
|
|
31
|
+
hypothesis: string;
|
|
32
|
+
autoeval_label?: boolean;
|
|
33
|
+
question_type?: QuestionType;
|
|
34
|
+
is_correct?: boolean;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
export type DatasetType = 'longmemeval_s' | 'longmemeval_m' | 'longmemeval_oracle';
|
|
38
|
+
|
|
39
|
+
export type MemoryConfigType =
|
|
40
|
+
| 'semantic-recall'
|
|
41
|
+
| 'working-memory'
|
|
42
|
+
| 'working-memory-tailored'
|
|
43
|
+
| 'combined'
|
|
44
|
+
| 'combined-tailored';
|
|
45
|
+
|
|
46
|
+
export interface MemoryConfigOptions {
|
|
47
|
+
type: MemoryConfigType;
|
|
48
|
+
options: MemoryConfig;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
export interface BenchmarkMetrics {
|
|
52
|
+
overall_accuracy: number;
|
|
53
|
+
accuracy_by_type: Partial<Record<QuestionType, { correct: number; total: number; accuracy: number }>>;
|
|
54
|
+
abstention_accuracy: number;
|
|
55
|
+
session_recall_accuracy?: number;
|
|
56
|
+
turn_recall_accuracy?: number;
|
|
57
|
+
total_questions: number;
|
|
58
|
+
correct_answers: number;
|
|
59
|
+
abstention_correct?: number;
|
|
60
|
+
abstention_total?: number;
|
|
61
|
+
}
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import { EmbeddingModelV2, TooManyEmbeddingValuesForCallError } from '@ai-sdk/provider';
|
|
2
|
+
import { xxh3 } from '@node-rs/xxhash';
|
|
3
|
+
import { existsSync, mkdirSync, readFileSync, writeFileSync } from 'fs';
|
|
4
|
+
import { join } from 'path';
|
|
5
|
+
import { OpenAIEmbedding } from '@ai-sdk/openai';
|
|
6
|
+
import { Mutex } from 'async-mutex';
|
|
7
|
+
|
|
8
|
+
// Global cache statistics
|
|
9
|
+
export const embeddingCacheStats = {
|
|
10
|
+
cacheHits: 0,
|
|
11
|
+
cacheMisses: 0,
|
|
12
|
+
cacheWrites: 0,
|
|
13
|
+
reset() {
|
|
14
|
+
this.cacheHits = 0;
|
|
15
|
+
this.cacheMisses = 0;
|
|
16
|
+
this.cacheWrites = 0;
|
|
17
|
+
},
|
|
18
|
+
};
|
|
19
|
+
|
|
20
|
+
export class CachedOpenAIEmbeddingModel implements EmbeddingModelV2<string> {
|
|
21
|
+
readonly specificationVersion = 'v2';
|
|
22
|
+
readonly modelId: string;
|
|
23
|
+
readonly maxEmbeddingsPerCall = 2048;
|
|
24
|
+
readonly supportsParallelCalls = true;
|
|
25
|
+
|
|
26
|
+
private readonly cacheDir: string;
|
|
27
|
+
private readonly delegate: EmbeddingModelV2<string>;
|
|
28
|
+
private memoryCache: Map<string, number[]> = new Map();
|
|
29
|
+
private readonly fileOperationMutex = new Mutex();
|
|
30
|
+
|
|
31
|
+
get provider(): string {
|
|
32
|
+
return this.delegate.provider;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
constructor(delegate: EmbeddingModelV2<string>, cacheDir: string = './embedding-cache') {
|
|
36
|
+
this.delegate = delegate;
|
|
37
|
+
this.modelId = delegate.modelId;
|
|
38
|
+
this.cacheDir = cacheDir;
|
|
39
|
+
|
|
40
|
+
// Ensure cache directory exists
|
|
41
|
+
if (!existsSync(this.cacheDir)) {
|
|
42
|
+
mkdirSync(this.cacheDir, { recursive: true });
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
// Load existing cache into memory
|
|
46
|
+
this.loadMemoryCache();
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
private getCacheKey(value: string): string {
|
|
50
|
+
// Use XXHash3 for ultra-fast hashing
|
|
51
|
+
const combined = `${this.modelId}:${value}`;
|
|
52
|
+
const hash = xxh3.xxh128(combined).toString(16).padStart(32, '0');
|
|
53
|
+
return hash;
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
private getCachePath(key: string): string {
|
|
57
|
+
// Split cache files into subdirectories to avoid too many files in one directory
|
|
58
|
+
const subdir = key.substring(0, 2);
|
|
59
|
+
const dir = join(this.cacheDir, subdir);
|
|
60
|
+
if (!existsSync(dir)) {
|
|
61
|
+
mkdirSync(dir, { recursive: true });
|
|
62
|
+
}
|
|
63
|
+
return join(dir, `${key}.json`);
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
private loadMemoryCache(): void {
|
|
67
|
+
// This could be optimized to load lazily or with a size limit
|
|
68
|
+
// console.log('Loading embedding cache into memory...');
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
private async getCachedEmbedding(value: string): Promise<number[] | null> {
|
|
72
|
+
const key = this.getCacheKey(value);
|
|
73
|
+
|
|
74
|
+
// Check memory cache first
|
|
75
|
+
if (this.memoryCache.has(key)) {
|
|
76
|
+
embeddingCacheStats.cacheHits++;
|
|
77
|
+
return this.memoryCache.get(key)!;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
// Check file cache with mutex
|
|
81
|
+
const cachePath = this.getCachePath(key);
|
|
82
|
+
const embedding = await this.fileOperationMutex.runExclusive(async () => {
|
|
83
|
+
// Double-check memory cache in case another thread loaded it
|
|
84
|
+
if (this.memoryCache.has(key)) {
|
|
85
|
+
return this.memoryCache.get(key)!;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
if (existsSync(cachePath)) {
|
|
89
|
+
try {
|
|
90
|
+
const content = readFileSync(cachePath, 'utf-8');
|
|
91
|
+
const cached = JSON.parse(content);
|
|
92
|
+
return cached.embedding;
|
|
93
|
+
} catch (e) {
|
|
94
|
+
// If JSON is corrupted, delete the file
|
|
95
|
+
console.warn(`Corrupted cache file ${cachePath}, deleting...`);
|
|
96
|
+
try {
|
|
97
|
+
require('fs').unlinkSync(cachePath);
|
|
98
|
+
} catch (deleteError) {
|
|
99
|
+
// Ignore delete errors
|
|
100
|
+
}
|
|
101
|
+
return null;
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
return null;
|
|
105
|
+
});
|
|
106
|
+
|
|
107
|
+
if (embedding) {
|
|
108
|
+
this.memoryCache.set(key, embedding);
|
|
109
|
+
embeddingCacheStats.cacheHits++;
|
|
110
|
+
return embedding;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
embeddingCacheStats.cacheMisses++;
|
|
114
|
+
return null;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
private cacheEmbedding(value: string, embedding: number[]): void {
|
|
118
|
+
const key = this.getCacheKey(value);
|
|
119
|
+
|
|
120
|
+
// Store in memory cache immediately
|
|
121
|
+
this.memoryCache.set(key, embedding);
|
|
122
|
+
embeddingCacheStats.cacheWrites++;
|
|
123
|
+
|
|
124
|
+
// Store in file cache asynchronously with mutex (fire and forget)
|
|
125
|
+
this.fileOperationMutex
|
|
126
|
+
.runExclusive(async () => {
|
|
127
|
+
const cachePath = this.getCachePath(key);
|
|
128
|
+
try {
|
|
129
|
+
const data = JSON.stringify({
|
|
130
|
+
value: value,
|
|
131
|
+
embedding: embedding,
|
|
132
|
+
modelId: this.modelId,
|
|
133
|
+
timestamp: new Date().toISOString(),
|
|
134
|
+
});
|
|
135
|
+
|
|
136
|
+
// Write to temp file first, then rename (atomic operation)
|
|
137
|
+
const tempPath = `${cachePath}.tmp`;
|
|
138
|
+
writeFileSync(tempPath, data);
|
|
139
|
+
require('fs').renameSync(tempPath, cachePath);
|
|
140
|
+
} catch (e) {
|
|
141
|
+
// console.warn(`Failed to cache embedding for ${key}:`, e);
|
|
142
|
+
// Clean up temp file if it exists
|
|
143
|
+
try {
|
|
144
|
+
require('fs').unlinkSync(`${cachePath}.tmp`);
|
|
145
|
+
} catch (cleanupError) {
|
|
146
|
+
// Ignore cleanup errors
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
})
|
|
150
|
+
.catch(() => {
|
|
151
|
+
// Ignore write errors
|
|
152
|
+
});
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
async doEmbed({
|
|
156
|
+
values,
|
|
157
|
+
headers,
|
|
158
|
+
abortSignal,
|
|
159
|
+
providerOptions,
|
|
160
|
+
}: Parameters<EmbeddingModelV2<string>['doEmbed']>[0]): Promise<
|
|
161
|
+
Awaited<ReturnType<EmbeddingModelV2<string>['doEmbed']>>
|
|
162
|
+
> {
|
|
163
|
+
if (values.length > this.maxEmbeddingsPerCall) {
|
|
164
|
+
throw new TooManyEmbeddingValuesForCallError({
|
|
165
|
+
provider: this.provider,
|
|
166
|
+
modelId: this.modelId,
|
|
167
|
+
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
|
|
168
|
+
values,
|
|
169
|
+
});
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
const embeddings: number[][] = [];
|
|
173
|
+
const uncachedValues: string[] = [];
|
|
174
|
+
const uncachedIndices: number[] = [];
|
|
175
|
+
|
|
176
|
+
// Check cache for each value
|
|
177
|
+
for (let i = 0; i < values.length; i++) {
|
|
178
|
+
const cached = await this.getCachedEmbedding(values[i]);
|
|
179
|
+
if (cached) {
|
|
180
|
+
embeddings[i] = cached;
|
|
181
|
+
} else {
|
|
182
|
+
uncachedValues.push(values[i]);
|
|
183
|
+
uncachedIndices.push(i);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
let usage = { tokens: 0 };
|
|
188
|
+
let responseHeaders: Record<string, string> = {};
|
|
189
|
+
let rawValue: any = {};
|
|
190
|
+
|
|
191
|
+
// If we have uncached values, fetch them from the API
|
|
192
|
+
if (uncachedValues.length > 0) {
|
|
193
|
+
// console.log(`Fetching ${uncachedValues.length} uncached embeddings (${values.length - uncachedValues.length} cached)`);
|
|
194
|
+
|
|
195
|
+
const result = await this.delegate.doEmbed({
|
|
196
|
+
values: uncachedValues,
|
|
197
|
+
headers,
|
|
198
|
+
abortSignal,
|
|
199
|
+
providerOptions,
|
|
200
|
+
});
|
|
201
|
+
|
|
202
|
+
// Cache the new embeddings and add them to our results
|
|
203
|
+
for (let i = 0; i < uncachedValues.length; i++) {
|
|
204
|
+
const value = uncachedValues[i];
|
|
205
|
+
const embedding = result.embeddings[i];
|
|
206
|
+
const originalIndex = uncachedIndices[i];
|
|
207
|
+
|
|
208
|
+
this.cacheEmbedding(value, embedding);
|
|
209
|
+
embeddings[originalIndex] = embedding;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
usage = result.usage || { tokens: 0 };
|
|
213
|
+
responseHeaders = result.response?.headers || {};
|
|
214
|
+
rawValue = result.response?.body || {};
|
|
215
|
+
} else {
|
|
216
|
+
// console.log(`All ${values.length} embeddings served from cache`);
|
|
217
|
+
// Yield to prevent blocking when everything is cached
|
|
218
|
+
await new Promise(resolve => setImmediate(resolve));
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
return {
|
|
222
|
+
embeddings,
|
|
223
|
+
usage,
|
|
224
|
+
response: { headers: responseHeaders, body: rawValue },
|
|
225
|
+
};
|
|
226
|
+
}
|
|
227
|
+
}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import { openai as originalOpenAI, createOpenAI } from '@ai-sdk/openai';
|
|
2
|
+
import { OpenAIProvider } from '@ai-sdk/openai';
|
|
3
|
+
import { CachedOpenAIEmbeddingModel } from './cached-openai-embedding-model';
|
|
4
|
+
import { join } from 'path';
|
|
5
|
+
|
|
6
|
+
export interface CachedOpenAIOptions {
|
|
7
|
+
apiKey?: string;
|
|
8
|
+
cacheDir?: string;
|
|
9
|
+
baseURL?: string;
|
|
10
|
+
headers?: Record<string, string>;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
export function createCachedOpenAI(options: CachedOpenAIOptions = {}) {
|
|
14
|
+
// Create the original OpenAI provider
|
|
15
|
+
const provider = createOpenAI({
|
|
16
|
+
apiKey: options.apiKey,
|
|
17
|
+
baseURL: options.baseURL,
|
|
18
|
+
headers: options.headers,
|
|
19
|
+
});
|
|
20
|
+
|
|
21
|
+
// Create a proxy that intercepts embedding model creation
|
|
22
|
+
return new Proxy(provider, {
|
|
23
|
+
get(target, prop, receiver) {
|
|
24
|
+
if (prop === 'embedding') {
|
|
25
|
+
// Return a function that creates cached embedding models
|
|
26
|
+
return (modelId: string) => {
|
|
27
|
+
const originalModel = target.embedding(modelId);
|
|
28
|
+
const cacheDir = options.cacheDir || join(process.cwd(), '.embedding-cache', modelId);
|
|
29
|
+
return new CachedOpenAIEmbeddingModel(originalModel, cacheDir);
|
|
30
|
+
};
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
// For all other properties, use the original
|
|
34
|
+
return Reflect.get(target, prop, receiver);
|
|
35
|
+
},
|
|
36
|
+
});
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
// Export a default cached OpenAI instance
|
|
40
|
+
export const cachedOpenAI = createCachedOpenAI();
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
|
2
|
+
import { LongMemEvalMetric, createLongMemEvalMetric } from '../longmemeval-metric';
|
|
3
|
+
import { Agent } from '@mastra/core/agent';
|
|
4
|
+
|
|
5
|
+
// Mock Agent
|
|
6
|
+
const mockAgent = {
|
|
7
|
+
generate: vi.fn().mockImplementation(async messages => {
|
|
8
|
+
const content = messages[0].content;
|
|
9
|
+
|
|
10
|
+
// Check if it's asking about correct response
|
|
11
|
+
if (content.includes('Is the model response correct?')) {
|
|
12
|
+
// If the model response contains the correct answer, return yes
|
|
13
|
+
if (content.includes('Model Response: Blue') && content.includes('Correct Answer: Blue')) {
|
|
14
|
+
return {
|
|
15
|
+
text: 'yes',
|
|
16
|
+
};
|
|
17
|
+
}
|
|
18
|
+
// If the response doesn't match, return no with reason
|
|
19
|
+
return {
|
|
20
|
+
text: 'no: the model did not provide the correct answer',
|
|
21
|
+
};
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
// For abstention questions
|
|
25
|
+
if (content.includes('Does the model correctly identify the question as unanswerable?')) {
|
|
26
|
+
if (content.includes('cannot answer') || content.includes("don't have that information")) {
|
|
27
|
+
return {
|
|
28
|
+
text: 'yes',
|
|
29
|
+
};
|
|
30
|
+
}
|
|
31
|
+
return {
|
|
32
|
+
text: 'no: the model attempted to answer an unanswerable question',
|
|
33
|
+
};
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
// Default response
|
|
37
|
+
return {
|
|
38
|
+
text: 'yes',
|
|
39
|
+
};
|
|
40
|
+
}),
|
|
41
|
+
} as unknown as Agent;
|
|
42
|
+
|
|
43
|
+
describe('LongMemEvalMetric', () => {
|
|
44
|
+
describe('measure', () => {
|
|
45
|
+
it('should return score 1 for correct answer', async () => {
|
|
46
|
+
const metric = new LongMemEvalMetric({
|
|
47
|
+
agent: mockAgent,
|
|
48
|
+
questionType: 'single-session-user',
|
|
49
|
+
});
|
|
50
|
+
|
|
51
|
+
const input = JSON.stringify({
|
|
52
|
+
question: 'What is my favorite color?',
|
|
53
|
+
answer: 'Blue',
|
|
54
|
+
});
|
|
55
|
+
const output = 'Blue';
|
|
56
|
+
|
|
57
|
+
const result = await metric.measure(input, output);
|
|
58
|
+
|
|
59
|
+
expect(result.score).toBe(1);
|
|
60
|
+
expect(result.info?.questionType).toBe('single-session-user');
|
|
61
|
+
expect(result.info?.evaluatorResponse).toBe('yes');
|
|
62
|
+
});
|
|
63
|
+
|
|
64
|
+
it('should return score 0 for incorrect answer', async () => {
|
|
65
|
+
const metric = new LongMemEvalMetric({
|
|
66
|
+
agent: mockAgent,
|
|
67
|
+
questionType: 'single-session-user',
|
|
68
|
+
});
|
|
69
|
+
|
|
70
|
+
const input = JSON.stringify({
|
|
71
|
+
question: 'What is my favorite color?',
|
|
72
|
+
answer: 'Blue',
|
|
73
|
+
});
|
|
74
|
+
const output = 'Red';
|
|
75
|
+
|
|
76
|
+
const result = await metric.measure(input, output);
|
|
77
|
+
|
|
78
|
+
expect(result.score).toBe(0);
|
|
79
|
+
expect(result.info?.reason).toBe('the model did not provide the correct answer');
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
it('should handle abstention questions correctly', async () => {
|
|
83
|
+
const metric = new LongMemEvalMetric({
|
|
84
|
+
agent: mockAgent,
|
|
85
|
+
questionType: 'single-session-user',
|
|
86
|
+
isAbstention: true,
|
|
87
|
+
});
|
|
88
|
+
|
|
89
|
+
const input = JSON.stringify({
|
|
90
|
+
question: 'What is my favorite food?',
|
|
91
|
+
answer: 'This question cannot be answered based on the conversation history',
|
|
92
|
+
});
|
|
93
|
+
const output = 'I cannot answer that question based on our conversation history.';
|
|
94
|
+
|
|
95
|
+
const result = await metric.measure(input, output);
|
|
96
|
+
|
|
97
|
+
expect(result.score).toBe(1);
|
|
98
|
+
expect(result.info?.isAbstention).toBe(true);
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
it('should handle different question types', async () => {
|
|
102
|
+
const temporalMetric = createLongMemEvalMetric('temporal-reasoning', mockAgent);
|
|
103
|
+
const knowledgeMetric = createLongMemEvalMetric('knowledge-update', mockAgent);
|
|
104
|
+
const preferenceMetric = createLongMemEvalMetric('single-session-preference', mockAgent);
|
|
105
|
+
|
|
106
|
+
// All should be instances of LongMemEvalMetric
|
|
107
|
+
expect(temporalMetric).toBeInstanceOf(LongMemEvalMetric);
|
|
108
|
+
expect(knowledgeMetric).toBeInstanceOf(LongMemEvalMetric);
|
|
109
|
+
expect(preferenceMetric).toBeInstanceOf(LongMemEvalMetric);
|
|
110
|
+
});
|
|
111
|
+
|
|
112
|
+
it('should throw error for unknown question type', async () => {
|
|
113
|
+
expect(() => {
|
|
114
|
+
new LongMemEvalMetric({
|
|
115
|
+
agent: mockAgent,
|
|
116
|
+
questionType: 'invalid-type' as any,
|
|
117
|
+
});
|
|
118
|
+
}).not.toThrow(); // Constructor doesn't validate
|
|
119
|
+
|
|
120
|
+
// The error would be thrown during measure when getting the prompt
|
|
121
|
+
const metric = new LongMemEvalMetric({
|
|
122
|
+
agent: mockAgent,
|
|
123
|
+
questionType: 'invalid-type' as any,
|
|
124
|
+
});
|
|
125
|
+
|
|
126
|
+
const input = JSON.stringify({
|
|
127
|
+
question: 'Test question',
|
|
128
|
+
answer: 'Test answer',
|
|
129
|
+
});
|
|
130
|
+
|
|
131
|
+
await expect(metric.measure(input, 'Test output')).rejects.toThrow('Unknown question type: invalid-type');
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
it('should parse evaluator response correctly', async () => {
|
|
135
|
+
const metric = new LongMemEvalMetric({
|
|
136
|
+
agent: mockAgent,
|
|
137
|
+
questionType: 'single-session-user',
|
|
138
|
+
});
|
|
139
|
+
|
|
140
|
+
const input = JSON.stringify({
|
|
141
|
+
question: 'What is my name?',
|
|
142
|
+
answer: 'John',
|
|
143
|
+
});
|
|
144
|
+
const output = "I don't know your name";
|
|
145
|
+
|
|
146
|
+
const result = await metric.measure(input, output);
|
|
147
|
+
|
|
148
|
+
expect(result.score).toBe(0);
|
|
149
|
+
expect(result.info?.evaluatorResponse).toContain('no');
|
|
150
|
+
expect(result.info?.reason).toBeTruthy();
|
|
151
|
+
});
|
|
152
|
+
});
|
|
153
|
+
|
|
154
|
+
describe('createLongMemEvalMetric', () => {
|
|
155
|
+
it('should create metric with correct configuration', () => {
|
|
156
|
+
const metric = createLongMemEvalMetric('multi-session', mockAgent);
|
|
157
|
+
|
|
158
|
+
expect(metric).toBeInstanceOf(LongMemEvalMetric);
|
|
159
|
+
});
|
|
160
|
+
|
|
161
|
+
it('should throw error when agent is not provided', () => {
|
|
162
|
+
expect(() => {
|
|
163
|
+
new LongMemEvalMetric({
|
|
164
|
+
questionType: 'single-session-user',
|
|
165
|
+
} as any);
|
|
166
|
+
}).toThrow('Agent instance is required for LongMemEvalMetric');
|
|
167
|
+
});
|
|
168
|
+
});
|
|
169
|
+
});
|