@mastra/rag 0.10.1-alpha.0 → 0.10.1-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +6 -0
- package/dist/_tsup-dts-rollup.d.cts +101 -224
- package/dist/_tsup-dts-rollup.d.ts +101 -224
- package/dist/index.cjs +33 -60
- package/dist/index.js +33 -60
- package/package.json +1 -1
- package/src/tools/graph-rag.ts +7 -37
- package/src/tools/vector-query.test.ts +19 -26
- package/src/tools/vector-query.ts +6 -38
- package/src/utils/index.ts +1 -0
- package/src/utils/tool-schemas.ts +38 -0
package/src/tools/graph-rag.ts
CHANGED
|
@@ -3,13 +3,8 @@ import type { EmbeddingModel } from 'ai';
|
|
|
3
3
|
import { z } from 'zod';
|
|
4
4
|
|
|
5
5
|
import { GraphRAG } from '../graph-rag';
|
|
6
|
-
import {
|
|
7
|
-
|
|
8
|
-
defaultGraphRagDescription,
|
|
9
|
-
filterDescription,
|
|
10
|
-
topKDescription,
|
|
11
|
-
queryTextDescription,
|
|
12
|
-
} from '../utils';
|
|
6
|
+
import { vectorQuerySearch, defaultGraphRagDescription, filterSchema, outputSchema, baseSchema } from '../utils';
|
|
7
|
+
import type { RagTool } from '../utils';
|
|
13
8
|
import { convertToSources } from '../utils/convert-sources';
|
|
14
9
|
|
|
15
10
|
export const createGraphRAGTool = ({
|
|
@@ -47,38 +42,12 @@ export const createGraphRAGTool = ({
|
|
|
47
42
|
const graphRag = new GraphRAG(graphOptions.dimension, graphOptions.threshold);
|
|
48
43
|
let isInitialized = false;
|
|
49
44
|
|
|
50
|
-
const
|
|
51
|
-
|
|
52
|
-
topK: z.coerce.number().describe(topKDescription),
|
|
53
|
-
};
|
|
54
|
-
const inputSchema = enableFilter
|
|
55
|
-
? z
|
|
56
|
-
.object({
|
|
57
|
-
...baseSchema,
|
|
58
|
-
filter: z.coerce.string().describe(filterDescription),
|
|
59
|
-
})
|
|
60
|
-
.passthrough()
|
|
61
|
-
: z.object(baseSchema).passthrough();
|
|
45
|
+
const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
46
|
+
|
|
62
47
|
return createTool({
|
|
63
48
|
id: toolId,
|
|
64
49
|
inputSchema,
|
|
65
|
-
|
|
66
|
-
// Each source contains all information needed to reference
|
|
67
|
-
// the original document, chunk, and similarity score.
|
|
68
|
-
outputSchema: z.object({
|
|
69
|
-
// Array of metadata or content for compatibility with prior usage
|
|
70
|
-
relevantContext: z.any(),
|
|
71
|
-
// Array of full retrieval result objects
|
|
72
|
-
sources: z.array(
|
|
73
|
-
z.object({
|
|
74
|
-
id: z.string(), // Unique chunk/document identifier
|
|
75
|
-
metadata: z.any(), // All metadata fields (document ID, etc.)
|
|
76
|
-
vector: z.array(z.number()), // Embedding vector (if available)
|
|
77
|
-
score: z.number(), // Similarity score for this retrieval
|
|
78
|
-
document: z.string(), // Full chunk/document text (if available)
|
|
79
|
-
}),
|
|
80
|
-
),
|
|
81
|
-
}),
|
|
50
|
+
outputSchema,
|
|
82
51
|
description: toolDescription,
|
|
83
52
|
execute: async ({ context: { queryText, topK, filter }, mastra }) => {
|
|
84
53
|
const logger = mastra?.getLogger();
|
|
@@ -188,5 +157,6 @@ export const createGraphRAGTool = ({
|
|
|
188
157
|
return { relevantContext: [], sources: [] };
|
|
189
158
|
}
|
|
190
159
|
},
|
|
191
|
-
|
|
160
|
+
// Use any for output schema as the structure of the output causes type inference issues
|
|
161
|
+
}) as RagTool<typeof inputSchema, any>;
|
|
192
162
|
};
|
|
@@ -2,23 +2,13 @@ import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
|
2
2
|
import { vectorQuerySearch } from '../utils';
|
|
3
3
|
import { createVectorQueryTool } from './vector-query';
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
})),
|
|
13
|
-
}));
|
|
14
|
-
|
|
15
|
-
vi.mock('../utils', () => ({
|
|
16
|
-
vectorQuerySearch: vi.fn().mockResolvedValue({ results: [] }),
|
|
17
|
-
defaultVectorQueryDescription: () => 'Default vector query description',
|
|
18
|
-
queryTextDescription: 'Query text description',
|
|
19
|
-
filterDescription: 'Filter description',
|
|
20
|
-
topKDescription: 'Top K description',
|
|
21
|
-
}));
|
|
5
|
+
vi.mock('../utils', async importOriginal => {
|
|
6
|
+
const actual: any = await importOriginal();
|
|
7
|
+
return {
|
|
8
|
+
...actual,
|
|
9
|
+
vectorQuerySearch: vi.fn().mockResolvedValue({ results: [] }),
|
|
10
|
+
};
|
|
11
|
+
});
|
|
22
12
|
|
|
23
13
|
describe('createVectorQueryTool', () => {
|
|
24
14
|
const mockModel = { name: 'test-model' } as any;
|
|
@@ -60,21 +50,21 @@ describe('createVectorQueryTool', () => {
|
|
|
60
50
|
});
|
|
61
51
|
|
|
62
52
|
// Get the Zod schema
|
|
63
|
-
const schema = tool.
|
|
53
|
+
const schema = tool.inputSchema;
|
|
64
54
|
|
|
65
55
|
// Test with no filter (should be valid)
|
|
66
56
|
const validInput = {
|
|
67
57
|
queryText: 'test query',
|
|
68
58
|
topK: 5,
|
|
69
59
|
};
|
|
70
|
-
expect(() => schema
|
|
60
|
+
expect(() => schema?.parse(validInput)).not.toThrow();
|
|
71
61
|
|
|
72
62
|
// Test with filter (should throw - unexpected property)
|
|
73
63
|
const inputWithFilter = {
|
|
74
64
|
...validInput,
|
|
75
65
|
filter: '{"field": "value"}',
|
|
76
66
|
};
|
|
77
|
-
expect(() => schema
|
|
67
|
+
expect(() => schema?.parse(inputWithFilter)).not.toThrow();
|
|
78
68
|
});
|
|
79
69
|
|
|
80
70
|
it('should handle filter when enableFilter is true', () => {
|
|
@@ -86,7 +76,7 @@ describe('createVectorQueryTool', () => {
|
|
|
86
76
|
});
|
|
87
77
|
|
|
88
78
|
// Get the Zod schema
|
|
89
|
-
const schema = tool.
|
|
79
|
+
const schema = tool.inputSchema;
|
|
90
80
|
|
|
91
81
|
// Test various filter inputs that should coerce to string
|
|
92
82
|
const testCases = [
|
|
@@ -105,7 +95,7 @@ describe('createVectorQueryTool', () => {
|
|
|
105
95
|
|
|
106
96
|
testCases.forEach(({ filter }) => {
|
|
107
97
|
expect(() =>
|
|
108
|
-
schema
|
|
98
|
+
schema?.parse({
|
|
109
99
|
queryText: 'test query',
|
|
110
100
|
topK: 5,
|
|
111
101
|
filter,
|
|
@@ -115,12 +105,12 @@ describe('createVectorQueryTool', () => {
|
|
|
115
105
|
|
|
116
106
|
// Verify that all parsed values are strings
|
|
117
107
|
testCases.forEach(({ filter }) => {
|
|
118
|
-
const result = schema
|
|
108
|
+
const result = schema?.parse({
|
|
119
109
|
queryText: 'test query',
|
|
120
110
|
topK: 5,
|
|
121
111
|
filter,
|
|
122
112
|
});
|
|
123
|
-
expect(typeof result
|
|
113
|
+
expect(typeof result?.filter).toBe('string');
|
|
124
114
|
});
|
|
125
115
|
});
|
|
126
116
|
|
|
@@ -135,7 +125,7 @@ describe('createVectorQueryTool', () => {
|
|
|
135
125
|
|
|
136
126
|
// Should reject unexpected property
|
|
137
127
|
expect(() =>
|
|
138
|
-
toolWithoutFilter.
|
|
128
|
+
toolWithoutFilter.inputSchema?.parse({
|
|
139
129
|
queryText: 'test query',
|
|
140
130
|
topK: 5,
|
|
141
131
|
unexpectedProp: 'value',
|
|
@@ -152,7 +142,7 @@ describe('createVectorQueryTool', () => {
|
|
|
152
142
|
|
|
153
143
|
// Should reject unexpected property even with valid filter
|
|
154
144
|
expect(() =>
|
|
155
|
-
toolWithFilter.
|
|
145
|
+
toolWithFilter.inputSchema?.parse({
|
|
156
146
|
queryText: 'test query',
|
|
157
147
|
topK: 5,
|
|
158
148
|
filter: '{}',
|
|
@@ -179,6 +169,7 @@ describe('createVectorQueryTool', () => {
|
|
|
179
169
|
topK: 5,
|
|
180
170
|
},
|
|
181
171
|
mastra: mockMastra as any,
|
|
172
|
+
runtimeContext: {} as any,
|
|
182
173
|
});
|
|
183
174
|
|
|
184
175
|
// Check that vectorQuerySearch was called with undefined queryFilter
|
|
@@ -208,6 +199,7 @@ describe('createVectorQueryTool', () => {
|
|
|
208
199
|
filter: filterJson,
|
|
209
200
|
},
|
|
210
201
|
mastra: mockMastra as any,
|
|
202
|
+
runtimeContext: {} as any,
|
|
211
203
|
});
|
|
212
204
|
|
|
213
205
|
// Check that vectorQuerySearch was called with the parsed filter
|
|
@@ -237,6 +229,7 @@ describe('createVectorQueryTool', () => {
|
|
|
237
229
|
filter: stringFilter,
|
|
238
230
|
},
|
|
239
231
|
mastra: mockMastra as any,
|
|
232
|
+
runtimeContext: {} as any,
|
|
240
233
|
});
|
|
241
234
|
|
|
242
235
|
// Since this is not a valid filter, it should be ignored
|
|
@@ -4,13 +4,8 @@ import { z } from 'zod';
|
|
|
4
4
|
|
|
5
5
|
import { rerank } from '../rerank';
|
|
6
6
|
import type { RerankConfig } from '../rerank';
|
|
7
|
-
import {
|
|
8
|
-
|
|
9
|
-
defaultVectorQueryDescription,
|
|
10
|
-
filterDescription,
|
|
11
|
-
topKDescription,
|
|
12
|
-
queryTextDescription,
|
|
13
|
-
} from '../utils';
|
|
7
|
+
import { vectorQuerySearch, defaultVectorQueryDescription, filterSchema, outputSchema, baseSchema } from '../utils';
|
|
8
|
+
import type { RagTool } from '../utils';
|
|
14
9
|
import { convertToSources } from '../utils/convert-sources';
|
|
15
10
|
|
|
16
11
|
export const createVectorQueryTool = ({
|
|
@@ -36,39 +31,11 @@ export const createVectorQueryTool = ({
|
|
|
36
31
|
}) => {
|
|
37
32
|
const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
|
|
38
33
|
const toolDescription = description || defaultVectorQueryDescription();
|
|
39
|
-
|
|
40
|
-
const baseSchema = {
|
|
41
|
-
queryText: z.string().describe(queryTextDescription),
|
|
42
|
-
topK: z.coerce.number().describe(topKDescription),
|
|
43
|
-
};
|
|
44
|
-
const inputSchema = enableFilter
|
|
45
|
-
? z
|
|
46
|
-
.object({
|
|
47
|
-
...baseSchema,
|
|
48
|
-
filter: z.coerce.string().describe(filterDescription),
|
|
49
|
-
})
|
|
50
|
-
.passthrough()
|
|
51
|
-
: z.object(baseSchema).passthrough();
|
|
34
|
+
const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
52
35
|
return createTool({
|
|
53
36
|
id: toolId,
|
|
54
37
|
inputSchema,
|
|
55
|
-
|
|
56
|
-
// Each source contains all information needed to reference
|
|
57
|
-
// the original document, chunk, and similarity score.
|
|
58
|
-
outputSchema: z.object({
|
|
59
|
-
// Array of metadata or content for compatibility with prior usage
|
|
60
|
-
relevantContext: z.any(),
|
|
61
|
-
// Array of full retrieval result objects
|
|
62
|
-
sources: z.array(
|
|
63
|
-
z.object({
|
|
64
|
-
id: z.string(), // Unique chunk/document identifier
|
|
65
|
-
metadata: z.any(), // All metadata fields (document ID, etc.)
|
|
66
|
-
vector: z.array(z.number()), // Embedding vector (if available)
|
|
67
|
-
score: z.number(), // Similarity score for this retrieval
|
|
68
|
-
document: z.string(), // Full chunk/document text (if available)
|
|
69
|
-
}),
|
|
70
|
-
),
|
|
71
|
-
}),
|
|
38
|
+
outputSchema,
|
|
72
39
|
description: toolDescription,
|
|
73
40
|
execute: async ({ context: { queryText, topK, filter }, mastra }) => {
|
|
74
41
|
const logger = mastra?.getLogger();
|
|
@@ -167,5 +134,6 @@ export const createVectorQueryTool = ({
|
|
|
167
134
|
return { relevantContext: [], sources: [] };
|
|
168
135
|
}
|
|
169
136
|
},
|
|
170
|
-
|
|
137
|
+
// Use any for output schema as the structure of the output causes type inference issues
|
|
138
|
+
}) as RagTool<typeof inputSchema, any>;
|
|
171
139
|
};
|
package/src/utils/index.ts
CHANGED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import type { Tool } from '@mastra/core/tools';
|
|
2
|
+
import { z } from 'zod';
|
|
3
|
+
import { queryTextDescription, topKDescription, filterDescription } from './default-settings';
|
|
4
|
+
|
|
5
|
+
export const baseSchema = {
|
|
6
|
+
queryText: z.string().describe(queryTextDescription),
|
|
7
|
+
topK: z.coerce.number().describe(topKDescription),
|
|
8
|
+
};
|
|
9
|
+
|
|
10
|
+
// Output schema includes `sources`, which exposes the full set of retrieved chunks (QueryResult objects)
|
|
11
|
+
// Each source contains all information needed to reference
|
|
12
|
+
// the original document, chunk, and similarity score.
|
|
13
|
+
export const outputSchema = z.object({
|
|
14
|
+
// Array of metadata or content for compatibility with prior usage
|
|
15
|
+
relevantContext: z.any(),
|
|
16
|
+
// Array of full retrieval result objects
|
|
17
|
+
sources: z.array(
|
|
18
|
+
z.object({
|
|
19
|
+
id: z.string(), // Unique chunk/document identifier
|
|
20
|
+
metadata: z.any(), // All metadata fields (document ID, etc.)
|
|
21
|
+
vector: z.array(z.number()), // Embedding vector (if available)
|
|
22
|
+
score: z.number(), // Similarity score for this retrieval
|
|
23
|
+
document: z.string(), // Full chunk/document text (if available)
|
|
24
|
+
}),
|
|
25
|
+
),
|
|
26
|
+
});
|
|
27
|
+
|
|
28
|
+
export const filterSchema = z.object({
|
|
29
|
+
...baseSchema,
|
|
30
|
+
filter: z.coerce.string().describe(filterDescription),
|
|
31
|
+
});
|
|
32
|
+
|
|
33
|
+
export type RagTool<
|
|
34
|
+
TInput extends z.ZodType<any, z.ZodTypeDef, any> | undefined,
|
|
35
|
+
TOutput extends z.ZodType<any, z.ZodTypeDef, any> | undefined,
|
|
36
|
+
> = Tool<TInput, TOutput> & {
|
|
37
|
+
execute: NonNullable<Tool<TInput, TOutput>['execute']>;
|
|
38
|
+
};
|