@mastra/rag 0.10.1-alpha.1 → 0.10.2-alpha.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,23 +1,23 @@
1
1
 
2
- > @mastra/rag@0.10.1-alpha.1 build /home/runner/work/mastra/mastra/packages/rag
2
+ > @mastra/rag@0.10.2-alpha.0 build /home/runner/work/mastra/mastra/packages/rag
3
3
  > tsup src/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting
4
4
 
5
5
  CLI Building entry: src/index.ts
6
6
  CLI Using tsconfig: tsconfig.json
7
7
  CLI tsup v8.4.0
8
8
  TSC Build start
9
- TSC ⚡️ Build success in 15915ms
9
+ TSC ⚡️ Build success in 13387ms
10
10
  DTS Build start
11
11
  CLI Target: es2022
12
12
  Analysis will use the bundled TypeScript version 5.8.3
13
13
  Writing package typings: /home/runner/work/mastra/mastra/packages/rag/dist/_tsup-dts-rollup.d.ts
14
14
  Analysis will use the bundled TypeScript version 5.8.3
15
15
  Writing package typings: /home/runner/work/mastra/mastra/packages/rag/dist/_tsup-dts-rollup.d.cts
16
- DTS ⚡️ Build success in 14168ms
16
+ DTS ⚡️ Build success in 10836ms
17
17
  CLI Cleaning output folder
18
18
  ESM Build start
19
19
  CJS Build start
20
- CJS dist/index.cjs 239.65 KB
21
- CJS ⚡️ Build success in 4626ms
22
- ESM dist/index.js 237.94 KB
23
- ESM ⚡️ Build success in 4627ms
20
+ ESM dist/index.js 239.65 KB
21
+ ESM ⚡️ Build success in 3608ms
22
+ CJS dist/index.cjs 241.36 KB
23
+ CJS ⚡️ Build success in 3609ms
package/CHANGELOG.md CHANGED
@@ -1,5 +1,31 @@
1
1
  # @mastra/rag
2
2
 
3
+ ## 0.10.2-alpha.0
4
+
5
+ ### Patch Changes
6
+
7
+ - f0d559f: Fix peerdeps for alpha channel
8
+ - Updated dependencies [1e8bb40]
9
+ - @mastra/core@0.10.2-alpha.2
10
+
11
+ ## 0.10.1
12
+
13
+ ### Patch Changes
14
+
15
+ - 8784cef: Changed stripHeaders for markdown chunking to strip headers correctly from output when true
16
+ - f56fd29: added return type for vector-query and graph-rag
17
+ - Updated dependencies [d70b807]
18
+ - Updated dependencies [6d16390]
19
+ - Updated dependencies [1e4a421]
20
+ - Updated dependencies [200d0da]
21
+ - Updated dependencies [bf5f17b]
22
+ - Updated dependencies [5343f93]
23
+ - Updated dependencies [38aee50]
24
+ - Updated dependencies [5c41100]
25
+ - Updated dependencies [d6a759b]
26
+ - Updated dependencies [6015bdf]
27
+ - @mastra/core@0.10.1
28
+
3
29
  ## 0.10.1-alpha.1
4
30
 
5
31
  ### Patch Changes
@@ -174,21 +174,7 @@ export { createDocumentChunkerTool }
174
174
  export { createDocumentChunkerTool as createDocumentChunkerTool_alias_1 }
175
175
  export { createDocumentChunkerTool as createDocumentChunkerTool_alias_2 }
176
176
 
177
- declare const createGraphRAGTool: ({ vectorStoreName, indexName, model, enableFilter, includeSources, graphOptions, id, description, }: {
178
- vectorStoreName: string;
179
- indexName: string;
180
- model: EmbeddingModel<string>;
181
- enableFilter?: boolean;
182
- includeSources?: boolean;
183
- graphOptions?: {
184
- dimension?: number;
185
- randomWalkSteps?: number;
186
- restartProb?: number;
187
- threshold?: number;
188
- };
189
- id?: string;
190
- description?: string;
191
- }) => RagTool<z.ZodObject<{
177
+ declare const createGraphRAGTool: (options: GraphRagToolOptions) => RagTool<z.ZodObject<{
192
178
  filter: z.ZodString;
193
179
  queryText: z.ZodString;
194
180
  topK: z.ZodNumber;
@@ -214,17 +200,7 @@ export { createGraphRAGTool }
214
200
  export { createGraphRAGTool as createGraphRAGTool_alias_1 }
215
201
  export { createGraphRAGTool as createGraphRAGTool_alias_2 }
216
202
 
217
- declare const createVectorQueryTool: ({ vectorStoreName, indexName, model, enableFilter, includeVectors, includeSources, reranker, id, description, }: {
218
- vectorStoreName: string;
219
- indexName: string;
220
- model: EmbeddingModel<string>;
221
- enableFilter?: boolean;
222
- includeVectors?: boolean;
223
- includeSources?: boolean;
224
- reranker?: RerankConfig;
225
- id?: string;
226
- description?: string;
227
- }) => RagTool<z.ZodObject<{
203
+ declare const createVectorQueryTool: (options: VectorQueryToolOptions) => RagTool<z.ZodObject<{
228
204
  filter: z.ZodString;
229
205
  queryText: z.ZodString;
230
206
  topK: z.ZodNumber;
@@ -250,6 +226,25 @@ export { createVectorQueryTool }
250
226
  export { createVectorQueryTool as createVectorQueryTool_alias_1 }
251
227
  export { createVectorQueryTool as createVectorQueryTool_alias_2 }
252
228
 
229
+ /**
230
+ * Default options for GraphRAG
231
+ * @default
232
+ * ```json
233
+ * {
234
+ * "dimension": 1536,
235
+ * "randomWalkSteps": 100,
236
+ * "restartProb": 0.15,
237
+ * "threshold": 0.7
238
+ * }
239
+ * ```
240
+ */
241
+ export declare const defaultGraphOptions: {
242
+ dimension: number;
243
+ randomWalkSteps: number;
244
+ restartProb: number;
245
+ threshold: number;
246
+ };
247
+
253
248
  declare const defaultGraphRagDescription: () => string;
254
249
  export { defaultGraphRagDescription }
255
250
  export { defaultGraphRagDescription as defaultGraphRagDescription_alias_1 }
@@ -403,6 +398,22 @@ declare class GraphRAG {
403
398
  export { GraphRAG }
404
399
  export { GraphRAG as GraphRAG_alias_1 }
405
400
 
401
+ export declare type GraphRagToolOptions = {
402
+ id?: string;
403
+ description?: string;
404
+ indexName: string;
405
+ vectorStoreName: string;
406
+ model: EmbeddingModel<string>;
407
+ enableFilter?: boolean;
408
+ includeSources?: boolean;
409
+ graphOptions?: {
410
+ dimension?: number;
411
+ randomWalkSteps?: number;
412
+ restartProb?: number;
413
+ threshold?: number;
414
+ };
415
+ };
416
+
406
417
  export declare class HTMLHeaderTransformer {
407
418
  private headersToSplitOn;
408
419
  private returnEachElement;
@@ -1165,6 +1176,18 @@ declare interface VectorQuerySearchResult {
1165
1176
  queryEmbedding: number[];
1166
1177
  }
1167
1178
 
1179
+ export declare type VectorQueryToolOptions = {
1180
+ id?: string;
1181
+ description?: string;
1182
+ indexName: string;
1183
+ vectorStoreName: string;
1184
+ model: EmbeddingModel<string>;
1185
+ enableFilter?: boolean;
1186
+ includeVectors?: boolean;
1187
+ includeSources?: boolean;
1188
+ reranker?: RerankConfig;
1189
+ };
1190
+
1168
1191
  declare type WeightConfig = {
1169
1192
  semantic?: number;
1170
1193
  vector?: number;
@@ -174,21 +174,7 @@ export { createDocumentChunkerTool }
174
174
  export { createDocumentChunkerTool as createDocumentChunkerTool_alias_1 }
175
175
  export { createDocumentChunkerTool as createDocumentChunkerTool_alias_2 }
176
176
 
177
- declare const createGraphRAGTool: ({ vectorStoreName, indexName, model, enableFilter, includeSources, graphOptions, id, description, }: {
178
- vectorStoreName: string;
179
- indexName: string;
180
- model: EmbeddingModel<string>;
181
- enableFilter?: boolean;
182
- includeSources?: boolean;
183
- graphOptions?: {
184
- dimension?: number;
185
- randomWalkSteps?: number;
186
- restartProb?: number;
187
- threshold?: number;
188
- };
189
- id?: string;
190
- description?: string;
191
- }) => RagTool<z.ZodObject<{
177
+ declare const createGraphRAGTool: (options: GraphRagToolOptions) => RagTool<z.ZodObject<{
192
178
  filter: z.ZodString;
193
179
  queryText: z.ZodString;
194
180
  topK: z.ZodNumber;
@@ -214,17 +200,7 @@ export { createGraphRAGTool }
214
200
  export { createGraphRAGTool as createGraphRAGTool_alias_1 }
215
201
  export { createGraphRAGTool as createGraphRAGTool_alias_2 }
216
202
 
217
- declare const createVectorQueryTool: ({ vectorStoreName, indexName, model, enableFilter, includeVectors, includeSources, reranker, id, description, }: {
218
- vectorStoreName: string;
219
- indexName: string;
220
- model: EmbeddingModel<string>;
221
- enableFilter?: boolean;
222
- includeVectors?: boolean;
223
- includeSources?: boolean;
224
- reranker?: RerankConfig;
225
- id?: string;
226
- description?: string;
227
- }) => RagTool<z.ZodObject<{
203
+ declare const createVectorQueryTool: (options: VectorQueryToolOptions) => RagTool<z.ZodObject<{
228
204
  filter: z.ZodString;
229
205
  queryText: z.ZodString;
230
206
  topK: z.ZodNumber;
@@ -250,6 +226,25 @@ export { createVectorQueryTool }
250
226
  export { createVectorQueryTool as createVectorQueryTool_alias_1 }
251
227
  export { createVectorQueryTool as createVectorQueryTool_alias_2 }
252
228
 
229
+ /**
230
+ * Default options for GraphRAG
231
+ * @default
232
+ * ```json
233
+ * {
234
+ * "dimension": 1536,
235
+ * "randomWalkSteps": 100,
236
+ * "restartProb": 0.15,
237
+ * "threshold": 0.7
238
+ * }
239
+ * ```
240
+ */
241
+ export declare const defaultGraphOptions: {
242
+ dimension: number;
243
+ randomWalkSteps: number;
244
+ restartProb: number;
245
+ threshold: number;
246
+ };
247
+
253
248
  declare const defaultGraphRagDescription: () => string;
254
249
  export { defaultGraphRagDescription }
255
250
  export { defaultGraphRagDescription as defaultGraphRagDescription_alias_1 }
@@ -403,6 +398,22 @@ declare class GraphRAG {
403
398
  export { GraphRAG }
404
399
  export { GraphRAG as GraphRAG_alias_1 }
405
400
 
401
+ export declare type GraphRagToolOptions = {
402
+ id?: string;
403
+ description?: string;
404
+ indexName: string;
405
+ vectorStoreName: string;
406
+ model: EmbeddingModel<string>;
407
+ enableFilter?: boolean;
408
+ includeSources?: boolean;
409
+ graphOptions?: {
410
+ dimension?: number;
411
+ randomWalkSteps?: number;
412
+ restartProb?: number;
413
+ threshold?: number;
414
+ };
415
+ };
416
+
406
417
  export declare class HTMLHeaderTransformer {
407
418
  private headersToSplitOn;
408
419
  private returnEachElement;
@@ -1165,6 +1176,18 @@ declare interface VectorQuerySearchResult {
1165
1176
  queryEmbedding: number[];
1166
1177
  }
1167
1178
 
1179
+ export declare type VectorQueryToolOptions = {
1180
+ id?: string;
1181
+ description?: string;
1182
+ indexName: string;
1183
+ vectorStoreName: string;
1184
+ model: EmbeddingModel<string>;
1185
+ enableFilter?: boolean;
1186
+ includeVectors?: boolean;
1187
+ includeSources?: boolean;
1188
+ reranker?: RerankConfig;
1189
+ };
1190
+
1168
1191
  declare type WeightConfig = {
1169
1192
  semantic?: number;
1170
1193
  vector?: number;
package/dist/index.cjs CHANGED
@@ -6334,33 +6334,43 @@ var convertToSources = (results) => {
6334
6334
  });
6335
6335
  };
6336
6336
 
6337
+ // src/tools/types.ts
6338
+ var defaultGraphOptions = {
6339
+ dimension: 1536,
6340
+ randomWalkSteps: 100,
6341
+ restartProb: 0.15,
6342
+ threshold: 0.7
6343
+ };
6344
+
6337
6345
  // src/tools/graph-rag.ts
6338
- var createGraphRAGTool = ({
6339
- vectorStoreName,
6340
- indexName,
6341
- model,
6342
- enableFilter = false,
6343
- includeSources = true,
6344
- graphOptions = {
6345
- dimension: 1536,
6346
- randomWalkSteps: 100,
6347
- restartProb: 0.15,
6348
- threshold: 0.7
6349
- },
6350
- id,
6351
- description
6352
- }) => {
6353
- const toolId = id || `GraphRAG ${vectorStoreName} ${indexName} Tool`;
6346
+ var createGraphRAGTool = (options) => {
6347
+ const { model, id, description } = options;
6348
+ const toolId = id || `GraphRAG ${options.vectorStoreName} ${options.indexName} Tool`;
6354
6349
  const toolDescription = description || defaultGraphRagDescription();
6350
+ const graphOptions = {
6351
+ ...defaultGraphOptions,
6352
+ ...options.graphOptions || {}
6353
+ };
6355
6354
  const graphRag = new GraphRAG(graphOptions.dimension, graphOptions.threshold);
6356
6355
  let isInitialized = false;
6357
- const inputSchema = enableFilter ? filterSchema : zod.z.object(baseSchema).passthrough();
6356
+ const inputSchema = options.enableFilter ? filterSchema : zod.z.object(baseSchema).passthrough();
6358
6357
  return tools.createTool({
6359
6358
  id: toolId,
6360
6359
  inputSchema,
6361
6360
  outputSchema,
6362
6361
  description: toolDescription,
6363
- execute: async ({ context: { queryText, topK, filter }, mastra }) => {
6362
+ execute: async ({ context, mastra, runtimeContext }) => {
6363
+ const indexName = runtimeContext.get("indexName") ?? options.indexName;
6364
+ const vectorStoreName = runtimeContext.get("vectorStoreName") ?? options.vectorStoreName;
6365
+ if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
6366
+ if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
6367
+ const includeSources = runtimeContext.get("includeSources") ?? options.includeSources ?? true;
6368
+ const randomWalkSteps = runtimeContext.get("randomWalkSteps") ?? graphOptions.randomWalkSteps;
6369
+ const restartProb = runtimeContext.get("restartProb") ?? graphOptions.restartProb;
6370
+ const topK = runtimeContext.get("topK") ?? context.topK ?? 10;
6371
+ const filter = runtimeContext.get("filter") ?? context.filter;
6372
+ const queryText = context.queryText;
6373
+ const enableFilter = !!runtimeContext.get("filter") || (options.enableFilter ?? false);
6364
6374
  const logger = mastra?.getLogger();
6365
6375
  if (!logger) {
6366
6376
  console.warn(
@@ -6426,8 +6436,8 @@ var createGraphRAGTool = ({
6426
6436
  const rerankedResults = graphRag.query({
6427
6437
  query: queryEmbedding,
6428
6438
  topK: topKValue,
6429
- randomWalkSteps: graphOptions.randomWalkSteps,
6430
- restartProb: graphOptions.restartProb
6439
+ randomWalkSteps,
6440
+ restartProb
6431
6441
  });
6432
6442
  if (logger) {
6433
6443
  logger.debug("GraphRAG query returned results", { count: rerankedResults.length });
@@ -6455,26 +6465,28 @@ var createGraphRAGTool = ({
6455
6465
  // Use any for output schema as the structure of the output causes type inference issues
6456
6466
  });
6457
6467
  };
6458
- var createVectorQueryTool = ({
6459
- vectorStoreName,
6460
- indexName,
6461
- model,
6462
- enableFilter = false,
6463
- includeVectors = false,
6464
- includeSources = true,
6465
- reranker,
6466
- id,
6467
- description
6468
- }) => {
6469
- const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
6468
+ var createVectorQueryTool = (options) => {
6469
+ const { model, id, description } = options;
6470
+ const toolId = id || `VectorQuery ${options.vectorStoreName} ${options.indexName} Tool`;
6470
6471
  const toolDescription = description || defaultVectorQueryDescription();
6471
- const inputSchema = enableFilter ? filterSchema : zod.z.object(baseSchema).passthrough();
6472
+ const inputSchema = options.enableFilter ? filterSchema : zod.z.object(baseSchema).passthrough();
6472
6473
  return tools.createTool({
6473
6474
  id: toolId,
6475
+ description: toolDescription,
6474
6476
  inputSchema,
6475
6477
  outputSchema,
6476
- description: toolDescription,
6477
- execute: async ({ context: { queryText, topK, filter }, mastra }) => {
6478
+ execute: async ({ context, mastra, runtimeContext }) => {
6479
+ const indexName = runtimeContext.get("indexName") ?? options.indexName;
6480
+ const vectorStoreName = runtimeContext.get("vectorStoreName") ?? options.vectorStoreName;
6481
+ const includeVectors = runtimeContext.get("includeVectors") ?? options.includeVectors ?? false;
6482
+ const includeSources = runtimeContext.get("includeSources") ?? options.includeSources ?? true;
6483
+ const reranker = runtimeContext.get("reranker") ?? options.reranker;
6484
+ if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
6485
+ if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
6486
+ const topK = runtimeContext.get("topK") ?? context.topK ?? 10;
6487
+ const filter = runtimeContext.get("filter") ?? context.filter;
6488
+ const queryText = context.queryText;
6489
+ const enableFilter = !!runtimeContext.get("filter") || (options.enableFilter ?? false);
6478
6490
  const logger = mastra?.getLogger();
6479
6491
  if (!logger) {
6480
6492
  console.warn(
package/dist/index.js CHANGED
@@ -6332,33 +6332,43 @@ var convertToSources = (results) => {
6332
6332
  });
6333
6333
  };
6334
6334
 
6335
+ // src/tools/types.ts
6336
+ var defaultGraphOptions = {
6337
+ dimension: 1536,
6338
+ randomWalkSteps: 100,
6339
+ restartProb: 0.15,
6340
+ threshold: 0.7
6341
+ };
6342
+
6335
6343
  // src/tools/graph-rag.ts
6336
- var createGraphRAGTool = ({
6337
- vectorStoreName,
6338
- indexName,
6339
- model,
6340
- enableFilter = false,
6341
- includeSources = true,
6342
- graphOptions = {
6343
- dimension: 1536,
6344
- randomWalkSteps: 100,
6345
- restartProb: 0.15,
6346
- threshold: 0.7
6347
- },
6348
- id,
6349
- description
6350
- }) => {
6351
- const toolId = id || `GraphRAG ${vectorStoreName} ${indexName} Tool`;
6344
+ var createGraphRAGTool = (options) => {
6345
+ const { model, id, description } = options;
6346
+ const toolId = id || `GraphRAG ${options.vectorStoreName} ${options.indexName} Tool`;
6352
6347
  const toolDescription = description || defaultGraphRagDescription();
6348
+ const graphOptions = {
6349
+ ...defaultGraphOptions,
6350
+ ...options.graphOptions || {}
6351
+ };
6353
6352
  const graphRag = new GraphRAG(graphOptions.dimension, graphOptions.threshold);
6354
6353
  let isInitialized = false;
6355
- const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
6354
+ const inputSchema = options.enableFilter ? filterSchema : z.object(baseSchema).passthrough();
6356
6355
  return createTool({
6357
6356
  id: toolId,
6358
6357
  inputSchema,
6359
6358
  outputSchema,
6360
6359
  description: toolDescription,
6361
- execute: async ({ context: { queryText, topK, filter }, mastra }) => {
6360
+ execute: async ({ context, mastra, runtimeContext }) => {
6361
+ const indexName = runtimeContext.get("indexName") ?? options.indexName;
6362
+ const vectorStoreName = runtimeContext.get("vectorStoreName") ?? options.vectorStoreName;
6363
+ if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
6364
+ if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
6365
+ const includeSources = runtimeContext.get("includeSources") ?? options.includeSources ?? true;
6366
+ const randomWalkSteps = runtimeContext.get("randomWalkSteps") ?? graphOptions.randomWalkSteps;
6367
+ const restartProb = runtimeContext.get("restartProb") ?? graphOptions.restartProb;
6368
+ const topK = runtimeContext.get("topK") ?? context.topK ?? 10;
6369
+ const filter = runtimeContext.get("filter") ?? context.filter;
6370
+ const queryText = context.queryText;
6371
+ const enableFilter = !!runtimeContext.get("filter") || (options.enableFilter ?? false);
6362
6372
  const logger = mastra?.getLogger();
6363
6373
  if (!logger) {
6364
6374
  console.warn(
@@ -6424,8 +6434,8 @@ var createGraphRAGTool = ({
6424
6434
  const rerankedResults = graphRag.query({
6425
6435
  query: queryEmbedding,
6426
6436
  topK: topKValue,
6427
- randomWalkSteps: graphOptions.randomWalkSteps,
6428
- restartProb: graphOptions.restartProb
6437
+ randomWalkSteps,
6438
+ restartProb
6429
6439
  });
6430
6440
  if (logger) {
6431
6441
  logger.debug("GraphRAG query returned results", { count: rerankedResults.length });
@@ -6453,26 +6463,28 @@ var createGraphRAGTool = ({
6453
6463
  // Use any for output schema as the structure of the output causes type inference issues
6454
6464
  });
6455
6465
  };
6456
- var createVectorQueryTool = ({
6457
- vectorStoreName,
6458
- indexName,
6459
- model,
6460
- enableFilter = false,
6461
- includeVectors = false,
6462
- includeSources = true,
6463
- reranker,
6464
- id,
6465
- description
6466
- }) => {
6467
- const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
6466
+ var createVectorQueryTool = (options) => {
6467
+ const { model, id, description } = options;
6468
+ const toolId = id || `VectorQuery ${options.vectorStoreName} ${options.indexName} Tool`;
6468
6469
  const toolDescription = description || defaultVectorQueryDescription();
6469
- const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
6470
+ const inputSchema = options.enableFilter ? filterSchema : z.object(baseSchema).passthrough();
6470
6471
  return createTool({
6471
6472
  id: toolId,
6473
+ description: toolDescription,
6472
6474
  inputSchema,
6473
6475
  outputSchema,
6474
- description: toolDescription,
6475
- execute: async ({ context: { queryText, topK, filter }, mastra }) => {
6476
+ execute: async ({ context, mastra, runtimeContext }) => {
6477
+ const indexName = runtimeContext.get("indexName") ?? options.indexName;
6478
+ const vectorStoreName = runtimeContext.get("vectorStoreName") ?? options.vectorStoreName;
6479
+ const includeVectors = runtimeContext.get("includeVectors") ?? options.includeVectors ?? false;
6480
+ const includeSources = runtimeContext.get("includeSources") ?? options.includeSources ?? true;
6481
+ const reranker = runtimeContext.get("reranker") ?? options.reranker;
6482
+ if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
6483
+ if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
6484
+ const topK = runtimeContext.get("topK") ?? context.topK ?? 10;
6485
+ const filter = runtimeContext.get("filter") ?? context.filter;
6486
+ const queryText = context.queryText;
6487
+ const enableFilter = !!runtimeContext.get("filter") || (options.enableFilter ?? false);
6476
6488
  const logger = mastra?.getLogger();
6477
6489
  if (!logger) {
6478
6490
  console.warn(
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@mastra/rag",
3
- "version": "0.10.1-alpha.1",
3
+ "version": "0.10.2-alpha.0",
4
4
  "description": "",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
@@ -30,7 +30,7 @@
30
30
  },
31
31
  "peerDependencies": {
32
32
  "ai": "^4.0.0",
33
- "@mastra/core": "^0.10.0"
33
+ "@mastra/core": "^0.10.0-alpha.0"
34
34
  },
35
35
  "devDependencies": {
36
36
  "@ai-sdk/cohere": "latest",
@@ -44,8 +44,8 @@
44
44
  "tsup": "^8.4.0",
45
45
  "typescript": "^5.8.2",
46
46
  "vitest": "^3.1.2",
47
- "@internal/lint": "0.0.6",
48
- "@mastra/core": "0.10.1-alpha.0"
47
+ "@internal/lint": "0.0.7",
48
+ "@mastra/core": "0.10.2-alpha.2"
49
49
  },
50
50
  "keywords": [
51
51
  "rag",
@@ -0,0 +1,115 @@
1
+ import { RuntimeContext } from '@mastra/core/runtime-context';
2
+ import { describe, it, expect, vi, beforeEach } from 'vitest';
3
+ import { GraphRAG } from '../graph-rag';
4
+ import { vectorQuerySearch } from '../utils';
5
+ import { createGraphRAGTool } from './graph-rag';
6
+
7
+ vi.mock('../utils', async importOriginal => {
8
+ const actual: any = await importOriginal();
9
+ return {
10
+ ...actual,
11
+ vectorQuerySearch: vi.fn().mockResolvedValue({
12
+ results: [
13
+ { metadata: { text: 'foo' }, vector: [1, 2, 3] },
14
+ { metadata: { text: 'bar' }, vector: [4, 5, 6] },
15
+ ],
16
+ queryEmbedding: [1, 2, 3],
17
+ }),
18
+ };
19
+ });
20
+
21
+ vi.mock('../graph-rag', async importOriginal => {
22
+ const actual: any = await importOriginal();
23
+ return {
24
+ ...actual,
25
+ GraphRAG: vi.fn().mockImplementation(() => {
26
+ return {
27
+ createGraph: vi.fn(),
28
+ query: vi.fn(() => [
29
+ { content: 'foo', metadata: { text: 'foo' } },
30
+ { content: 'bar', metadata: { text: 'bar' } },
31
+ ]),
32
+ };
33
+ }),
34
+ };
35
+ });
36
+
37
+ const mockModel = { name: 'test-model' } as any;
38
+ const mockMastra = {
39
+ getVector: vi.fn(storeName => ({
40
+ [storeName]: {},
41
+ })),
42
+ getLogger: vi.fn(() => ({
43
+ debug: vi.fn(),
44
+ warn: vi.fn(),
45
+ info: vi.fn(),
46
+ error: vi.fn(),
47
+ })),
48
+ };
49
+
50
+ describe('createGraphRAGTool', () => {
51
+ beforeEach(() => {
52
+ vi.clearAllMocks();
53
+ });
54
+
55
+ it('validates input schema', () => {
56
+ const tool = createGraphRAGTool({
57
+ id: 'test',
58
+ model: mockModel,
59
+ vectorStoreName: 'testStore',
60
+ indexName: 'testIndex',
61
+ });
62
+ expect(() => tool.inputSchema?.parse({ queryText: 'foo', topK: 10 })).not.toThrow();
63
+ expect(() => tool.inputSchema?.parse({})).toThrow();
64
+ });
65
+
66
+ describe('runtimeContext', () => {
67
+ it('calls vectorQuerySearch and GraphRAG with runtimeContext params', async () => {
68
+ const tool = createGraphRAGTool({
69
+ id: 'test',
70
+ model: mockModel,
71
+ indexName: 'testIndex',
72
+ vectorStoreName: 'testStore',
73
+ });
74
+ const runtimeContext = new RuntimeContext();
75
+ runtimeContext.set('indexName', 'anotherIndex');
76
+ runtimeContext.set('vectorStoreName', 'anotherStore');
77
+ runtimeContext.set('topK', 5);
78
+ runtimeContext.set('filter', { foo: 'bar' });
79
+ runtimeContext.set('randomWalkSteps', 99);
80
+ runtimeContext.set('restartProb', 0.42);
81
+ const result = await tool.execute({
82
+ context: { queryText: 'foo', topK: 2 },
83
+ mastra: mockMastra as any,
84
+ runtimeContext,
85
+ });
86
+ expect(result.relevantContext).toEqual(['foo', 'bar']);
87
+ expect(result.sources.length).toBe(2);
88
+ expect(vectorQuerySearch).toHaveBeenCalledWith(
89
+ expect.objectContaining({
90
+ indexName: 'anotherIndex',
91
+ vectorStore: {
92
+ anotherStore: {},
93
+ },
94
+ queryText: 'foo',
95
+ model: mockModel,
96
+ queryFilter: { foo: 'bar' },
97
+ topK: 5,
98
+ includeVectors: true,
99
+ }),
100
+ );
101
+ // GraphRAG createGraph and query should be called
102
+ expect(GraphRAG).toHaveBeenCalled();
103
+ const instance = (GraphRAG as any).mock.results[0].value;
104
+ expect(instance.createGraph).toHaveBeenCalled();
105
+ expect(instance.query).toHaveBeenCalledWith(
106
+ expect.objectContaining({
107
+ query: [1, 2, 3],
108
+ topK: 5,
109
+ randomWalkSteps: 99,
110
+ restartProb: 0.42,
111
+ }),
112
+ );
113
+ });
114
+ });
115
+ });
@@ -1,55 +1,47 @@
1
1
  import { createTool } from '@mastra/core/tools';
2
- import type { EmbeddingModel } from 'ai';
3
2
  import { z } from 'zod';
4
3
 
5
4
  import { GraphRAG } from '../graph-rag';
6
5
  import { vectorQuerySearch, defaultGraphRagDescription, filterSchema, outputSchema, baseSchema } from '../utils';
7
6
  import type { RagTool } from '../utils';
8
7
  import { convertToSources } from '../utils/convert-sources';
8
+ import type { GraphRagToolOptions } from './types';
9
+ import { defaultGraphOptions } from './types';
9
10
 
10
- export const createGraphRAGTool = ({
11
- vectorStoreName,
12
- indexName,
13
- model,
14
- enableFilter = false,
15
- includeSources = true,
16
- graphOptions = {
17
- dimension: 1536,
18
- randomWalkSteps: 100,
19
- restartProb: 0.15,
20
- threshold: 0.7,
21
- },
22
- id,
23
- description,
24
- }: {
25
- vectorStoreName: string;
26
- indexName: string;
27
- model: EmbeddingModel<string>;
28
- enableFilter?: boolean;
29
- includeSources?: boolean;
30
- graphOptions?: {
31
- dimension?: number;
32
- randomWalkSteps?: number;
33
- restartProb?: number;
34
- threshold?: number;
35
- };
36
- id?: string;
37
- description?: string;
38
- }) => {
39
- const toolId = id || `GraphRAG ${vectorStoreName} ${indexName} Tool`;
11
+ export const createGraphRAGTool = (options: GraphRagToolOptions) => {
12
+ const { model, id, description } = options;
13
+
14
+ const toolId = id || `GraphRAG ${options.vectorStoreName} ${options.indexName} Tool`;
40
15
  const toolDescription = description || defaultGraphRagDescription();
16
+ const graphOptions = {
17
+ ...defaultGraphOptions,
18
+ ...(options.graphOptions || {}),
19
+ };
41
20
  // Initialize GraphRAG
42
21
  const graphRag = new GraphRAG(graphOptions.dimension, graphOptions.threshold);
43
22
  let isInitialized = false;
44
23
 
45
- const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
24
+ const inputSchema = options.enableFilter ? filterSchema : z.object(baseSchema).passthrough();
46
25
 
47
26
  return createTool({
48
27
  id: toolId,
49
28
  inputSchema,
50
29
  outputSchema,
51
30
  description: toolDescription,
52
- execute: async ({ context: { queryText, topK, filter }, mastra }) => {
31
+ execute: async ({ context, mastra, runtimeContext }) => {
32
+ const indexName: string = runtimeContext.get('indexName') ?? options.indexName;
33
+ const vectorStoreName: string = runtimeContext.get('vectorStoreName') ?? options.vectorStoreName;
34
+ if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
35
+ if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
36
+ const includeSources: boolean = runtimeContext.get('includeSources') ?? options.includeSources ?? true;
37
+ const randomWalkSteps: number | undefined = runtimeContext.get('randomWalkSteps') ?? graphOptions.randomWalkSteps;
38
+ const restartProb: number | undefined = runtimeContext.get('restartProb') ?? graphOptions.restartProb;
39
+ const topK: number = runtimeContext.get('topK') ?? context.topK ?? 10;
40
+ const filter: Record<string, any> = runtimeContext.get('filter') ?? context.filter;
41
+ const queryText = context.queryText;
42
+
43
+ const enableFilter = !!runtimeContext.get('filter') || (options.enableFilter ?? false);
44
+
53
45
  const logger = mastra?.getLogger();
54
46
  if (!logger) {
55
47
  console.warn(
@@ -129,8 +121,8 @@ export const createGraphRAGTool = ({
129
121
  const rerankedResults = graphRag.query({
130
122
  query: queryEmbedding,
131
123
  topK: topKValue,
132
- randomWalkSteps: graphOptions.randomWalkSteps,
133
- restartProb: graphOptions.restartProb,
124
+ randomWalkSteps,
125
+ restartProb,
134
126
  });
135
127
  if (logger) {
136
128
  logger.debug('GraphRAG query returned results', { count: rerankedResults.length });
@@ -0,0 +1,49 @@
1
+ import type { EmbeddingModel } from 'ai';
2
+ import type { RerankConfig } from '../rerank';
3
+
4
+ export type VectorQueryToolOptions = {
5
+ id?: string;
6
+ description?: string;
7
+ indexName: string;
8
+ vectorStoreName: string;
9
+ model: EmbeddingModel<string>;
10
+ enableFilter?: boolean;
11
+ includeVectors?: boolean;
12
+ includeSources?: boolean;
13
+ reranker?: RerankConfig;
14
+ };
15
+
16
+ export type GraphRagToolOptions = {
17
+ id?: string;
18
+ description?: string;
19
+ indexName: string;
20
+ vectorStoreName: string;
21
+ model: EmbeddingModel<string>;
22
+ enableFilter?: boolean;
23
+ includeSources?: boolean;
24
+ graphOptions?: {
25
+ dimension?: number;
26
+ randomWalkSteps?: number;
27
+ restartProb?: number;
28
+ threshold?: number;
29
+ };
30
+ };
31
+
32
+ /**
33
+ * Default options for GraphRAG
34
+ * @default
35
+ * ```json
36
+ * {
37
+ * "dimension": 1536,
38
+ * "randomWalkSteps": 100,
39
+ * "restartProb": 0.15,
40
+ * "threshold": 0.7
41
+ * }
42
+ * ```
43
+ */
44
+ export const defaultGraphOptions = {
45
+ dimension: 1536,
46
+ randomWalkSteps: 100,
47
+ restartProb: 0.15,
48
+ threshold: 0.7,
49
+ };
@@ -1,4 +1,6 @@
1
+ import { RuntimeContext } from '@mastra/core/runtime-context';
1
2
  import { describe, it, expect, vi, beforeEach } from 'vitest';
3
+ import { rerank } from '../rerank';
2
4
  import { vectorQuerySearch } from '../utils';
3
5
  import { createVectorQueryTool } from './vector-query';
4
6
 
@@ -6,7 +8,19 @@ vi.mock('../utils', async importOriginal => {
6
8
  const actual: any = await importOriginal();
7
9
  return {
8
10
  ...actual,
9
- vectorQuerySearch: vi.fn().mockResolvedValue({ results: [] }),
11
+ vectorQuerySearch: vi.fn().mockResolvedValue({ results: [{ metadata: { text: 'foo' }, vector: [1, 2, 3] }] }),
12
+ };
13
+ });
14
+
15
+ vi.mock('../rerank', async importOriginal => {
16
+ const actual: any = await importOriginal();
17
+ return {
18
+ ...actual,
19
+ rerank: vi
20
+ .fn()
21
+ .mockResolvedValue([
22
+ { result: { id: '1', metadata: { text: 'bar' }, score: 1, details: { semantic: 1, vector: 1, position: 1 } } },
23
+ ]),
10
24
  };
11
25
  });
12
26
 
@@ -17,9 +31,12 @@ describe('createVectorQueryTool', () => {
17
31
  testStore: {
18
32
  // Mock vector store methods
19
33
  },
34
+ anotherStore: {
35
+ // Mock vector store methods
36
+ },
20
37
  },
21
- getVector: vi.fn(() => ({
22
- testStore: {
38
+ getVector: vi.fn(storeName => ({
39
+ [storeName]: {
23
40
  // Mock vector store methods
24
41
  },
25
42
  })),
@@ -154,6 +171,8 @@ describe('createVectorQueryTool', () => {
154
171
 
155
172
  describe('execute function', () => {
156
173
  it('should not process filter when enableFilter is false', async () => {
174
+ const runtimeContext = new RuntimeContext();
175
+
157
176
  // Create tool with enableFilter set to false
158
177
  const tool = createVectorQueryTool({
159
178
  vectorStoreName: 'testStore',
@@ -169,7 +188,7 @@ describe('createVectorQueryTool', () => {
169
188
  topK: 5,
170
189
  },
171
190
  mastra: mockMastra as any,
172
- runtimeContext: {} as any,
191
+ runtimeContext,
173
192
  });
174
193
 
175
194
  // Check that vectorQuerySearch was called with undefined queryFilter
@@ -181,6 +200,7 @@ describe('createVectorQueryTool', () => {
181
200
  });
182
201
 
183
202
  it('should process filter when enableFilter is true and filter is provided', async () => {
203
+ const runtimeContext = new RuntimeContext();
184
204
  // Create tool with enableFilter set to true
185
205
  const tool = createVectorQueryTool({
186
206
  vectorStoreName: 'testStore',
@@ -199,7 +219,7 @@ describe('createVectorQueryTool', () => {
199
219
  filter: filterJson,
200
220
  },
201
221
  mastra: mockMastra as any,
202
- runtimeContext: {} as any,
222
+ runtimeContext,
203
223
  });
204
224
 
205
225
  // Check that vectorQuerySearch was called with the parsed filter
@@ -211,6 +231,7 @@ describe('createVectorQueryTool', () => {
211
231
  });
212
232
 
213
233
  it('should handle string filters correctly', async () => {
234
+ const runtimeContext = new RuntimeContext();
214
235
  // Create tool with enableFilter set to true
215
236
  const tool = createVectorQueryTool({
216
237
  vectorStoreName: 'testStore',
@@ -229,7 +250,7 @@ describe('createVectorQueryTool', () => {
229
250
  filter: stringFilter,
230
251
  },
231
252
  mastra: mockMastra as any,
232
- runtimeContext: {} as any,
253
+ runtimeContext,
233
254
  });
234
255
 
235
256
  // Since this is not a valid filter, it should be ignored
@@ -240,4 +261,69 @@ describe('createVectorQueryTool', () => {
240
261
  );
241
262
  });
242
263
  });
264
+
265
+ describe('runtimeContext', () => {
266
+ it('calls vectorQuerySearch with runtimeContext params', async () => {
267
+ const tool = createVectorQueryTool({
268
+ id: 'test',
269
+ model: mockModel,
270
+ indexName: 'testIndex',
271
+ vectorStoreName: 'testStore',
272
+ });
273
+ const runtimeContext = new RuntimeContext();
274
+ runtimeContext.set('indexName', 'anotherIndex');
275
+ runtimeContext.set('vectorStoreName', 'anotherStore');
276
+ runtimeContext.set('topK', 3);
277
+ runtimeContext.set('filter', { foo: 'bar' });
278
+ runtimeContext.set('includeVectors', true);
279
+ runtimeContext.set('includeSources', false);
280
+ const result = await tool.execute({
281
+ context: { queryText: 'foo', topK: 6 },
282
+ mastra: mockMastra as any,
283
+ runtimeContext,
284
+ });
285
+ expect(result.relevantContext.length).toBeGreaterThan(0);
286
+ expect(result.sources).toEqual([]); // includeSources false
287
+ expect(vectorQuerySearch).toHaveBeenCalledWith(
288
+ expect.objectContaining({
289
+ indexName: 'anotherIndex',
290
+ vectorStore: {
291
+ anotherStore: {},
292
+ },
293
+ queryText: 'foo',
294
+ model: mockModel,
295
+ queryFilter: { foo: 'bar' },
296
+ topK: 3,
297
+ includeVectors: true,
298
+ }),
299
+ );
300
+ });
301
+
302
+ it('handles reranker from runtimeContext', async () => {
303
+ const tool = createVectorQueryTool({
304
+ id: 'test',
305
+ model: mockModel,
306
+ indexName: 'testIndex',
307
+ vectorStoreName: 'testStore',
308
+ });
309
+ const runtimeContext = new RuntimeContext();
310
+ runtimeContext.set('indexName', 'testIndex');
311
+ runtimeContext.set('vectorStoreName', 'testStore');
312
+ runtimeContext.set('reranker', { model: 'reranker-model', options: { topK: 1 } });
313
+ // Mock rerank
314
+ vi.mocked(rerank).mockResolvedValue([
315
+ {
316
+ result: { id: '1', metadata: { text: 'bar' }, score: 1 },
317
+ score: 1,
318
+ details: { semantic: 1, vector: 1, position: 1 },
319
+ },
320
+ ]);
321
+ const result = await tool.execute({
322
+ context: { queryText: 'foo', topK: 1 },
323
+ mastra: mockMastra as any,
324
+ runtimeContext,
325
+ });
326
+ expect(result.relevantContext[0]).toEqual({ text: 'bar' });
327
+ });
328
+ });
243
329
  });
@@ -1,5 +1,4 @@
1
1
  import { createTool } from '@mastra/core/tools';
2
- import type { EmbeddingModel } from 'ai';
3
2
  import { z } from 'zod';
4
3
 
5
4
  import { rerank } from '../rerank';
@@ -7,37 +6,33 @@ import type { RerankConfig } from '../rerank';
7
6
  import { vectorQuerySearch, defaultVectorQueryDescription, filterSchema, outputSchema, baseSchema } from '../utils';
8
7
  import type { RagTool } from '../utils';
9
8
  import { convertToSources } from '../utils/convert-sources';
9
+ import type { VectorQueryToolOptions } from './types';
10
10
 
11
- export const createVectorQueryTool = ({
12
- vectorStoreName,
13
- indexName,
14
- model,
15
- enableFilter = false,
16
- includeVectors = false,
17
- includeSources = true,
18
- reranker,
19
- id,
20
- description,
21
- }: {
22
- vectorStoreName: string;
23
- indexName: string;
24
- model: EmbeddingModel<string>;
25
- enableFilter?: boolean;
26
- includeVectors?: boolean;
27
- includeSources?: boolean;
28
- reranker?: RerankConfig;
29
- id?: string;
30
- description?: string;
31
- }) => {
32
- const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
11
+ export const createVectorQueryTool = (options: VectorQueryToolOptions) => {
12
+ const { model, id, description } = options;
13
+ const toolId = id || `VectorQuery ${options.vectorStoreName} ${options.indexName} Tool`;
33
14
  const toolDescription = description || defaultVectorQueryDescription();
34
- const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
15
+ const inputSchema = options.enableFilter ? filterSchema : z.object(baseSchema).passthrough();
16
+
35
17
  return createTool({
36
18
  id: toolId,
19
+ description: toolDescription,
37
20
  inputSchema,
38
21
  outputSchema,
39
- description: toolDescription,
40
- execute: async ({ context: { queryText, topK, filter }, mastra }) => {
22
+ execute: async ({ context, mastra, runtimeContext }) => {
23
+ const indexName: string = runtimeContext.get('indexName') ?? options.indexName;
24
+ const vectorStoreName: string = runtimeContext.get('vectorStoreName') ?? options.vectorStoreName;
25
+ const includeVectors: boolean = runtimeContext.get('includeVectors') ?? options.includeVectors ?? false;
26
+ const includeSources: boolean = runtimeContext.get('includeSources') ?? options.includeSources ?? true;
27
+ const reranker: RerankConfig = runtimeContext.get('reranker') ?? options.reranker;
28
+ if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
29
+ if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
30
+
31
+ const topK: number = runtimeContext.get('topK') ?? context.topK ?? 10;
32
+ const filter: Record<string, any> = runtimeContext.get('filter') ?? context.filter;
33
+ const queryText = context.queryText;
34
+ const enableFilter = !!runtimeContext.get('filter') || (options.enableFilter ?? false);
35
+
41
36
  const logger = mastra?.getLogger();
42
37
  if (!logger) {
43
38
  console.warn(