@mastra/rag 0.1.0-alpha.86 → 0.1.0-alpha.87
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +13 -0
- package/dist/index.d.ts +12 -78
- package/dist/index.js +19 -30
- package/package.json +8 -2
- package/src/document/document.test.ts +10 -5
- package/src/index.ts +0 -1
- package/src/rerank/index.test.ts +14 -42
- package/src/rerank/index.ts +6 -11
- package/src/tools/document-chunker.ts +1 -1
- package/src/tools/graph-rag.ts +5 -5
- package/src/tools/vector-query.ts +5 -5
- package/src/utils/vector-search.ts +10 -6
- package/src/embeddings/index.ts +0 -17
package/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,18 @@
|
|
|
1
1
|
# @mastra/rag
|
|
2
2
|
|
|
3
|
+
## 0.1.0-alpha.87
|
|
4
|
+
|
|
5
|
+
### Minor Changes
|
|
6
|
+
|
|
7
|
+
- 8b416d9: Breaking changes
|
|
8
|
+
|
|
9
|
+
### Patch Changes
|
|
10
|
+
|
|
11
|
+
- 9c10484: update all packages
|
|
12
|
+
- Updated dependencies [9c10484]
|
|
13
|
+
- Updated dependencies [8b416d9]
|
|
14
|
+
- @mastra/core@0.2.0-alpha.94
|
|
15
|
+
|
|
3
16
|
## 0.1.0-alpha.86
|
|
4
17
|
|
|
5
18
|
### Patch Changes
|
package/dist/index.d.ts
CHANGED
|
@@ -1,11 +1,8 @@
|
|
|
1
1
|
import { LLM, TitleExtractorPrompt, TitleCombinePrompt, SummaryPrompt, QuestionExtractPrompt, Document } from 'llamaindex';
|
|
2
2
|
import { TiktokenEncoding, TiktokenModel } from 'js-tiktoken';
|
|
3
|
-
import * as _mastra_core from '@mastra/core';
|
|
4
|
-
import { EmbeddingOptions, EmbedResult, EmbedManyResult, ModelConfig } from '@mastra/core';
|
|
5
3
|
import { QueryResult } from '@mastra/core/vector';
|
|
6
|
-
import
|
|
7
|
-
import {
|
|
8
|
-
import { EmbeddingOptions as EmbeddingOptions$1 } from '@mastra/core/embeddings';
|
|
4
|
+
import { LanguageModelV1, EmbeddingModel } from 'ai';
|
|
5
|
+
import { createTool } from '@mastra/core/tools';
|
|
9
6
|
|
|
10
7
|
declare enum Language {
|
|
11
8
|
CPP = "cpp",
|
|
@@ -118,9 +115,6 @@ declare class MDocument {
|
|
|
118
115
|
getMetadata(): Record<string, any>[];
|
|
119
116
|
}
|
|
120
117
|
|
|
121
|
-
declare function embed(chunk: Document | string, options: EmbeddingOptions): Promise<EmbedResult<string>>;
|
|
122
|
-
declare function embedMany(chunks: (Document | string)[], options: EmbeddingOptions): Promise<EmbedManyResult<string>>;
|
|
123
|
-
|
|
124
118
|
type WeightConfig = {
|
|
125
119
|
semantic?: number;
|
|
126
120
|
vector?: number;
|
|
@@ -151,9 +145,9 @@ interface RerankerFunctionOptions {
|
|
|
151
145
|
}
|
|
152
146
|
interface RerankConfig {
|
|
153
147
|
options?: RerankerOptions;
|
|
154
|
-
model:
|
|
148
|
+
model: LanguageModelV1;
|
|
155
149
|
}
|
|
156
|
-
declare function rerank(results: QueryResult[], query: string,
|
|
150
|
+
declare function rerank(results: QueryResult[], query: string, model: LanguageModelV1, options: RerankerFunctionOptions): Promise<RerankResult[]>;
|
|
157
151
|
|
|
158
152
|
/**
|
|
159
153
|
* TODO: GraphRAG Enhancements
|
|
@@ -214,12 +208,12 @@ declare class GraphRAG {
|
|
|
214
208
|
declare const createDocumentChunkerTool: ({ doc, params, }: {
|
|
215
209
|
doc: MDocument;
|
|
216
210
|
params?: ChunkParams;
|
|
217
|
-
}) =>
|
|
211
|
+
}) => ReturnType<typeof createTool>;
|
|
218
212
|
|
|
219
|
-
declare const createGraphRAGTool: ({ vectorStoreName, indexName,
|
|
213
|
+
declare const createGraphRAGTool: ({ vectorStoreName, indexName, model, enableFilter, graphOptions, id, description, }: {
|
|
220
214
|
vectorStoreName: string;
|
|
221
215
|
indexName: string;
|
|
222
|
-
|
|
216
|
+
model: EmbeddingModel<string>;
|
|
223
217
|
enableFilter?: boolean;
|
|
224
218
|
graphOptions?: {
|
|
225
219
|
dimension?: number;
|
|
@@ -229,77 +223,17 @@ declare const createGraphRAGTool: ({ vectorStoreName, indexName, options, enable
|
|
|
229
223
|
};
|
|
230
224
|
id?: string;
|
|
231
225
|
description?: string;
|
|
232
|
-
}) =>
|
|
233
|
-
queryText: z.ZodString;
|
|
234
|
-
topK: z.ZodNumber;
|
|
235
|
-
filter: z.ZodString;
|
|
236
|
-
}, "strip", z.ZodTypeAny, {
|
|
237
|
-
filter: string;
|
|
238
|
-
topK: number;
|
|
239
|
-
queryText: string;
|
|
240
|
-
}, {
|
|
241
|
-
filter: string;
|
|
242
|
-
topK: number;
|
|
243
|
-
queryText: string;
|
|
244
|
-
}>, z.ZodObject<{
|
|
245
|
-
relevantContext: z.ZodAny;
|
|
246
|
-
}, "strip", z.ZodTypeAny, {
|
|
247
|
-
relevantContext?: any;
|
|
248
|
-
}, {
|
|
249
|
-
relevantContext?: any;
|
|
250
|
-
}>, _mastra_core.ToolExecutionContext<z.ZodObject<{
|
|
251
|
-
queryText: z.ZodString;
|
|
252
|
-
topK: z.ZodNumber;
|
|
253
|
-
filter: z.ZodString;
|
|
254
|
-
}, "strip", z.ZodTypeAny, {
|
|
255
|
-
filter: string;
|
|
256
|
-
topK: number;
|
|
257
|
-
queryText: string;
|
|
258
|
-
}, {
|
|
259
|
-
filter: string;
|
|
260
|
-
topK: number;
|
|
261
|
-
queryText: string;
|
|
262
|
-
}>, _mastra_core.WorkflowContext<any>>>;
|
|
226
|
+
}) => ReturnType<typeof createTool>;
|
|
263
227
|
|
|
264
|
-
declare const createVectorQueryTool: ({ vectorStoreName, indexName,
|
|
228
|
+
declare const createVectorQueryTool: ({ vectorStoreName, indexName, model, enableFilter, reranker, id, description, }: {
|
|
265
229
|
vectorStoreName: string;
|
|
266
230
|
indexName: string;
|
|
267
|
-
|
|
231
|
+
model: EmbeddingModel<string>;
|
|
268
232
|
enableFilter?: boolean;
|
|
269
233
|
reranker?: RerankConfig;
|
|
270
234
|
id?: string;
|
|
271
235
|
description?: string;
|
|
272
|
-
}) =>
|
|
273
|
-
queryText: z.ZodString;
|
|
274
|
-
topK: z.ZodNumber;
|
|
275
|
-
filter: z.ZodString;
|
|
276
|
-
}, "strip", z.ZodTypeAny, {
|
|
277
|
-
filter: string;
|
|
278
|
-
topK: number;
|
|
279
|
-
queryText: string;
|
|
280
|
-
}, {
|
|
281
|
-
filter: string;
|
|
282
|
-
topK: number;
|
|
283
|
-
queryText: string;
|
|
284
|
-
}>, z.ZodObject<{
|
|
285
|
-
relevantContext: z.ZodAny;
|
|
286
|
-
}, "strip", z.ZodTypeAny, {
|
|
287
|
-
relevantContext?: any;
|
|
288
|
-
}, {
|
|
289
|
-
relevantContext?: any;
|
|
290
|
-
}>, _mastra_core.ToolExecutionContext<z.ZodObject<{
|
|
291
|
-
queryText: z.ZodString;
|
|
292
|
-
topK: z.ZodNumber;
|
|
293
|
-
filter: z.ZodString;
|
|
294
|
-
}, "strip", z.ZodTypeAny, {
|
|
295
|
-
filter: string;
|
|
296
|
-
topK: number;
|
|
297
|
-
queryText: string;
|
|
298
|
-
}, {
|
|
299
|
-
filter: string;
|
|
300
|
-
topK: number;
|
|
301
|
-
queryText: string;
|
|
302
|
-
}>, _mastra_core.WorkflowContext<any>>>;
|
|
236
|
+
}) => ReturnType<typeof createTool>;
|
|
303
237
|
|
|
304
238
|
/**
|
|
305
239
|
* Vector store specific prompts that detail supported operators and examples.
|
|
@@ -319,4 +253,4 @@ declare const defaultFilter = "You MUST generate for each query:\n filter: qu
|
|
|
319
253
|
declare const defaultVectorQueryDescription: (vectorStoreName: string, indexName: string) => string;
|
|
320
254
|
declare const defaultGraphRagDescription: (vectorStoreName: string, indexName: string) => string;
|
|
321
255
|
|
|
322
|
-
export { ASTRA_PROMPT, CHROMA_PROMPT, GraphRAG, LIBSQL_PROMPT, MDocument, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, type RerankConfig, type RerankResult, type RerankerFunctionOptions, type RerankerOptions, UPSTASH_PROMPT, VECTORIZE_PROMPT, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultFilter, defaultGraphRagDescription, defaultTopK, defaultVectorQueryDescription,
|
|
256
|
+
export { ASTRA_PROMPT, CHROMA_PROMPT, GraphRAG, LIBSQL_PROMPT, MDocument, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, type RerankConfig, type RerankResult, type RerankerFunctionOptions, type RerankerOptions, UPSTASH_PROMPT, VECTORIZE_PROMPT, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultFilter, defaultGraphRagDescription, defaultTopK, defaultVectorQueryDescription, rerank };
|
package/dist/index.js
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import { Document, SummaryExtractor, QuestionsAnsweredExtractor, KeywordExtractor, TitleExtractor, IngestionPipeline } from 'llamaindex';
|
|
2
2
|
import { parse } from 'node-html-better-parser';
|
|
3
3
|
import { encodingForModel, getEncoding } from 'js-tiktoken';
|
|
4
|
-
import {
|
|
5
|
-
import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from '@mastra/core/relevance';
|
|
4
|
+
import { CohereRelevanceScorer, MastraAgentRelevanceScorer } from '@mastra/core/relevance';
|
|
6
5
|
import { createTool } from '@mastra/core/tools';
|
|
7
6
|
import { z } from 'zod';
|
|
7
|
+
import { embed } from 'ai';
|
|
8
8
|
|
|
9
9
|
// src/document/document.ts
|
|
10
10
|
|
|
@@ -1272,15 +1272,6 @@ var MDocument = class _MDocument {
|
|
|
1272
1272
|
return this.chunks.map((doc) => doc.metadata);
|
|
1273
1273
|
}
|
|
1274
1274
|
};
|
|
1275
|
-
function getText(input) {
|
|
1276
|
-
return input instanceof Document ? input.getText() : input;
|
|
1277
|
-
}
|
|
1278
|
-
function embed(chunk, options) {
|
|
1279
|
-
return embed$1(getText(chunk), options);
|
|
1280
|
-
}
|
|
1281
|
-
function embedMany(chunks, options) {
|
|
1282
|
-
return embedMany$1(chunks.map(getText), options);
|
|
1283
|
-
}
|
|
1284
1275
|
var DEFAULT_WEIGHTS = {
|
|
1285
1276
|
semantic: 0.4,
|
|
1286
1277
|
vector: 0.4,
|
|
@@ -1299,17 +1290,12 @@ function adjustScores(score, queryAnalysis) {
|
|
|
1299
1290
|
const featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
|
|
1300
1291
|
return score * magnitudeAdjustment * featureStrengthAdjustment;
|
|
1301
1292
|
}
|
|
1302
|
-
async function rerank(results, query,
|
|
1303
|
-
const { provider } = modelConfig;
|
|
1293
|
+
async function rerank(results, query, model, options) {
|
|
1304
1294
|
let semanticProvider;
|
|
1305
|
-
if (
|
|
1306
|
-
semanticProvider = new
|
|
1307
|
-
} else if (provider === "COHERE" && "name" in modelConfig && modelConfig.name === "rerank-v3.5") {
|
|
1308
|
-
semanticProvider = new CohereRelevanceScorer(modelConfig.name, modelConfig.apiKey);
|
|
1309
|
-
} else if ("name" in modelConfig) {
|
|
1310
|
-
semanticProvider = new MastraAgentRelevanceScorer(provider, modelConfig.name);
|
|
1295
|
+
if (model.modelId === "rerank-v3.5") {
|
|
1296
|
+
semanticProvider = new CohereRelevanceScorer(model.modelId);
|
|
1311
1297
|
} else {
|
|
1312
|
-
|
|
1298
|
+
semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
|
|
1313
1299
|
}
|
|
1314
1300
|
const { queryEmbedding, topK = 3 } = options;
|
|
1315
1301
|
const weights = {
|
|
@@ -1571,18 +1557,21 @@ var createDocumentChunkerTool = ({
|
|
|
1571
1557
|
}
|
|
1572
1558
|
});
|
|
1573
1559
|
};
|
|
1574
|
-
|
|
1575
|
-
// src/utils/vector-search.ts
|
|
1576
1560
|
var vectorQuerySearch = async ({
|
|
1577
1561
|
indexName,
|
|
1578
1562
|
vectorStore,
|
|
1579
1563
|
queryText,
|
|
1580
|
-
|
|
1564
|
+
model,
|
|
1581
1565
|
queryFilter = {},
|
|
1582
1566
|
topK,
|
|
1583
|
-
includeVectors = false
|
|
1567
|
+
includeVectors = false,
|
|
1568
|
+
maxRetries = 2
|
|
1584
1569
|
}) => {
|
|
1585
|
-
const { embedding } = await embed(
|
|
1570
|
+
const { embedding } = await embed({
|
|
1571
|
+
value: queryText,
|
|
1572
|
+
model,
|
|
1573
|
+
maxRetries
|
|
1574
|
+
});
|
|
1586
1575
|
const results = await vectorStore.query(indexName, embedding, topK, queryFilter, includeVectors);
|
|
1587
1576
|
return { results, queryEmbedding: embedding };
|
|
1588
1577
|
};
|
|
@@ -1621,7 +1610,7 @@ var defaultGraphRagDescription = (vectorStoreName, indexName) => `Fetches and re
|
|
|
1621
1610
|
var createGraphRAGTool = ({
|
|
1622
1611
|
vectorStoreName,
|
|
1623
1612
|
indexName,
|
|
1624
|
-
|
|
1613
|
+
model,
|
|
1625
1614
|
enableFilter = false,
|
|
1626
1615
|
graphOptions = {
|
|
1627
1616
|
dimension: 1536,
|
|
@@ -1667,7 +1656,7 @@ var createGraphRAGTool = ({
|
|
|
1667
1656
|
indexName,
|
|
1668
1657
|
vectorStore,
|
|
1669
1658
|
queryText,
|
|
1670
|
-
|
|
1659
|
+
model,
|
|
1671
1660
|
queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
|
|
1672
1661
|
topK,
|
|
1673
1662
|
includeVectors: true
|
|
@@ -1703,7 +1692,7 @@ var createGraphRAGTool = ({
|
|
|
1703
1692
|
var createVectorQueryTool = ({
|
|
1704
1693
|
vectorStoreName,
|
|
1705
1694
|
indexName,
|
|
1706
|
-
|
|
1695
|
+
model,
|
|
1707
1696
|
enableFilter = false,
|
|
1708
1697
|
reranker,
|
|
1709
1698
|
id,
|
|
@@ -1742,7 +1731,7 @@ var createVectorQueryTool = ({
|
|
|
1742
1731
|
indexName,
|
|
1743
1732
|
vectorStore,
|
|
1744
1733
|
queryText,
|
|
1745
|
-
|
|
1734
|
+
model,
|
|
1746
1735
|
queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
|
|
1747
1736
|
topK
|
|
1748
1737
|
});
|
|
@@ -2485,4 +2474,4 @@ Example Complex Query:
|
|
|
2485
2474
|
"inStock": true
|
|
2486
2475
|
}`;
|
|
2487
2476
|
|
|
2488
|
-
export { ASTRA_PROMPT, CHROMA_PROMPT, GraphRAG, LIBSQL_PROMPT, MDocument, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, UPSTASH_PROMPT, VECTORIZE_PROMPT, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultFilter, defaultGraphRagDescription, defaultTopK, defaultVectorQueryDescription,
|
|
2477
|
+
export { ASTRA_PROMPT, CHROMA_PROMPT, GraphRAG, LIBSQL_PROMPT, MDocument, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, UPSTASH_PROMPT, VECTORIZE_PROMPT, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultFilter, defaultGraphRagDescription, defaultTopK, defaultVectorQueryDescription, rerank };
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@mastra/rag",
|
|
3
|
-
"version": "0.1.0-alpha.
|
|
3
|
+
"version": "0.1.0-alpha.87",
|
|
4
4
|
"description": "",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "dist/index.js",
|
|
@@ -25,13 +25,19 @@
|
|
|
25
25
|
"node-html-better-parser": "^1.4.7",
|
|
26
26
|
"pathe": "^2.0.2",
|
|
27
27
|
"zod": "^3.24.1",
|
|
28
|
-
"@mastra/core": "^0.2.0-alpha.
|
|
28
|
+
"@mastra/core": "^0.2.0-alpha.94"
|
|
29
|
+
},
|
|
30
|
+
"peerDependencies": {
|
|
31
|
+
"ai": "^4.0.0"
|
|
29
32
|
},
|
|
30
33
|
"devDependencies": {
|
|
34
|
+
"@ai-sdk/openai": "latest",
|
|
35
|
+
"@ai-sdk/cohere": "latest",
|
|
31
36
|
"@babel/preset-env": "^7.26.0",
|
|
32
37
|
"@babel/preset-typescript": "^7.26.0",
|
|
33
38
|
"@tsconfig/recommended": "^1.0.7",
|
|
34
39
|
"@types/node": "^22.9.0",
|
|
40
|
+
"ai": "^4.0.34",
|
|
35
41
|
"tsup": "^8.0.1",
|
|
36
42
|
"vitest": "^3.0.4"
|
|
37
43
|
},
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { createOpenAI } from '@ai-sdk/openai';
|
|
2
|
+
import { embedMany } from 'ai';
|
|
3
|
+
import { describe, it, expect } from 'vitest';
|
|
2
4
|
|
|
3
5
|
import { MDocument } from './document';
|
|
4
6
|
import { Language } from './types';
|
|
@@ -14,6 +16,10 @@ Welcome to our comprehensive guide on modern web development. This resource cove
|
|
|
14
16
|
- Senior developers seeking a refresher on current best practices
|
|
15
17
|
`;
|
|
16
18
|
|
|
19
|
+
const openai = createOpenAI({
|
|
20
|
+
apiKey: process.env.OPENAI_API_KEY,
|
|
21
|
+
});
|
|
22
|
+
|
|
17
23
|
describe('MDocument', () => {
|
|
18
24
|
describe('basics', () => {
|
|
19
25
|
let chunks: MDocument['chunks'];
|
|
@@ -48,10 +54,9 @@ describe('MDocument', () => {
|
|
|
48
54
|
}, 15000);
|
|
49
55
|
|
|
50
56
|
it('embed - create embedding from chunk', async () => {
|
|
51
|
-
const embeddings = await embedMany(
|
|
52
|
-
|
|
53
|
-
model: 'text-embedding-3-small',
|
|
54
|
-
maxRetries: 3,
|
|
57
|
+
const embeddings = await embedMany({
|
|
58
|
+
values: chunks.map(chunk => chunk.text),
|
|
59
|
+
model: openai.embedding('text-embedding-3-small'),
|
|
55
60
|
});
|
|
56
61
|
|
|
57
62
|
console.log(embeddings);
|
package/src/index.ts
CHANGED
package/src/rerank/index.test.ts
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { cohere } from '@ai-sdk/cohere';
|
|
1
2
|
import { CohereRelevanceScorer } from '@mastra/core/relevance';
|
|
2
3
|
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
3
4
|
|
|
@@ -25,12 +26,7 @@ describe('rerank', () => {
|
|
|
25
26
|
{ id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
|
|
26
27
|
];
|
|
27
28
|
await expect(
|
|
28
|
-
rerank(
|
|
29
|
-
results,
|
|
30
|
-
'test query',
|
|
31
|
-
{ provider: 'COHERE', name: 'rerank-v3.5' },
|
|
32
|
-
{ weights: { semantic: 0.5, vector: 0.3, position: 0.5 } },
|
|
33
|
-
),
|
|
29
|
+
rerank(results, 'test query', cohere('rerank-v3.5'), { weights: { semantic: 0.5, vector: 0.3, position: 0.5 } }),
|
|
34
30
|
).rejects.toThrow('Weights must add up to 1');
|
|
35
31
|
});
|
|
36
32
|
|
|
@@ -41,15 +37,7 @@ describe('rerank', () => {
|
|
|
41
37
|
{ id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
|
|
42
38
|
];
|
|
43
39
|
|
|
44
|
-
const rerankedResults = await rerank(
|
|
45
|
-
results,
|
|
46
|
-
'test query',
|
|
47
|
-
{
|
|
48
|
-
provider: 'COHERE',
|
|
49
|
-
name: 'rerank-v3.5',
|
|
50
|
-
},
|
|
51
|
-
{ topK: 2 },
|
|
52
|
-
);
|
|
40
|
+
const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), { topK: 2 });
|
|
53
41
|
|
|
54
42
|
expect(rerankedResults).toHaveLength(2);
|
|
55
43
|
expect(rerankedResults[0]).toStrictEqual({
|
|
@@ -83,22 +71,14 @@ describe('rerank', () => {
|
|
|
83
71
|
{ id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
|
|
84
72
|
];
|
|
85
73
|
|
|
86
|
-
const rerankedResults = await rerank(
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
name: 'rerank-v3.5',
|
|
92
|
-
},
|
|
93
|
-
{
|
|
94
|
-
weights: {
|
|
95
|
-
semantic: 0.5,
|
|
96
|
-
vector: 0.4,
|
|
97
|
-
position: 0.1,
|
|
98
|
-
},
|
|
99
|
-
topK: 2,
|
|
74
|
+
const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), {
|
|
75
|
+
weights: {
|
|
76
|
+
semantic: 0.5,
|
|
77
|
+
vector: 0.4,
|
|
78
|
+
position: 0.1,
|
|
100
79
|
},
|
|
101
|
-
|
|
80
|
+
topK: 2,
|
|
81
|
+
});
|
|
102
82
|
|
|
103
83
|
expect(rerankedResults).toHaveLength(2);
|
|
104
84
|
expect(rerankedResults[0]).toStrictEqual({
|
|
@@ -130,18 +110,10 @@ describe('rerank', () => {
|
|
|
130
110
|
{ id: '2', metadata: { text: 'Test result 2' }, score: 0.6 },
|
|
131
111
|
];
|
|
132
112
|
|
|
133
|
-
const rerankedResults = await rerank(
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
provider: 'COHERE',
|
|
138
|
-
name: 'rerank-v3.5',
|
|
139
|
-
},
|
|
140
|
-
{
|
|
141
|
-
queryEmbedding: [0.5, 0.3, -0.2, 0.4],
|
|
142
|
-
topK: 2,
|
|
143
|
-
},
|
|
144
|
-
);
|
|
113
|
+
const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), {
|
|
114
|
+
queryEmbedding: [0.5, 0.3, -0.2, 0.4],
|
|
115
|
+
topK: 2,
|
|
116
|
+
});
|
|
145
117
|
|
|
146
118
|
// Ensure query embedding analysis is being applied (we don't know exact score without knowing internals)
|
|
147
119
|
expect(rerankedResults).toHaveLength(2);
|
package/src/rerank/index.ts
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { type ModelConfig } from '@mastra/core';
|
|
2
1
|
import { MastraAgentRelevanceScorer, CohereRelevanceScorer, RelevanceScoreProvider } from '@mastra/core/relevance';
|
|
3
2
|
import { QueryResult } from '@mastra/core/vector';
|
|
3
|
+
import { LanguageModelV1 } from 'ai';
|
|
4
4
|
|
|
5
5
|
// Default weights for different scoring components (must add up to 1)
|
|
6
6
|
const DEFAULT_WEIGHTS = {
|
|
@@ -46,7 +46,7 @@ export interface RerankerFunctionOptions {
|
|
|
46
46
|
|
|
47
47
|
export interface RerankConfig {
|
|
48
48
|
options?: RerankerOptions;
|
|
49
|
-
model:
|
|
49
|
+
model: LanguageModelV1;
|
|
50
50
|
}
|
|
51
51
|
|
|
52
52
|
// Calculate position score based on position in original list
|
|
@@ -85,19 +85,14 @@ function adjustScores(score: number, queryAnalysis: { magnitude: number; dominan
|
|
|
85
85
|
export async function rerank(
|
|
86
86
|
results: QueryResult[],
|
|
87
87
|
query: string,
|
|
88
|
-
|
|
88
|
+
model: LanguageModelV1,
|
|
89
89
|
options: RerankerFunctionOptions,
|
|
90
90
|
): Promise<RerankResult[]> {
|
|
91
|
-
const { provider } = modelConfig;
|
|
92
91
|
let semanticProvider: RelevanceScoreProvider;
|
|
93
|
-
if (
|
|
94
|
-
semanticProvider = new
|
|
95
|
-
} else if (provider === 'COHERE' && 'name' in modelConfig && modelConfig.name === 'rerank-v3.5') {
|
|
96
|
-
semanticProvider = new CohereRelevanceScorer(modelConfig.name, modelConfig.apiKey);
|
|
97
|
-
} else if ('name' in modelConfig) {
|
|
98
|
-
semanticProvider = new MastraAgentRelevanceScorer(provider, modelConfig.name);
|
|
92
|
+
if (model.modelId === 'rerank-v3.5') {
|
|
93
|
+
semanticProvider = new CohereRelevanceScorer(model.modelId);
|
|
99
94
|
} else {
|
|
100
|
-
|
|
95
|
+
semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
|
|
101
96
|
}
|
|
102
97
|
const { queryEmbedding, topK = 3 } = options;
|
|
103
98
|
const weights = {
|
|
@@ -14,7 +14,7 @@ export const createDocumentChunkerTool = ({
|
|
|
14
14
|
}: {
|
|
15
15
|
doc: MDocument;
|
|
16
16
|
params?: ChunkParams;
|
|
17
|
-
}) => {
|
|
17
|
+
}): ReturnType<typeof createTool> => {
|
|
18
18
|
return createTool({
|
|
19
19
|
id: `Document Chunker ${params.strategy} ${params.size}`,
|
|
20
20
|
inputSchema: z.object({}),
|
package/src/tools/graph-rag.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { type EmbeddingOptions } from '@mastra/core/embeddings';
|
|
2
1
|
import { createTool } from '@mastra/core/tools';
|
|
2
|
+
import { EmbeddingModel } from 'ai';
|
|
3
3
|
import { z } from 'zod';
|
|
4
4
|
|
|
5
5
|
import { GraphRAG } from '../graph-rag';
|
|
@@ -8,7 +8,7 @@ import { vectorQuerySearch, defaultGraphRagDescription } from '../utils';
|
|
|
8
8
|
export const createGraphRAGTool = ({
|
|
9
9
|
vectorStoreName,
|
|
10
10
|
indexName,
|
|
11
|
-
|
|
11
|
+
model,
|
|
12
12
|
enableFilter = false,
|
|
13
13
|
graphOptions = {
|
|
14
14
|
dimension: 1536,
|
|
@@ -21,7 +21,7 @@ export const createGraphRAGTool = ({
|
|
|
21
21
|
}: {
|
|
22
22
|
vectorStoreName: string;
|
|
23
23
|
indexName: string;
|
|
24
|
-
|
|
24
|
+
model: EmbeddingModel<string>;
|
|
25
25
|
enableFilter?: boolean;
|
|
26
26
|
graphOptions?: {
|
|
27
27
|
dimension?: number;
|
|
@@ -31,7 +31,7 @@ export const createGraphRAGTool = ({
|
|
|
31
31
|
};
|
|
32
32
|
id?: string;
|
|
33
33
|
description?: string;
|
|
34
|
-
}) => {
|
|
34
|
+
}): ReturnType<typeof createTool> => {
|
|
35
35
|
const toolId = id || `GraphRAG ${vectorStoreName} ${indexName} Tool`;
|
|
36
36
|
const toolDescription = description || defaultGraphRagDescription(vectorStoreName, indexName);
|
|
37
37
|
|
|
@@ -73,7 +73,7 @@ export const createGraphRAGTool = ({
|
|
|
73
73
|
indexName,
|
|
74
74
|
vectorStore,
|
|
75
75
|
queryText,
|
|
76
|
-
|
|
76
|
+
model,
|
|
77
77
|
queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
|
|
78
78
|
topK,
|
|
79
79
|
includeVectors: true,
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { type EmbeddingOptions } from '@mastra/core/embeddings';
|
|
2
1
|
import { createTool } from '@mastra/core/tools';
|
|
2
|
+
import { EmbeddingModel } from 'ai';
|
|
3
3
|
import { z } from 'zod';
|
|
4
4
|
|
|
5
5
|
import { rerank, RerankConfig } from '../rerank';
|
|
@@ -8,7 +8,7 @@ import { vectorQuerySearch, defaultVectorQueryDescription } from '../utils';
|
|
|
8
8
|
export const createVectorQueryTool = ({
|
|
9
9
|
vectorStoreName,
|
|
10
10
|
indexName,
|
|
11
|
-
|
|
11
|
+
model,
|
|
12
12
|
enableFilter = false,
|
|
13
13
|
reranker,
|
|
14
14
|
id,
|
|
@@ -16,12 +16,12 @@ export const createVectorQueryTool = ({
|
|
|
16
16
|
}: {
|
|
17
17
|
vectorStoreName: string;
|
|
18
18
|
indexName: string;
|
|
19
|
-
|
|
19
|
+
model: EmbeddingModel<string>;
|
|
20
20
|
enableFilter?: boolean;
|
|
21
21
|
reranker?: RerankConfig;
|
|
22
22
|
id?: string;
|
|
23
23
|
description?: string;
|
|
24
|
-
}) => {
|
|
24
|
+
}): ReturnType<typeof createTool> => {
|
|
25
25
|
const toolId = id || `VectorQuery ${vectorStoreName} ${indexName} Tool`;
|
|
26
26
|
const toolDescription = description || defaultVectorQueryDescription(vectorStoreName, indexName);
|
|
27
27
|
|
|
@@ -61,7 +61,7 @@ export const createVectorQueryTool = ({
|
|
|
61
61
|
indexName,
|
|
62
62
|
vectorStore,
|
|
63
63
|
queryText,
|
|
64
|
-
|
|
64
|
+
model,
|
|
65
65
|
queryFilter: Object.keys(queryFilter || {}).length > 0 ? queryFilter : undefined,
|
|
66
66
|
topK,
|
|
67
67
|
});
|
|
@@ -1,16 +1,15 @@
|
|
|
1
|
-
import { type EmbeddingOptions } from '@mastra/core/embeddings';
|
|
2
1
|
import { type MastraVector, type QueryResult } from '@mastra/core/vector';
|
|
3
|
-
|
|
4
|
-
import { embed } from '../embeddings';
|
|
2
|
+
import { embed, EmbeddingModel } from 'ai';
|
|
5
3
|
|
|
6
4
|
interface VectorQuerySearchParams {
|
|
7
5
|
indexName: string;
|
|
8
6
|
vectorStore: MastraVector;
|
|
9
7
|
queryText: string;
|
|
10
|
-
|
|
8
|
+
model: EmbeddingModel<string>;
|
|
11
9
|
queryFilter?: any;
|
|
12
10
|
topK: number;
|
|
13
11
|
includeVectors?: boolean;
|
|
12
|
+
maxRetries?: number;
|
|
14
13
|
}
|
|
15
14
|
|
|
16
15
|
interface VectorQuerySearchResult {
|
|
@@ -23,12 +22,17 @@ export const vectorQuerySearch = async ({
|
|
|
23
22
|
indexName,
|
|
24
23
|
vectorStore,
|
|
25
24
|
queryText,
|
|
26
|
-
|
|
25
|
+
model,
|
|
27
26
|
queryFilter = {},
|
|
28
27
|
topK,
|
|
29
28
|
includeVectors = false,
|
|
29
|
+
maxRetries = 2,
|
|
30
30
|
}: VectorQuerySearchParams): Promise<VectorQuerySearchResult> => {
|
|
31
|
-
const { embedding } = await embed(
|
|
31
|
+
const { embedding } = await embed({
|
|
32
|
+
value: queryText,
|
|
33
|
+
model,
|
|
34
|
+
maxRetries,
|
|
35
|
+
});
|
|
32
36
|
// Get relevant chunks from the vector database
|
|
33
37
|
const results = await vectorStore.query(indexName, embedding, topK, queryFilter, includeVectors);
|
|
34
38
|
|
package/src/embeddings/index.ts
DELETED
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
import { type EmbedManyResult, type EmbedResult, type EmbeddingOptions } from '@mastra/core';
|
|
2
|
-
import { embed as embedCore, embedMany as embedManyCore } from '@mastra/core/embeddings';
|
|
3
|
-
import { Document as Chunk } from 'llamaindex';
|
|
4
|
-
|
|
5
|
-
function getText(input: Chunk | string): string {
|
|
6
|
-
return input instanceof Chunk ? input.getText() : input;
|
|
7
|
-
}
|
|
8
|
-
|
|
9
|
-
// Added explicit return type as it was not being inferred correctly
|
|
10
|
-
export function embed(chunk: Chunk | string, options: EmbeddingOptions): Promise<EmbedResult<string>> {
|
|
11
|
-
return embedCore(getText(chunk), options);
|
|
12
|
-
}
|
|
13
|
-
|
|
14
|
-
// Added explicit return type as it was not being inferred correctly
|
|
15
|
-
export function embedMany(chunks: (Chunk | string)[], options: EmbeddingOptions): Promise<EmbedManyResult<string>> {
|
|
16
|
-
return embedManyCore(chunks.map(getText), options);
|
|
17
|
-
}
|