@mastra/rag 1.2.3-alpha.0 → 1.2.3-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/CHANGELOG.md +9 -0
  2. package/package.json +18 -5
  3. package/.turbo/turbo-build.log +0 -4
  4. package/docker-compose.yaml +0 -22
  5. package/eslint.config.js +0 -6
  6. package/src/document/document.test.ts +0 -2975
  7. package/src/document/document.ts +0 -335
  8. package/src/document/extractors/base.ts +0 -30
  9. package/src/document/extractors/index.ts +0 -5
  10. package/src/document/extractors/keywords.test.ts +0 -125
  11. package/src/document/extractors/keywords.ts +0 -126
  12. package/src/document/extractors/questions.test.ts +0 -120
  13. package/src/document/extractors/questions.ts +0 -111
  14. package/src/document/extractors/summary.test.ts +0 -107
  15. package/src/document/extractors/summary.ts +0 -122
  16. package/src/document/extractors/title.test.ts +0 -121
  17. package/src/document/extractors/title.ts +0 -185
  18. package/src/document/extractors/types.ts +0 -40
  19. package/src/document/index.ts +0 -2
  20. package/src/document/prompts/base.ts +0 -77
  21. package/src/document/prompts/format.ts +0 -9
  22. package/src/document/prompts/index.ts +0 -15
  23. package/src/document/prompts/prompt.ts +0 -60
  24. package/src/document/prompts/types.ts +0 -29
  25. package/src/document/schema/index.ts +0 -3
  26. package/src/document/schema/node.ts +0 -187
  27. package/src/document/schema/types.ts +0 -40
  28. package/src/document/transformers/character.ts +0 -267
  29. package/src/document/transformers/html.ts +0 -346
  30. package/src/document/transformers/json.ts +0 -536
  31. package/src/document/transformers/latex.ts +0 -11
  32. package/src/document/transformers/markdown.ts +0 -239
  33. package/src/document/transformers/semantic-markdown.ts +0 -227
  34. package/src/document/transformers/sentence.ts +0 -314
  35. package/src/document/transformers/text.ts +0 -158
  36. package/src/document/transformers/token.ts +0 -137
  37. package/src/document/transformers/transformer.ts +0 -5
  38. package/src/document/types.ts +0 -145
  39. package/src/document/validation.ts +0 -158
  40. package/src/graph-rag/index.test.ts +0 -235
  41. package/src/graph-rag/index.ts +0 -306
  42. package/src/index.ts +0 -8
  43. package/src/rerank/index.test.ts +0 -150
  44. package/src/rerank/index.ts +0 -198
  45. package/src/rerank/relevance/cohere/index.ts +0 -56
  46. package/src/rerank/relevance/index.ts +0 -3
  47. package/src/rerank/relevance/mastra-agent/index.ts +0 -32
  48. package/src/rerank/relevance/zeroentropy/index.ts +0 -26
  49. package/src/tools/README.md +0 -153
  50. package/src/tools/document-chunker.ts +0 -34
  51. package/src/tools/graph-rag.test.ts +0 -115
  52. package/src/tools/graph-rag.ts +0 -157
  53. package/src/tools/index.ts +0 -3
  54. package/src/tools/types.ts +0 -126
  55. package/src/tools/vector-query-database-config.test.ts +0 -190
  56. package/src/tools/vector-query.test.ts +0 -477
  57. package/src/tools/vector-query.ts +0 -171
  58. package/src/utils/convert-sources.ts +0 -43
  59. package/src/utils/default-settings.ts +0 -38
  60. package/src/utils/index.ts +0 -3
  61. package/src/utils/tool-schemas.ts +0 -38
  62. package/src/utils/vector-prompts.ts +0 -832
  63. package/src/utils/vector-search.ts +0 -130
  64. package/tsconfig.build.json +0 -9
  65. package/tsconfig.json +0 -5
  66. package/tsup.config.ts +0 -17
  67. package/vitest.config.ts +0 -8
@@ -1,306 +0,0 @@
1
- /**
2
- * TODO: GraphRAG Enhancements
3
- * - Add support for more edge types (sequential, hierarchical, citation, etc)
4
- * - Allow for custom edge types
5
- * - Utilize metadata for richer connections
6
- * - Improve graph traversal and querying using types
7
- */
8
-
9
- type SupportedEdgeType = 'semantic';
10
-
11
- // Types for graph nodes and edges
12
- export interface GraphNode {
13
- id: string;
14
- content: string;
15
- embedding?: number[];
16
- metadata?: Record<string, any>;
17
- }
18
-
19
- export interface RankedNode extends GraphNode {
20
- score: number;
21
- }
22
-
23
- export interface GraphEdge {
24
- source: string;
25
- target: string;
26
- weight: number;
27
- type: SupportedEdgeType;
28
- }
29
-
30
- export interface GraphChunk {
31
- text: string;
32
- metadata: Record<string, any>;
33
- }
34
-
35
- export interface GraphEmbedding {
36
- vector: number[];
37
- }
38
-
39
- export class GraphRAG {
40
- private nodes: Map<string, GraphNode>;
41
- private edges: GraphEdge[];
42
- private dimension: number;
43
- private threshold: number;
44
-
45
- constructor(dimension: number = 1536, threshold: number = 0.7) {
46
- this.nodes = new Map();
47
- this.edges = [];
48
- this.dimension = dimension;
49
- this.threshold = threshold;
50
- }
51
-
52
- // Add a node to the graph
53
- addNode(node: GraphNode): void {
54
- if (!node.embedding) {
55
- throw new Error('Node must have an embedding');
56
- }
57
- if (node.embedding.length !== this.dimension) {
58
- throw new Error(`Embedding dimension must be ${this.dimension}`);
59
- }
60
- this.nodes.set(node.id, node);
61
- }
62
-
63
- // Add an edge between two nodes
64
- addEdge(edge: GraphEdge): void {
65
- if (!this.nodes.has(edge.source) || !this.nodes.has(edge.target)) {
66
- throw new Error('Both source and target nodes must exist');
67
- }
68
- this.edges.push(edge);
69
- // Add reverse edge
70
- this.edges.push({
71
- source: edge.target,
72
- target: edge.source,
73
- weight: edge.weight,
74
- type: edge.type,
75
- });
76
- }
77
-
78
- // Helper method to get all nodes
79
- getNodes(): GraphNode[] {
80
- return Array.from(this.nodes.values());
81
- }
82
-
83
- // Helper method to get all edges
84
- getEdges(): GraphEdge[] {
85
- return this.edges;
86
- }
87
-
88
- getEdgesByType(type: string): GraphEdge[] {
89
- return this.edges.filter(edge => edge.type === type);
90
- }
91
-
92
- clear(): void {
93
- this.nodes.clear();
94
- this.edges = [];
95
- }
96
-
97
- updateNodeContent(id: string, newContent: string): void {
98
- const node = this.nodes.get(id);
99
- if (!node) {
100
- throw new Error(`Node ${id} not found`);
101
- }
102
- node.content = newContent;
103
- }
104
-
105
- // Get neighbors of a node
106
- private getNeighbors(nodeId: string, edgeType?: string): { id: string; weight: number }[] {
107
- return this.edges
108
- .filter(edge => edge.source === nodeId && (!edgeType || edge.type === edgeType))
109
- .map(edge => ({
110
- id: edge.target,
111
- weight: edge.weight,
112
- }))
113
- .filter(node => node !== undefined);
114
- }
115
-
116
- // Calculate cosine similarity between two vectors
117
- private cosineSimilarity(vec1: number[], vec2: number[]): number {
118
- if (!vec1 || !vec2) {
119
- throw new Error('Vectors must not be null or undefined');
120
- }
121
- const vectorLength = vec1.length;
122
-
123
- if (vectorLength !== vec2.length) {
124
- throw new Error(`Vector dimensions must match: vec1(${vec1.length}) !== vec2(${vec2.length})`);
125
- }
126
-
127
- let dotProduct = 0;
128
- let normVec1 = 0;
129
- let normVec2 = 0;
130
-
131
- for (let i = 0; i < vectorLength; i++) {
132
- const a = vec1[i]!; // Non-null assertion operator
133
- const b = vec2[i]!;
134
-
135
- dotProduct += a * b;
136
- normVec1 += a * a;
137
- normVec2 += b * b;
138
- }
139
- const magnitudeProduct = Math.sqrt(normVec1 * normVec2);
140
-
141
- if (magnitudeProduct === 0) {
142
- return 0;
143
- }
144
-
145
- const similarity = dotProduct / magnitudeProduct;
146
- return Math.max(-1, Math.min(1, similarity));
147
- }
148
-
149
- createGraph(chunks: GraphChunk[], embeddings: GraphEmbedding[]) {
150
- if (!chunks?.length || !embeddings?.length) {
151
- throw new Error('Chunks and embeddings arrays must not be empty');
152
- }
153
- if (chunks.length !== embeddings.length) {
154
- throw new Error('Chunks and embeddings must have the same length');
155
- }
156
- // Create nodes from chunks
157
- chunks.forEach((chunk, index) => {
158
- const node: GraphNode = {
159
- id: index.toString(),
160
- content: chunk.text,
161
- embedding: embeddings[index]?.vector,
162
- metadata: { ...chunk.metadata },
163
- };
164
- this.addNode(node);
165
- this.nodes.set(node.id, node);
166
- });
167
-
168
- // Create edges based on cosine similarity
169
- for (let i = 0; i < chunks.length; i++) {
170
- const firstEmbedding = embeddings[i]?.vector as number[];
171
- for (let j = i + 1; j < chunks.length; j++) {
172
- const secondEmbedding = embeddings[j]?.vector as number[];
173
- const similarity = this.cosineSimilarity(firstEmbedding, secondEmbedding);
174
-
175
- // Only create edges if similarity is above threshold
176
- if (similarity > this.threshold) {
177
- this.addEdge({
178
- source: i.toString(),
179
- target: j.toString(),
180
- weight: similarity,
181
- type: 'semantic',
182
- });
183
- }
184
- }
185
- }
186
- }
187
-
188
- private selectWeightedNeighbor(neighbors: Array<{ id: string; weight: number }>): string {
189
- // Sum all weights to normalize probabilities
190
- const totalWeight = neighbors.reduce((sum, n) => sum + n.weight, 0);
191
-
192
- // Pick a random point in the total weight range
193
- let remainingWeight = Math.random() * totalWeight;
194
-
195
- // Subtract each weight from our random value until we go below 0
196
- // Higher weights will make us go below 0 more often, making them more likely to be selected
197
- for (const neighbor of neighbors) {
198
- remainingWeight -= neighbor.weight;
199
- if (remainingWeight <= 0) {
200
- return neighbor.id;
201
- }
202
- }
203
-
204
- return neighbors[neighbors.length - 1]?.id as string;
205
- }
206
-
207
- // Perform random walk with restart
208
- private randomWalkWithRestart(startNodeId: string, steps: number, restartProb: number): Map<string, number> {
209
- const visits = new Map<string, number>();
210
- let currentNodeId = startNodeId;
211
-
212
- for (let step = 0; step < steps; step++) {
213
- // Record visit
214
- visits.set(currentNodeId, (visits.get(currentNodeId) || 0) + 1);
215
-
216
- // Decide whether to restart
217
- if (Math.random() < restartProb) {
218
- currentNodeId = startNodeId;
219
- continue;
220
- }
221
-
222
- // Get neighbors
223
- const neighbors = this.getNeighbors(currentNodeId);
224
- if (neighbors.length === 0) {
225
- currentNodeId = startNodeId;
226
- continue;
227
- }
228
-
229
- // Select random weighted neighbor and set as current node
230
- currentNodeId = this.selectWeightedNeighbor(neighbors);
231
- }
232
-
233
- // Normalize visits
234
- const totalVisits = Array.from(visits.values()).reduce((a, b) => a + b, 0);
235
- const normalizedVisits = new Map<string, number>();
236
- for (const [nodeId, count] of visits) {
237
- normalizedVisits.set(nodeId, count / totalVisits);
238
- }
239
-
240
- return normalizedVisits;
241
- }
242
-
243
- // Retrieve relevant nodes using hybrid approach
244
- query({
245
- query,
246
- topK = 10,
247
- randomWalkSteps = 100,
248
- restartProb = 0.15,
249
- }: {
250
- query: number[];
251
- topK?: number;
252
- randomWalkSteps?: number;
253
- restartProb?: number;
254
- }): RankedNode[] {
255
- if (!query || query.length !== this.dimension) {
256
- throw new Error(`Query embedding must have dimension ${this.dimension}`);
257
- }
258
- if (topK < 1) {
259
- throw new Error('TopK must be greater than 0');
260
- }
261
- if (randomWalkSteps < 1) {
262
- throw new Error('Random walk steps must be greater than 0');
263
- }
264
- if (restartProb <= 0 || restartProb >= 1) {
265
- throw new Error('Restart probability must be between 0 and 1');
266
- }
267
- // Retrieve nodes and calculate similarity
268
- const similarities = Array.from(this.nodes.values()).map(node => ({
269
- node,
270
- similarity: this.cosineSimilarity(query, node.embedding!),
271
- }));
272
-
273
- // Sort by similarity
274
- similarities.sort((a, b) => b.similarity - a.similarity);
275
- const topNodes = similarities.slice(0, topK);
276
-
277
- // Re-ranks nodes using random walk with restart
278
- const rerankedNodes = new Map<string, { node: GraphNode; score: number }>();
279
-
280
- // For each top node, perform random walk
281
- for (const { node, similarity } of topNodes) {
282
- const walkScores = this.randomWalkWithRestart(node.id, randomWalkSteps, restartProb);
283
-
284
- // Combine dense retrieval score with graph score
285
- for (const [nodeId, walkScore] of walkScores) {
286
- const node = this.nodes.get(nodeId)!;
287
- const existingScore = rerankedNodes.get(nodeId)?.score || 0;
288
- rerankedNodes.set(nodeId, {
289
- node,
290
- score: existingScore + similarity * walkScore,
291
- });
292
- }
293
- }
294
-
295
- // Sort by final score and return top K nodes
296
- return Array.from(rerankedNodes.values())
297
- .sort((a, b) => b.score - a.score)
298
- .slice(0, topK)
299
- .map(item => ({
300
- id: item.node.id,
301
- content: item.node.content,
302
- metadata: item.node.metadata,
303
- score: item.score,
304
- }));
305
- }
306
- }
package/src/index.ts DELETED
@@ -1,8 +0,0 @@
1
- export * from './document/document';
2
- export * from './document/types';
3
- export * from './rerank';
4
- export * from './rerank/relevance';
5
- export { GraphRAG } from './graph-rag';
6
- export * from './tools';
7
- export * from './utils/vector-prompts';
8
- export * from './utils/default-settings';
@@ -1,150 +0,0 @@
1
- import { cohere } from '@ai-sdk/cohere';
2
- import { describe, it, expect, vi, beforeEach } from 'vitest';
3
- import { CohereRelevanceScorer } from './relevance';
4
-
5
- import { rerank } from '.';
6
-
7
- vi.spyOn(CohereRelevanceScorer.prototype, 'getRelevanceScore').mockImplementation(async () => {
8
- return 1;
9
- });
10
-
11
- const getScoreSpreads = (results1: any, results2: any) => {
12
- const scoreSpread1 = Math.max(...results1.map((r: any) => r.score)) - Math.min(...results1.map((r: any) => r.score));
13
- const scoreSpread2 = Math.max(...results2.map((r: any) => r.score)) - Math.min(...results2.map((r: any) => r.score));
14
- return { scoreSpread1, scoreSpread2 };
15
- };
16
-
17
- describe('rerank', () => {
18
- beforeEach(() => {
19
- vi.clearAllMocks();
20
- });
21
-
22
- it('should throw an error if weights do not add up to 1', async () => {
23
- const results = [
24
- { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
25
- { id: '2', metadata: { text: 'Test result 2' }, score: 0.4 },
26
- { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
27
- ];
28
- await expect(
29
- rerank(results, 'test query', cohere('rerank-v3.5'), { weights: { semantic: 0.5, vector: 0.3, position: 0.5 } }),
30
- ).rejects.toThrow('Weights must add up to 1');
31
- });
32
-
33
- it('should rerank results with default weights', async () => {
34
- const results = [
35
- { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
36
- { id: '2', metadata: { text: 'Test result 2' }, score: 0.4 },
37
- { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
38
- ];
39
-
40
- const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), { topK: 2 });
41
-
42
- expect(rerankedResults).toHaveLength(2);
43
- expect(rerankedResults[0]).toStrictEqual({
44
- result: { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
45
- score: 0.8266666666666667,
46
- details: {
47
- semantic: 1,
48
- vector: 0.9,
49
- position: 0.33333333333333337,
50
- },
51
- });
52
- expect(rerankedResults[1]).toStrictEqual({
53
- result: { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
54
- score: 0.8,
55
- details: {
56
- semantic: 1,
57
- vector: 0.5,
58
- position: 1,
59
- },
60
- });
61
-
62
- const { scoreSpread1, scoreSpread2 } = getScoreSpreads(results, rerankedResults);
63
- expect(scoreSpread1).toBe(0.5);
64
- expect(scoreSpread2).toBe(0.026666666666666616);
65
- });
66
-
67
- it('should rerank results with custom weights', async () => {
68
- const results = [
69
- { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
70
- { id: '2', metadata: { text: 'Test result 2' }, score: 0.4 },
71
- { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
72
- ];
73
-
74
- const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), {
75
- weights: {
76
- semantic: 0.5,
77
- vector: 0.4,
78
- position: 0.1,
79
- },
80
- topK: 2,
81
- });
82
-
83
- expect(rerankedResults).toHaveLength(2);
84
- expect(rerankedResults[0]).toStrictEqual({
85
- result: { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
86
- score: 0.8933333333333334,
87
- details: {
88
- semantic: 1,
89
- vector: 0.9,
90
- position: 0.33333333333333337,
91
- },
92
- });
93
- expect(rerankedResults[1]).toStrictEqual({
94
- result: { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
95
- score: 0.7999999999999999,
96
- details: {
97
- semantic: 1,
98
- vector: 0.5,
99
- position: 1,
100
- },
101
- });
102
- const { scoreSpread1, scoreSpread2 } = getScoreSpreads(results, rerankedResults);
103
- expect(scoreSpread1).toBe(0.5);
104
- expect(scoreSpread2).toBe(0.09333333333333349);
105
- });
106
-
107
- it('should handle query embedding when provided', async () => {
108
- const results = [
109
- { id: '1', metadata: { text: 'Test result 1' }, score: 0.8 },
110
- { id: '2', metadata: { text: 'Test result 2' }, score: 0.6 },
111
- ];
112
-
113
- const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), {
114
- queryEmbedding: [0.5, 0.3, -0.2, 0.4],
115
- topK: 2,
116
- });
117
-
118
- // Ensure query embedding analysis is being applied (we don't know exact score without knowing internals)
119
- expect(rerankedResults).toHaveLength(2);
120
- expect(rerankedResults[0]).toStrictEqual({
121
- result: { id: '1', metadata: { text: 'Test result 1' }, score: 0.8 },
122
- score: 0.9200000000000002,
123
- details: {
124
- semantic: 1,
125
- vector: 0.8,
126
- position: 1,
127
- queryAnalysis: {
128
- magnitude: 0.7348469228349535,
129
- dominantFeatures: [0, 3, 1, 2],
130
- },
131
- },
132
- });
133
- expect(rerankedResults[1]).toStrictEqual({
134
- result: { id: '2', metadata: { text: 'Test result 2' }, score: 0.6 },
135
- score: 0.74,
136
- details: {
137
- semantic: 1,
138
- vector: 0.6,
139
- position: 0.5,
140
- queryAnalysis: {
141
- magnitude: 0.7348469228349535,
142
- dominantFeatures: [0, 3, 1, 2],
143
- },
144
- },
145
- });
146
- const { scoreSpread1, scoreSpread2 } = getScoreSpreads(results, rerankedResults);
147
- expect(scoreSpread1).toBe(0.20000000000000007);
148
- expect(scoreSpread2).toBe(0.18000000000000016);
149
- });
150
- });
@@ -1,198 +0,0 @@
1
- import type { MastraLanguageModel } from '@mastra/core/agent';
2
- import type { RelevanceScoreProvider } from '@mastra/core/relevance';
3
- import type { QueryResult } from '@mastra/core/vector';
4
- import { Big } from 'big.js';
5
- import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from './relevance';
6
-
7
- // Default weights for different scoring components (must add up to 1)
8
- const DEFAULT_WEIGHTS = {
9
- semantic: 0.4,
10
- vector: 0.4,
11
- position: 0.2,
12
- } as const;
13
-
14
- type WeightConfig = {
15
- semantic?: number;
16
- vector?: number;
17
- position?: number;
18
- };
19
-
20
- interface ScoringDetails {
21
- semantic: number;
22
- vector: number;
23
- position: number;
24
- queryAnalysis?: {
25
- magnitude: number;
26
- dominantFeatures: number[];
27
- };
28
- }
29
-
30
- export interface RerankResult {
31
- result: QueryResult;
32
- score: number;
33
- details: ScoringDetails;
34
- }
35
-
36
- // For use in the vector store tool
37
- export interface RerankerOptions {
38
- weights?: WeightConfig;
39
- topK?: number;
40
- }
41
-
42
- // For use in the rerank function
43
- export interface RerankerFunctionOptions {
44
- weights?: WeightConfig;
45
- queryEmbedding?: number[];
46
- topK?: number;
47
- }
48
-
49
- export interface RerankConfig {
50
- options?: RerankerOptions;
51
- model: MastraLanguageModel | RelevanceScoreProvider;
52
- }
53
-
54
- // Calculate position score based on position in original list
55
- function calculatePositionScore(position: number, totalChunks: number): number {
56
- return 1 - position / totalChunks;
57
- }
58
-
59
- // Analyze query embedding features if needed
60
- function analyzeQueryEmbedding(embedding: number[]): {
61
- magnitude: number;
62
- dominantFeatures: number[];
63
- } {
64
- // Calculate embedding magnitude
65
- const magnitude = Math.sqrt(embedding.reduce((sum, val) => sum + val * val, 0));
66
-
67
- // Find dominant features (highest absolute values)
68
- const dominantFeatures = embedding
69
- .map((value, index) => ({ value: Math.abs(value), index }))
70
- .sort((a, b) => b.value - a.value)
71
- .slice(0, 5)
72
- .map(item => item.index);
73
-
74
- return { magnitude, dominantFeatures };
75
- }
76
-
77
- // Adjust scores based on query characteristics
78
- function adjustScores(score: number, queryAnalysis: { magnitude: number; dominantFeatures: number[] }): number {
79
- const magnitudeAdjustment = queryAnalysis.magnitude > 10 ? 1.1 : 1;
80
-
81
- const featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
82
-
83
- return score * magnitudeAdjustment * featureStrengthAdjustment;
84
- }
85
-
86
- async function executeRerank({
87
- results,
88
- query,
89
- scorer,
90
- options,
91
- }: {
92
- results: QueryResult[];
93
- query: string;
94
- scorer: RelevanceScoreProvider;
95
- options: RerankerFunctionOptions;
96
- }) {
97
- const { queryEmbedding, topK = 3 } = options;
98
- const weights = {
99
- ...DEFAULT_WEIGHTS,
100
- ...options.weights,
101
- };
102
-
103
- //weights must add up to 1
104
- const sum = Object.values(weights).reduce((acc: Big, w: number) => acc.plus(w.toString()), new Big(0));
105
- if (!sum.eq(1)) {
106
- throw new Error(`Weights must add up to 1. Got ${sum} from ${weights}`);
107
- }
108
-
109
- const resultLength = results.length;
110
-
111
- const queryAnalysis = queryEmbedding ? analyzeQueryEmbedding(queryEmbedding) : null;
112
-
113
- // Get scores for each result
114
- const scoredResults = await Promise.all(
115
- results.map(async (result, index) => {
116
- // Get semantic score from chosen provider
117
- let semanticScore = 0;
118
- if (result?.metadata?.text) {
119
- semanticScore = await scorer.getRelevanceScore(query, result?.metadata?.text);
120
- }
121
-
122
- // Get existing vector score from result
123
- const vectorScore = result.score;
124
-
125
- // Get score of vector based on position in original list
126
- const positionScore = calculatePositionScore(index, resultLength);
127
-
128
- // Combine scores using weights for each component
129
- let finalScore =
130
- weights.semantic * semanticScore + weights.vector * vectorScore + weights.position * positionScore;
131
-
132
- if (queryAnalysis) {
133
- finalScore = adjustScores(finalScore, queryAnalysis);
134
- }
135
-
136
- return {
137
- result,
138
- score: finalScore,
139
- details: {
140
- semantic: semanticScore,
141
- vector: vectorScore,
142
- position: positionScore,
143
- ...(queryAnalysis && {
144
- queryAnalysis: {
145
- magnitude: queryAnalysis.magnitude,
146
- dominantFeatures: queryAnalysis.dominantFeatures,
147
- },
148
- }),
149
- },
150
- };
151
- }),
152
- );
153
-
154
- // Sort by score and take top K
155
- return scoredResults.sort((a, b) => b.score - a.score).slice(0, topK);
156
- }
157
-
158
- export async function rerankWithScorer({
159
- results,
160
- query,
161
- scorer,
162
- options,
163
- }: {
164
- results: QueryResult[];
165
- query: string;
166
- scorer: RelevanceScoreProvider;
167
- options: RerankerFunctionOptions;
168
- }): Promise<RerankResult[]> {
169
- return executeRerank({
170
- results,
171
- query,
172
- scorer,
173
- options,
174
- });
175
- }
176
-
177
- // Takes in a list of results from a vector store and reranks them based on semantic, vector, and position scores
178
- export async function rerank(
179
- results: QueryResult[],
180
- query: string,
181
- model: MastraLanguageModel,
182
- options: RerankerFunctionOptions,
183
- ): Promise<RerankResult[]> {
184
- let semanticProvider: RelevanceScoreProvider;
185
-
186
- if (model.modelId === 'rerank-v3.5') {
187
- semanticProvider = new CohereRelevanceScorer(model.modelId);
188
- } else {
189
- semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
190
- }
191
-
192
- return executeRerank({
193
- results,
194
- query,
195
- scorer: semanticProvider,
196
- options,
197
- });
198
- }