@mastra/rag 0.1.0-alpha.86 → 0.1.0-alpha.87

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/CHANGELOG.md CHANGED
@@ -1,5 +1,18 @@
1
1
  # @mastra/rag
2
2
 
3
+ ## 0.1.0-alpha.87
4
+
5
+ ### Minor Changes
6
+
7
+ - 8b416d9: Breaking changes
8
+
9
+ ### Patch Changes
10
+
11
+ - 9c10484: update all packages
12
+ - Updated dependencies [9c10484]
13
+ - Updated dependencies [8b416d9]
14
+ - @mastra/core@0.2.0-alpha.94
15
+
3
16
  ## 0.1.0-alpha.86
4
17
 
5
18
  ### Patch Changes
package/dist/index.d.ts CHANGED
@@ -1,11 +1,8 @@
1
1
  import { LLM, TitleExtractorPrompt, TitleCombinePrompt, SummaryPrompt, QuestionExtractPrompt, Document } from 'llamaindex';
2
2
  import { TiktokenEncoding, TiktokenModel } from 'js-tiktoken';
3
- import * as _mastra_core from '@mastra/core';
4
- import { EmbeddingOptions, EmbedResult, EmbedManyResult, ModelConfig } from '@mastra/core';
5
3
  import { QueryResult } from '@mastra/core/vector';
6
- import * as _mastra_core_tools from '@mastra/core/tools';
7
- import { z } from 'zod';
8
- import { EmbeddingOptions as EmbeddingOptions$1 } from '@mastra/core/embeddings';
4
+ import { LanguageModelV1, EmbeddingModel } from 'ai';
5
+ import { createTool } from '@mastra/core/tools';
9
6
 
10
7
  declare enum Language {
11
8
  CPP = "cpp",
@@ -118,9 +115,6 @@ declare class MDocument {
118
115
  getMetadata(): Record<string, any>[];
119
116
  }
120
117
 
121
- declare function embed(chunk: Document | string, options: EmbeddingOptions): Promise<EmbedResult<string>>;
122
- declare function embedMany(chunks: (Document | string)[], options: EmbeddingOptions): Promise<EmbedManyResult<string>>;
123
-
124
118
  type WeightConfig = {
125
119
  semantic?: number;
126
120
  vector?: number;
@@ -151,9 +145,9 @@ interface RerankerFunctionOptions {
151
145
  }
152
146
  interface RerankConfig {
153
147
  options?: RerankerOptions;
154
- model: ModelConfig;
148
+ model: LanguageModelV1;
155
149
  }
156
- declare function rerank(results: QueryResult[], query: string, modelConfig: ModelConfig, options: RerankerFunctionOptions): Promise<RerankResult[]>;
150
+ declare function rerank(results: QueryResult[], query: string, model: LanguageModelV1, options: RerankerFunctionOptions): Promise<RerankResult[]>;
157
151
 
158
152
  /**
159
153
  * TODO: GraphRAG Enhancements
@@ -214,12 +208,12 @@ declare class GraphRAG {
214
208
  declare const createDocumentChunkerTool: ({ doc, params, }: {
215
209
  doc: MDocument;
216
210
  params?: ChunkParams;
217
- }) => _mastra_core_tools.Tool<"Document Chunker undefined undefined" | `Document Chunker undefined ${number}` | "Document Chunker markdown undefined" | `Document Chunker markdown ${number}` | "Document Chunker latex undefined" | `Document Chunker latex ${number}` | "Document Chunker html undefined" | `Document Chunker html ${number}` | "Document Chunker recursive undefined" | `Document Chunker recursive ${number}` | "Document Chunker character undefined" | `Document Chunker character ${number}` | "Document Chunker token undefined" | `Document Chunker token ${number}` | "Document Chunker json undefined" | `Document Chunker json ${number}`, z.ZodObject<{}, "strip", z.ZodTypeAny, {}, {}>, undefined, _mastra_core.ToolExecutionContext<z.ZodObject<{}, "strip", z.ZodTypeAny, {}, {}>, _mastra_core.WorkflowContext<any>>>;
211
+ }) => ReturnType<typeof createTool>;
218
212
 
219
- declare const createGraphRAGTool: ({ vectorStoreName, indexName, options, enableFilter, graphOptions, id, description, }: {
213
+ declare const createGraphRAGTool: ({ vectorStoreName, indexName, model, enableFilter, graphOptions, id, description, }: {
220
214
  vectorStoreName: string;
221
215
  indexName: string;
222
- options: EmbeddingOptions$1;
216
+ model: EmbeddingModel<string>;
223
217
  enableFilter?: boolean;
224
218
  graphOptions?: {
225
219
  dimension?: number;
@@ -229,77 +223,17 @@ declare const createGraphRAGTool: ({ vectorStoreName, indexName, options, enable
229
223
  };
230
224
  id?: string;
231
225
  description?: string;
232
- }) => _mastra_core_tools.Tool<string, z.ZodObject<{
233
- queryText: z.ZodString;
234
- topK: z.ZodNumber;
235
- filter: z.ZodString;
236
- }, "strip", z.ZodTypeAny, {
237
- filter: string;
238
- topK: number;
239
- queryText: string;
240
- }, {
241
- filter: string;
242
- topK: number;
243
- queryText: string;
244
- }>, z.ZodObject<{
245
- relevantContext: z.ZodAny;
246
- }, "strip", z.ZodTypeAny, {
247
- relevantContext?: any;
248
- }, {
249
- relevantContext?: any;
250
- }>, _mastra_core.ToolExecutionContext<z.ZodObject<{
251
- queryText: z.ZodString;
252
- topK: z.ZodNumber;
253
- filter: z.ZodString;
254
- }, "strip", z.ZodTypeAny, {
255
- filter: string;
256
- topK: number;
257
- queryText: string;
258
- }, {
259
- filter: string;
260
- topK: number;
261
- queryText: string;
262
- }>, _mastra_core.WorkflowContext<any>>>;
226
+ }) => ReturnType<typeof createTool>;
263
227
 
264
- declare const createVectorQueryTool: ({ vectorStoreName, indexName, options, enableFilter, reranker, id, description, }: {
228
+ declare const createVectorQueryTool: ({ vectorStoreName, indexName, model, enableFilter, reranker, id, description, }: {
265
229
  vectorStoreName: string;
266
230
  indexName: string;
267
- options: EmbeddingOptions$1;
231
+ model: EmbeddingModel<string>;
268
232
  enableFilter?: boolean;
269
233
  reranker?: RerankConfig;
270
234
  id?: string;
271
235
  description?: string;
272
- }) => _mastra_core_tools.Tool<string, z.ZodObject<{
273
- queryText: z.ZodString;
274
- topK: z.ZodNumber;
275
- filter: z.ZodString;
276
- }, "strip", z.ZodTypeAny, {
277
- filter: string;
278
- topK: number;
279
- queryText: string;
280
- }, {
281
- filter: string;
282
- topK: number;
283
- queryText: string;
284
- }>, z.ZodObject<{
285
- relevantContext: z.ZodAny;
286
- }, "strip", z.ZodTypeAny, {
287
- relevantContext?: any;
288
- }, {
289
- relevantContext?: any;
290
- }>, _mastra_core.ToolExecutionContext<z.ZodObject<{
291
- queryText: z.ZodString;
292
- topK: z.ZodNumber;
293
- filter: z.ZodString;
294
- }, "strip", z.ZodTypeAny, {
295
- filter: string;
296
- topK: number;
297
- queryText: string;
298
- }, {
299
- filter: string;
300
- topK: number;
301
- queryText: string;
302
- }>, _mastra_core.WorkflowContext<any>>>;
236
+ }) => ReturnType<typeof createTool>;
303
237
 
304
238
  /**
305
239
  * Vector store specific prompts that detail supported operators and examples.
@@ -319,4 +253,4 @@ declare const defaultFilter = "You MUST generate for each query:\n filter: qu
319
253
  declare const defaultVectorQueryDescription: (vectorStoreName: string, indexName: string) => string;
320
254
  declare const defaultGraphRagDescription: (vectorStoreName: string, indexName: string) => string;
321
255
 
322
- export { ASTRA_PROMPT, CHROMA_PROMPT, GraphRAG, LIBSQL_PROMPT, MDocument, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, type RerankConfig, type RerankResult, type RerankerFunctionOptions, type RerankerOptions, UPSTASH_PROMPT, VECTORIZE_PROMPT, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultFilter, defaultGraphRagDescription, defaultTopK, defaultVectorQueryDescription, embed, embedMany, rerank };
256
+ export { ASTRA_PROMPT, CHROMA_PROMPT, GraphRAG, LIBSQL_PROMPT, MDocument, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, type RerankConfig, type RerankResult, type RerankerFunctionOptions, type RerankerOptions, UPSTASH_PROMPT, VECTORIZE_PROMPT, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultFilter, defaultGraphRagDescription, defaultTopK, defaultVectorQueryDescription, rerank };
package/dist/index.js CHANGED
@@ -1,10 +1,10 @@
1
1
  import { Document, SummaryExtractor, QuestionsAnsweredExtractor, KeywordExtractor, TitleExtractor, IngestionPipeline } from 'llamaindex';
2
2
  import { parse } from 'node-html-better-parser';
3
3
  import { encodingForModel, getEncoding } from 'js-tiktoken';
4
- import { embed as embed$1, embedMany as embedMany$1 } from '@mastra/core/embeddings';
5
- import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from '@mastra/core/relevance';
4
+ import { CohereRelevanceScorer, MastraAgentRelevanceScorer } from '@mastra/core/relevance';
6
5
  import { createTool } from '@mastra/core/tools';
7
6
  import { z } from 'zod';
7
+ import { embed } from 'ai';
8
8
 
9
9
  // src/document/document.ts
10
10
 
@@ -1272,15 +1272,6 @@ var MDocument = class _MDocument {
1272
1272
  return this.chunks.map((doc) => doc.metadata);
1273
1273
  }
1274
1274
  };
1275
- function getText(input) {
1276
- return input instanceof Document ? input.getText() : input;
1277
- }
1278
- function embed(chunk, options) {
1279
- return embed$1(getText(chunk), options);
1280
- }
1281
- function embedMany(chunks, options) {
1282
- return embedMany$1(chunks.map(getText), options);
1283
- }
1284
1275
  var DEFAULT_WEIGHTS = {
1285
1276
  semantic: 0.4,
1286
1277
  vector: 0.4,
@@ -1299,17 +1290,12 @@ function adjustScores(score, queryAnalysis) {
1299
1290
  const featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
1300
1291
  return score * magnitudeAdjustment * featureStrengthAdjustment;
1301
1292
  }
1302
- async function rerank(results, query, modelConfig, options) {
1303
- const { provider } = modelConfig;
1293
+ async function rerank(results, query, model, options) {
1304
1294
  let semanticProvider;
1305
- if ("model" in modelConfig) {
1306
- semanticProvider = new MastraAgentRelevanceScorer(provider, "CUSTOM_MODEL", modelConfig.model);
1307
- } else if (provider === "COHERE" && "name" in modelConfig && modelConfig.name === "rerank-v3.5") {
1308
- semanticProvider = new CohereRelevanceScorer(modelConfig.name, modelConfig.apiKey);
1309
- } else if ("name" in modelConfig) {
1310
- semanticProvider = new MastraAgentRelevanceScorer(provider, modelConfig.name);
1295
+ if (model.modelId === "rerank-v3.5") {
1296
+ semanticProvider = new CohereRelevanceScorer(model.modelId);
1311
1297
  } else {
1312
- throw new Error("Invalid model configuration");
1298
+ semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
1313
1299
  }
1314
1300
  const { queryEmbedding, topK = 3 } = options;
1315
1301
  const weights = {
@@ -1571,18 +1557,21 @@ var createDocumentChunkerTool = ({
1571
1557
  }
1572
1558
  });
1573
1559
  };
1574
-
1575
- // src/utils/vector-search.ts
1576
1560
  var vectorQuerySearch = async ({
1577
1561
  indexName,
1578
1562
  vectorStore,
1579
1563
  queryText,
1580
- options,
1564
+ model,
1581
1565
  queryFilter = {},
1582
1566
  topK,
1583
- includeVectors = false
1567
+ includeVectors = false,
1568
+ maxRetries = 2
1584
1569
  }) => {
1585
- const { embedding } = await embed(queryText, options);
1570
+ const { embedding } = await embed({
1571
+ value: queryText,
1572
+ model,
1573
+ maxRetries
1574
+ });
1586
1575
  const results = await vectorStore.query(indexName, embedding, topK, queryFilter, includeVectors);
1587
1576
  return { results, queryEmbedding: embedding };
1588
1577
  };
@@ -1621,7 +1610,7 @@ var defaultGraphRagDescription = (vectorStoreName, indexName) => `Fetches and re
1621
1610
  var createGraphRAGTool = ({
1622
1611
  vectorStoreName,
1623
1612
  indexName,
1624
- options,
1613
+ model,
1625
1614
  enableFilter = false,
1626
1615
  graphOptions = {
1627
1616
  dimension: 1536,
@@ -1667,7 +1656,7 @@ var createGraphRAGTool = ({
1667
1656
  indexName,
1668
1657
  vectorStore,
1669
1658
  queryText,
1670
- options,
1659
+ model,
1671
1660
  queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
1672
1661
  topK,
1673
1662
  includeVectors: true
@@ -1703,7 +1692,7 @@ var createGraphRAGTool = ({
1703
1692
  var createVectorQueryTool = ({
1704
1693
  vectorStoreName,
1705
1694
  indexName,
1706
- options,
1695
+ model,
1707
1696
  enableFilter = false,
1708
1697
  reranker,
1709
1698
  id,
@@ -1742,7 +1731,7 @@ var createVectorQueryTool = ({
1742
1731
  indexName,
1743
1732
  vectorStore,
1744
1733
  queryText,
1745
- options,
1734
+ model,
1746
1735
  queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
1747
1736
  topK
1748
1737
  });
@@ -2485,4 +2474,4 @@ Example Complex Query:
2485
2474
  "inStock": true
2486
2475
  }`;
2487
2476
 
2488
- export { ASTRA_PROMPT, CHROMA_PROMPT, GraphRAG, LIBSQL_PROMPT, MDocument, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, UPSTASH_PROMPT, VECTORIZE_PROMPT, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultFilter, defaultGraphRagDescription, defaultTopK, defaultVectorQueryDescription, embed, embedMany, rerank };
2477
+ export { ASTRA_PROMPT, CHROMA_PROMPT, GraphRAG, LIBSQL_PROMPT, MDocument, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, UPSTASH_PROMPT, VECTORIZE_PROMPT, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultFilter, defaultGraphRagDescription, defaultTopK, defaultVectorQueryDescription, rerank };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@mastra/rag",
3
- "version": "0.1.0-alpha.86",
3
+ "version": "0.1.0-alpha.87",
4
4
  "description": "",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
@@ -25,13 +25,19 @@
25
25
  "node-html-better-parser": "^1.4.7",
26
26
  "pathe": "^2.0.2",
27
27
  "zod": "^3.24.1",
28
- "@mastra/core": "^0.2.0-alpha.93"
28
+ "@mastra/core": "^0.2.0-alpha.94"
29
+ },
30
+ "peerDependencies": {
31
+ "ai": "^4.0.0"
29
32
  },
30
33
  "devDependencies": {
34
+ "@ai-sdk/openai": "latest",
35
+ "@ai-sdk/cohere": "latest",
31
36
  "@babel/preset-env": "^7.26.0",
32
37
  "@babel/preset-typescript": "^7.26.0",
33
38
  "@tsconfig/recommended": "^1.0.7",
34
39
  "@types/node": "^22.9.0",
40
+ "ai": "^4.0.34",
35
41
  "tsup": "^8.0.1",
36
42
  "vitest": "^3.0.4"
37
43
  },
@@ -1,4 +1,6 @@
1
- import { embedMany } from '../embeddings';
1
+ import { createOpenAI } from '@ai-sdk/openai';
2
+ import { embedMany } from 'ai';
3
+ import { describe, it, expect } from 'vitest';
2
4
 
3
5
  import { MDocument } from './document';
4
6
  import { Language } from './types';
@@ -14,6 +16,10 @@ Welcome to our comprehensive guide on modern web development. This resource cove
14
16
  - Senior developers seeking a refresher on current best practices
15
17
  `;
16
18
 
19
+ const openai = createOpenAI({
20
+ apiKey: process.env.OPENAI_API_KEY,
21
+ });
22
+
17
23
  describe('MDocument', () => {
18
24
  describe('basics', () => {
19
25
  let chunks: MDocument['chunks'];
@@ -48,10 +54,9 @@ describe('MDocument', () => {
48
54
  }, 15000);
49
55
 
50
56
  it('embed - create embedding from chunk', async () => {
51
- const embeddings = await embedMany(chunks, {
52
- provider: 'OPEN_AI',
53
- model: 'text-embedding-3-small',
54
- maxRetries: 3,
57
+ const embeddings = await embedMany({
58
+ values: chunks.map(chunk => chunk.text),
59
+ model: openai.embedding('text-embedding-3-small'),
55
60
  });
56
61
 
57
62
  console.log(embeddings);
package/src/index.ts CHANGED
@@ -1,5 +1,4 @@
1
1
  export * from './document/document';
2
- export * from './embeddings';
3
2
  export * from './rerank';
4
3
  export { GraphRAG } from './graph-rag';
5
4
  export * from './tools';
@@ -1,3 +1,4 @@
1
+ import { cohere } from '@ai-sdk/cohere';
1
2
  import { CohereRelevanceScorer } from '@mastra/core/relevance';
2
3
  import { describe, it, expect, vi, beforeEach } from 'vitest';
3
4
 
@@ -25,12 +26,7 @@ describe('rerank', () => {
25
26
  { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
26
27
  ];
27
28
  await expect(
28
- rerank(
29
- results,
30
- 'test query',
31
- { provider: 'COHERE', name: 'rerank-v3.5' },
32
- { weights: { semantic: 0.5, vector: 0.3, position: 0.5 } },
33
- ),
29
+ rerank(results, 'test query', cohere('rerank-v3.5'), { weights: { semantic: 0.5, vector: 0.3, position: 0.5 } }),
34
30
  ).rejects.toThrow('Weights must add up to 1');
35
31
  });
36
32
 
@@ -41,15 +37,7 @@ describe('rerank', () => {
41
37
  { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
42
38
  ];
43
39
 
44
- const rerankedResults = await rerank(
45
- results,
46
- 'test query',
47
- {
48
- provider: 'COHERE',
49
- name: 'rerank-v3.5',
50
- },
51
- { topK: 2 },
52
- );
40
+ const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), { topK: 2 });
53
41
 
54
42
  expect(rerankedResults).toHaveLength(2);
55
43
  expect(rerankedResults[0]).toStrictEqual({
@@ -83,22 +71,14 @@ describe('rerank', () => {
83
71
  { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
84
72
  ];
85
73
 
86
- const rerankedResults = await rerank(
87
- results,
88
- 'test query',
89
- {
90
- provider: 'COHERE',
91
- name: 'rerank-v3.5',
92
- },
93
- {
94
- weights: {
95
- semantic: 0.5,
96
- vector: 0.4,
97
- position: 0.1,
98
- },
99
- topK: 2,
74
+ const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), {
75
+ weights: {
76
+ semantic: 0.5,
77
+ vector: 0.4,
78
+ position: 0.1,
100
79
  },
101
- );
80
+ topK: 2,
81
+ });
102
82
 
103
83
  expect(rerankedResults).toHaveLength(2);
104
84
  expect(rerankedResults[0]).toStrictEqual({
@@ -130,18 +110,10 @@ describe('rerank', () => {
130
110
  { id: '2', metadata: { text: 'Test result 2' }, score: 0.6 },
131
111
  ];
132
112
 
133
- const rerankedResults = await rerank(
134
- results,
135
- 'test query',
136
- {
137
- provider: 'COHERE',
138
- name: 'rerank-v3.5',
139
- },
140
- {
141
- queryEmbedding: [0.5, 0.3, -0.2, 0.4],
142
- topK: 2,
143
- },
144
- );
113
+ const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), {
114
+ queryEmbedding: [0.5, 0.3, -0.2, 0.4],
115
+ topK: 2,
116
+ });
145
117
 
146
118
  // Ensure query embedding analysis is being applied (we don't know exact score without knowing internals)
147
119
  expect(rerankedResults).toHaveLength(2);
@@ -1,6 +1,6 @@
1
- import { type ModelConfig } from '@mastra/core';
2
1
  import { MastraAgentRelevanceScorer, CohereRelevanceScorer, RelevanceScoreProvider } from '@mastra/core/relevance';
3
2
  import { QueryResult } from '@mastra/core/vector';
3
+ import { LanguageModelV1 } from 'ai';
4
4
 
5
5
  // Default weights for different scoring components (must add up to 1)
6
6
  const DEFAULT_WEIGHTS = {
@@ -46,7 +46,7 @@ export interface RerankerFunctionOptions {
46
46
 
47
47
  export interface RerankConfig {
48
48
  options?: RerankerOptions;
49
- model: ModelConfig;
49
+ model: LanguageModelV1;
50
50
  }
51
51
 
52
52
  // Calculate position score based on position in original list
@@ -85,19 +85,14 @@ function adjustScores(score: number, queryAnalysis: { magnitude: number; dominan
85
85
  export async function rerank(
86
86
  results: QueryResult[],
87
87
  query: string,
88
- modelConfig: ModelConfig,
88
+ model: LanguageModelV1,
89
89
  options: RerankerFunctionOptions,
90
90
  ): Promise<RerankResult[]> {
91
- const { provider } = modelConfig;
92
91
  let semanticProvider: RelevanceScoreProvider;
93
- if ('model' in modelConfig) {
94
- semanticProvider = new MastraAgentRelevanceScorer(provider, 'CUSTOM_MODEL', modelConfig.model);
95
- } else if (provider === 'COHERE' && 'name' in modelConfig && modelConfig.name === 'rerank-v3.5') {
96
- semanticProvider = new CohereRelevanceScorer(modelConfig.name, modelConfig.apiKey);
97
- } else if ('name' in modelConfig) {
98
- semanticProvider = new MastraAgentRelevanceScorer(provider, modelConfig.name);
92
+ if (model.modelId === 'rerank-v3.5') {
93
+ semanticProvider = new CohereRelevanceScorer(model.modelId);
99
94
  } else {
100
- throw new Error('Invalid model configuration');
95
+ semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
101
96
  }
102
97
  const { queryEmbedding, topK = 3 } = options;
103
98
  const weights = {
@@ -14,7 +14,7 @@ export const createDocumentChunkerTool = ({
14
14
  }: {
15
15
  doc: MDocument;
16
16
  params?: ChunkParams;
17
- }) => {
17
+ }): ReturnType<typeof createTool> => {
18
18
  return createTool({
19
19
  id: `Document Chunker ${params.strategy} ${params.size}`,
20
20
  inputSchema: z.object({}),
@@ -1,5 +1,5 @@
1
- import { type EmbeddingOptions } from '@mastra/core/embeddings';
2
1
  import { createTool } from '@mastra/core/tools';
2
+ import { EmbeddingModel } from 'ai';
3
3
  import { z } from 'zod';
4
4
 
5
5
  import { GraphRAG } from '../graph-rag';
@@ -8,7 +8,7 @@ import { vectorQuerySearch, defaultGraphRagDescription } from '../utils';
8
8
  export const createGraphRAGTool = ({
9
9
  vectorStoreName,
10
10
  indexName,
11
- options,
11
+ model,
12
12
  enableFilter = false,
13
13
  graphOptions = {
14
14
  dimension: 1536,
@@ -21,7 +21,7 @@ export const createGraphRAGTool = ({
21
21
  }: {
22
22
  vectorStoreName: string;
23
23
  indexName: string;
24
- options: EmbeddingOptions;
24
+ model: EmbeddingModel<string>;
25
25
  enableFilter?: boolean;
26
26
  graphOptions?: {
27
27
  dimension?: number;
@@ -31,7 +31,7 @@ export const createGraphRAGTool = ({
31
31
  };
32
32
  id?: string;
33
33
  description?: string;
34
- }) => {
34
+ }): ReturnType<typeof createTool> => {
35
35
  const toolId = id || `GraphRAG ${vectorStoreName} ${indexName} Tool`;
36
36
  const toolDescription = description || defaultGraphRagDescription(vectorStoreName, indexName);
37
37
 
@@ -73,7 +73,7 @@ export const createGraphRAGTool = ({
73
73
  indexName,
74
74
  vectorStore,
75
75
  queryText,
76
- options,
76
+ model,
77
77
  queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
78
78
  topK,
79
79
  includeVectors: true,
@@ -1,5 +1,5 @@
1
- import { type EmbeddingOptions } from '@mastra/core/embeddings';
2
1
  import { createTool } from '@mastra/core/tools';
2
+ import { EmbeddingModel } from 'ai';
3
3
  import { z } from 'zod';
4
4
 
5
5
  import { rerank, RerankConfig } from '../rerank';
@@ -8,7 +8,7 @@ import { vectorQuerySearch, defaultVectorQueryDescription } from '../utils';
8
8
  export const createVectorQueryTool = ({
9
9
  vectorStoreName,
10
10
  indexName,
11
- options,
11
+ model,
12
12
  enableFilter = false,
13
13
  reranker,
14
14
  id,
@@ -16,12 +16,12 @@ export const createVectorQueryTool = ({
16
16
  }: {
17
17
  vectorStoreName: string;
18
18
  indexName: string;
19
- options: EmbeddingOptions;
19
+ model: EmbeddingModel<string>;
20
20
  enableFilter?: boolean;
21
21
  reranker?: RerankConfig;
22
22
  id?: string;
23
23
  description?: string;
24
- }) => {
24
+ }): ReturnType<typeof createTool> => {
25
25
  const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
26
26
  const toolDescription = description || defaultVectorQueryDescription(vectorStoreName, indexName);
27
27
 
@@ -61,7 +61,7 @@ export const createVectorQueryTool = ({
61
61
  indexName,
62
62
  vectorStore,
63
63
  queryText,
64
- options,
64
+ model,
65
65
  queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
66
66
  topK,
67
67
  });
@@ -1,16 +1,15 @@
1
- import { type EmbeddingOptions } from '@mastra/core/embeddings';
2
1
  import { type MastraVector, type QueryResult } from '@mastra/core/vector';
3
-
4
- import { embed } from '../embeddings';
2
+ import { embed, EmbeddingModel } from 'ai';
5
3
 
6
4
  interface VectorQuerySearchParams {
7
5
  indexName: string;
8
6
  vectorStore: MastraVector;
9
7
  queryText: string;
10
- options: EmbeddingOptions;
8
+ model: EmbeddingModel<string>;
11
9
  queryFilter?: any;
12
10
  topK: number;
13
11
  includeVectors?: boolean;
12
+ maxRetries?: number;
14
13
  }
15
14
 
16
15
  interface VectorQuerySearchResult {
@@ -23,12 +22,17 @@ export const vectorQuerySearch = async ({
23
22
  indexName,
24
23
  vectorStore,
25
24
  queryText,
26
- options,
25
+ model,
27
26
  queryFilter = {},
28
27
  topK,
29
28
  includeVectors = false,
29
+ maxRetries = 2,
30
30
  }: VectorQuerySearchParams): Promise<VectorQuerySearchResult> => {
31
- const { embedding } = await embed(queryText, options);
31
+ const { embedding } = await embed({
32
+ value: queryText,
33
+ model,
34
+ maxRetries,
35
+ });
32
36
  // Get relevant chunks from the vector database
33
37
  const results = await vectorStore.query(indexName, embedding, topK, queryFilter, includeVectors);
34
38
 
@@ -1,17 +0,0 @@
1
- import { type EmbedManyResult, type EmbedResult, type EmbeddingOptions } from '@mastra/core';
2
- import { embed as embedCore, embedMany as embedManyCore } from '@mastra/core/embeddings';
3
- import { Document as Chunk } from 'llamaindex';
4
-
5
- function getText(input: Chunk | string): string {
6
- return input instanceof Chunk ? input.getText() : input;
7
- }
8
-
9
- // Added explicit return type as it was not being inferred correctly
10
- export function embed(chunk: Chunk | string, options: EmbeddingOptions): Promise<EmbedResult<string>> {
11
- return embedCore(getText(chunk), options);
12
- }
13
-
14
- // Added explicit return type as it was not being inferred correctly
15
- export function embedMany(chunks: (Chunk | string)[], options: EmbeddingOptions): Promise<EmbedManyResult<string>> {
16
- return embedManyCore(chunks.map(getText), options);
17
- }