@mastra/rag 0.10.1-alpha.0 → 0.10.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,13 +3,8 @@ import type { EmbeddingModel } from 'ai';
3
3
  import { z } from 'zod';
4
4
 
5
5
  import { GraphRAG } from '../graph-rag';
6
- import {
7
- vectorQuerySearch,
8
- defaultGraphRagDescription,
9
- filterDescription,
10
- topKDescription,
11
- queryTextDescription,
12
- } from '../utils';
6
+ import { vectorQuerySearch, defaultGraphRagDescription, filterSchema, outputSchema, baseSchema } from '../utils';
7
+ import type { RagTool } from '../utils';
13
8
  import { convertToSources } from '../utils/convert-sources';
14
9
 
15
10
  export const createGraphRAGTool = ({
@@ -47,38 +42,12 @@ export const createGraphRAGTool = ({
47
42
  const graphRag = new GraphRAG(graphOptions.dimension, graphOptions.threshold);
48
43
  let isInitialized = false;
49
44
 
50
- const baseSchema = {
51
- queryText: z.string().describe(queryTextDescription),
52
- topK: z.coerce.number().describe(topKDescription),
53
- };
54
- const inputSchema = enableFilter
55
- ? z
56
- .object({
57
- ...baseSchema,
58
- filter: z.coerce.string().describe(filterDescription),
59
- })
60
- .passthrough()
61
- : z.object(baseSchema).passthrough();
45
+ const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
46
+
62
47
  return createTool({
63
48
  id: toolId,
64
49
  inputSchema,
65
- // Output schema includes `sources`, which exposes the full set of retrieved chunks (QueryResult objects)
66
- // Each source contains all information needed to reference
67
- // the original document, chunk, and similarity score.
68
- outputSchema: z.object({
69
- // Array of metadata or content for compatibility with prior usage
70
- relevantContext: z.any(),
71
- // Array of full retrieval result objects
72
- sources: z.array(
73
- z.object({
74
- id: z.string(), // Unique chunk/document identifier
75
- metadata: z.any(), // All metadata fields (document ID, etc.)
76
- vector: z.array(z.number()), // Embedding vector (if available)
77
- score: z.number(), // Similarity score for this retrieval
78
- document: z.string(), // Full chunk/document text (if available)
79
- }),
80
- ),
81
- }),
50
+ outputSchema,
82
51
  description: toolDescription,
83
52
  execute: async ({ context: { queryText, topK, filter }, mastra }) => {
84
53
  const logger = mastra?.getLogger();
@@ -188,5 +157,6 @@ export const createGraphRAGTool = ({
188
157
  return { relevantContext: [], sources: [] };
189
158
  }
190
159
  },
191
- });
160
+ // Use any for output schema as the structure of the output causes type inference issues
161
+ }) as RagTool<typeof inputSchema, any>;
192
162
  };
@@ -2,23 +2,13 @@ import { describe, it, expect, vi, beforeEach } from 'vitest';
2
2
  import { vectorQuerySearch } from '../utils';
3
3
  import { createVectorQueryTool } from './vector-query';
4
4
 
5
- // Mock dependencies
6
- vi.mock('@mastra/core/tools', () => ({
7
- createTool: vi.fn(({ inputSchema, execute }) => ({
8
- inputSchema,
9
- execute,
10
- // Return a simplified version of the tool for testing
11
- __inputSchema: inputSchema,
12
- })),
13
- }));
14
-
15
- vi.mock('../utils', () => ({
16
- vectorQuerySearch: vi.fn().mockResolvedValue({ results: [] }),
17
- defaultVectorQueryDescription: () => 'Default vector query description',
18
- queryTextDescription: 'Query text description',
19
- filterDescription: 'Filter description',
20
- topKDescription: 'Top K description',
21
- }));
5
+ vi.mock('../utils', async importOriginal => {
6
+ const actual: any = await importOriginal();
7
+ return {
8
+ ...actual,
9
+ vectorQuerySearch: vi.fn().mockResolvedValue({ results: [] }),
10
+ };
11
+ });
22
12
 
23
13
  describe('createVectorQueryTool', () => {
24
14
  const mockModel = { name: 'test-model' } as any;
@@ -60,21 +50,21 @@ describe('createVectorQueryTool', () => {
60
50
  });
61
51
 
62
52
  // Get the Zod schema
63
- const schema = tool.__inputSchema;
53
+ const schema = tool.inputSchema;
64
54
 
65
55
  // Test with no filter (should be valid)
66
56
  const validInput = {
67
57
  queryText: 'test query',
68
58
  topK: 5,
69
59
  };
70
- expect(() => schema.parse(validInput)).not.toThrow();
60
+ expect(() => schema?.parse(validInput)).not.toThrow();
71
61
 
72
62
  // Test with filter (should throw - unexpected property)
73
63
  const inputWithFilter = {
74
64
  ...validInput,
75
65
  filter: '{"field": "value"}',
76
66
  };
77
- expect(() => schema.parse(inputWithFilter)).not.toThrow();
67
+ expect(() => schema?.parse(inputWithFilter)).not.toThrow();
78
68
  });
79
69
 
80
70
  it('should handle filter when enableFilter is true', () => {
@@ -86,7 +76,7 @@ describe('createVectorQueryTool', () => {
86
76
  });
87
77
 
88
78
  // Get the Zod schema
89
- const schema = tool.__inputSchema;
79
+ const schema = tool.inputSchema;
90
80
 
91
81
  // Test various filter inputs that should coerce to string
92
82
  const testCases = [
@@ -105,7 +95,7 @@ describe('createVectorQueryTool', () => {
105
95
 
106
96
  testCases.forEach(({ filter }) => {
107
97
  expect(() =>
108
- schema.parse({
98
+ schema?.parse({
109
99
  queryText: 'test query',
110
100
  topK: 5,
111
101
  filter,
@@ -115,12 +105,12 @@ describe('createVectorQueryTool', () => {
115
105
 
116
106
  // Verify that all parsed values are strings
117
107
  testCases.forEach(({ filter }) => {
118
- const result = schema.parse({
108
+ const result = schema?.parse({
119
109
  queryText: 'test query',
120
110
  topK: 5,
121
111
  filter,
122
112
  });
123
- expect(typeof result.filter).toBe('string');
113
+ expect(typeof result?.filter).toBe('string');
124
114
  });
125
115
  });
126
116
 
@@ -135,7 +125,7 @@ describe('createVectorQueryTool', () => {
135
125
 
136
126
  // Should reject unexpected property
137
127
  expect(() =>
138
- toolWithoutFilter.__inputSchema.parse({
128
+ toolWithoutFilter.inputSchema?.parse({
139
129
  queryText: 'test query',
140
130
  topK: 5,
141
131
  unexpectedProp: 'value',
@@ -152,7 +142,7 @@ describe('createVectorQueryTool', () => {
152
142
 
153
143
  // Should reject unexpected property even with valid filter
154
144
  expect(() =>
155
- toolWithFilter.__inputSchema.parse({
145
+ toolWithFilter.inputSchema?.parse({
156
146
  queryText: 'test query',
157
147
  topK: 5,
158
148
  filter: '{}',
@@ -179,6 +169,7 @@ describe('createVectorQueryTool', () => {
179
169
  topK: 5,
180
170
  },
181
171
  mastra: mockMastra as any,
172
+ runtimeContext: {} as any,
182
173
  });
183
174
 
184
175
  // Check that vectorQuerySearch was called with undefined queryFilter
@@ -208,6 +199,7 @@ describe('createVectorQueryTool', () => {
208
199
  filter: filterJson,
209
200
  },
210
201
  mastra: mockMastra as any,
202
+ runtimeContext: {} as any,
211
203
  });
212
204
 
213
205
  // Check that vectorQuerySearch was called with the parsed filter
@@ -237,6 +229,7 @@ describe('createVectorQueryTool', () => {
237
229
  filter: stringFilter,
238
230
  },
239
231
  mastra: mockMastra as any,
232
+ runtimeContext: {} as any,
240
233
  });
241
234
 
242
235
  // Since this is not a valid filter, it should be ignored
@@ -4,13 +4,8 @@ import { z } from 'zod';
4
4
 
5
5
  import { rerank } from '../rerank';
6
6
  import type { RerankConfig } from '../rerank';
7
- import {
8
- vectorQuerySearch,
9
- defaultVectorQueryDescription,
10
- filterDescription,
11
- topKDescription,
12
- queryTextDescription,
13
- } from '../utils';
7
+ import { vectorQuerySearch, defaultVectorQueryDescription, filterSchema, outputSchema, baseSchema } from '../utils';
8
+ import type { RagTool } from '../utils';
14
9
  import { convertToSources } from '../utils/convert-sources';
15
10
 
16
11
  export const createVectorQueryTool = ({
@@ -36,39 +31,11 @@ export const createVectorQueryTool = ({
36
31
  }) => {
37
32
  const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
38
33
  const toolDescription = description || defaultVectorQueryDescription();
39
- // Create base schema with required fields
40
- const baseSchema = {
41
- queryText: z.string().describe(queryTextDescription),
42
- topK: z.coerce.number().describe(topKDescription),
43
- };
44
- const inputSchema = enableFilter
45
- ? z
46
- .object({
47
- ...baseSchema,
48
- filter: z.coerce.string().describe(filterDescription),
49
- })
50
- .passthrough()
51
- : z.object(baseSchema).passthrough();
34
+ const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
52
35
  return createTool({
53
36
  id: toolId,
54
37
  inputSchema,
55
- // Output schema includes `sources`, which exposes the full set of retrieved chunks (QueryResult objects)
56
- // Each source contains all information needed to reference
57
- // the original document, chunk, and similarity score.
58
- outputSchema: z.object({
59
- // Array of metadata or content for compatibility with prior usage
60
- relevantContext: z.any(),
61
- // Array of full retrieval result objects
62
- sources: z.array(
63
- z.object({
64
- id: z.string(), // Unique chunk/document identifier
65
- metadata: z.any(), // All metadata fields (document ID, etc.)
66
- vector: z.array(z.number()), // Embedding vector (if available)
67
- score: z.number(), // Similarity score for this retrieval
68
- document: z.string(), // Full chunk/document text (if available)
69
- }),
70
- ),
71
- }),
38
+ outputSchema,
72
39
  description: toolDescription,
73
40
  execute: async ({ context: { queryText, topK, filter }, mastra }) => {
74
41
  const logger = mastra?.getLogger();
@@ -167,5 +134,6 @@ export const createVectorQueryTool = ({
167
134
  return { relevantContext: [], sources: [] };
168
135
  }
169
136
  },
170
- });
137
+ // Use any for output schema as the structure of the output causes type inference issues
138
+ }) as RagTool<typeof inputSchema, any>;
171
139
  };
@@ -1,2 +1,3 @@
1
1
  export * from './vector-search';
2
2
  export * from './default-settings';
3
+ export * from './tool-schemas';
@@ -0,0 +1,38 @@
1
+ import type { Tool } from '@mastra/core/tools';
2
+ import { z } from 'zod';
3
+ import { queryTextDescription, topKDescription, filterDescription } from './default-settings';
4
+
5
+ export const baseSchema = {
6
+ queryText: z.string().describe(queryTextDescription),
7
+ topK: z.coerce.number().describe(topKDescription),
8
+ };
9
+
10
+ // Output schema includes `sources`, which exposes the full set of retrieved chunks (QueryResult objects)
11
+ // Each source contains all information needed to reference
12
+ // the original document, chunk, and similarity score.
13
+ export const outputSchema = z.object({
14
+ // Array of metadata or content for compatibility with prior usage
15
+ relevantContext: z.any(),
16
+ // Array of full retrieval result objects
17
+ sources: z.array(
18
+ z.object({
19
+ id: z.string(), // Unique chunk/document identifier
20
+ metadata: z.any(), // All metadata fields (document ID, etc.)
21
+ vector: z.array(z.number()), // Embedding vector (if available)
22
+ score: z.number(), // Similarity score for this retrieval
23
+ document: z.string(), // Full chunk/document text (if available)
24
+ }),
25
+ ),
26
+ });
27
+
28
+ export const filterSchema = z.object({
29
+ ...baseSchema,
30
+ filter: z.coerce.string().describe(filterDescription),
31
+ });
32
+
33
+ export type RagTool<
34
+ TInput extends z.ZodType<any, z.ZodTypeDef, any> | undefined,
35
+ TOutput extends z.ZodType<any, z.ZodTypeDef, any> | undefined,
36
+ > = Tool<TInput, TOutput> & {
37
+ execute: NonNullable<Tool<TInput, TOutput>['execute']>;
38
+ };