@mastra/rag 1.2.3-alpha.0 → 1.2.3-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +9 -0
- package/package.json +18 -5
- package/.turbo/turbo-build.log +0 -4
- package/docker-compose.yaml +0 -22
- package/eslint.config.js +0 -6
- package/src/document/document.test.ts +0 -2975
- package/src/document/document.ts +0 -335
- package/src/document/extractors/base.ts +0 -30
- package/src/document/extractors/index.ts +0 -5
- package/src/document/extractors/keywords.test.ts +0 -125
- package/src/document/extractors/keywords.ts +0 -126
- package/src/document/extractors/questions.test.ts +0 -120
- package/src/document/extractors/questions.ts +0 -111
- package/src/document/extractors/summary.test.ts +0 -107
- package/src/document/extractors/summary.ts +0 -122
- package/src/document/extractors/title.test.ts +0 -121
- package/src/document/extractors/title.ts +0 -185
- package/src/document/extractors/types.ts +0 -40
- package/src/document/index.ts +0 -2
- package/src/document/prompts/base.ts +0 -77
- package/src/document/prompts/format.ts +0 -9
- package/src/document/prompts/index.ts +0 -15
- package/src/document/prompts/prompt.ts +0 -60
- package/src/document/prompts/types.ts +0 -29
- package/src/document/schema/index.ts +0 -3
- package/src/document/schema/node.ts +0 -187
- package/src/document/schema/types.ts +0 -40
- package/src/document/transformers/character.ts +0 -267
- package/src/document/transformers/html.ts +0 -346
- package/src/document/transformers/json.ts +0 -536
- package/src/document/transformers/latex.ts +0 -11
- package/src/document/transformers/markdown.ts +0 -239
- package/src/document/transformers/semantic-markdown.ts +0 -227
- package/src/document/transformers/sentence.ts +0 -314
- package/src/document/transformers/text.ts +0 -158
- package/src/document/transformers/token.ts +0 -137
- package/src/document/transformers/transformer.ts +0 -5
- package/src/document/types.ts +0 -145
- package/src/document/validation.ts +0 -158
- package/src/graph-rag/index.test.ts +0 -235
- package/src/graph-rag/index.ts +0 -306
- package/src/index.ts +0 -8
- package/src/rerank/index.test.ts +0 -150
- package/src/rerank/index.ts +0 -198
- package/src/rerank/relevance/cohere/index.ts +0 -56
- package/src/rerank/relevance/index.ts +0 -3
- package/src/rerank/relevance/mastra-agent/index.ts +0 -32
- package/src/rerank/relevance/zeroentropy/index.ts +0 -26
- package/src/tools/README.md +0 -153
- package/src/tools/document-chunker.ts +0 -34
- package/src/tools/graph-rag.test.ts +0 -115
- package/src/tools/graph-rag.ts +0 -157
- package/src/tools/index.ts +0 -3
- package/src/tools/types.ts +0 -126
- package/src/tools/vector-query-database-config.test.ts +0 -190
- package/src/tools/vector-query.test.ts +0 -477
- package/src/tools/vector-query.ts +0 -171
- package/src/utils/convert-sources.ts +0 -43
- package/src/utils/default-settings.ts +0 -38
- package/src/utils/index.ts +0 -3
- package/src/utils/tool-schemas.ts +0 -38
- package/src/utils/vector-prompts.ts +0 -832
- package/src/utils/vector-search.ts +0 -130
- package/tsconfig.build.json +0 -9
- package/tsconfig.json +0 -5
- package/tsup.config.ts +0 -17
- package/vitest.config.ts +0 -8
|
@@ -1,477 +0,0 @@
|
|
|
1
|
-
import { RuntimeContext } from '@mastra/core/runtime-context';
|
|
2
|
-
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
3
|
-
import { rerank } from '../rerank';
|
|
4
|
-
import { vectorQuerySearch } from '../utils';
|
|
5
|
-
import { createVectorQueryTool } from './vector-query';
|
|
6
|
-
|
|
7
|
-
vi.mock('../utils', async importOriginal => {
|
|
8
|
-
const actual: any = await importOriginal();
|
|
9
|
-
return {
|
|
10
|
-
...actual,
|
|
11
|
-
vectorQuerySearch: vi.fn().mockResolvedValue({ results: [{ metadata: { text: 'foo' }, vector: [1, 2, 3] }] }),
|
|
12
|
-
};
|
|
13
|
-
});
|
|
14
|
-
|
|
15
|
-
vi.mock('../rerank', async importOriginal => {
|
|
16
|
-
const actual: any = await importOriginal();
|
|
17
|
-
return {
|
|
18
|
-
...actual,
|
|
19
|
-
rerank: vi
|
|
20
|
-
.fn()
|
|
21
|
-
.mockResolvedValue([
|
|
22
|
-
{ result: { id: '1', metadata: { text: 'bar' }, score: 1, details: { semantic: 1, vector: 1, position: 1 } } },
|
|
23
|
-
]),
|
|
24
|
-
};
|
|
25
|
-
});
|
|
26
|
-
|
|
27
|
-
describe('createVectorQueryTool', () => {
|
|
28
|
-
const mockModel = { name: 'test-model' } as any;
|
|
29
|
-
const mockMastra = {
|
|
30
|
-
vectors: {
|
|
31
|
-
testStore: {
|
|
32
|
-
// Mock vector store methods
|
|
33
|
-
},
|
|
34
|
-
anotherStore: {
|
|
35
|
-
// Mock vector store methods
|
|
36
|
-
},
|
|
37
|
-
},
|
|
38
|
-
getVector: vi.fn(storeName => ({
|
|
39
|
-
[storeName]: {
|
|
40
|
-
// Mock vector store methods
|
|
41
|
-
},
|
|
42
|
-
})),
|
|
43
|
-
logger: {
|
|
44
|
-
debug: vi.fn(),
|
|
45
|
-
warn: vi.fn(),
|
|
46
|
-
info: vi.fn(),
|
|
47
|
-
error: vi.fn(),
|
|
48
|
-
},
|
|
49
|
-
getLogger: vi.fn(() => ({
|
|
50
|
-
debug: vi.fn(),
|
|
51
|
-
warn: vi.fn(),
|
|
52
|
-
info: vi.fn(),
|
|
53
|
-
error: vi.fn(),
|
|
54
|
-
})),
|
|
55
|
-
};
|
|
56
|
-
|
|
57
|
-
beforeEach(() => {
|
|
58
|
-
vi.clearAllMocks();
|
|
59
|
-
});
|
|
60
|
-
|
|
61
|
-
describe('input schema validation', () => {
|
|
62
|
-
it('should handle filter permissively when enableFilter is false', () => {
|
|
63
|
-
// Create tool with enableFilter set to false
|
|
64
|
-
const tool = createVectorQueryTool({
|
|
65
|
-
vectorStoreName: 'testStore',
|
|
66
|
-
indexName: 'testIndex',
|
|
67
|
-
model: mockModel,
|
|
68
|
-
enableFilter: false,
|
|
69
|
-
});
|
|
70
|
-
|
|
71
|
-
// Get the Zod schema
|
|
72
|
-
const schema = tool.inputSchema;
|
|
73
|
-
|
|
74
|
-
// Test with no filter (should be valid)
|
|
75
|
-
const validInput = {
|
|
76
|
-
queryText: 'test query',
|
|
77
|
-
topK: 5,
|
|
78
|
-
};
|
|
79
|
-
expect(() => schema?.parse(validInput)).not.toThrow();
|
|
80
|
-
|
|
81
|
-
// Test with filter (should throw - unexpected property)
|
|
82
|
-
const inputWithFilter = {
|
|
83
|
-
...validInput,
|
|
84
|
-
filter: '{"field": "value"}',
|
|
85
|
-
};
|
|
86
|
-
expect(() => schema?.parse(inputWithFilter)).not.toThrow();
|
|
87
|
-
});
|
|
88
|
-
|
|
89
|
-
it('should handle filter when enableFilter is true', () => {
|
|
90
|
-
const tool = createVectorQueryTool({
|
|
91
|
-
vectorStoreName: 'testStore',
|
|
92
|
-
indexName: 'testIndex',
|
|
93
|
-
model: mockModel,
|
|
94
|
-
enableFilter: true,
|
|
95
|
-
});
|
|
96
|
-
|
|
97
|
-
// Get the Zod schema
|
|
98
|
-
const schema = tool.inputSchema;
|
|
99
|
-
|
|
100
|
-
// Test various filter inputs that should coerce to string
|
|
101
|
-
const testCases = [
|
|
102
|
-
// String inputs
|
|
103
|
-
{ filter: '{"field": "value"}' },
|
|
104
|
-
{ filter: '{}' },
|
|
105
|
-
{ filter: 'simple-string' },
|
|
106
|
-
// Empty
|
|
107
|
-
{ filter: '' },
|
|
108
|
-
{ filter: { field: 'value' } },
|
|
109
|
-
{ filter: {} },
|
|
110
|
-
{ filter: 123 },
|
|
111
|
-
{ filter: null },
|
|
112
|
-
{ filter: undefined },
|
|
113
|
-
];
|
|
114
|
-
|
|
115
|
-
testCases.forEach(({ filter }) => {
|
|
116
|
-
expect(() =>
|
|
117
|
-
schema?.parse({
|
|
118
|
-
queryText: 'test query',
|
|
119
|
-
topK: 5,
|
|
120
|
-
filter,
|
|
121
|
-
}),
|
|
122
|
-
).not.toThrow();
|
|
123
|
-
});
|
|
124
|
-
|
|
125
|
-
// Verify that all parsed values are strings
|
|
126
|
-
testCases.forEach(({ filter }) => {
|
|
127
|
-
const result = schema?.parse({
|
|
128
|
-
queryText: 'test query',
|
|
129
|
-
topK: 5,
|
|
130
|
-
filter,
|
|
131
|
-
});
|
|
132
|
-
expect(typeof result?.filter).toBe('string');
|
|
133
|
-
});
|
|
134
|
-
});
|
|
135
|
-
|
|
136
|
-
it('should not reject unexpected properties in both modes', () => {
|
|
137
|
-
// Test with enableFilter false
|
|
138
|
-
const toolWithoutFilter = createVectorQueryTool({
|
|
139
|
-
vectorStoreName: 'testStore',
|
|
140
|
-
indexName: 'testIndex',
|
|
141
|
-
model: mockModel,
|
|
142
|
-
enableFilter: false,
|
|
143
|
-
});
|
|
144
|
-
|
|
145
|
-
// Should reject unexpected property
|
|
146
|
-
expect(() =>
|
|
147
|
-
toolWithoutFilter.inputSchema?.parse({
|
|
148
|
-
queryText: 'test query',
|
|
149
|
-
topK: 5,
|
|
150
|
-
unexpectedProp: 'value',
|
|
151
|
-
}),
|
|
152
|
-
).not.toThrow();
|
|
153
|
-
|
|
154
|
-
// Test with enableFilter true
|
|
155
|
-
const toolWithFilter = createVectorQueryTool({
|
|
156
|
-
vectorStoreName: 'testStore',
|
|
157
|
-
indexName: 'testIndex',
|
|
158
|
-
model: mockModel,
|
|
159
|
-
enableFilter: true,
|
|
160
|
-
});
|
|
161
|
-
|
|
162
|
-
// Should reject unexpected property even with valid filter
|
|
163
|
-
expect(() =>
|
|
164
|
-
toolWithFilter.inputSchema?.parse({
|
|
165
|
-
queryText: 'test query',
|
|
166
|
-
topK: 5,
|
|
167
|
-
filter: '{}',
|
|
168
|
-
unexpectedProp: 'value',
|
|
169
|
-
}),
|
|
170
|
-
).not.toThrow();
|
|
171
|
-
});
|
|
172
|
-
});
|
|
173
|
-
|
|
174
|
-
describe('execute function', () => {
|
|
175
|
-
it('should not process filter when enableFilter is false', async () => {
|
|
176
|
-
const runtimeContext = new RuntimeContext();
|
|
177
|
-
|
|
178
|
-
// Create tool with enableFilter set to false
|
|
179
|
-
const tool = createVectorQueryTool({
|
|
180
|
-
vectorStoreName: 'testStore',
|
|
181
|
-
indexName: 'testIndex',
|
|
182
|
-
model: mockModel,
|
|
183
|
-
enableFilter: false,
|
|
184
|
-
});
|
|
185
|
-
|
|
186
|
-
// Execute with no filter
|
|
187
|
-
await tool.execute?.({
|
|
188
|
-
context: {
|
|
189
|
-
queryText: 'test query',
|
|
190
|
-
topK: 5,
|
|
191
|
-
},
|
|
192
|
-
mastra: mockMastra as any,
|
|
193
|
-
runtimeContext,
|
|
194
|
-
});
|
|
195
|
-
|
|
196
|
-
// Check that vectorQuerySearch was called with undefined queryFilter
|
|
197
|
-
expect(vectorQuerySearch).toHaveBeenCalledWith(
|
|
198
|
-
expect.objectContaining({
|
|
199
|
-
queryFilter: undefined,
|
|
200
|
-
}),
|
|
201
|
-
);
|
|
202
|
-
});
|
|
203
|
-
|
|
204
|
-
it('should process filter when enableFilter is true and filter is provided', async () => {
|
|
205
|
-
const runtimeContext = new RuntimeContext();
|
|
206
|
-
// Create tool with enableFilter set to true
|
|
207
|
-
const tool = createVectorQueryTool({
|
|
208
|
-
vectorStoreName: 'testStore',
|
|
209
|
-
indexName: 'testIndex',
|
|
210
|
-
model: mockModel,
|
|
211
|
-
enableFilter: true,
|
|
212
|
-
});
|
|
213
|
-
|
|
214
|
-
const filterJson = '{"field": "value"}';
|
|
215
|
-
|
|
216
|
-
// Execute with filter
|
|
217
|
-
await tool.execute?.({
|
|
218
|
-
context: {
|
|
219
|
-
queryText: 'test query',
|
|
220
|
-
topK: 5,
|
|
221
|
-
filter: filterJson,
|
|
222
|
-
},
|
|
223
|
-
mastra: mockMastra as any,
|
|
224
|
-
runtimeContext,
|
|
225
|
-
});
|
|
226
|
-
|
|
227
|
-
// Check that vectorQuerySearch was called with the parsed filter
|
|
228
|
-
expect(vectorQuerySearch).toHaveBeenCalledWith(
|
|
229
|
-
expect.objectContaining({
|
|
230
|
-
queryFilter: { field: 'value' },
|
|
231
|
-
}),
|
|
232
|
-
);
|
|
233
|
-
});
|
|
234
|
-
|
|
235
|
-
it('should handle string filters correctly', async () => {
|
|
236
|
-
const runtimeContext = new RuntimeContext();
|
|
237
|
-
// Create tool with enableFilter set to true
|
|
238
|
-
const tool = createVectorQueryTool({
|
|
239
|
-
vectorStoreName: 'testStore',
|
|
240
|
-
indexName: 'testIndex',
|
|
241
|
-
model: mockModel,
|
|
242
|
-
enableFilter: true,
|
|
243
|
-
});
|
|
244
|
-
|
|
245
|
-
const stringFilter = 'string-filter';
|
|
246
|
-
|
|
247
|
-
// Execute with string filter
|
|
248
|
-
await tool.execute?.({
|
|
249
|
-
context: {
|
|
250
|
-
queryText: 'test query',
|
|
251
|
-
topK: 5,
|
|
252
|
-
filter: stringFilter,
|
|
253
|
-
},
|
|
254
|
-
mastra: mockMastra as any,
|
|
255
|
-
runtimeContext,
|
|
256
|
-
});
|
|
257
|
-
|
|
258
|
-
// Since this is not a valid filter, it should be ignored
|
|
259
|
-
expect(vectorQuerySearch).toHaveBeenCalledWith(
|
|
260
|
-
expect.objectContaining({
|
|
261
|
-
queryFilter: undefined,
|
|
262
|
-
}),
|
|
263
|
-
);
|
|
264
|
-
});
|
|
265
|
-
|
|
266
|
-
it('Returns early when no Mastra server or vector store is provided', async () => {
|
|
267
|
-
const tool = createVectorQueryTool({
|
|
268
|
-
id: 'test',
|
|
269
|
-
model: mockModel,
|
|
270
|
-
indexName: 'testIndex',
|
|
271
|
-
vectorStoreName: 'testStore',
|
|
272
|
-
});
|
|
273
|
-
|
|
274
|
-
const runtimeContext = new RuntimeContext();
|
|
275
|
-
const result = await tool.execute({
|
|
276
|
-
context: { queryText: 'foo', topK: 1 },
|
|
277
|
-
runtimeContext,
|
|
278
|
-
});
|
|
279
|
-
|
|
280
|
-
expect(result).toEqual({ relevantContext: [], sources: [] });
|
|
281
|
-
expect(vectorQuerySearch).not.toHaveBeenCalled();
|
|
282
|
-
});
|
|
283
|
-
|
|
284
|
-
it('works without a mastra server if a vector store is passed as an argument', async () => {
|
|
285
|
-
const testStore = {
|
|
286
|
-
testStore: {},
|
|
287
|
-
};
|
|
288
|
-
const tool = createVectorQueryTool({
|
|
289
|
-
id: 'test',
|
|
290
|
-
model: mockModel,
|
|
291
|
-
indexName: 'testIndex',
|
|
292
|
-
vectorStoreName: 'testStore',
|
|
293
|
-
vectorStore: testStore as any,
|
|
294
|
-
});
|
|
295
|
-
|
|
296
|
-
const runtimeContext = new RuntimeContext();
|
|
297
|
-
const result = await tool.execute({
|
|
298
|
-
context: { queryText: 'foo', topK: 1 },
|
|
299
|
-
runtimeContext,
|
|
300
|
-
});
|
|
301
|
-
|
|
302
|
-
expect(result.relevantContext[0]).toEqual({ text: 'foo' });
|
|
303
|
-
expect(vectorQuerySearch).toHaveBeenCalledWith(
|
|
304
|
-
expect.objectContaining({
|
|
305
|
-
databaseConfig: undefined,
|
|
306
|
-
indexName: 'testIndex',
|
|
307
|
-
vectorStore: {
|
|
308
|
-
testStore: {},
|
|
309
|
-
},
|
|
310
|
-
queryText: 'foo',
|
|
311
|
-
model: mockModel,
|
|
312
|
-
queryFilter: undefined,
|
|
313
|
-
topK: 1,
|
|
314
|
-
}),
|
|
315
|
-
);
|
|
316
|
-
});
|
|
317
|
-
|
|
318
|
-
it('prefers the passed vector store over one from a passed Mastra server', async () => {
|
|
319
|
-
const thirdStore = {
|
|
320
|
-
thirdStore: {},
|
|
321
|
-
};
|
|
322
|
-
const tool = createVectorQueryTool({
|
|
323
|
-
id: 'test',
|
|
324
|
-
model: mockModel,
|
|
325
|
-
indexName: 'testIndex',
|
|
326
|
-
vectorStoreName: 'thirdStore',
|
|
327
|
-
vectorStore: thirdStore as any,
|
|
328
|
-
});
|
|
329
|
-
|
|
330
|
-
const runtimeContext = new RuntimeContext();
|
|
331
|
-
const result = await tool.execute({
|
|
332
|
-
context: { queryText: 'foo', topK: 1 },
|
|
333
|
-
mastra: mockMastra as any,
|
|
334
|
-
runtimeContext,
|
|
335
|
-
});
|
|
336
|
-
|
|
337
|
-
expect(result.relevantContext[0]).toEqual({ text: 'foo' });
|
|
338
|
-
expect(vectorQuerySearch).toHaveBeenCalledWith(
|
|
339
|
-
expect.objectContaining({
|
|
340
|
-
databaseConfig: undefined,
|
|
341
|
-
indexName: 'testIndex',
|
|
342
|
-
vectorStore: {
|
|
343
|
-
thirdStore: {},
|
|
344
|
-
},
|
|
345
|
-
queryText: 'foo',
|
|
346
|
-
model: mockModel,
|
|
347
|
-
queryFilter: undefined,
|
|
348
|
-
topK: 1,
|
|
349
|
-
}),
|
|
350
|
-
);
|
|
351
|
-
});
|
|
352
|
-
});
|
|
353
|
-
|
|
354
|
-
describe('runtimeContext', () => {
|
|
355
|
-
it('calls vectorQuerySearch with runtimeContext params', async () => {
|
|
356
|
-
const tool = createVectorQueryTool({
|
|
357
|
-
id: 'test',
|
|
358
|
-
model: mockModel,
|
|
359
|
-
indexName: 'testIndex',
|
|
360
|
-
vectorStoreName: 'testStore',
|
|
361
|
-
});
|
|
362
|
-
const runtimeContext = new RuntimeContext();
|
|
363
|
-
runtimeContext.set('indexName', 'anotherIndex');
|
|
364
|
-
runtimeContext.set('vectorStoreName', 'anotherStore');
|
|
365
|
-
runtimeContext.set('topK', 3);
|
|
366
|
-
runtimeContext.set('filter', { foo: 'bar' });
|
|
367
|
-
runtimeContext.set('includeVectors', true);
|
|
368
|
-
runtimeContext.set('includeSources', false);
|
|
369
|
-
const result = await tool.execute({
|
|
370
|
-
context: { queryText: 'foo', topK: 6 },
|
|
371
|
-
mastra: mockMastra as any,
|
|
372
|
-
runtimeContext,
|
|
373
|
-
});
|
|
374
|
-
expect(result.relevantContext.length).toBeGreaterThan(0);
|
|
375
|
-
expect(result.sources).toEqual([]); // includeSources false
|
|
376
|
-
expect(vectorQuerySearch).toHaveBeenCalledWith(
|
|
377
|
-
expect.objectContaining({
|
|
378
|
-
indexName: 'anotherIndex',
|
|
379
|
-
vectorStore: {
|
|
380
|
-
anotherStore: {},
|
|
381
|
-
},
|
|
382
|
-
queryText: 'foo',
|
|
383
|
-
model: mockModel,
|
|
384
|
-
queryFilter: { foo: 'bar' },
|
|
385
|
-
topK: 3,
|
|
386
|
-
includeVectors: true,
|
|
387
|
-
}),
|
|
388
|
-
);
|
|
389
|
-
});
|
|
390
|
-
|
|
391
|
-
it('handles reranker from runtimeContext', async () => {
|
|
392
|
-
const tool = createVectorQueryTool({
|
|
393
|
-
id: 'test',
|
|
394
|
-
model: mockModel,
|
|
395
|
-
indexName: 'testIndex',
|
|
396
|
-
vectorStoreName: 'testStore',
|
|
397
|
-
});
|
|
398
|
-
const runtimeContext = new RuntimeContext();
|
|
399
|
-
runtimeContext.set('indexName', 'testIndex');
|
|
400
|
-
runtimeContext.set('vectorStoreName', 'testStore');
|
|
401
|
-
runtimeContext.set('reranker', { model: 'reranker-model', options: { topK: 1 } });
|
|
402
|
-
// Mock rerank
|
|
403
|
-
vi.mocked(rerank).mockResolvedValue([
|
|
404
|
-
{
|
|
405
|
-
result: { id: '1', metadata: { text: 'bar' }, score: 1 },
|
|
406
|
-
score: 1,
|
|
407
|
-
details: { semantic: 1, vector: 1, position: 1 },
|
|
408
|
-
},
|
|
409
|
-
]);
|
|
410
|
-
const result = await tool.execute({
|
|
411
|
-
context: { queryText: 'foo', topK: 1 },
|
|
412
|
-
mastra: mockMastra as any,
|
|
413
|
-
runtimeContext,
|
|
414
|
-
});
|
|
415
|
-
expect(result.relevantContext[0]).toEqual({ text: 'bar' });
|
|
416
|
-
});
|
|
417
|
-
});
|
|
418
|
-
|
|
419
|
-
describe('providerOptions', () => {
|
|
420
|
-
it('should pass providerOptions to vectorQuerySearch', async () => {
|
|
421
|
-
const tool = createVectorQueryTool({
|
|
422
|
-
indexName: 'testIndex',
|
|
423
|
-
model: mockModel,
|
|
424
|
-
vectorStoreName: 'testStore',
|
|
425
|
-
providerOptions: { google: { outputDimensionality: 1536 } },
|
|
426
|
-
});
|
|
427
|
-
|
|
428
|
-
await tool.execute({
|
|
429
|
-
context: { queryText: 'foo', topK: 10 },
|
|
430
|
-
mastra: mockMastra as any,
|
|
431
|
-
runtimeContext: new RuntimeContext(),
|
|
432
|
-
});
|
|
433
|
-
|
|
434
|
-
expect(vectorQuerySearch).toHaveBeenCalledWith({
|
|
435
|
-
indexName: 'testIndex',
|
|
436
|
-
vectorStore: { testStore: {} },
|
|
437
|
-
queryText: 'foo',
|
|
438
|
-
model: mockModel,
|
|
439
|
-
queryFilter: undefined,
|
|
440
|
-
topK: 10,
|
|
441
|
-
includeVectors: false,
|
|
442
|
-
databaseConfig: undefined,
|
|
443
|
-
providerOptions: { google: { outputDimensionality: 1536 } },
|
|
444
|
-
});
|
|
445
|
-
});
|
|
446
|
-
|
|
447
|
-
it('should allow providerOptions override via runtimeContext', async () => {
|
|
448
|
-
const tool = createVectorQueryTool({
|
|
449
|
-
indexName: 'testIndex',
|
|
450
|
-
model: mockModel,
|
|
451
|
-
vectorStoreName: 'testStore',
|
|
452
|
-
providerOptions: { google: { outputDimensionality: 1536 } },
|
|
453
|
-
});
|
|
454
|
-
|
|
455
|
-
const runtimeContext = new RuntimeContext();
|
|
456
|
-
runtimeContext.set('providerOptions', { google: { outputDimensionality: 768 } });
|
|
457
|
-
|
|
458
|
-
await tool.execute({
|
|
459
|
-
context: { queryText: 'foo', topK: 10 },
|
|
460
|
-
mastra: mockMastra as any,
|
|
461
|
-
runtimeContext,
|
|
462
|
-
});
|
|
463
|
-
|
|
464
|
-
expect(vectorQuerySearch).toHaveBeenCalledWith({
|
|
465
|
-
indexName: 'testIndex',
|
|
466
|
-
vectorStore: { testStore: {} },
|
|
467
|
-
queryText: 'foo',
|
|
468
|
-
model: mockModel,
|
|
469
|
-
queryFilter: undefined,
|
|
470
|
-
topK: 10,
|
|
471
|
-
includeVectors: false,
|
|
472
|
-
databaseConfig: undefined,
|
|
473
|
-
providerOptions: { google: { outputDimensionality: 768 } },
|
|
474
|
-
});
|
|
475
|
-
});
|
|
476
|
-
});
|
|
477
|
-
});
|
|
@@ -1,171 +0,0 @@
|
|
|
1
|
-
import { createTool } from '@mastra/core/tools';
|
|
2
|
-
import type { MastraVector, MastraEmbeddingModel } from '@mastra/core/vector';
|
|
3
|
-
import { z } from 'zod';
|
|
4
|
-
|
|
5
|
-
import { rerank, rerankWithScorer } from '../rerank';
|
|
6
|
-
import type { RerankConfig, RerankResult } from '../rerank';
|
|
7
|
-
import { vectorQuerySearch, defaultVectorQueryDescription, filterSchema, outputSchema, baseSchema } from '../utils';
|
|
8
|
-
import type { RagTool } from '../utils';
|
|
9
|
-
import { convertToSources } from '../utils/convert-sources';
|
|
10
|
-
import type { VectorQueryToolOptions } from './types';
|
|
11
|
-
|
|
12
|
-
export const createVectorQueryTool = (options: VectorQueryToolOptions) => {
|
|
13
|
-
const { id, description } = options;
|
|
14
|
-
const storeName = options['vectorStoreName'] ? options.vectorStoreName : 'DirectVectorStore';
|
|
15
|
-
|
|
16
|
-
const toolId = id || `VectorQuery ${storeName} ${options.indexName} Tool`;
|
|
17
|
-
const toolDescription = description || defaultVectorQueryDescription();
|
|
18
|
-
const inputSchema = options.enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
19
|
-
|
|
20
|
-
return createTool({
|
|
21
|
-
id: toolId,
|
|
22
|
-
description: toolDescription,
|
|
23
|
-
inputSchema,
|
|
24
|
-
outputSchema,
|
|
25
|
-
execute: async ({ context, mastra, runtimeContext }) => {
|
|
26
|
-
const indexName: string = runtimeContext.get('indexName') ?? options.indexName;
|
|
27
|
-
const vectorStoreName: string =
|
|
28
|
-
'vectorStore' in options ? storeName : (runtimeContext.get('vectorStoreName') ?? storeName);
|
|
29
|
-
const includeVectors: boolean = runtimeContext.get('includeVectors') ?? options.includeVectors ?? false;
|
|
30
|
-
const includeSources: boolean = runtimeContext.get('includeSources') ?? options.includeSources ?? true;
|
|
31
|
-
const reranker: RerankConfig = runtimeContext.get('reranker') ?? options.reranker;
|
|
32
|
-
const databaseConfig = runtimeContext.get('databaseConfig') ?? options.databaseConfig;
|
|
33
|
-
const model: MastraEmbeddingModel<string> = runtimeContext.get('model') ?? options.model;
|
|
34
|
-
const providerOptions: Record<string, Record<string, any>> | undefined =
|
|
35
|
-
runtimeContext.get('providerOptions') ?? options.providerOptions;
|
|
36
|
-
|
|
37
|
-
if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
|
|
38
|
-
if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`); // won't fire
|
|
39
|
-
|
|
40
|
-
const topK: number = runtimeContext.get('topK') ?? context.topK ?? 10;
|
|
41
|
-
const filter: Record<string, any> = runtimeContext.get('filter') ?? context.filter;
|
|
42
|
-
const queryText = context.queryText;
|
|
43
|
-
const enableFilter = !!runtimeContext.get('filter') || (options.enableFilter ?? false);
|
|
44
|
-
|
|
45
|
-
const logger = mastra?.getLogger();
|
|
46
|
-
if (!logger) {
|
|
47
|
-
console.warn(
|
|
48
|
-
'[VectorQueryTool] Logger not initialized: no debug or error logs will be recorded for this tool execution.',
|
|
49
|
-
);
|
|
50
|
-
}
|
|
51
|
-
if (logger) {
|
|
52
|
-
logger.debug('[VectorQueryTool] execute called with:', { queryText, topK, filter, databaseConfig });
|
|
53
|
-
}
|
|
54
|
-
try {
|
|
55
|
-
const topKValue =
|
|
56
|
-
typeof topK === 'number' && !isNaN(topK)
|
|
57
|
-
? topK
|
|
58
|
-
: typeof topK === 'string' && !isNaN(Number(topK))
|
|
59
|
-
? Number(topK)
|
|
60
|
-
: 10;
|
|
61
|
-
|
|
62
|
-
let vectorStore: MastraVector | undefined = undefined;
|
|
63
|
-
if ('vectorStore' in options) {
|
|
64
|
-
vectorStore = options.vectorStore;
|
|
65
|
-
} else if (mastra) {
|
|
66
|
-
vectorStore = mastra.getVector(vectorStoreName);
|
|
67
|
-
}
|
|
68
|
-
if (!vectorStore) {
|
|
69
|
-
if (logger) {
|
|
70
|
-
logger.error('Vector store not found', { vectorStoreName });
|
|
71
|
-
}
|
|
72
|
-
return { relevantContext: [], sources: [] };
|
|
73
|
-
}
|
|
74
|
-
// Get relevant chunks from the vector database
|
|
75
|
-
let queryFilter = {};
|
|
76
|
-
if (enableFilter && filter) {
|
|
77
|
-
queryFilter = (() => {
|
|
78
|
-
try {
|
|
79
|
-
return typeof filter === 'string' ? JSON.parse(filter) : filter;
|
|
80
|
-
} catch (error) {
|
|
81
|
-
// Log the error and use empty object
|
|
82
|
-
if (logger) {
|
|
83
|
-
logger.warn('Failed to parse filter as JSON, using empty filter', { filter, error });
|
|
84
|
-
}
|
|
85
|
-
return {};
|
|
86
|
-
}
|
|
87
|
-
})();
|
|
88
|
-
}
|
|
89
|
-
if (logger) {
|
|
90
|
-
logger.debug('Prepared vector query parameters', { queryText, topK: topKValue, queryFilter, databaseConfig });
|
|
91
|
-
}
|
|
92
|
-
|
|
93
|
-
const { results } = await vectorQuerySearch({
|
|
94
|
-
indexName,
|
|
95
|
-
vectorStore,
|
|
96
|
-
queryText,
|
|
97
|
-
model,
|
|
98
|
-
queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
|
|
99
|
-
topK: topKValue,
|
|
100
|
-
includeVectors,
|
|
101
|
-
databaseConfig,
|
|
102
|
-
providerOptions,
|
|
103
|
-
});
|
|
104
|
-
if (logger) {
|
|
105
|
-
logger.debug('vectorQuerySearch returned results', { count: results.length });
|
|
106
|
-
}
|
|
107
|
-
|
|
108
|
-
if (reranker) {
|
|
109
|
-
if (logger) {
|
|
110
|
-
logger.debug('Reranking results', { rerankerModel: reranker.model, rerankerOptions: reranker.options });
|
|
111
|
-
}
|
|
112
|
-
|
|
113
|
-
let rerankedResults: RerankResult[] = [];
|
|
114
|
-
|
|
115
|
-
if (typeof reranker?.model === 'object' && 'getRelevanceScore' in reranker?.model) {
|
|
116
|
-
rerankedResults = await rerankWithScorer({
|
|
117
|
-
results,
|
|
118
|
-
query: queryText,
|
|
119
|
-
scorer: reranker.model,
|
|
120
|
-
options: {
|
|
121
|
-
...reranker.options,
|
|
122
|
-
topK: reranker.options?.topK || topKValue,
|
|
123
|
-
},
|
|
124
|
-
});
|
|
125
|
-
} else {
|
|
126
|
-
rerankedResults = await rerank(results, queryText, reranker.model, {
|
|
127
|
-
...reranker.options,
|
|
128
|
-
topK: reranker.options?.topK || topKValue,
|
|
129
|
-
});
|
|
130
|
-
}
|
|
131
|
-
|
|
132
|
-
if (logger) {
|
|
133
|
-
logger.debug('Reranking complete', { rerankedCount: rerankedResults.length });
|
|
134
|
-
}
|
|
135
|
-
|
|
136
|
-
const relevantChunks = rerankedResults.map(({ result }) => result?.metadata);
|
|
137
|
-
|
|
138
|
-
if (logger) {
|
|
139
|
-
logger.debug('Returning reranked relevant context chunks', { count: relevantChunks.length });
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
const sources = includeSources ? convertToSources(rerankedResults) : [];
|
|
143
|
-
|
|
144
|
-
return { relevantContext: relevantChunks, sources };
|
|
145
|
-
}
|
|
146
|
-
|
|
147
|
-
const relevantChunks = results.map(result => result?.metadata);
|
|
148
|
-
|
|
149
|
-
if (logger) {
|
|
150
|
-
logger.debug('Returning relevant context chunks', { count: relevantChunks.length });
|
|
151
|
-
}
|
|
152
|
-
// `sources` exposes the full retrieval objects
|
|
153
|
-
const sources = includeSources ? convertToSources(results) : [];
|
|
154
|
-
return {
|
|
155
|
-
relevantContext: relevantChunks,
|
|
156
|
-
sources,
|
|
157
|
-
};
|
|
158
|
-
} catch (err) {
|
|
159
|
-
if (logger) {
|
|
160
|
-
logger.error('Unexpected error in VectorQueryTool execute', {
|
|
161
|
-
error: err,
|
|
162
|
-
errorMessage: err instanceof Error ? err.message : String(err),
|
|
163
|
-
errorStack: err instanceof Error ? err.stack : undefined,
|
|
164
|
-
});
|
|
165
|
-
}
|
|
166
|
-
return { relevantContext: [], sources: [] };
|
|
167
|
-
}
|
|
168
|
-
},
|
|
169
|
-
// Use any for output schema as the structure of the output causes type inference issues
|
|
170
|
-
}) as RagTool<typeof inputSchema, any>;
|
|
171
|
-
};
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
import type { QueryResult } from '@mastra/core/vector';
|
|
2
|
-
import type { RankedNode } from '../graph-rag';
|
|
3
|
-
import type { RerankResult } from '../rerank';
|
|
4
|
-
|
|
5
|
-
type SourceInput = QueryResult | RankedNode | RerankResult;
|
|
6
|
-
|
|
7
|
-
/**
|
|
8
|
-
* Convert an array of source inputs (QueryResult, RankedNode, or RerankResult) to an array of sources.
|
|
9
|
-
* @param results Array of source inputs to convert.
|
|
10
|
-
* @returns Array of sources.
|
|
11
|
-
*/
|
|
12
|
-
export const convertToSources = (results: SourceInput[]) => {
|
|
13
|
-
return results.map(result => {
|
|
14
|
-
// RankedNode
|
|
15
|
-
if ('content' in result) {
|
|
16
|
-
return {
|
|
17
|
-
id: result.id,
|
|
18
|
-
vector: result.embedding || [],
|
|
19
|
-
score: result.score,
|
|
20
|
-
metadata: result.metadata,
|
|
21
|
-
document: result.content || '',
|
|
22
|
-
};
|
|
23
|
-
}
|
|
24
|
-
// RerankResult
|
|
25
|
-
if ('result' in result) {
|
|
26
|
-
return {
|
|
27
|
-
id: result.result.id,
|
|
28
|
-
vector: result.result.vector || [],
|
|
29
|
-
score: result.score,
|
|
30
|
-
metadata: result.result.metadata,
|
|
31
|
-
document: result.result.document || '',
|
|
32
|
-
};
|
|
33
|
-
}
|
|
34
|
-
// QueryResult
|
|
35
|
-
return {
|
|
36
|
-
id: result.id,
|
|
37
|
-
vector: result.vector || [],
|
|
38
|
-
score: result.score,
|
|
39
|
-
metadata: result.metadata,
|
|
40
|
-
document: result.document || '',
|
|
41
|
-
};
|
|
42
|
-
});
|
|
43
|
-
};
|