@mastra/rag 0.1.23 → 0.2.0-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -3,6 +3,7 @@ import { z } from 'zod';
3
3
  import { parse } from 'node-html-better-parser';
4
4
  import { encodingForModel, getEncoding } from 'js-tiktoken';
5
5
  import { CohereRelevanceScorer, MastraAgentRelevanceScorer } from '@mastra/core/relevance';
6
+ import { Big } from 'big.js';
6
7
  import { createTool } from '@mastra/core/tools';
7
8
  import { embed } from 'ai';
8
9
 
@@ -5948,9 +5949,9 @@ async function rerank(results, query, model, options) {
5948
5949
  ...DEFAULT_WEIGHTS,
5949
5950
  ...options.weights
5950
5951
  };
5951
- const totalWeights = Object.values(weights).reduce((sum, weight) => sum + weight, 0);
5952
- if (totalWeights !== 1) {
5953
- throw new Error("Weights must add up to 1");
5952
+ const sum = Object.values(weights).reduce((acc, w) => acc.plus(w.toString()), new Big(0));
5953
+ if (!sum.eq(1)) {
5954
+ throw new Error(`Weights must add up to 1. Got ${sum} from ${weights}`);
5954
5955
  }
5955
5956
  const resultLength = results.length;
5956
5957
  const queryAnalysis = queryEmbedding ? analyzeQueryEmbedding(queryEmbedding) : null;
@@ -6269,12 +6270,44 @@ var filterDescription = `JSON-formatted criteria to refine search results.
6269
6270
  - IMPORTANT: Always ensure JSON is properly closed with matching brackets
6270
6271
  - Multiple filters can be combined`;
6271
6272
 
6273
+ // src/utils/convert-sources.ts
6274
+ var convertToSources = (results) => {
6275
+ return results.map((result) => {
6276
+ if ("content" in result) {
6277
+ return {
6278
+ id: result.id,
6279
+ vector: result.embedding || [],
6280
+ score: result.score,
6281
+ metadata: result.metadata,
6282
+ document: result.content || ""
6283
+ };
6284
+ }
6285
+ if ("result" in result) {
6286
+ return {
6287
+ id: result.result.id,
6288
+ vector: result.result.vector || [],
6289
+ score: result.score,
6290
+ metadata: result.result.metadata,
6291
+ document: result.result.document || ""
6292
+ };
6293
+ }
6294
+ return {
6295
+ id: result.id,
6296
+ vector: result.vector || [],
6297
+ score: result.score,
6298
+ metadata: result.metadata,
6299
+ document: result.document || ""
6300
+ };
6301
+ });
6302
+ };
6303
+
6272
6304
  // src/tools/graph-rag.ts
6273
6305
  var createGraphRAGTool = ({
6274
6306
  vectorStoreName,
6275
6307
  indexName,
6276
6308
  model,
6277
6309
  enableFilter = false,
6310
+ includeSources = true,
6278
6311
  graphOptions = {
6279
6312
  dimension: 1536,
6280
6313
  randomWalkSteps: 100,
@@ -6299,8 +6332,27 @@ var createGraphRAGTool = ({
6299
6332
  return createTool({
6300
6333
  id: toolId,
6301
6334
  inputSchema,
6335
+ // Output schema includes `sources`, which exposes the full set of retrieved chunks (QueryResult objects)
6336
+ // Each source contains all information needed to reference
6337
+ // the original document, chunk, and similarity score.
6302
6338
  outputSchema: z.object({
6303
- relevantContext: z.any()
6339
+ // Array of metadata or content for compatibility with prior usage
6340
+ relevantContext: z.any(),
6341
+ // Array of full retrieval result objects
6342
+ sources: z.array(
6343
+ z.object({
6344
+ id: z.string(),
6345
+ // Unique chunk/document identifier
6346
+ metadata: z.any(),
6347
+ // All metadata fields (document ID, etc.)
6348
+ vector: z.array(z.number()),
6349
+ // Embedding vector (if available)
6350
+ score: z.number(),
6351
+ // Similarity score for this retrieval
6352
+ document: z.string()
6353
+ // Full chunk/document text (if available)
6354
+ })
6355
+ )
6304
6356
  }),
6305
6357
  description: toolDescription,
6306
6358
  execute: async ({ context: { queryText, topK, filter }, mastra }) => {
@@ -6320,7 +6372,7 @@ var createGraphRAGTool = ({
6320
6372
  if (logger) {
6321
6373
  logger.error("Vector store not found", { vectorStoreName });
6322
6374
  }
6323
- return { relevantContext: [] };
6375
+ return { relevantContext: [], sources: [] };
6324
6376
  }
6325
6377
  let queryFilter = {};
6326
6378
  if (enableFilter) {
@@ -6379,8 +6431,10 @@ var createGraphRAGTool = ({
6379
6431
  if (logger) {
6380
6432
  logger.debug("Returning relevant context chunks", { count: relevantChunks.length });
6381
6433
  }
6434
+ const sources = includeSources ? convertToSources(rerankedResults) : [];
6382
6435
  return {
6383
- relevantContext: relevantChunks
6436
+ relevantContext: relevantChunks,
6437
+ sources
6384
6438
  };
6385
6439
  } catch (err) {
6386
6440
  if (logger) {
@@ -6390,7 +6444,7 @@ var createGraphRAGTool = ({
6390
6444
  errorStack: err instanceof Error ? err.stack : void 0
6391
6445
  });
6392
6446
  }
6393
- return { relevantContext: [] };
6447
+ return { relevantContext: [], sources: [] };
6394
6448
  }
6395
6449
  }
6396
6450
  });
@@ -6400,6 +6454,8 @@ var createVectorQueryTool = ({
6400
6454
  indexName,
6401
6455
  model,
6402
6456
  enableFilter = false,
6457
+ includeVectors = false,
6458
+ includeSources = true,
6403
6459
  reranker,
6404
6460
  id,
6405
6461
  description
@@ -6417,8 +6473,27 @@ var createVectorQueryTool = ({
6417
6473
  return createTool({
6418
6474
  id: toolId,
6419
6475
  inputSchema,
6476
+ // Output schema includes `sources`, which exposes the full set of retrieved chunks (QueryResult objects)
6477
+ // Each source contains all information needed to reference
6478
+ // the original document, chunk, and similarity score.
6420
6479
  outputSchema: z.object({
6421
- relevantContext: z.any()
6480
+ // Array of metadata or content for compatibility with prior usage
6481
+ relevantContext: z.any(),
6482
+ // Array of full retrieval result objects
6483
+ sources: z.array(
6484
+ z.object({
6485
+ id: z.string(),
6486
+ // Unique chunk/document identifier
6487
+ metadata: z.any(),
6488
+ // All metadata fields (document ID, etc.)
6489
+ vector: z.array(z.number()),
6490
+ // Embedding vector (if available)
6491
+ score: z.number(),
6492
+ // Similarity score for this retrieval
6493
+ document: z.string()
6494
+ // Full chunk/document text (if available)
6495
+ })
6496
+ )
6422
6497
  }),
6423
6498
  description: toolDescription,
6424
6499
  execute: async ({ context: { queryText, topK, filter }, mastra }) => {
@@ -6438,7 +6513,7 @@ var createVectorQueryTool = ({
6438
6513
  if (logger) {
6439
6514
  logger.error("Vector store not found", { vectorStoreName });
6440
6515
  }
6441
- return { relevantContext: [] };
6516
+ return { relevantContext: [], sources: [] };
6442
6517
  }
6443
6518
  let queryFilter = {};
6444
6519
  if (enableFilter && filter) {
@@ -6462,7 +6537,8 @@ var createVectorQueryTool = ({
6462
6537
  queryText,
6463
6538
  model,
6464
6539
  queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : void 0,
6465
- topK: topKValue
6540
+ topK: topKValue,
6541
+ includeVectors
6466
6542
  });
6467
6543
  if (logger) {
6468
6544
  logger.debug("vectorQuerySearch returned results", { count: results.length });
@@ -6482,14 +6558,17 @@ var createVectorQueryTool = ({
6482
6558
  if (logger) {
6483
6559
  logger.debug("Returning reranked relevant context chunks", { count: relevantChunks2.length });
6484
6560
  }
6485
- return { relevantContext: relevantChunks2 };
6561
+ const sources2 = includeSources ? convertToSources(rerankedResults) : [];
6562
+ return { relevantContext: relevantChunks2, sources: sources2 };
6486
6563
  }
6487
6564
  const relevantChunks = results.map((result) => result?.metadata);
6488
6565
  if (logger) {
6489
6566
  logger.debug("Returning relevant context chunks", { count: relevantChunks.length });
6490
6567
  }
6568
+ const sources = includeSources ? convertToSources(results) : [];
6491
6569
  return {
6492
- relevantContext: relevantChunks
6570
+ relevantContext: relevantChunks,
6571
+ sources
6493
6572
  };
6494
6573
  } catch (err) {
6495
6574
  if (logger) {
@@ -6499,7 +6578,7 @@ var createVectorQueryTool = ({
6499
6578
  errorStack: err instanceof Error ? err.stack : void 0
6500
6579
  });
6501
6580
  }
6502
- return { relevantContext: [] };
6581
+ return { relevantContext: [], sources: [] };
6503
6582
  }
6504
6583
  }
6505
6584
  });
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@mastra/rag",
3
- "version": "0.1.23",
3
+ "version": "0.2.0-alpha.1",
4
4
  "description": "",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
@@ -22,19 +22,21 @@
22
22
  "license": "Elastic-2.0",
23
23
  "dependencies": {
24
24
  "@paralleldrive/cuid2": "^2.2.2",
25
+ "big.js": "^7.0.1",
25
26
  "js-tiktoken": "^1.0.19",
26
27
  "node-html-better-parser": "^1.4.7",
27
28
  "pathe": "^2.0.3",
28
- "zod": "^3.24.3",
29
- "@mastra/core": "^0.9.4"
29
+ "zod": "^3.24.3"
30
30
  },
31
31
  "peerDependencies": {
32
- "ai": "^4.0.0"
32
+ "ai": "^4.0.0",
33
+ "@mastra/core": "^0.9.4"
33
34
  },
34
35
  "devDependencies": {
35
36
  "@ai-sdk/cohere": "latest",
36
37
  "@ai-sdk/openai": "latest",
37
38
  "@microsoft/api-extractor": "^7.52.5",
39
+ "@types/big.js": "^6.2.2",
38
40
  "@types/node": "^20.17.27",
39
41
  "ai": "^4.2.2",
40
42
  "dotenv": "^16.4.7",
@@ -42,7 +44,8 @@
42
44
  "tsup": "^8.4.0",
43
45
  "typescript": "^5.8.2",
44
46
  "vitest": "^3.1.2",
45
- "@internal/lint": "0.0.5"
47
+ "@internal/lint": "0.0.5",
48
+ "@mastra/core": "0.10.0-alpha.1"
46
49
  },
47
50
  "keywords": [
48
51
  "rag",
@@ -16,7 +16,7 @@ export interface GraphNode {
16
16
  metadata?: Record<string, any>;
17
17
  }
18
18
 
19
- interface RankedNode extends GraphNode {
19
+ export interface RankedNode extends GraphNode {
20
20
  score: number;
21
21
  }
22
22
 
@@ -2,6 +2,7 @@ import type { MastraLanguageModel } from '@mastra/core/agent';
2
2
  import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from '@mastra/core/relevance';
3
3
  import type { RelevanceScoreProvider } from '@mastra/core/relevance';
4
4
  import type { QueryResult } from '@mastra/core/vector';
5
+ import { Big } from 'big.js';
5
6
 
6
7
  // Default weights for different scoring components (must add up to 1)
7
8
  const DEFAULT_WEIGHTS = {
@@ -102,9 +103,9 @@ export async function rerank(
102
103
  };
103
104
 
104
105
  //weights must add up to 1
105
- const totalWeights = Object.values(weights).reduce((sum, weight) => sum + weight, 0);
106
- if (totalWeights !== 1) {
107
- throw new Error('Weights must add up to 1');
106
+ const sum = Object.values(weights).reduce((acc: Big, w: number) => acc.plus(w.toString()), new Big(0));
107
+ if (!sum.eq(1)) {
108
+ throw new Error(`Weights must add up to 1. Got ${sum} from ${weights}`);
108
109
  }
109
110
 
110
111
  const resultLength = results.length;
@@ -10,12 +10,14 @@ import {
10
10
  topKDescription,
11
11
  queryTextDescription,
12
12
  } from '../utils';
13
+ import { convertToSources } from '../utils/convert-sources';
13
14
 
14
15
  export const createGraphRAGTool = ({
15
16
  vectorStoreName,
16
17
  indexName,
17
18
  model,
18
19
  enableFilter = false,
20
+ includeSources = true,
19
21
  graphOptions = {
20
22
  dimension: 1536,
21
23
  randomWalkSteps: 100,
@@ -29,6 +31,7 @@ export const createGraphRAGTool = ({
29
31
  indexName: string;
30
32
  model: EmbeddingModel<string>;
31
33
  enableFilter?: boolean;
34
+ includeSources?: boolean;
32
35
  graphOptions?: {
33
36
  dimension?: number;
34
37
  randomWalkSteps?: number;
@@ -37,7 +40,7 @@ export const createGraphRAGTool = ({
37
40
  };
38
41
  id?: string;
39
42
  description?: string;
40
- }): ReturnType<typeof createTool> => {
43
+ }) => {
41
44
  const toolId = id || `GraphRAG ${vectorStoreName} ${indexName} Tool`;
42
45
  const toolDescription = description || defaultGraphRagDescription();
43
46
  // Initialize GraphRAG
@@ -59,8 +62,22 @@ export const createGraphRAGTool = ({
59
62
  return createTool({
60
63
  id: toolId,
61
64
  inputSchema,
65
+ // Output schema includes `sources`, which exposes the full set of retrieved chunks (QueryResult objects)
66
+ // Each source contains all information needed to reference
67
+ // the original document, chunk, and similarity score.
62
68
  outputSchema: z.object({
69
+ // Array of metadata or content for compatibility with prior usage
63
70
  relevantContext: z.any(),
71
+ // Array of full retrieval result objects
72
+ sources: z.array(
73
+ z.object({
74
+ id: z.string(), // Unique chunk/document identifier
75
+ metadata: z.any(), // All metadata fields (document ID, etc.)
76
+ vector: z.array(z.number()), // Embedding vector (if available)
77
+ score: z.number(), // Similarity score for this retrieval
78
+ document: z.string(), // Full chunk/document text (if available)
79
+ }),
80
+ ),
64
81
  }),
65
82
  description: toolDescription,
66
83
  execute: async ({ context: { queryText, topK, filter }, mastra }) => {
@@ -86,7 +103,7 @@ export const createGraphRAGTool = ({
86
103
  if (logger) {
87
104
  logger.error('Vector store not found', { vectorStoreName });
88
105
  }
89
- return { relevantContext: [] };
106
+ return { relevantContext: [], sources: [] };
90
107
  }
91
108
 
92
109
  let queryFilter = {};
@@ -154,8 +171,11 @@ export const createGraphRAGTool = ({
154
171
  if (logger) {
155
172
  logger.debug('Returning relevant context chunks', { count: relevantChunks.length });
156
173
  }
174
+ // `sources` exposes the full retrieval objects
175
+ const sources = includeSources ? convertToSources(rerankedResults) : [];
157
176
  return {
158
177
  relevantContext: relevantChunks,
178
+ sources,
159
179
  };
160
180
  } catch (err) {
161
181
  if (logger) {
@@ -165,7 +185,7 @@ export const createGraphRAGTool = ({
165
185
  errorStack: err instanceof Error ? err.stack : undefined,
166
186
  });
167
187
  }
168
- return { relevantContext: [] };
188
+ return { relevantContext: [], sources: [] };
169
189
  }
170
190
  },
171
191
  });
@@ -11,12 +11,15 @@ import {
11
11
  topKDescription,
12
12
  queryTextDescription,
13
13
  } from '../utils';
14
+ import { convertToSources } from '../utils/convert-sources';
14
15
 
15
16
  export const createVectorQueryTool = ({
16
17
  vectorStoreName,
17
18
  indexName,
18
19
  model,
19
20
  enableFilter = false,
21
+ includeVectors = false,
22
+ includeSources = true,
20
23
  reranker,
21
24
  id,
22
25
  description,
@@ -25,10 +28,12 @@ export const createVectorQueryTool = ({
25
28
  indexName: string;
26
29
  model: EmbeddingModel<string>;
27
30
  enableFilter?: boolean;
31
+ includeVectors?: boolean;
32
+ includeSources?: boolean;
28
33
  reranker?: RerankConfig;
29
34
  id?: string;
30
35
  description?: string;
31
- }): ReturnType<typeof createTool> => {
36
+ }) => {
32
37
  const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
33
38
  const toolDescription = description || defaultVectorQueryDescription();
34
39
  // Create base schema with required fields
@@ -47,8 +52,22 @@ export const createVectorQueryTool = ({
47
52
  return createTool({
48
53
  id: toolId,
49
54
  inputSchema,
55
+ // Output schema includes `sources`, which exposes the full set of retrieved chunks (QueryResult objects)
56
+ // Each source contains all information needed to reference
57
+ // the original document, chunk, and similarity score.
50
58
  outputSchema: z.object({
59
+ // Array of metadata or content for compatibility with prior usage
51
60
  relevantContext: z.any(),
61
+ // Array of full retrieval result objects
62
+ sources: z.array(
63
+ z.object({
64
+ id: z.string(), // Unique chunk/document identifier
65
+ metadata: z.any(), // All metadata fields (document ID, etc.)
66
+ vector: z.array(z.number()), // Embedding vector (if available)
67
+ score: z.number(), // Similarity score for this retrieval
68
+ document: z.string(), // Full chunk/document text (if available)
69
+ }),
70
+ ),
52
71
  }),
53
72
  description: toolDescription,
54
73
  execute: async ({ context: { queryText, topK, filter }, mastra }) => {
@@ -75,7 +94,7 @@ export const createVectorQueryTool = ({
75
94
  if (logger) {
76
95
  logger.error('Vector store not found', { vectorStoreName });
77
96
  }
78
- return { relevantContext: [] };
97
+ return { relevantContext: [], sources: [] };
79
98
  }
80
99
  // Get relevant chunks from the vector database
81
100
  let queryFilter = {};
@@ -103,6 +122,7 @@ export const createVectorQueryTool = ({
103
122
  model,
104
123
  queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
105
124
  topK: topKValue,
125
+ includeVectors,
106
126
  });
107
127
  if (logger) {
108
128
  logger.debug('vectorQuerySearch returned results', { count: results.length });
@@ -122,15 +142,19 @@ export const createVectorQueryTool = ({
122
142
  if (logger) {
123
143
  logger.debug('Returning reranked relevant context chunks', { count: relevantChunks.length });
124
144
  }
125
- return { relevantContext: relevantChunks };
145
+ const sources = includeSources ? convertToSources(rerankedResults) : [];
146
+ return { relevantContext: relevantChunks, sources };
126
147
  }
127
148
 
128
149
  const relevantChunks = results.map(result => result?.metadata);
129
150
  if (logger) {
130
151
  logger.debug('Returning relevant context chunks', { count: relevantChunks.length });
131
152
  }
153
+ // `sources` exposes the full retrieval objects
154
+ const sources = includeSources ? convertToSources(results) : [];
132
155
  return {
133
156
  relevantContext: relevantChunks,
157
+ sources,
134
158
  };
135
159
  } catch (err) {
136
160
  if (logger) {
@@ -140,7 +164,7 @@ export const createVectorQueryTool = ({
140
164
  errorStack: err instanceof Error ? err.stack : undefined,
141
165
  });
142
166
  }
143
- return { relevantContext: [] };
167
+ return { relevantContext: [], sources: [] };
144
168
  }
145
169
  },
146
170
  });
@@ -0,0 +1,43 @@
1
+ import type { QueryResult } from '@mastra/core';
2
+ import type { RankedNode } from '../graph-rag';
3
+ import type { RerankResult } from '../rerank';
4
+
5
+ type SourceInput = QueryResult | RankedNode | RerankResult;
6
+
7
+ /**
8
+ * Convert an array of source inputs (QueryResult, RankedNode, or RerankResult) to an array of sources.
9
+ * @param results Array of source inputs to convert.
10
+ * @returns Array of sources.
11
+ */
12
+ export const convertToSources = (results: SourceInput[]) => {
13
+ return results.map(result => {
14
+ // RankedNode
15
+ if ('content' in result) {
16
+ return {
17
+ id: result.id,
18
+ vector: result.embedding || [],
19
+ score: result.score,
20
+ metadata: result.metadata,
21
+ document: result.content || '',
22
+ };
23
+ }
24
+ // RerankResult
25
+ if ('result' in result) {
26
+ return {
27
+ id: result.result.id,
28
+ vector: result.result.vector || [],
29
+ score: result.score,
30
+ metadata: result.result.metadata,
31
+ document: result.result.document || '',
32
+ };
33
+ }
34
+ // QueryResult
35
+ return {
36
+ id: result.id,
37
+ vector: result.vector || [],
38
+ score: result.score,
39
+ metadata: result.metadata,
40
+ document: result.document || '',
41
+ };
42
+ });
43
+ };