@mastra/rag 1.2.2 → 1.2.3-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. package/CHANGELOG.md +22 -0
  2. package/dist/index.cjs +25 -9
  3. package/dist/index.cjs.map +1 -1
  4. package/dist/index.js +25 -9
  5. package/dist/index.js.map +1 -1
  6. package/dist/tools/graph-rag.d.ts.map +1 -1
  7. package/dist/tools/types.d.ts +18 -5
  8. package/dist/tools/types.d.ts.map +1 -1
  9. package/dist/tools/vector-query.d.ts.map +1 -1
  10. package/dist/utils/vector-search.d.ts +6 -7
  11. package/dist/utils/vector-search.d.ts.map +1 -1
  12. package/package.json +19 -6
  13. package/.turbo/turbo-build.log +0 -4
  14. package/docker-compose.yaml +0 -22
  15. package/eslint.config.js +0 -6
  16. package/src/document/document.test.ts +0 -2975
  17. package/src/document/document.ts +0 -335
  18. package/src/document/extractors/base.ts +0 -30
  19. package/src/document/extractors/index.ts +0 -5
  20. package/src/document/extractors/keywords.test.ts +0 -125
  21. package/src/document/extractors/keywords.ts +0 -126
  22. package/src/document/extractors/questions.test.ts +0 -120
  23. package/src/document/extractors/questions.ts +0 -111
  24. package/src/document/extractors/summary.test.ts +0 -107
  25. package/src/document/extractors/summary.ts +0 -122
  26. package/src/document/extractors/title.test.ts +0 -121
  27. package/src/document/extractors/title.ts +0 -185
  28. package/src/document/extractors/types.ts +0 -40
  29. package/src/document/index.ts +0 -2
  30. package/src/document/prompts/base.ts +0 -77
  31. package/src/document/prompts/format.ts +0 -9
  32. package/src/document/prompts/index.ts +0 -15
  33. package/src/document/prompts/prompt.ts +0 -60
  34. package/src/document/prompts/types.ts +0 -29
  35. package/src/document/schema/index.ts +0 -3
  36. package/src/document/schema/node.ts +0 -187
  37. package/src/document/schema/types.ts +0 -40
  38. package/src/document/transformers/character.ts +0 -267
  39. package/src/document/transformers/html.ts +0 -346
  40. package/src/document/transformers/json.ts +0 -536
  41. package/src/document/transformers/latex.ts +0 -11
  42. package/src/document/transformers/markdown.ts +0 -239
  43. package/src/document/transformers/semantic-markdown.ts +0 -227
  44. package/src/document/transformers/sentence.ts +0 -314
  45. package/src/document/transformers/text.ts +0 -158
  46. package/src/document/transformers/token.ts +0 -137
  47. package/src/document/transformers/transformer.ts +0 -5
  48. package/src/document/types.ts +0 -145
  49. package/src/document/validation.ts +0 -158
  50. package/src/graph-rag/index.test.ts +0 -235
  51. package/src/graph-rag/index.ts +0 -306
  52. package/src/index.ts +0 -8
  53. package/src/rerank/index.test.ts +0 -150
  54. package/src/rerank/index.ts +0 -198
  55. package/src/rerank/relevance/cohere/index.ts +0 -56
  56. package/src/rerank/relevance/index.ts +0 -3
  57. package/src/rerank/relevance/mastra-agent/index.ts +0 -32
  58. package/src/rerank/relevance/zeroentropy/index.ts +0 -26
  59. package/src/tools/README.md +0 -153
  60. package/src/tools/document-chunker.ts +0 -34
  61. package/src/tools/graph-rag.test.ts +0 -115
  62. package/src/tools/graph-rag.ts +0 -154
  63. package/src/tools/index.ts +0 -3
  64. package/src/tools/types.ts +0 -110
  65. package/src/tools/vector-query-database-config.test.ts +0 -190
  66. package/src/tools/vector-query.test.ts +0 -418
  67. package/src/tools/vector-query.ts +0 -169
  68. package/src/utils/convert-sources.ts +0 -43
  69. package/src/utils/default-settings.ts +0 -38
  70. package/src/utils/index.ts +0 -3
  71. package/src/utils/tool-schemas.ts +0 -38
  72. package/src/utils/vector-prompts.ts +0 -832
  73. package/src/utils/vector-search.ts +0 -117
  74. package/tsconfig.build.json +0 -9
  75. package/tsconfig.json +0 -5
  76. package/tsup.config.ts +0 -17
  77. package/vitest.config.ts +0 -8
@@ -1,306 +0,0 @@
1
- /**
2
- * TODO: GraphRAG Enhancements
3
- * - Add support for more edge types (sequential, hierarchical, citation, etc)
4
- * - Allow for custom edge types
5
- * - Utilize metadata for richer connections
6
- * - Improve graph traversal and querying using types
7
- */
8
-
9
- type SupportedEdgeType = 'semantic';
10
-
11
- // Types for graph nodes and edges
12
- export interface GraphNode {
13
- id: string;
14
- content: string;
15
- embedding?: number[];
16
- metadata?: Record<string, any>;
17
- }
18
-
19
- export interface RankedNode extends GraphNode {
20
- score: number;
21
- }
22
-
23
- export interface GraphEdge {
24
- source: string;
25
- target: string;
26
- weight: number;
27
- type: SupportedEdgeType;
28
- }
29
-
30
- export interface GraphChunk {
31
- text: string;
32
- metadata: Record<string, any>;
33
- }
34
-
35
- export interface GraphEmbedding {
36
- vector: number[];
37
- }
38
-
39
- export class GraphRAG {
40
- private nodes: Map<string, GraphNode>;
41
- private edges: GraphEdge[];
42
- private dimension: number;
43
- private threshold: number;
44
-
45
- constructor(dimension: number = 1536, threshold: number = 0.7) {
46
- this.nodes = new Map();
47
- this.edges = [];
48
- this.dimension = dimension;
49
- this.threshold = threshold;
50
- }
51
-
52
- // Add a node to the graph
53
- addNode(node: GraphNode): void {
54
- if (!node.embedding) {
55
- throw new Error('Node must have an embedding');
56
- }
57
- if (node.embedding.length !== this.dimension) {
58
- throw new Error(`Embedding dimension must be ${this.dimension}`);
59
- }
60
- this.nodes.set(node.id, node);
61
- }
62
-
63
- // Add an edge between two nodes
64
- addEdge(edge: GraphEdge): void {
65
- if (!this.nodes.has(edge.source) || !this.nodes.has(edge.target)) {
66
- throw new Error('Both source and target nodes must exist');
67
- }
68
- this.edges.push(edge);
69
- // Add reverse edge
70
- this.edges.push({
71
- source: edge.target,
72
- target: edge.source,
73
- weight: edge.weight,
74
- type: edge.type,
75
- });
76
- }
77
-
78
- // Helper method to get all nodes
79
- getNodes(): GraphNode[] {
80
- return Array.from(this.nodes.values());
81
- }
82
-
83
- // Helper method to get all edges
84
- getEdges(): GraphEdge[] {
85
- return this.edges;
86
- }
87
-
88
- getEdgesByType(type: string): GraphEdge[] {
89
- return this.edges.filter(edge => edge.type === type);
90
- }
91
-
92
- clear(): void {
93
- this.nodes.clear();
94
- this.edges = [];
95
- }
96
-
97
- updateNodeContent(id: string, newContent: string): void {
98
- const node = this.nodes.get(id);
99
- if (!node) {
100
- throw new Error(`Node ${id} not found`);
101
- }
102
- node.content = newContent;
103
- }
104
-
105
- // Get neighbors of a node
106
- private getNeighbors(nodeId: string, edgeType?: string): { id: string; weight: number }[] {
107
- return this.edges
108
- .filter(edge => edge.source === nodeId && (!edgeType || edge.type === edgeType))
109
- .map(edge => ({
110
- id: edge.target,
111
- weight: edge.weight,
112
- }))
113
- .filter(node => node !== undefined);
114
- }
115
-
116
- // Calculate cosine similarity between two vectors
117
- private cosineSimilarity(vec1: number[], vec2: number[]): number {
118
- if (!vec1 || !vec2) {
119
- throw new Error('Vectors must not be null or undefined');
120
- }
121
- const vectorLength = vec1.length;
122
-
123
- if (vectorLength !== vec2.length) {
124
- throw new Error(`Vector dimensions must match: vec1(${vec1.length}) !== vec2(${vec2.length})`);
125
- }
126
-
127
- let dotProduct = 0;
128
- let normVec1 = 0;
129
- let normVec2 = 0;
130
-
131
- for (let i = 0; i < vectorLength; i++) {
132
- const a = vec1[i]!; // Non-null assertion operator
133
- const b = vec2[i]!;
134
-
135
- dotProduct += a * b;
136
- normVec1 += a * a;
137
- normVec2 += b * b;
138
- }
139
- const magnitudeProduct = Math.sqrt(normVec1 * normVec2);
140
-
141
- if (magnitudeProduct === 0) {
142
- return 0;
143
- }
144
-
145
- const similarity = dotProduct / magnitudeProduct;
146
- return Math.max(-1, Math.min(1, similarity));
147
- }
148
-
149
- createGraph(chunks: GraphChunk[], embeddings: GraphEmbedding[]) {
150
- if (!chunks?.length || !embeddings?.length) {
151
- throw new Error('Chunks and embeddings arrays must not be empty');
152
- }
153
- if (chunks.length !== embeddings.length) {
154
- throw new Error('Chunks and embeddings must have the same length');
155
- }
156
- // Create nodes from chunks
157
- chunks.forEach((chunk, index) => {
158
- const node: GraphNode = {
159
- id: index.toString(),
160
- content: chunk.text,
161
- embedding: embeddings[index]?.vector,
162
- metadata: { ...chunk.metadata },
163
- };
164
- this.addNode(node);
165
- this.nodes.set(node.id, node);
166
- });
167
-
168
- // Create edges based on cosine similarity
169
- for (let i = 0; i < chunks.length; i++) {
170
- const firstEmbedding = embeddings[i]?.vector as number[];
171
- for (let j = i + 1; j < chunks.length; j++) {
172
- const secondEmbedding = embeddings[j]?.vector as number[];
173
- const similarity = this.cosineSimilarity(firstEmbedding, secondEmbedding);
174
-
175
- // Only create edges if similarity is above threshold
176
- if (similarity > this.threshold) {
177
- this.addEdge({
178
- source: i.toString(),
179
- target: j.toString(),
180
- weight: similarity,
181
- type: 'semantic',
182
- });
183
- }
184
- }
185
- }
186
- }
187
-
188
- private selectWeightedNeighbor(neighbors: Array<{ id: string; weight: number }>): string {
189
- // Sum all weights to normalize probabilities
190
- const totalWeight = neighbors.reduce((sum, n) => sum + n.weight, 0);
191
-
192
- // Pick a random point in the total weight range
193
- let remainingWeight = Math.random() * totalWeight;
194
-
195
- // Subtract each weight from our random value until we go below 0
196
- // Higher weights will make us go below 0 more often, making them more likely to be selected
197
- for (const neighbor of neighbors) {
198
- remainingWeight -= neighbor.weight;
199
- if (remainingWeight <= 0) {
200
- return neighbor.id;
201
- }
202
- }
203
-
204
- return neighbors[neighbors.length - 1]?.id as string;
205
- }
206
-
207
- // Perform random walk with restart
208
- private randomWalkWithRestart(startNodeId: string, steps: number, restartProb: number): Map<string, number> {
209
- const visits = new Map<string, number>();
210
- let currentNodeId = startNodeId;
211
-
212
- for (let step = 0; step < steps; step++) {
213
- // Record visit
214
- visits.set(currentNodeId, (visits.get(currentNodeId) || 0) + 1);
215
-
216
- // Decide whether to restart
217
- if (Math.random() < restartProb) {
218
- currentNodeId = startNodeId;
219
- continue;
220
- }
221
-
222
- // Get neighbors
223
- const neighbors = this.getNeighbors(currentNodeId);
224
- if (neighbors.length === 0) {
225
- currentNodeId = startNodeId;
226
- continue;
227
- }
228
-
229
- // Select random weighted neighbor and set as current node
230
- currentNodeId = this.selectWeightedNeighbor(neighbors);
231
- }
232
-
233
- // Normalize visits
234
- const totalVisits = Array.from(visits.values()).reduce((a, b) => a + b, 0);
235
- const normalizedVisits = new Map<string, number>();
236
- for (const [nodeId, count] of visits) {
237
- normalizedVisits.set(nodeId, count / totalVisits);
238
- }
239
-
240
- return normalizedVisits;
241
- }
242
-
243
- // Retrieve relevant nodes using hybrid approach
244
- query({
245
- query,
246
- topK = 10,
247
- randomWalkSteps = 100,
248
- restartProb = 0.15,
249
- }: {
250
- query: number[];
251
- topK?: number;
252
- randomWalkSteps?: number;
253
- restartProb?: number;
254
- }): RankedNode[] {
255
- if (!query || query.length !== this.dimension) {
256
- throw new Error(`Query embedding must have dimension ${this.dimension}`);
257
- }
258
- if (topK < 1) {
259
- throw new Error('TopK must be greater than 0');
260
- }
261
- if (randomWalkSteps < 1) {
262
- throw new Error('Random walk steps must be greater than 0');
263
- }
264
- if (restartProb <= 0 || restartProb >= 1) {
265
- throw new Error('Restart probability must be between 0 and 1');
266
- }
267
- // Retrieve nodes and calculate similarity
268
- const similarities = Array.from(this.nodes.values()).map(node => ({
269
- node,
270
- similarity: this.cosineSimilarity(query, node.embedding!),
271
- }));
272
-
273
- // Sort by similarity
274
- similarities.sort((a, b) => b.similarity - a.similarity);
275
- const topNodes = similarities.slice(0, topK);
276
-
277
- // Re-ranks nodes using random walk with restart
278
- const rerankedNodes = new Map<string, { node: GraphNode; score: number }>();
279
-
280
- // For each top node, perform random walk
281
- for (const { node, similarity } of topNodes) {
282
- const walkScores = this.randomWalkWithRestart(node.id, randomWalkSteps, restartProb);
283
-
284
- // Combine dense retrieval score with graph score
285
- for (const [nodeId, walkScore] of walkScores) {
286
- const node = this.nodes.get(nodeId)!;
287
- const existingScore = rerankedNodes.get(nodeId)?.score || 0;
288
- rerankedNodes.set(nodeId, {
289
- node,
290
- score: existingScore + similarity * walkScore,
291
- });
292
- }
293
- }
294
-
295
- // Sort by final score and return top K nodes
296
- return Array.from(rerankedNodes.values())
297
- .sort((a, b) => b.score - a.score)
298
- .slice(0, topK)
299
- .map(item => ({
300
- id: item.node.id,
301
- content: item.node.content,
302
- metadata: item.node.metadata,
303
- score: item.score,
304
- }));
305
- }
306
- }
package/src/index.ts DELETED
@@ -1,8 +0,0 @@
1
- export * from './document/document';
2
- export * from './document/types';
3
- export * from './rerank';
4
- export * from './rerank/relevance';
5
- export { GraphRAG } from './graph-rag';
6
- export * from './tools';
7
- export * from './utils/vector-prompts';
8
- export * from './utils/default-settings';
@@ -1,150 +0,0 @@
1
- import { cohere } from '@ai-sdk/cohere';
2
- import { describe, it, expect, vi, beforeEach } from 'vitest';
3
- import { CohereRelevanceScorer } from './relevance';
4
-
5
- import { rerank } from '.';
6
-
7
- vi.spyOn(CohereRelevanceScorer.prototype, 'getRelevanceScore').mockImplementation(async () => {
8
- return 1;
9
- });
10
-
11
- const getScoreSpreads = (results1: any, results2: any) => {
12
- const scoreSpread1 = Math.max(...results1.map((r: any) => r.score)) - Math.min(...results1.map((r: any) => r.score));
13
- const scoreSpread2 = Math.max(...results2.map((r: any) => r.score)) - Math.min(...results2.map((r: any) => r.score));
14
- return { scoreSpread1, scoreSpread2 };
15
- };
16
-
17
- describe('rerank', () => {
18
- beforeEach(() => {
19
- vi.clearAllMocks();
20
- });
21
-
22
- it('should throw an error if weights do not add up to 1', async () => {
23
- const results = [
24
- { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
25
- { id: '2', metadata: { text: 'Test result 2' }, score: 0.4 },
26
- { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
27
- ];
28
- await expect(
29
- rerank(results, 'test query', cohere('rerank-v3.5'), { weights: { semantic: 0.5, vector: 0.3, position: 0.5 } }),
30
- ).rejects.toThrow('Weights must add up to 1');
31
- });
32
-
33
- it('should rerank results with default weights', async () => {
34
- const results = [
35
- { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
36
- { id: '2', metadata: { text: 'Test result 2' }, score: 0.4 },
37
- { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
38
- ];
39
-
40
- const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), { topK: 2 });
41
-
42
- expect(rerankedResults).toHaveLength(2);
43
- expect(rerankedResults[0]).toStrictEqual({
44
- result: { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
45
- score: 0.8266666666666667,
46
- details: {
47
- semantic: 1,
48
- vector: 0.9,
49
- position: 0.33333333333333337,
50
- },
51
- });
52
- expect(rerankedResults[1]).toStrictEqual({
53
- result: { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
54
- score: 0.8,
55
- details: {
56
- semantic: 1,
57
- vector: 0.5,
58
- position: 1,
59
- },
60
- });
61
-
62
- const { scoreSpread1, scoreSpread2 } = getScoreSpreads(results, rerankedResults);
63
- expect(scoreSpread1).toBe(0.5);
64
- expect(scoreSpread2).toBe(0.026666666666666616);
65
- });
66
-
67
- it('should rerank results with custom weights', async () => {
68
- const results = [
69
- { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
70
- { id: '2', metadata: { text: 'Test result 2' }, score: 0.4 },
71
- { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
72
- ];
73
-
74
- const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), {
75
- weights: {
76
- semantic: 0.5,
77
- vector: 0.4,
78
- position: 0.1,
79
- },
80
- topK: 2,
81
- });
82
-
83
- expect(rerankedResults).toHaveLength(2);
84
- expect(rerankedResults[0]).toStrictEqual({
85
- result: { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
86
- score: 0.8933333333333334,
87
- details: {
88
- semantic: 1,
89
- vector: 0.9,
90
- position: 0.33333333333333337,
91
- },
92
- });
93
- expect(rerankedResults[1]).toStrictEqual({
94
- result: { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
95
- score: 0.7999999999999999,
96
- details: {
97
- semantic: 1,
98
- vector: 0.5,
99
- position: 1,
100
- },
101
- });
102
- const { scoreSpread1, scoreSpread2 } = getScoreSpreads(results, rerankedResults);
103
- expect(scoreSpread1).toBe(0.5);
104
- expect(scoreSpread2).toBe(0.09333333333333349);
105
- });
106
-
107
- it('should handle query embedding when provided', async () => {
108
- const results = [
109
- { id: '1', metadata: { text: 'Test result 1' }, score: 0.8 },
110
- { id: '2', metadata: { text: 'Test result 2' }, score: 0.6 },
111
- ];
112
-
113
- const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), {
114
- queryEmbedding: [0.5, 0.3, -0.2, 0.4],
115
- topK: 2,
116
- });
117
-
118
- // Ensure query embedding analysis is being applied (we don't know exact score without knowing internals)
119
- expect(rerankedResults).toHaveLength(2);
120
- expect(rerankedResults[0]).toStrictEqual({
121
- result: { id: '1', metadata: { text: 'Test result 1' }, score: 0.8 },
122
- score: 0.9200000000000002,
123
- details: {
124
- semantic: 1,
125
- vector: 0.8,
126
- position: 1,
127
- queryAnalysis: {
128
- magnitude: 0.7348469228349535,
129
- dominantFeatures: [0, 3, 1, 2],
130
- },
131
- },
132
- });
133
- expect(rerankedResults[1]).toStrictEqual({
134
- result: { id: '2', metadata: { text: 'Test result 2' }, score: 0.6 },
135
- score: 0.74,
136
- details: {
137
- semantic: 1,
138
- vector: 0.6,
139
- position: 0.5,
140
- queryAnalysis: {
141
- magnitude: 0.7348469228349535,
142
- dominantFeatures: [0, 3, 1, 2],
143
- },
144
- },
145
- });
146
- const { scoreSpread1, scoreSpread2 } = getScoreSpreads(results, rerankedResults);
147
- expect(scoreSpread1).toBe(0.20000000000000007);
148
- expect(scoreSpread2).toBe(0.18000000000000016);
149
- });
150
- });
@@ -1,198 +0,0 @@
1
- import type { MastraLanguageModel } from '@mastra/core/agent';
2
- import type { RelevanceScoreProvider } from '@mastra/core/relevance';
3
- import type { QueryResult } from '@mastra/core/vector';
4
- import { Big } from 'big.js';
5
- import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from './relevance';
6
-
7
- // Default weights for different scoring components (must add up to 1)
8
- const DEFAULT_WEIGHTS = {
9
- semantic: 0.4,
10
- vector: 0.4,
11
- position: 0.2,
12
- } as const;
13
-
14
- type WeightConfig = {
15
- semantic?: number;
16
- vector?: number;
17
- position?: number;
18
- };
19
-
20
- interface ScoringDetails {
21
- semantic: number;
22
- vector: number;
23
- position: number;
24
- queryAnalysis?: {
25
- magnitude: number;
26
- dominantFeatures: number[];
27
- };
28
- }
29
-
30
- export interface RerankResult {
31
- result: QueryResult;
32
- score: number;
33
- details: ScoringDetails;
34
- }
35
-
36
- // For use in the vector store tool
37
- export interface RerankerOptions {
38
- weights?: WeightConfig;
39
- topK?: number;
40
- }
41
-
42
- // For use in the rerank function
43
- export interface RerankerFunctionOptions {
44
- weights?: WeightConfig;
45
- queryEmbedding?: number[];
46
- topK?: number;
47
- }
48
-
49
- export interface RerankConfig {
50
- options?: RerankerOptions;
51
- model: MastraLanguageModel | RelevanceScoreProvider;
52
- }
53
-
54
- // Calculate position score based on position in original list
55
- function calculatePositionScore(position: number, totalChunks: number): number {
56
- return 1 - position / totalChunks;
57
- }
58
-
59
- // Analyze query embedding features if needed
60
- function analyzeQueryEmbedding(embedding: number[]): {
61
- magnitude: number;
62
- dominantFeatures: number[];
63
- } {
64
- // Calculate embedding magnitude
65
- const magnitude = Math.sqrt(embedding.reduce((sum, val) => sum + val * val, 0));
66
-
67
- // Find dominant features (highest absolute values)
68
- const dominantFeatures = embedding
69
- .map((value, index) => ({ value: Math.abs(value), index }))
70
- .sort((a, b) => b.value - a.value)
71
- .slice(0, 5)
72
- .map(item => item.index);
73
-
74
- return { magnitude, dominantFeatures };
75
- }
76
-
77
- // Adjust scores based on query characteristics
78
- function adjustScores(score: number, queryAnalysis: { magnitude: number; dominantFeatures: number[] }): number {
79
- const magnitudeAdjustment = queryAnalysis.magnitude > 10 ? 1.1 : 1;
80
-
81
- const featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
82
-
83
- return score * magnitudeAdjustment * featureStrengthAdjustment;
84
- }
85
-
86
- async function executeRerank({
87
- results,
88
- query,
89
- scorer,
90
- options,
91
- }: {
92
- results: QueryResult[];
93
- query: string;
94
- scorer: RelevanceScoreProvider;
95
- options: RerankerFunctionOptions;
96
- }) {
97
- const { queryEmbedding, topK = 3 } = options;
98
- const weights = {
99
- ...DEFAULT_WEIGHTS,
100
- ...options.weights,
101
- };
102
-
103
- //weights must add up to 1
104
- const sum = Object.values(weights).reduce((acc: Big, w: number) => acc.plus(w.toString()), new Big(0));
105
- if (!sum.eq(1)) {
106
- throw new Error(`Weights must add up to 1. Got ${sum} from ${weights}`);
107
- }
108
-
109
- const resultLength = results.length;
110
-
111
- const queryAnalysis = queryEmbedding ? analyzeQueryEmbedding(queryEmbedding) : null;
112
-
113
- // Get scores for each result
114
- const scoredResults = await Promise.all(
115
- results.map(async (result, index) => {
116
- // Get semantic score from chosen provider
117
- let semanticScore = 0;
118
- if (result?.metadata?.text) {
119
- semanticScore = await scorer.getRelevanceScore(query, result?.metadata?.text);
120
- }
121
-
122
- // Get existing vector score from result
123
- const vectorScore = result.score;
124
-
125
- // Get score of vector based on position in original list
126
- const positionScore = calculatePositionScore(index, resultLength);
127
-
128
- // Combine scores using weights for each component
129
- let finalScore =
130
- weights.semantic * semanticScore + weights.vector * vectorScore + weights.position * positionScore;
131
-
132
- if (queryAnalysis) {
133
- finalScore = adjustScores(finalScore, queryAnalysis);
134
- }
135
-
136
- return {
137
- result,
138
- score: finalScore,
139
- details: {
140
- semantic: semanticScore,
141
- vector: vectorScore,
142
- position: positionScore,
143
- ...(queryAnalysis && {
144
- queryAnalysis: {
145
- magnitude: queryAnalysis.magnitude,
146
- dominantFeatures: queryAnalysis.dominantFeatures,
147
- },
148
- }),
149
- },
150
- };
151
- }),
152
- );
153
-
154
- // Sort by score and take top K
155
- return scoredResults.sort((a, b) => b.score - a.score).slice(0, topK);
156
- }
157
-
158
- export async function rerankWithScorer({
159
- results,
160
- query,
161
- scorer,
162
- options,
163
- }: {
164
- results: QueryResult[];
165
- query: string;
166
- scorer: RelevanceScoreProvider;
167
- options: RerankerFunctionOptions;
168
- }): Promise<RerankResult[]> {
169
- return executeRerank({
170
- results,
171
- query,
172
- scorer,
173
- options,
174
- });
175
- }
176
-
177
- // Takes in a list of results from a vector store and reranks them based on semantic, vector, and position scores
178
- export async function rerank(
179
- results: QueryResult[],
180
- query: string,
181
- model: MastraLanguageModel,
182
- options: RerankerFunctionOptions,
183
- ): Promise<RerankResult[]> {
184
- let semanticProvider: RelevanceScoreProvider;
185
-
186
- if (model.modelId === 'rerank-v3.5') {
187
- semanticProvider = new CohereRelevanceScorer(model.modelId);
188
- } else {
189
- semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
190
- }
191
-
192
- return executeRerank({
193
- results,
194
- query,
195
- scorer: semanticProvider,
196
- options,
197
- });
198
- }