@mastra/rag 0.10.1-alpha.1 → 0.10.2-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +26 -0
- package/dist/_tsup-dts-rollup.d.cts +49 -26
- package/dist/_tsup-dts-rollup.d.ts +49 -26
- package/dist/index.cjs +47 -35
- package/dist/index.js +47 -35
- package/package.json +4 -4
- package/src/tools/graph-rag.test.ts +115 -0
- package/src/tools/graph-rag.ts +27 -35
- package/src/tools/types.ts +49 -0
- package/src/tools/vector-query.test.ts +92 -6
- package/src/tools/vector-query.ts +21 -26
package/.turbo/turbo-build.log
CHANGED
|
@@ -1,23 +1,23 @@
|
|
|
1
1
|
|
|
2
|
-
> @mastra/rag@0.10.
|
|
2
|
+
> @mastra/rag@0.10.2-alpha.0 build /home/runner/work/mastra/mastra/packages/rag
|
|
3
3
|
> tsup src/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting
|
|
4
4
|
|
|
5
5
|
[34mCLI[39m Building entry: src/index.ts
|
|
6
6
|
[34mCLI[39m Using tsconfig: tsconfig.json
|
|
7
7
|
[34mCLI[39m tsup v8.4.0
|
|
8
8
|
[34mTSC[39m Build start
|
|
9
|
-
[32mTSC[39m ⚡️ Build success in
|
|
9
|
+
[32mTSC[39m ⚡️ Build success in 13387ms
|
|
10
10
|
[34mDTS[39m Build start
|
|
11
11
|
[34mCLI[39m Target: es2022
|
|
12
12
|
Analysis will use the bundled TypeScript version 5.8.3
|
|
13
13
|
[36mWriting package typings: /home/runner/work/mastra/mastra/packages/rag/dist/_tsup-dts-rollup.d.ts[39m
|
|
14
14
|
Analysis will use the bundled TypeScript version 5.8.3
|
|
15
15
|
[36mWriting package typings: /home/runner/work/mastra/mastra/packages/rag/dist/_tsup-dts-rollup.d.cts[39m
|
|
16
|
-
[32mDTS[39m ⚡️ Build success in
|
|
16
|
+
[32mDTS[39m ⚡️ Build success in 10836ms
|
|
17
17
|
[34mCLI[39m Cleaning output folder
|
|
18
18
|
[34mESM[39m Build start
|
|
19
19
|
[34mCJS[39m Build start
|
|
20
|
-
[
|
|
21
|
-
[
|
|
22
|
-
[
|
|
23
|
-
[
|
|
20
|
+
[32mESM[39m [1mdist/index.js [22m[32m239.65 KB[39m
|
|
21
|
+
[32mESM[39m ⚡️ Build success in 3608ms
|
|
22
|
+
[32mCJS[39m [1mdist/index.cjs [22m[32m241.36 KB[39m
|
|
23
|
+
[32mCJS[39m ⚡️ Build success in 3609ms
|
package/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,31 @@
|
|
|
1
1
|
# @mastra/rag
|
|
2
2
|
|
|
3
|
+
## 0.10.2-alpha.0
|
|
4
|
+
|
|
5
|
+
### Patch Changes
|
|
6
|
+
|
|
7
|
+
- f0d559f: Fix peerdeps for alpha channel
|
|
8
|
+
- Updated dependencies [1e8bb40]
|
|
9
|
+
- @mastra/core@0.10.2-alpha.2
|
|
10
|
+
|
|
11
|
+
## 0.10.1
|
|
12
|
+
|
|
13
|
+
### Patch Changes
|
|
14
|
+
|
|
15
|
+
- 8784cef: Changed stripHeaders for markdown chunking to strip headers correctly from output when true
|
|
16
|
+
- f56fd29: added return type for vector-query and graph-rag
|
|
17
|
+
- Updated dependencies [d70b807]
|
|
18
|
+
- Updated dependencies [6d16390]
|
|
19
|
+
- Updated dependencies [1e4a421]
|
|
20
|
+
- Updated dependencies [200d0da]
|
|
21
|
+
- Updated dependencies [bf5f17b]
|
|
22
|
+
- Updated dependencies [5343f93]
|
|
23
|
+
- Updated dependencies [38aee50]
|
|
24
|
+
- Updated dependencies [5c41100]
|
|
25
|
+
- Updated dependencies [d6a759b]
|
|
26
|
+
- Updated dependencies [6015bdf]
|
|
27
|
+
- @mastra/core@0.10.1
|
|
28
|
+
|
|
3
29
|
## 0.10.1-alpha.1
|
|
4
30
|
|
|
5
31
|
### Patch Changes
|
|
@@ -174,21 +174,7 @@ export { createDocumentChunkerTool }
|
|
|
174
174
|
export { createDocumentChunkerTool as createDocumentChunkerTool_alias_1 }
|
|
175
175
|
export { createDocumentChunkerTool as createDocumentChunkerTool_alias_2 }
|
|
176
176
|
|
|
177
|
-
declare const createGraphRAGTool: (
|
|
178
|
-
vectorStoreName: string;
|
|
179
|
-
indexName: string;
|
|
180
|
-
model: EmbeddingModel<string>;
|
|
181
|
-
enableFilter?: boolean;
|
|
182
|
-
includeSources?: boolean;
|
|
183
|
-
graphOptions?: {
|
|
184
|
-
dimension?: number;
|
|
185
|
-
randomWalkSteps?: number;
|
|
186
|
-
restartProb?: number;
|
|
187
|
-
threshold?: number;
|
|
188
|
-
};
|
|
189
|
-
id?: string;
|
|
190
|
-
description?: string;
|
|
191
|
-
}) => RagTool<z.ZodObject<{
|
|
177
|
+
declare const createGraphRAGTool: (options: GraphRagToolOptions) => RagTool<z.ZodObject<{
|
|
192
178
|
filter: z.ZodString;
|
|
193
179
|
queryText: z.ZodString;
|
|
194
180
|
topK: z.ZodNumber;
|
|
@@ -214,17 +200,7 @@ export { createGraphRAGTool }
|
|
|
214
200
|
export { createGraphRAGTool as createGraphRAGTool_alias_1 }
|
|
215
201
|
export { createGraphRAGTool as createGraphRAGTool_alias_2 }
|
|
216
202
|
|
|
217
|
-
declare const createVectorQueryTool: (
|
|
218
|
-
vectorStoreName: string;
|
|
219
|
-
indexName: string;
|
|
220
|
-
model: EmbeddingModel<string>;
|
|
221
|
-
enableFilter?: boolean;
|
|
222
|
-
includeVectors?: boolean;
|
|
223
|
-
includeSources?: boolean;
|
|
224
|
-
reranker?: RerankConfig;
|
|
225
|
-
id?: string;
|
|
226
|
-
description?: string;
|
|
227
|
-
}) => RagTool<z.ZodObject<{
|
|
203
|
+
declare const createVectorQueryTool: (options: VectorQueryToolOptions) => RagTool<z.ZodObject<{
|
|
228
204
|
filter: z.ZodString;
|
|
229
205
|
queryText: z.ZodString;
|
|
230
206
|
topK: z.ZodNumber;
|
|
@@ -250,6 +226,25 @@ export { createVectorQueryTool }
|
|
|
250
226
|
export { createVectorQueryTool as createVectorQueryTool_alias_1 }
|
|
251
227
|
export { createVectorQueryTool as createVectorQueryTool_alias_2 }
|
|
252
228
|
|
|
229
|
+
/**
|
|
230
|
+
* Default options for GraphRAG
|
|
231
|
+
* @default
|
|
232
|
+
* ```json
|
|
233
|
+
* {
|
|
234
|
+
* "dimension": 1536,
|
|
235
|
+
* "randomWalkSteps": 100,
|
|
236
|
+
* "restartProb": 0.15,
|
|
237
|
+
* "threshold": 0.7
|
|
238
|
+
* }
|
|
239
|
+
* ```
|
|
240
|
+
*/
|
|
241
|
+
export declare const defaultGraphOptions: {
|
|
242
|
+
dimension: number;
|
|
243
|
+
randomWalkSteps: number;
|
|
244
|
+
restartProb: number;
|
|
245
|
+
threshold: number;
|
|
246
|
+
};
|
|
247
|
+
|
|
253
248
|
declare const defaultGraphRagDescription: () => string;
|
|
254
249
|
export { defaultGraphRagDescription }
|
|
255
250
|
export { defaultGraphRagDescription as defaultGraphRagDescription_alias_1 }
|
|
@@ -403,6 +398,22 @@ declare class GraphRAG {
|
|
|
403
398
|
export { GraphRAG }
|
|
404
399
|
export { GraphRAG as GraphRAG_alias_1 }
|
|
405
400
|
|
|
401
|
+
export declare type GraphRagToolOptions = {
|
|
402
|
+
id?: string;
|
|
403
|
+
description?: string;
|
|
404
|
+
indexName: string;
|
|
405
|
+
vectorStoreName: string;
|
|
406
|
+
model: EmbeddingModel<string>;
|
|
407
|
+
enableFilter?: boolean;
|
|
408
|
+
includeSources?: boolean;
|
|
409
|
+
graphOptions?: {
|
|
410
|
+
dimension?: number;
|
|
411
|
+
randomWalkSteps?: number;
|
|
412
|
+
restartProb?: number;
|
|
413
|
+
threshold?: number;
|
|
414
|
+
};
|
|
415
|
+
};
|
|
416
|
+
|
|
406
417
|
export declare class HTMLHeaderTransformer {
|
|
407
418
|
private headersToSplitOn;
|
|
408
419
|
private returnEachElement;
|
|
@@ -1165,6 +1176,18 @@ declare interface VectorQuerySearchResult {
|
|
|
1165
1176
|
queryEmbedding: number[];
|
|
1166
1177
|
}
|
|
1167
1178
|
|
|
1179
|
+
export declare type VectorQueryToolOptions = {
|
|
1180
|
+
id?: string;
|
|
1181
|
+
description?: string;
|
|
1182
|
+
indexName: string;
|
|
1183
|
+
vectorStoreName: string;
|
|
1184
|
+
model: EmbeddingModel<string>;
|
|
1185
|
+
enableFilter?: boolean;
|
|
1186
|
+
includeVectors?: boolean;
|
|
1187
|
+
includeSources?: boolean;
|
|
1188
|
+
reranker?: RerankConfig;
|
|
1189
|
+
};
|
|
1190
|
+
|
|
1168
1191
|
declare type WeightConfig = {
|
|
1169
1192
|
semantic?: number;
|
|
1170
1193
|
vector?: number;
|
|
@@ -174,21 +174,7 @@ export { createDocumentChunkerTool }
|
|
|
174
174
|
export { createDocumentChunkerTool as createDocumentChunkerTool_alias_1 }
|
|
175
175
|
export { createDocumentChunkerTool as createDocumentChunkerTool_alias_2 }
|
|
176
176
|
|
|
177
|
-
declare const createGraphRAGTool: (
|
|
178
|
-
vectorStoreName: string;
|
|
179
|
-
indexName: string;
|
|
180
|
-
model: EmbeddingModel<string>;
|
|
181
|
-
enableFilter?: boolean;
|
|
182
|
-
includeSources?: boolean;
|
|
183
|
-
graphOptions?: {
|
|
184
|
-
dimension?: number;
|
|
185
|
-
randomWalkSteps?: number;
|
|
186
|
-
restartProb?: number;
|
|
187
|
-
threshold?: number;
|
|
188
|
-
};
|
|
189
|
-
id?: string;
|
|
190
|
-
description?: string;
|
|
191
|
-
}) => RagTool<z.ZodObject<{
|
|
177
|
+
declare const createGraphRAGTool: (options: GraphRagToolOptions) => RagTool<z.ZodObject<{
|
|
192
178
|
filter: z.ZodString;
|
|
193
179
|
queryText: z.ZodString;
|
|
194
180
|
topK: z.ZodNumber;
|
|
@@ -214,17 +200,7 @@ export { createGraphRAGTool }
|
|
|
214
200
|
export { createGraphRAGTool as createGraphRAGTool_alias_1 }
|
|
215
201
|
export { createGraphRAGTool as createGraphRAGTool_alias_2 }
|
|
216
202
|
|
|
217
|
-
declare const createVectorQueryTool: (
|
|
218
|
-
vectorStoreName: string;
|
|
219
|
-
indexName: string;
|
|
220
|
-
model: EmbeddingModel<string>;
|
|
221
|
-
enableFilter?: boolean;
|
|
222
|
-
includeVectors?: boolean;
|
|
223
|
-
includeSources?: boolean;
|
|
224
|
-
reranker?: RerankConfig;
|
|
225
|
-
id?: string;
|
|
226
|
-
description?: string;
|
|
227
|
-
}) => RagTool<z.ZodObject<{
|
|
203
|
+
declare const createVectorQueryTool: (options: VectorQueryToolOptions) => RagTool<z.ZodObject<{
|
|
228
204
|
filter: z.ZodString;
|
|
229
205
|
queryText: z.ZodString;
|
|
230
206
|
topK: z.ZodNumber;
|
|
@@ -250,6 +226,25 @@ export { createVectorQueryTool }
|
|
|
250
226
|
export { createVectorQueryTool as createVectorQueryTool_alias_1 }
|
|
251
227
|
export { createVectorQueryTool as createVectorQueryTool_alias_2 }
|
|
252
228
|
|
|
229
|
+
/**
|
|
230
|
+
* Default options for GraphRAG
|
|
231
|
+
* @default
|
|
232
|
+
* ```json
|
|
233
|
+
* {
|
|
234
|
+
* "dimension": 1536,
|
|
235
|
+
* "randomWalkSteps": 100,
|
|
236
|
+
* "restartProb": 0.15,
|
|
237
|
+
* "threshold": 0.7
|
|
238
|
+
* }
|
|
239
|
+
* ```
|
|
240
|
+
*/
|
|
241
|
+
export declare const defaultGraphOptions: {
|
|
242
|
+
dimension: number;
|
|
243
|
+
randomWalkSteps: number;
|
|
244
|
+
restartProb: number;
|
|
245
|
+
threshold: number;
|
|
246
|
+
};
|
|
247
|
+
|
|
253
248
|
declare const defaultGraphRagDescription: () => string;
|
|
254
249
|
export { defaultGraphRagDescription }
|
|
255
250
|
export { defaultGraphRagDescription as defaultGraphRagDescription_alias_1 }
|
|
@@ -403,6 +398,22 @@ declare class GraphRAG {
|
|
|
403
398
|
export { GraphRAG }
|
|
404
399
|
export { GraphRAG as GraphRAG_alias_1 }
|
|
405
400
|
|
|
401
|
+
export declare type GraphRagToolOptions = {
|
|
402
|
+
id?: string;
|
|
403
|
+
description?: string;
|
|
404
|
+
indexName: string;
|
|
405
|
+
vectorStoreName: string;
|
|
406
|
+
model: EmbeddingModel<string>;
|
|
407
|
+
enableFilter?: boolean;
|
|
408
|
+
includeSources?: boolean;
|
|
409
|
+
graphOptions?: {
|
|
410
|
+
dimension?: number;
|
|
411
|
+
randomWalkSteps?: number;
|
|
412
|
+
restartProb?: number;
|
|
413
|
+
threshold?: number;
|
|
414
|
+
};
|
|
415
|
+
};
|
|
416
|
+
|
|
406
417
|
export declare class HTMLHeaderTransformer {
|
|
407
418
|
private headersToSplitOn;
|
|
408
419
|
private returnEachElement;
|
|
@@ -1165,6 +1176,18 @@ declare interface VectorQuerySearchResult {
|
|
|
1165
1176
|
queryEmbedding: number[];
|
|
1166
1177
|
}
|
|
1167
1178
|
|
|
1179
|
+
export declare type VectorQueryToolOptions = {
|
|
1180
|
+
id?: string;
|
|
1181
|
+
description?: string;
|
|
1182
|
+
indexName: string;
|
|
1183
|
+
vectorStoreName: string;
|
|
1184
|
+
model: EmbeddingModel<string>;
|
|
1185
|
+
enableFilter?: boolean;
|
|
1186
|
+
includeVectors?: boolean;
|
|
1187
|
+
includeSources?: boolean;
|
|
1188
|
+
reranker?: RerankConfig;
|
|
1189
|
+
};
|
|
1190
|
+
|
|
1168
1191
|
declare type WeightConfig = {
|
|
1169
1192
|
semantic?: number;
|
|
1170
1193
|
vector?: number;
|
package/dist/index.cjs
CHANGED
|
@@ -6334,33 +6334,43 @@ var convertToSources = (results) => {
|
|
|
6334
6334
|
});
|
|
6335
6335
|
};
|
|
6336
6336
|
|
|
6337
|
+
// src/tools/types.ts
|
|
6338
|
+
var defaultGraphOptions = {
|
|
6339
|
+
dimension: 1536,
|
|
6340
|
+
randomWalkSteps: 100,
|
|
6341
|
+
restartProb: 0.15,
|
|
6342
|
+
threshold: 0.7
|
|
6343
|
+
};
|
|
6344
|
+
|
|
6337
6345
|
// src/tools/graph-rag.ts
|
|
6338
|
-
var createGraphRAGTool = ({
|
|
6339
|
-
|
|
6340
|
-
indexName
|
|
6341
|
-
model,
|
|
6342
|
-
enableFilter = false,
|
|
6343
|
-
includeSources = true,
|
|
6344
|
-
graphOptions = {
|
|
6345
|
-
dimension: 1536,
|
|
6346
|
-
randomWalkSteps: 100,
|
|
6347
|
-
restartProb: 0.15,
|
|
6348
|
-
threshold: 0.7
|
|
6349
|
-
},
|
|
6350
|
-
id,
|
|
6351
|
-
description
|
|
6352
|
-
}) => {
|
|
6353
|
-
const toolId = id || `GraphRAG ${vectorStoreName} ${indexName} Tool`;
|
|
6346
|
+
var createGraphRAGTool = (options) => {
|
|
6347
|
+
const { model, id, description } = options;
|
|
6348
|
+
const toolId = id || `GraphRAG ${options.vectorStoreName} ${options.indexName} Tool`;
|
|
6354
6349
|
const toolDescription = description || defaultGraphRagDescription();
|
|
6350
|
+
const graphOptions = {
|
|
6351
|
+
...defaultGraphOptions,
|
|
6352
|
+
...options.graphOptions || {}
|
|
6353
|
+
};
|
|
6355
6354
|
const graphRag = new GraphRAG(graphOptions.dimension, graphOptions.threshold);
|
|
6356
6355
|
let isInitialized = false;
|
|
6357
|
-
const inputSchema = enableFilter ? filterSchema : zod.z.object(baseSchema).passthrough();
|
|
6356
|
+
const inputSchema = options.enableFilter ? filterSchema : zod.z.object(baseSchema).passthrough();
|
|
6358
6357
|
return tools.createTool({
|
|
6359
6358
|
id: toolId,
|
|
6360
6359
|
inputSchema,
|
|
6361
6360
|
outputSchema,
|
|
6362
6361
|
description: toolDescription,
|
|
6363
|
-
execute: async ({ context
|
|
6362
|
+
execute: async ({ context, mastra, runtimeContext }) => {
|
|
6363
|
+
const indexName = runtimeContext.get("indexName") ?? options.indexName;
|
|
6364
|
+
const vectorStoreName = runtimeContext.get("vectorStoreName") ?? options.vectorStoreName;
|
|
6365
|
+
if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
|
|
6366
|
+
if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
|
|
6367
|
+
const includeSources = runtimeContext.get("includeSources") ?? options.includeSources ?? true;
|
|
6368
|
+
const randomWalkSteps = runtimeContext.get("randomWalkSteps") ?? graphOptions.randomWalkSteps;
|
|
6369
|
+
const restartProb = runtimeContext.get("restartProb") ?? graphOptions.restartProb;
|
|
6370
|
+
const topK = runtimeContext.get("topK") ?? context.topK ?? 10;
|
|
6371
|
+
const filter = runtimeContext.get("filter") ?? context.filter;
|
|
6372
|
+
const queryText = context.queryText;
|
|
6373
|
+
const enableFilter = !!runtimeContext.get("filter") || (options.enableFilter ?? false);
|
|
6364
6374
|
const logger = mastra?.getLogger();
|
|
6365
6375
|
if (!logger) {
|
|
6366
6376
|
console.warn(
|
|
@@ -6426,8 +6436,8 @@ var createGraphRAGTool = ({
|
|
|
6426
6436
|
const rerankedResults = graphRag.query({
|
|
6427
6437
|
query: queryEmbedding,
|
|
6428
6438
|
topK: topKValue,
|
|
6429
|
-
randomWalkSteps
|
|
6430
|
-
restartProb
|
|
6439
|
+
randomWalkSteps,
|
|
6440
|
+
restartProb
|
|
6431
6441
|
});
|
|
6432
6442
|
if (logger) {
|
|
6433
6443
|
logger.debug("GraphRAG query returned results", { count: rerankedResults.length });
|
|
@@ -6455,26 +6465,28 @@ var createGraphRAGTool = ({
|
|
|
6455
6465
|
// Use any for output schema as the structure of the output causes type inference issues
|
|
6456
6466
|
});
|
|
6457
6467
|
};
|
|
6458
|
-
var createVectorQueryTool = ({
|
|
6459
|
-
|
|
6460
|
-
indexName
|
|
6461
|
-
model,
|
|
6462
|
-
enableFilter = false,
|
|
6463
|
-
includeVectors = false,
|
|
6464
|
-
includeSources = true,
|
|
6465
|
-
reranker,
|
|
6466
|
-
id,
|
|
6467
|
-
description
|
|
6468
|
-
}) => {
|
|
6469
|
-
const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
|
|
6468
|
+
var createVectorQueryTool = (options) => {
|
|
6469
|
+
const { model, id, description } = options;
|
|
6470
|
+
const toolId = id || `VectorQuery ${options.vectorStoreName} ${options.indexName} Tool`;
|
|
6470
6471
|
const toolDescription = description || defaultVectorQueryDescription();
|
|
6471
|
-
const inputSchema = enableFilter ? filterSchema : zod.z.object(baseSchema).passthrough();
|
|
6472
|
+
const inputSchema = options.enableFilter ? filterSchema : zod.z.object(baseSchema).passthrough();
|
|
6472
6473
|
return tools.createTool({
|
|
6473
6474
|
id: toolId,
|
|
6475
|
+
description: toolDescription,
|
|
6474
6476
|
inputSchema,
|
|
6475
6477
|
outputSchema,
|
|
6476
|
-
|
|
6477
|
-
|
|
6478
|
+
execute: async ({ context, mastra, runtimeContext }) => {
|
|
6479
|
+
const indexName = runtimeContext.get("indexName") ?? options.indexName;
|
|
6480
|
+
const vectorStoreName = runtimeContext.get("vectorStoreName") ?? options.vectorStoreName;
|
|
6481
|
+
const includeVectors = runtimeContext.get("includeVectors") ?? options.includeVectors ?? false;
|
|
6482
|
+
const includeSources = runtimeContext.get("includeSources") ?? options.includeSources ?? true;
|
|
6483
|
+
const reranker = runtimeContext.get("reranker") ?? options.reranker;
|
|
6484
|
+
if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
|
|
6485
|
+
if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
|
|
6486
|
+
const topK = runtimeContext.get("topK") ?? context.topK ?? 10;
|
|
6487
|
+
const filter = runtimeContext.get("filter") ?? context.filter;
|
|
6488
|
+
const queryText = context.queryText;
|
|
6489
|
+
const enableFilter = !!runtimeContext.get("filter") || (options.enableFilter ?? false);
|
|
6478
6490
|
const logger = mastra?.getLogger();
|
|
6479
6491
|
if (!logger) {
|
|
6480
6492
|
console.warn(
|
package/dist/index.js
CHANGED
|
@@ -6332,33 +6332,43 @@ var convertToSources = (results) => {
|
|
|
6332
6332
|
});
|
|
6333
6333
|
};
|
|
6334
6334
|
|
|
6335
|
+
// src/tools/types.ts
|
|
6336
|
+
var defaultGraphOptions = {
|
|
6337
|
+
dimension: 1536,
|
|
6338
|
+
randomWalkSteps: 100,
|
|
6339
|
+
restartProb: 0.15,
|
|
6340
|
+
threshold: 0.7
|
|
6341
|
+
};
|
|
6342
|
+
|
|
6335
6343
|
// src/tools/graph-rag.ts
|
|
6336
|
-
var createGraphRAGTool = ({
|
|
6337
|
-
|
|
6338
|
-
indexName
|
|
6339
|
-
model,
|
|
6340
|
-
enableFilter = false,
|
|
6341
|
-
includeSources = true,
|
|
6342
|
-
graphOptions = {
|
|
6343
|
-
dimension: 1536,
|
|
6344
|
-
randomWalkSteps: 100,
|
|
6345
|
-
restartProb: 0.15,
|
|
6346
|
-
threshold: 0.7
|
|
6347
|
-
},
|
|
6348
|
-
id,
|
|
6349
|
-
description
|
|
6350
|
-
}) => {
|
|
6351
|
-
const toolId = id || `GraphRAG ${vectorStoreName} ${indexName} Tool`;
|
|
6344
|
+
var createGraphRAGTool = (options) => {
|
|
6345
|
+
const { model, id, description } = options;
|
|
6346
|
+
const toolId = id || `GraphRAG ${options.vectorStoreName} ${options.indexName} Tool`;
|
|
6352
6347
|
const toolDescription = description || defaultGraphRagDescription();
|
|
6348
|
+
const graphOptions = {
|
|
6349
|
+
...defaultGraphOptions,
|
|
6350
|
+
...options.graphOptions || {}
|
|
6351
|
+
};
|
|
6353
6352
|
const graphRag = new GraphRAG(graphOptions.dimension, graphOptions.threshold);
|
|
6354
6353
|
let isInitialized = false;
|
|
6355
|
-
const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
6354
|
+
const inputSchema = options.enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
6356
6355
|
return createTool({
|
|
6357
6356
|
id: toolId,
|
|
6358
6357
|
inputSchema,
|
|
6359
6358
|
outputSchema,
|
|
6360
6359
|
description: toolDescription,
|
|
6361
|
-
execute: async ({ context
|
|
6360
|
+
execute: async ({ context, mastra, runtimeContext }) => {
|
|
6361
|
+
const indexName = runtimeContext.get("indexName") ?? options.indexName;
|
|
6362
|
+
const vectorStoreName = runtimeContext.get("vectorStoreName") ?? options.vectorStoreName;
|
|
6363
|
+
if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
|
|
6364
|
+
if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
|
|
6365
|
+
const includeSources = runtimeContext.get("includeSources") ?? options.includeSources ?? true;
|
|
6366
|
+
const randomWalkSteps = runtimeContext.get("randomWalkSteps") ?? graphOptions.randomWalkSteps;
|
|
6367
|
+
const restartProb = runtimeContext.get("restartProb") ?? graphOptions.restartProb;
|
|
6368
|
+
const topK = runtimeContext.get("topK") ?? context.topK ?? 10;
|
|
6369
|
+
const filter = runtimeContext.get("filter") ?? context.filter;
|
|
6370
|
+
const queryText = context.queryText;
|
|
6371
|
+
const enableFilter = !!runtimeContext.get("filter") || (options.enableFilter ?? false);
|
|
6362
6372
|
const logger = mastra?.getLogger();
|
|
6363
6373
|
if (!logger) {
|
|
6364
6374
|
console.warn(
|
|
@@ -6424,8 +6434,8 @@ var createGraphRAGTool = ({
|
|
|
6424
6434
|
const rerankedResults = graphRag.query({
|
|
6425
6435
|
query: queryEmbedding,
|
|
6426
6436
|
topK: topKValue,
|
|
6427
|
-
randomWalkSteps
|
|
6428
|
-
restartProb
|
|
6437
|
+
randomWalkSteps,
|
|
6438
|
+
restartProb
|
|
6429
6439
|
});
|
|
6430
6440
|
if (logger) {
|
|
6431
6441
|
logger.debug("GraphRAG query returned results", { count: rerankedResults.length });
|
|
@@ -6453,26 +6463,28 @@ var createGraphRAGTool = ({
|
|
|
6453
6463
|
// Use any for output schema as the structure of the output causes type inference issues
|
|
6454
6464
|
});
|
|
6455
6465
|
};
|
|
6456
|
-
var createVectorQueryTool = ({
|
|
6457
|
-
|
|
6458
|
-
indexName
|
|
6459
|
-
model,
|
|
6460
|
-
enableFilter = false,
|
|
6461
|
-
includeVectors = false,
|
|
6462
|
-
includeSources = true,
|
|
6463
|
-
reranker,
|
|
6464
|
-
id,
|
|
6465
|
-
description
|
|
6466
|
-
}) => {
|
|
6467
|
-
const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
|
|
6466
|
+
var createVectorQueryTool = (options) => {
|
|
6467
|
+
const { model, id, description } = options;
|
|
6468
|
+
const toolId = id || `VectorQuery ${options.vectorStoreName} ${options.indexName} Tool`;
|
|
6468
6469
|
const toolDescription = description || defaultVectorQueryDescription();
|
|
6469
|
-
const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
6470
|
+
const inputSchema = options.enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
6470
6471
|
return createTool({
|
|
6471
6472
|
id: toolId,
|
|
6473
|
+
description: toolDescription,
|
|
6472
6474
|
inputSchema,
|
|
6473
6475
|
outputSchema,
|
|
6474
|
-
|
|
6475
|
-
|
|
6476
|
+
execute: async ({ context, mastra, runtimeContext }) => {
|
|
6477
|
+
const indexName = runtimeContext.get("indexName") ?? options.indexName;
|
|
6478
|
+
const vectorStoreName = runtimeContext.get("vectorStoreName") ?? options.vectorStoreName;
|
|
6479
|
+
const includeVectors = runtimeContext.get("includeVectors") ?? options.includeVectors ?? false;
|
|
6480
|
+
const includeSources = runtimeContext.get("includeSources") ?? options.includeSources ?? true;
|
|
6481
|
+
const reranker = runtimeContext.get("reranker") ?? options.reranker;
|
|
6482
|
+
if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
|
|
6483
|
+
if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
|
|
6484
|
+
const topK = runtimeContext.get("topK") ?? context.topK ?? 10;
|
|
6485
|
+
const filter = runtimeContext.get("filter") ?? context.filter;
|
|
6486
|
+
const queryText = context.queryText;
|
|
6487
|
+
const enableFilter = !!runtimeContext.get("filter") || (options.enableFilter ?? false);
|
|
6476
6488
|
const logger = mastra?.getLogger();
|
|
6477
6489
|
if (!logger) {
|
|
6478
6490
|
console.warn(
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@mastra/rag",
|
|
3
|
-
"version": "0.10.
|
|
3
|
+
"version": "0.10.2-alpha.0",
|
|
4
4
|
"description": "",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "dist/index.js",
|
|
@@ -30,7 +30,7 @@
|
|
|
30
30
|
},
|
|
31
31
|
"peerDependencies": {
|
|
32
32
|
"ai": "^4.0.0",
|
|
33
|
-
"@mastra/core": "^0.10.0"
|
|
33
|
+
"@mastra/core": "^0.10.0-alpha.0"
|
|
34
34
|
},
|
|
35
35
|
"devDependencies": {
|
|
36
36
|
"@ai-sdk/cohere": "latest",
|
|
@@ -44,8 +44,8 @@
|
|
|
44
44
|
"tsup": "^8.4.0",
|
|
45
45
|
"typescript": "^5.8.2",
|
|
46
46
|
"vitest": "^3.1.2",
|
|
47
|
-
"@internal/lint": "0.0.
|
|
48
|
-
"@mastra/core": "0.10.
|
|
47
|
+
"@internal/lint": "0.0.7",
|
|
48
|
+
"@mastra/core": "0.10.2-alpha.2"
|
|
49
49
|
},
|
|
50
50
|
"keywords": [
|
|
51
51
|
"rag",
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import { RuntimeContext } from '@mastra/core/runtime-context';
|
|
2
|
+
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
3
|
+
import { GraphRAG } from '../graph-rag';
|
|
4
|
+
import { vectorQuerySearch } from '../utils';
|
|
5
|
+
import { createGraphRAGTool } from './graph-rag';
|
|
6
|
+
|
|
7
|
+
vi.mock('../utils', async importOriginal => {
|
|
8
|
+
const actual: any = await importOriginal();
|
|
9
|
+
return {
|
|
10
|
+
...actual,
|
|
11
|
+
vectorQuerySearch: vi.fn().mockResolvedValue({
|
|
12
|
+
results: [
|
|
13
|
+
{ metadata: { text: 'foo' }, vector: [1, 2, 3] },
|
|
14
|
+
{ metadata: { text: 'bar' }, vector: [4, 5, 6] },
|
|
15
|
+
],
|
|
16
|
+
queryEmbedding: [1, 2, 3],
|
|
17
|
+
}),
|
|
18
|
+
};
|
|
19
|
+
});
|
|
20
|
+
|
|
21
|
+
vi.mock('../graph-rag', async importOriginal => {
|
|
22
|
+
const actual: any = await importOriginal();
|
|
23
|
+
return {
|
|
24
|
+
...actual,
|
|
25
|
+
GraphRAG: vi.fn().mockImplementation(() => {
|
|
26
|
+
return {
|
|
27
|
+
createGraph: vi.fn(),
|
|
28
|
+
query: vi.fn(() => [
|
|
29
|
+
{ content: 'foo', metadata: { text: 'foo' } },
|
|
30
|
+
{ content: 'bar', metadata: { text: 'bar' } },
|
|
31
|
+
]),
|
|
32
|
+
};
|
|
33
|
+
}),
|
|
34
|
+
};
|
|
35
|
+
});
|
|
36
|
+
|
|
37
|
+
const mockModel = { name: 'test-model' } as any;
|
|
38
|
+
const mockMastra = {
|
|
39
|
+
getVector: vi.fn(storeName => ({
|
|
40
|
+
[storeName]: {},
|
|
41
|
+
})),
|
|
42
|
+
getLogger: vi.fn(() => ({
|
|
43
|
+
debug: vi.fn(),
|
|
44
|
+
warn: vi.fn(),
|
|
45
|
+
info: vi.fn(),
|
|
46
|
+
error: vi.fn(),
|
|
47
|
+
})),
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
describe('createGraphRAGTool', () => {
|
|
51
|
+
beforeEach(() => {
|
|
52
|
+
vi.clearAllMocks();
|
|
53
|
+
});
|
|
54
|
+
|
|
55
|
+
it('validates input schema', () => {
|
|
56
|
+
const tool = createGraphRAGTool({
|
|
57
|
+
id: 'test',
|
|
58
|
+
model: mockModel,
|
|
59
|
+
vectorStoreName: 'testStore',
|
|
60
|
+
indexName: 'testIndex',
|
|
61
|
+
});
|
|
62
|
+
expect(() => tool.inputSchema?.parse({ queryText: 'foo', topK: 10 })).not.toThrow();
|
|
63
|
+
expect(() => tool.inputSchema?.parse({})).toThrow();
|
|
64
|
+
});
|
|
65
|
+
|
|
66
|
+
describe('runtimeContext', () => {
|
|
67
|
+
it('calls vectorQuerySearch and GraphRAG with runtimeContext params', async () => {
|
|
68
|
+
const tool = createGraphRAGTool({
|
|
69
|
+
id: 'test',
|
|
70
|
+
model: mockModel,
|
|
71
|
+
indexName: 'testIndex',
|
|
72
|
+
vectorStoreName: 'testStore',
|
|
73
|
+
});
|
|
74
|
+
const runtimeContext = new RuntimeContext();
|
|
75
|
+
runtimeContext.set('indexName', 'anotherIndex');
|
|
76
|
+
runtimeContext.set('vectorStoreName', 'anotherStore');
|
|
77
|
+
runtimeContext.set('topK', 5);
|
|
78
|
+
runtimeContext.set('filter', { foo: 'bar' });
|
|
79
|
+
runtimeContext.set('randomWalkSteps', 99);
|
|
80
|
+
runtimeContext.set('restartProb', 0.42);
|
|
81
|
+
const result = await tool.execute({
|
|
82
|
+
context: { queryText: 'foo', topK: 2 },
|
|
83
|
+
mastra: mockMastra as any,
|
|
84
|
+
runtimeContext,
|
|
85
|
+
});
|
|
86
|
+
expect(result.relevantContext).toEqual(['foo', 'bar']);
|
|
87
|
+
expect(result.sources.length).toBe(2);
|
|
88
|
+
expect(vectorQuerySearch).toHaveBeenCalledWith(
|
|
89
|
+
expect.objectContaining({
|
|
90
|
+
indexName: 'anotherIndex',
|
|
91
|
+
vectorStore: {
|
|
92
|
+
anotherStore: {},
|
|
93
|
+
},
|
|
94
|
+
queryText: 'foo',
|
|
95
|
+
model: mockModel,
|
|
96
|
+
queryFilter: { foo: 'bar' },
|
|
97
|
+
topK: 5,
|
|
98
|
+
includeVectors: true,
|
|
99
|
+
}),
|
|
100
|
+
);
|
|
101
|
+
// GraphRAG createGraph and query should be called
|
|
102
|
+
expect(GraphRAG).toHaveBeenCalled();
|
|
103
|
+
const instance = (GraphRAG as any).mock.results[0].value;
|
|
104
|
+
expect(instance.createGraph).toHaveBeenCalled();
|
|
105
|
+
expect(instance.query).toHaveBeenCalledWith(
|
|
106
|
+
expect.objectContaining({
|
|
107
|
+
query: [1, 2, 3],
|
|
108
|
+
topK: 5,
|
|
109
|
+
randomWalkSteps: 99,
|
|
110
|
+
restartProb: 0.42,
|
|
111
|
+
}),
|
|
112
|
+
);
|
|
113
|
+
});
|
|
114
|
+
});
|
|
115
|
+
});
|
package/src/tools/graph-rag.ts
CHANGED
|
@@ -1,55 +1,47 @@
|
|
|
1
1
|
import { createTool } from '@mastra/core/tools';
|
|
2
|
-
import type { EmbeddingModel } from 'ai';
|
|
3
2
|
import { z } from 'zod';
|
|
4
3
|
|
|
5
4
|
import { GraphRAG } from '../graph-rag';
|
|
6
5
|
import { vectorQuerySearch, defaultGraphRagDescription, filterSchema, outputSchema, baseSchema } from '../utils';
|
|
7
6
|
import type { RagTool } from '../utils';
|
|
8
7
|
import { convertToSources } from '../utils/convert-sources';
|
|
8
|
+
import type { GraphRagToolOptions } from './types';
|
|
9
|
+
import { defaultGraphOptions } from './types';
|
|
9
10
|
|
|
10
|
-
export const createGraphRAGTool = ({
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
enableFilter = false,
|
|
15
|
-
includeSources = true,
|
|
16
|
-
graphOptions = {
|
|
17
|
-
dimension: 1536,
|
|
18
|
-
randomWalkSteps: 100,
|
|
19
|
-
restartProb: 0.15,
|
|
20
|
-
threshold: 0.7,
|
|
21
|
-
},
|
|
22
|
-
id,
|
|
23
|
-
description,
|
|
24
|
-
}: {
|
|
25
|
-
vectorStoreName: string;
|
|
26
|
-
indexName: string;
|
|
27
|
-
model: EmbeddingModel<string>;
|
|
28
|
-
enableFilter?: boolean;
|
|
29
|
-
includeSources?: boolean;
|
|
30
|
-
graphOptions?: {
|
|
31
|
-
dimension?: number;
|
|
32
|
-
randomWalkSteps?: number;
|
|
33
|
-
restartProb?: number;
|
|
34
|
-
threshold?: number;
|
|
35
|
-
};
|
|
36
|
-
id?: string;
|
|
37
|
-
description?: string;
|
|
38
|
-
}) => {
|
|
39
|
-
const toolId = id || `GraphRAG ${vectorStoreName} ${indexName} Tool`;
|
|
11
|
+
export const createGraphRAGTool = (options: GraphRagToolOptions) => {
|
|
12
|
+
const { model, id, description } = options;
|
|
13
|
+
|
|
14
|
+
const toolId = id || `GraphRAG ${options.vectorStoreName} ${options.indexName} Tool`;
|
|
40
15
|
const toolDescription = description || defaultGraphRagDescription();
|
|
16
|
+
const graphOptions = {
|
|
17
|
+
...defaultGraphOptions,
|
|
18
|
+
...(options.graphOptions || {}),
|
|
19
|
+
};
|
|
41
20
|
// Initialize GraphRAG
|
|
42
21
|
const graphRag = new GraphRAG(graphOptions.dimension, graphOptions.threshold);
|
|
43
22
|
let isInitialized = false;
|
|
44
23
|
|
|
45
|
-
const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
24
|
+
const inputSchema = options.enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
46
25
|
|
|
47
26
|
return createTool({
|
|
48
27
|
id: toolId,
|
|
49
28
|
inputSchema,
|
|
50
29
|
outputSchema,
|
|
51
30
|
description: toolDescription,
|
|
52
|
-
execute: async ({ context
|
|
31
|
+
execute: async ({ context, mastra, runtimeContext }) => {
|
|
32
|
+
const indexName: string = runtimeContext.get('indexName') ?? options.indexName;
|
|
33
|
+
const vectorStoreName: string = runtimeContext.get('vectorStoreName') ?? options.vectorStoreName;
|
|
34
|
+
if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
|
|
35
|
+
if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
|
|
36
|
+
const includeSources: boolean = runtimeContext.get('includeSources') ?? options.includeSources ?? true;
|
|
37
|
+
const randomWalkSteps: number | undefined = runtimeContext.get('randomWalkSteps') ?? graphOptions.randomWalkSteps;
|
|
38
|
+
const restartProb: number | undefined = runtimeContext.get('restartProb') ?? graphOptions.restartProb;
|
|
39
|
+
const topK: number = runtimeContext.get('topK') ?? context.topK ?? 10;
|
|
40
|
+
const filter: Record<string, any> = runtimeContext.get('filter') ?? context.filter;
|
|
41
|
+
const queryText = context.queryText;
|
|
42
|
+
|
|
43
|
+
const enableFilter = !!runtimeContext.get('filter') || (options.enableFilter ?? false);
|
|
44
|
+
|
|
53
45
|
const logger = mastra?.getLogger();
|
|
54
46
|
if (!logger) {
|
|
55
47
|
console.warn(
|
|
@@ -129,8 +121,8 @@ export const createGraphRAGTool = ({
|
|
|
129
121
|
const rerankedResults = graphRag.query({
|
|
130
122
|
query: queryEmbedding,
|
|
131
123
|
topK: topKValue,
|
|
132
|
-
randomWalkSteps
|
|
133
|
-
restartProb
|
|
124
|
+
randomWalkSteps,
|
|
125
|
+
restartProb,
|
|
134
126
|
});
|
|
135
127
|
if (logger) {
|
|
136
128
|
logger.debug('GraphRAG query returned results', { count: rerankedResults.length });
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import type { EmbeddingModel } from 'ai';
|
|
2
|
+
import type { RerankConfig } from '../rerank';
|
|
3
|
+
|
|
4
|
+
export type VectorQueryToolOptions = {
|
|
5
|
+
id?: string;
|
|
6
|
+
description?: string;
|
|
7
|
+
indexName: string;
|
|
8
|
+
vectorStoreName: string;
|
|
9
|
+
model: EmbeddingModel<string>;
|
|
10
|
+
enableFilter?: boolean;
|
|
11
|
+
includeVectors?: boolean;
|
|
12
|
+
includeSources?: boolean;
|
|
13
|
+
reranker?: RerankConfig;
|
|
14
|
+
};
|
|
15
|
+
|
|
16
|
+
export type GraphRagToolOptions = {
|
|
17
|
+
id?: string;
|
|
18
|
+
description?: string;
|
|
19
|
+
indexName: string;
|
|
20
|
+
vectorStoreName: string;
|
|
21
|
+
model: EmbeddingModel<string>;
|
|
22
|
+
enableFilter?: boolean;
|
|
23
|
+
includeSources?: boolean;
|
|
24
|
+
graphOptions?: {
|
|
25
|
+
dimension?: number;
|
|
26
|
+
randomWalkSteps?: number;
|
|
27
|
+
restartProb?: number;
|
|
28
|
+
threshold?: number;
|
|
29
|
+
};
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* Default options for GraphRAG
|
|
34
|
+
* @default
|
|
35
|
+
* ```json
|
|
36
|
+
* {
|
|
37
|
+
* "dimension": 1536,
|
|
38
|
+
* "randomWalkSteps": 100,
|
|
39
|
+
* "restartProb": 0.15,
|
|
40
|
+
* "threshold": 0.7
|
|
41
|
+
* }
|
|
42
|
+
* ```
|
|
43
|
+
*/
|
|
44
|
+
export const defaultGraphOptions = {
|
|
45
|
+
dimension: 1536,
|
|
46
|
+
randomWalkSteps: 100,
|
|
47
|
+
restartProb: 0.15,
|
|
48
|
+
threshold: 0.7,
|
|
49
|
+
};
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import { RuntimeContext } from '@mastra/core/runtime-context';
|
|
1
2
|
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
3
|
+
import { rerank } from '../rerank';
|
|
2
4
|
import { vectorQuerySearch } from '../utils';
|
|
3
5
|
import { createVectorQueryTool } from './vector-query';
|
|
4
6
|
|
|
@@ -6,7 +8,19 @@ vi.mock('../utils', async importOriginal => {
|
|
|
6
8
|
const actual: any = await importOriginal();
|
|
7
9
|
return {
|
|
8
10
|
...actual,
|
|
9
|
-
vectorQuerySearch: vi.fn().mockResolvedValue({ results: [] }),
|
|
11
|
+
vectorQuerySearch: vi.fn().mockResolvedValue({ results: [{ metadata: { text: 'foo' }, vector: [1, 2, 3] }] }),
|
|
12
|
+
};
|
|
13
|
+
});
|
|
14
|
+
|
|
15
|
+
vi.mock('../rerank', async importOriginal => {
|
|
16
|
+
const actual: any = await importOriginal();
|
|
17
|
+
return {
|
|
18
|
+
...actual,
|
|
19
|
+
rerank: vi
|
|
20
|
+
.fn()
|
|
21
|
+
.mockResolvedValue([
|
|
22
|
+
{ result: { id: '1', metadata: { text: 'bar' }, score: 1, details: { semantic: 1, vector: 1, position: 1 } } },
|
|
23
|
+
]),
|
|
10
24
|
};
|
|
11
25
|
});
|
|
12
26
|
|
|
@@ -17,9 +31,12 @@ describe('createVectorQueryTool', () => {
|
|
|
17
31
|
testStore: {
|
|
18
32
|
// Mock vector store methods
|
|
19
33
|
},
|
|
34
|
+
anotherStore: {
|
|
35
|
+
// Mock vector store methods
|
|
36
|
+
},
|
|
20
37
|
},
|
|
21
|
-
getVector: vi.fn(
|
|
22
|
-
|
|
38
|
+
getVector: vi.fn(storeName => ({
|
|
39
|
+
[storeName]: {
|
|
23
40
|
// Mock vector store methods
|
|
24
41
|
},
|
|
25
42
|
})),
|
|
@@ -154,6 +171,8 @@ describe('createVectorQueryTool', () => {
|
|
|
154
171
|
|
|
155
172
|
describe('execute function', () => {
|
|
156
173
|
it('should not process filter when enableFilter is false', async () => {
|
|
174
|
+
const runtimeContext = new RuntimeContext();
|
|
175
|
+
|
|
157
176
|
// Create tool with enableFilter set to false
|
|
158
177
|
const tool = createVectorQueryTool({
|
|
159
178
|
vectorStoreName: 'testStore',
|
|
@@ -169,7 +188,7 @@ describe('createVectorQueryTool', () => {
|
|
|
169
188
|
topK: 5,
|
|
170
189
|
},
|
|
171
190
|
mastra: mockMastra as any,
|
|
172
|
-
runtimeContext
|
|
191
|
+
runtimeContext,
|
|
173
192
|
});
|
|
174
193
|
|
|
175
194
|
// Check that vectorQuerySearch was called with undefined queryFilter
|
|
@@ -181,6 +200,7 @@ describe('createVectorQueryTool', () => {
|
|
|
181
200
|
});
|
|
182
201
|
|
|
183
202
|
it('should process filter when enableFilter is true and filter is provided', async () => {
|
|
203
|
+
const runtimeContext = new RuntimeContext();
|
|
184
204
|
// Create tool with enableFilter set to true
|
|
185
205
|
const tool = createVectorQueryTool({
|
|
186
206
|
vectorStoreName: 'testStore',
|
|
@@ -199,7 +219,7 @@ describe('createVectorQueryTool', () => {
|
|
|
199
219
|
filter: filterJson,
|
|
200
220
|
},
|
|
201
221
|
mastra: mockMastra as any,
|
|
202
|
-
runtimeContext
|
|
222
|
+
runtimeContext,
|
|
203
223
|
});
|
|
204
224
|
|
|
205
225
|
// Check that vectorQuerySearch was called with the parsed filter
|
|
@@ -211,6 +231,7 @@ describe('createVectorQueryTool', () => {
|
|
|
211
231
|
});
|
|
212
232
|
|
|
213
233
|
it('should handle string filters correctly', async () => {
|
|
234
|
+
const runtimeContext = new RuntimeContext();
|
|
214
235
|
// Create tool with enableFilter set to true
|
|
215
236
|
const tool = createVectorQueryTool({
|
|
216
237
|
vectorStoreName: 'testStore',
|
|
@@ -229,7 +250,7 @@ describe('createVectorQueryTool', () => {
|
|
|
229
250
|
filter: stringFilter,
|
|
230
251
|
},
|
|
231
252
|
mastra: mockMastra as any,
|
|
232
|
-
runtimeContext
|
|
253
|
+
runtimeContext,
|
|
233
254
|
});
|
|
234
255
|
|
|
235
256
|
// Since this is not a valid filter, it should be ignored
|
|
@@ -240,4 +261,69 @@ describe('createVectorQueryTool', () => {
|
|
|
240
261
|
);
|
|
241
262
|
});
|
|
242
263
|
});
|
|
264
|
+
|
|
265
|
+
describe('runtimeContext', () => {
|
|
266
|
+
it('calls vectorQuerySearch with runtimeContext params', async () => {
|
|
267
|
+
const tool = createVectorQueryTool({
|
|
268
|
+
id: 'test',
|
|
269
|
+
model: mockModel,
|
|
270
|
+
indexName: 'testIndex',
|
|
271
|
+
vectorStoreName: 'testStore',
|
|
272
|
+
});
|
|
273
|
+
const runtimeContext = new RuntimeContext();
|
|
274
|
+
runtimeContext.set('indexName', 'anotherIndex');
|
|
275
|
+
runtimeContext.set('vectorStoreName', 'anotherStore');
|
|
276
|
+
runtimeContext.set('topK', 3);
|
|
277
|
+
runtimeContext.set('filter', { foo: 'bar' });
|
|
278
|
+
runtimeContext.set('includeVectors', true);
|
|
279
|
+
runtimeContext.set('includeSources', false);
|
|
280
|
+
const result = await tool.execute({
|
|
281
|
+
context: { queryText: 'foo', topK: 6 },
|
|
282
|
+
mastra: mockMastra as any,
|
|
283
|
+
runtimeContext,
|
|
284
|
+
});
|
|
285
|
+
expect(result.relevantContext.length).toBeGreaterThan(0);
|
|
286
|
+
expect(result.sources).toEqual([]); // includeSources false
|
|
287
|
+
expect(vectorQuerySearch).toHaveBeenCalledWith(
|
|
288
|
+
expect.objectContaining({
|
|
289
|
+
indexName: 'anotherIndex',
|
|
290
|
+
vectorStore: {
|
|
291
|
+
anotherStore: {},
|
|
292
|
+
},
|
|
293
|
+
queryText: 'foo',
|
|
294
|
+
model: mockModel,
|
|
295
|
+
queryFilter: { foo: 'bar' },
|
|
296
|
+
topK: 3,
|
|
297
|
+
includeVectors: true,
|
|
298
|
+
}),
|
|
299
|
+
);
|
|
300
|
+
});
|
|
301
|
+
|
|
302
|
+
it('handles reranker from runtimeContext', async () => {
|
|
303
|
+
const tool = createVectorQueryTool({
|
|
304
|
+
id: 'test',
|
|
305
|
+
model: mockModel,
|
|
306
|
+
indexName: 'testIndex',
|
|
307
|
+
vectorStoreName: 'testStore',
|
|
308
|
+
});
|
|
309
|
+
const runtimeContext = new RuntimeContext();
|
|
310
|
+
runtimeContext.set('indexName', 'testIndex');
|
|
311
|
+
runtimeContext.set('vectorStoreName', 'testStore');
|
|
312
|
+
runtimeContext.set('reranker', { model: 'reranker-model', options: { topK: 1 } });
|
|
313
|
+
// Mock rerank
|
|
314
|
+
vi.mocked(rerank).mockResolvedValue([
|
|
315
|
+
{
|
|
316
|
+
result: { id: '1', metadata: { text: 'bar' }, score: 1 },
|
|
317
|
+
score: 1,
|
|
318
|
+
details: { semantic: 1, vector: 1, position: 1 },
|
|
319
|
+
},
|
|
320
|
+
]);
|
|
321
|
+
const result = await tool.execute({
|
|
322
|
+
context: { queryText: 'foo', topK: 1 },
|
|
323
|
+
mastra: mockMastra as any,
|
|
324
|
+
runtimeContext,
|
|
325
|
+
});
|
|
326
|
+
expect(result.relevantContext[0]).toEqual({ text: 'bar' });
|
|
327
|
+
});
|
|
328
|
+
});
|
|
243
329
|
});
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import { createTool } from '@mastra/core/tools';
|
|
2
|
-
import type { EmbeddingModel } from 'ai';
|
|
3
2
|
import { z } from 'zod';
|
|
4
3
|
|
|
5
4
|
import { rerank } from '../rerank';
|
|
@@ -7,37 +6,33 @@ import type { RerankConfig } from '../rerank';
|
|
|
7
6
|
import { vectorQuerySearch, defaultVectorQueryDescription, filterSchema, outputSchema, baseSchema } from '../utils';
|
|
8
7
|
import type { RagTool } from '../utils';
|
|
9
8
|
import { convertToSources } from '../utils/convert-sources';
|
|
9
|
+
import type { VectorQueryToolOptions } from './types';
|
|
10
10
|
|
|
11
|
-
export const createVectorQueryTool = ({
|
|
12
|
-
|
|
13
|
-
indexName
|
|
14
|
-
model,
|
|
15
|
-
enableFilter = false,
|
|
16
|
-
includeVectors = false,
|
|
17
|
-
includeSources = true,
|
|
18
|
-
reranker,
|
|
19
|
-
id,
|
|
20
|
-
description,
|
|
21
|
-
}: {
|
|
22
|
-
vectorStoreName: string;
|
|
23
|
-
indexName: string;
|
|
24
|
-
model: EmbeddingModel<string>;
|
|
25
|
-
enableFilter?: boolean;
|
|
26
|
-
includeVectors?: boolean;
|
|
27
|
-
includeSources?: boolean;
|
|
28
|
-
reranker?: RerankConfig;
|
|
29
|
-
id?: string;
|
|
30
|
-
description?: string;
|
|
31
|
-
}) => {
|
|
32
|
-
const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
|
|
11
|
+
export const createVectorQueryTool = (options: VectorQueryToolOptions) => {
|
|
12
|
+
const { model, id, description } = options;
|
|
13
|
+
const toolId = id || `VectorQuery ${options.vectorStoreName} ${options.indexName} Tool`;
|
|
33
14
|
const toolDescription = description || defaultVectorQueryDescription();
|
|
34
|
-
const inputSchema = enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
15
|
+
const inputSchema = options.enableFilter ? filterSchema : z.object(baseSchema).passthrough();
|
|
16
|
+
|
|
35
17
|
return createTool({
|
|
36
18
|
id: toolId,
|
|
19
|
+
description: toolDescription,
|
|
37
20
|
inputSchema,
|
|
38
21
|
outputSchema,
|
|
39
|
-
|
|
40
|
-
|
|
22
|
+
execute: async ({ context, mastra, runtimeContext }) => {
|
|
23
|
+
const indexName: string = runtimeContext.get('indexName') ?? options.indexName;
|
|
24
|
+
const vectorStoreName: string = runtimeContext.get('vectorStoreName') ?? options.vectorStoreName;
|
|
25
|
+
const includeVectors: boolean = runtimeContext.get('includeVectors') ?? options.includeVectors ?? false;
|
|
26
|
+
const includeSources: boolean = runtimeContext.get('includeSources') ?? options.includeSources ?? true;
|
|
27
|
+
const reranker: RerankConfig = runtimeContext.get('reranker') ?? options.reranker;
|
|
28
|
+
if (!indexName) throw new Error(`indexName is required, got: ${indexName}`);
|
|
29
|
+
if (!vectorStoreName) throw new Error(`vectorStoreName is required, got: ${vectorStoreName}`);
|
|
30
|
+
|
|
31
|
+
const topK: number = runtimeContext.get('topK') ?? context.topK ?? 10;
|
|
32
|
+
const filter: Record<string, any> = runtimeContext.get('filter') ?? context.filter;
|
|
33
|
+
const queryText = context.queryText;
|
|
34
|
+
const enableFilter = !!runtimeContext.get('filter') || (options.enableFilter ?? false);
|
|
35
|
+
|
|
41
36
|
const logger = mastra?.getLogger();
|
|
42
37
|
if (!logger) {
|
|
43
38
|
console.warn(
|