@mastra/rag 1.2.2 → 1.2.3-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +22 -0
- package/dist/index.cjs +25 -9
- package/dist/index.cjs.map +1 -1
- package/dist/index.js +25 -9
- package/dist/index.js.map +1 -1
- package/dist/tools/graph-rag.d.ts.map +1 -1
- package/dist/tools/types.d.ts +18 -5
- package/dist/tools/types.d.ts.map +1 -1
- package/dist/tools/vector-query.d.ts.map +1 -1
- package/dist/utils/vector-search.d.ts +6 -7
- package/dist/utils/vector-search.d.ts.map +1 -1
- package/package.json +19 -6
- package/.turbo/turbo-build.log +0 -4
- package/docker-compose.yaml +0 -22
- package/eslint.config.js +0 -6
- package/src/document/document.test.ts +0 -2975
- package/src/document/document.ts +0 -335
- package/src/document/extractors/base.ts +0 -30
- package/src/document/extractors/index.ts +0 -5
- package/src/document/extractors/keywords.test.ts +0 -125
- package/src/document/extractors/keywords.ts +0 -126
- package/src/document/extractors/questions.test.ts +0 -120
- package/src/document/extractors/questions.ts +0 -111
- package/src/document/extractors/summary.test.ts +0 -107
- package/src/document/extractors/summary.ts +0 -122
- package/src/document/extractors/title.test.ts +0 -121
- package/src/document/extractors/title.ts +0 -185
- package/src/document/extractors/types.ts +0 -40
- package/src/document/index.ts +0 -2
- package/src/document/prompts/base.ts +0 -77
- package/src/document/prompts/format.ts +0 -9
- package/src/document/prompts/index.ts +0 -15
- package/src/document/prompts/prompt.ts +0 -60
- package/src/document/prompts/types.ts +0 -29
- package/src/document/schema/index.ts +0 -3
- package/src/document/schema/node.ts +0 -187
- package/src/document/schema/types.ts +0 -40
- package/src/document/transformers/character.ts +0 -267
- package/src/document/transformers/html.ts +0 -346
- package/src/document/transformers/json.ts +0 -536
- package/src/document/transformers/latex.ts +0 -11
- package/src/document/transformers/markdown.ts +0 -239
- package/src/document/transformers/semantic-markdown.ts +0 -227
- package/src/document/transformers/sentence.ts +0 -314
- package/src/document/transformers/text.ts +0 -158
- package/src/document/transformers/token.ts +0 -137
- package/src/document/transformers/transformer.ts +0 -5
- package/src/document/types.ts +0 -145
- package/src/document/validation.ts +0 -158
- package/src/graph-rag/index.test.ts +0 -235
- package/src/graph-rag/index.ts +0 -306
- package/src/index.ts +0 -8
- package/src/rerank/index.test.ts +0 -150
- package/src/rerank/index.ts +0 -198
- package/src/rerank/relevance/cohere/index.ts +0 -56
- package/src/rerank/relevance/index.ts +0 -3
- package/src/rerank/relevance/mastra-agent/index.ts +0 -32
- package/src/rerank/relevance/zeroentropy/index.ts +0 -26
- package/src/tools/README.md +0 -153
- package/src/tools/document-chunker.ts +0 -34
- package/src/tools/graph-rag.test.ts +0 -115
- package/src/tools/graph-rag.ts +0 -154
- package/src/tools/index.ts +0 -3
- package/src/tools/types.ts +0 -110
- package/src/tools/vector-query-database-config.test.ts +0 -190
- package/src/tools/vector-query.test.ts +0 -418
- package/src/tools/vector-query.ts +0 -169
- package/src/utils/convert-sources.ts +0 -43
- package/src/utils/default-settings.ts +0 -38
- package/src/utils/index.ts +0 -3
- package/src/utils/tool-schemas.ts +0 -38
- package/src/utils/vector-prompts.ts +0 -832
- package/src/utils/vector-search.ts +0 -117
- package/tsconfig.build.json +0 -9
- package/tsconfig.json +0 -5
- package/tsup.config.ts +0 -17
- package/vitest.config.ts +0 -8
package/src/graph-rag/index.ts
DELETED
|
@@ -1,306 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* TODO: GraphRAG Enhancements
|
|
3
|
-
* - Add support for more edge types (sequential, hierarchical, citation, etc)
|
|
4
|
-
* - Allow for custom edge types
|
|
5
|
-
* - Utilize metadata for richer connections
|
|
6
|
-
* - Improve graph traversal and querying using types
|
|
7
|
-
*/
|
|
8
|
-
|
|
9
|
-
type SupportedEdgeType = 'semantic';
|
|
10
|
-
|
|
11
|
-
// Types for graph nodes and edges
|
|
12
|
-
export interface GraphNode {
|
|
13
|
-
id: string;
|
|
14
|
-
content: string;
|
|
15
|
-
embedding?: number[];
|
|
16
|
-
metadata?: Record<string, any>;
|
|
17
|
-
}
|
|
18
|
-
|
|
19
|
-
export interface RankedNode extends GraphNode {
|
|
20
|
-
score: number;
|
|
21
|
-
}
|
|
22
|
-
|
|
23
|
-
export interface GraphEdge {
|
|
24
|
-
source: string;
|
|
25
|
-
target: string;
|
|
26
|
-
weight: number;
|
|
27
|
-
type: SupportedEdgeType;
|
|
28
|
-
}
|
|
29
|
-
|
|
30
|
-
export interface GraphChunk {
|
|
31
|
-
text: string;
|
|
32
|
-
metadata: Record<string, any>;
|
|
33
|
-
}
|
|
34
|
-
|
|
35
|
-
export interface GraphEmbedding {
|
|
36
|
-
vector: number[];
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
export class GraphRAG {
|
|
40
|
-
private nodes: Map<string, GraphNode>;
|
|
41
|
-
private edges: GraphEdge[];
|
|
42
|
-
private dimension: number;
|
|
43
|
-
private threshold: number;
|
|
44
|
-
|
|
45
|
-
constructor(dimension: number = 1536, threshold: number = 0.7) {
|
|
46
|
-
this.nodes = new Map();
|
|
47
|
-
this.edges = [];
|
|
48
|
-
this.dimension = dimension;
|
|
49
|
-
this.threshold = threshold;
|
|
50
|
-
}
|
|
51
|
-
|
|
52
|
-
// Add a node to the graph
|
|
53
|
-
addNode(node: GraphNode): void {
|
|
54
|
-
if (!node.embedding) {
|
|
55
|
-
throw new Error('Node must have an embedding');
|
|
56
|
-
}
|
|
57
|
-
if (node.embedding.length !== this.dimension) {
|
|
58
|
-
throw new Error(`Embedding dimension must be ${this.dimension}`);
|
|
59
|
-
}
|
|
60
|
-
this.nodes.set(node.id, node);
|
|
61
|
-
}
|
|
62
|
-
|
|
63
|
-
// Add an edge between two nodes
|
|
64
|
-
addEdge(edge: GraphEdge): void {
|
|
65
|
-
if (!this.nodes.has(edge.source) || !this.nodes.has(edge.target)) {
|
|
66
|
-
throw new Error('Both source and target nodes must exist');
|
|
67
|
-
}
|
|
68
|
-
this.edges.push(edge);
|
|
69
|
-
// Add reverse edge
|
|
70
|
-
this.edges.push({
|
|
71
|
-
source: edge.target,
|
|
72
|
-
target: edge.source,
|
|
73
|
-
weight: edge.weight,
|
|
74
|
-
type: edge.type,
|
|
75
|
-
});
|
|
76
|
-
}
|
|
77
|
-
|
|
78
|
-
// Helper method to get all nodes
|
|
79
|
-
getNodes(): GraphNode[] {
|
|
80
|
-
return Array.from(this.nodes.values());
|
|
81
|
-
}
|
|
82
|
-
|
|
83
|
-
// Helper method to get all edges
|
|
84
|
-
getEdges(): GraphEdge[] {
|
|
85
|
-
return this.edges;
|
|
86
|
-
}
|
|
87
|
-
|
|
88
|
-
getEdgesByType(type: string): GraphEdge[] {
|
|
89
|
-
return this.edges.filter(edge => edge.type === type);
|
|
90
|
-
}
|
|
91
|
-
|
|
92
|
-
clear(): void {
|
|
93
|
-
this.nodes.clear();
|
|
94
|
-
this.edges = [];
|
|
95
|
-
}
|
|
96
|
-
|
|
97
|
-
updateNodeContent(id: string, newContent: string): void {
|
|
98
|
-
const node = this.nodes.get(id);
|
|
99
|
-
if (!node) {
|
|
100
|
-
throw new Error(`Node ${id} not found`);
|
|
101
|
-
}
|
|
102
|
-
node.content = newContent;
|
|
103
|
-
}
|
|
104
|
-
|
|
105
|
-
// Get neighbors of a node
|
|
106
|
-
private getNeighbors(nodeId: string, edgeType?: string): { id: string; weight: number }[] {
|
|
107
|
-
return this.edges
|
|
108
|
-
.filter(edge => edge.source === nodeId && (!edgeType || edge.type === edgeType))
|
|
109
|
-
.map(edge => ({
|
|
110
|
-
id: edge.target,
|
|
111
|
-
weight: edge.weight,
|
|
112
|
-
}))
|
|
113
|
-
.filter(node => node !== undefined);
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
// Calculate cosine similarity between two vectors
|
|
117
|
-
private cosineSimilarity(vec1: number[], vec2: number[]): number {
|
|
118
|
-
if (!vec1 || !vec2) {
|
|
119
|
-
throw new Error('Vectors must not be null or undefined');
|
|
120
|
-
}
|
|
121
|
-
const vectorLength = vec1.length;
|
|
122
|
-
|
|
123
|
-
if (vectorLength !== vec2.length) {
|
|
124
|
-
throw new Error(`Vector dimensions must match: vec1(${vec1.length}) !== vec2(${vec2.length})`);
|
|
125
|
-
}
|
|
126
|
-
|
|
127
|
-
let dotProduct = 0;
|
|
128
|
-
let normVec1 = 0;
|
|
129
|
-
let normVec2 = 0;
|
|
130
|
-
|
|
131
|
-
for (let i = 0; i < vectorLength; i++) {
|
|
132
|
-
const a = vec1[i]!; // Non-null assertion operator
|
|
133
|
-
const b = vec2[i]!;
|
|
134
|
-
|
|
135
|
-
dotProduct += a * b;
|
|
136
|
-
normVec1 += a * a;
|
|
137
|
-
normVec2 += b * b;
|
|
138
|
-
}
|
|
139
|
-
const magnitudeProduct = Math.sqrt(normVec1 * normVec2);
|
|
140
|
-
|
|
141
|
-
if (magnitudeProduct === 0) {
|
|
142
|
-
return 0;
|
|
143
|
-
}
|
|
144
|
-
|
|
145
|
-
const similarity = dotProduct / magnitudeProduct;
|
|
146
|
-
return Math.max(-1, Math.min(1, similarity));
|
|
147
|
-
}
|
|
148
|
-
|
|
149
|
-
createGraph(chunks: GraphChunk[], embeddings: GraphEmbedding[]) {
|
|
150
|
-
if (!chunks?.length || !embeddings?.length) {
|
|
151
|
-
throw new Error('Chunks and embeddings arrays must not be empty');
|
|
152
|
-
}
|
|
153
|
-
if (chunks.length !== embeddings.length) {
|
|
154
|
-
throw new Error('Chunks and embeddings must have the same length');
|
|
155
|
-
}
|
|
156
|
-
// Create nodes from chunks
|
|
157
|
-
chunks.forEach((chunk, index) => {
|
|
158
|
-
const node: GraphNode = {
|
|
159
|
-
id: index.toString(),
|
|
160
|
-
content: chunk.text,
|
|
161
|
-
embedding: embeddings[index]?.vector,
|
|
162
|
-
metadata: { ...chunk.metadata },
|
|
163
|
-
};
|
|
164
|
-
this.addNode(node);
|
|
165
|
-
this.nodes.set(node.id, node);
|
|
166
|
-
});
|
|
167
|
-
|
|
168
|
-
// Create edges based on cosine similarity
|
|
169
|
-
for (let i = 0; i < chunks.length; i++) {
|
|
170
|
-
const firstEmbedding = embeddings[i]?.vector as number[];
|
|
171
|
-
for (let j = i + 1; j < chunks.length; j++) {
|
|
172
|
-
const secondEmbedding = embeddings[j]?.vector as number[];
|
|
173
|
-
const similarity = this.cosineSimilarity(firstEmbedding, secondEmbedding);
|
|
174
|
-
|
|
175
|
-
// Only create edges if similarity is above threshold
|
|
176
|
-
if (similarity > this.threshold) {
|
|
177
|
-
this.addEdge({
|
|
178
|
-
source: i.toString(),
|
|
179
|
-
target: j.toString(),
|
|
180
|
-
weight: similarity,
|
|
181
|
-
type: 'semantic',
|
|
182
|
-
});
|
|
183
|
-
}
|
|
184
|
-
}
|
|
185
|
-
}
|
|
186
|
-
}
|
|
187
|
-
|
|
188
|
-
private selectWeightedNeighbor(neighbors: Array<{ id: string; weight: number }>): string {
|
|
189
|
-
// Sum all weights to normalize probabilities
|
|
190
|
-
const totalWeight = neighbors.reduce((sum, n) => sum + n.weight, 0);
|
|
191
|
-
|
|
192
|
-
// Pick a random point in the total weight range
|
|
193
|
-
let remainingWeight = Math.random() * totalWeight;
|
|
194
|
-
|
|
195
|
-
// Subtract each weight from our random value until we go below 0
|
|
196
|
-
// Higher weights will make us go below 0 more often, making them more likely to be selected
|
|
197
|
-
for (const neighbor of neighbors) {
|
|
198
|
-
remainingWeight -= neighbor.weight;
|
|
199
|
-
if (remainingWeight <= 0) {
|
|
200
|
-
return neighbor.id;
|
|
201
|
-
}
|
|
202
|
-
}
|
|
203
|
-
|
|
204
|
-
return neighbors[neighbors.length - 1]?.id as string;
|
|
205
|
-
}
|
|
206
|
-
|
|
207
|
-
// Perform random walk with restart
|
|
208
|
-
private randomWalkWithRestart(startNodeId: string, steps: number, restartProb: number): Map<string, number> {
|
|
209
|
-
const visits = new Map<string, number>();
|
|
210
|
-
let currentNodeId = startNodeId;
|
|
211
|
-
|
|
212
|
-
for (let step = 0; step < steps; step++) {
|
|
213
|
-
// Record visit
|
|
214
|
-
visits.set(currentNodeId, (visits.get(currentNodeId) || 0) + 1);
|
|
215
|
-
|
|
216
|
-
// Decide whether to restart
|
|
217
|
-
if (Math.random() < restartProb) {
|
|
218
|
-
currentNodeId = startNodeId;
|
|
219
|
-
continue;
|
|
220
|
-
}
|
|
221
|
-
|
|
222
|
-
// Get neighbors
|
|
223
|
-
const neighbors = this.getNeighbors(currentNodeId);
|
|
224
|
-
if (neighbors.length === 0) {
|
|
225
|
-
currentNodeId = startNodeId;
|
|
226
|
-
continue;
|
|
227
|
-
}
|
|
228
|
-
|
|
229
|
-
// Select random weighted neighbor and set as current node
|
|
230
|
-
currentNodeId = this.selectWeightedNeighbor(neighbors);
|
|
231
|
-
}
|
|
232
|
-
|
|
233
|
-
// Normalize visits
|
|
234
|
-
const totalVisits = Array.from(visits.values()).reduce((a, b) => a + b, 0);
|
|
235
|
-
const normalizedVisits = new Map<string, number>();
|
|
236
|
-
for (const [nodeId, count] of visits) {
|
|
237
|
-
normalizedVisits.set(nodeId, count / totalVisits);
|
|
238
|
-
}
|
|
239
|
-
|
|
240
|
-
return normalizedVisits;
|
|
241
|
-
}
|
|
242
|
-
|
|
243
|
-
// Retrieve relevant nodes using hybrid approach
|
|
244
|
-
query({
|
|
245
|
-
query,
|
|
246
|
-
topK = 10,
|
|
247
|
-
randomWalkSteps = 100,
|
|
248
|
-
restartProb = 0.15,
|
|
249
|
-
}: {
|
|
250
|
-
query: number[];
|
|
251
|
-
topK?: number;
|
|
252
|
-
randomWalkSteps?: number;
|
|
253
|
-
restartProb?: number;
|
|
254
|
-
}): RankedNode[] {
|
|
255
|
-
if (!query || query.length !== this.dimension) {
|
|
256
|
-
throw new Error(`Query embedding must have dimension ${this.dimension}`);
|
|
257
|
-
}
|
|
258
|
-
if (topK < 1) {
|
|
259
|
-
throw new Error('TopK must be greater than 0');
|
|
260
|
-
}
|
|
261
|
-
if (randomWalkSteps < 1) {
|
|
262
|
-
throw new Error('Random walk steps must be greater than 0');
|
|
263
|
-
}
|
|
264
|
-
if (restartProb <= 0 || restartProb >= 1) {
|
|
265
|
-
throw new Error('Restart probability must be between 0 and 1');
|
|
266
|
-
}
|
|
267
|
-
// Retrieve nodes and calculate similarity
|
|
268
|
-
const similarities = Array.from(this.nodes.values()).map(node => ({
|
|
269
|
-
node,
|
|
270
|
-
similarity: this.cosineSimilarity(query, node.embedding!),
|
|
271
|
-
}));
|
|
272
|
-
|
|
273
|
-
// Sort by similarity
|
|
274
|
-
similarities.sort((a, b) => b.similarity - a.similarity);
|
|
275
|
-
const topNodes = similarities.slice(0, topK);
|
|
276
|
-
|
|
277
|
-
// Re-ranks nodes using random walk with restart
|
|
278
|
-
const rerankedNodes = new Map<string, { node: GraphNode; score: number }>();
|
|
279
|
-
|
|
280
|
-
// For each top node, perform random walk
|
|
281
|
-
for (const { node, similarity } of topNodes) {
|
|
282
|
-
const walkScores = this.randomWalkWithRestart(node.id, randomWalkSteps, restartProb);
|
|
283
|
-
|
|
284
|
-
// Combine dense retrieval score with graph score
|
|
285
|
-
for (const [nodeId, walkScore] of walkScores) {
|
|
286
|
-
const node = this.nodes.get(nodeId)!;
|
|
287
|
-
const existingScore = rerankedNodes.get(nodeId)?.score || 0;
|
|
288
|
-
rerankedNodes.set(nodeId, {
|
|
289
|
-
node,
|
|
290
|
-
score: existingScore + similarity * walkScore,
|
|
291
|
-
});
|
|
292
|
-
}
|
|
293
|
-
}
|
|
294
|
-
|
|
295
|
-
// Sort by final score and return top K nodes
|
|
296
|
-
return Array.from(rerankedNodes.values())
|
|
297
|
-
.sort((a, b) => b.score - a.score)
|
|
298
|
-
.slice(0, topK)
|
|
299
|
-
.map(item => ({
|
|
300
|
-
id: item.node.id,
|
|
301
|
-
content: item.node.content,
|
|
302
|
-
metadata: item.node.metadata,
|
|
303
|
-
score: item.score,
|
|
304
|
-
}));
|
|
305
|
-
}
|
|
306
|
-
}
|
package/src/index.ts
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
export * from './document/document';
|
|
2
|
-
export * from './document/types';
|
|
3
|
-
export * from './rerank';
|
|
4
|
-
export * from './rerank/relevance';
|
|
5
|
-
export { GraphRAG } from './graph-rag';
|
|
6
|
-
export * from './tools';
|
|
7
|
-
export * from './utils/vector-prompts';
|
|
8
|
-
export * from './utils/default-settings';
|
package/src/rerank/index.test.ts
DELETED
|
@@ -1,150 +0,0 @@
|
|
|
1
|
-
import { cohere } from '@ai-sdk/cohere';
|
|
2
|
-
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
3
|
-
import { CohereRelevanceScorer } from './relevance';
|
|
4
|
-
|
|
5
|
-
import { rerank } from '.';
|
|
6
|
-
|
|
7
|
-
vi.spyOn(CohereRelevanceScorer.prototype, 'getRelevanceScore').mockImplementation(async () => {
|
|
8
|
-
return 1;
|
|
9
|
-
});
|
|
10
|
-
|
|
11
|
-
const getScoreSpreads = (results1: any, results2: any) => {
|
|
12
|
-
const scoreSpread1 = Math.max(...results1.map((r: any) => r.score)) - Math.min(...results1.map((r: any) => r.score));
|
|
13
|
-
const scoreSpread2 = Math.max(...results2.map((r: any) => r.score)) - Math.min(...results2.map((r: any) => r.score));
|
|
14
|
-
return { scoreSpread1, scoreSpread2 };
|
|
15
|
-
};
|
|
16
|
-
|
|
17
|
-
describe('rerank', () => {
|
|
18
|
-
beforeEach(() => {
|
|
19
|
-
vi.clearAllMocks();
|
|
20
|
-
});
|
|
21
|
-
|
|
22
|
-
it('should throw an error if weights do not add up to 1', async () => {
|
|
23
|
-
const results = [
|
|
24
|
-
{ id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
|
|
25
|
-
{ id: '2', metadata: { text: 'Test result 2' }, score: 0.4 },
|
|
26
|
-
{ id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
|
|
27
|
-
];
|
|
28
|
-
await expect(
|
|
29
|
-
rerank(results, 'test query', cohere('rerank-v3.5'), { weights: { semantic: 0.5, vector: 0.3, position: 0.5 } }),
|
|
30
|
-
).rejects.toThrow('Weights must add up to 1');
|
|
31
|
-
});
|
|
32
|
-
|
|
33
|
-
it('should rerank results with default weights', async () => {
|
|
34
|
-
const results = [
|
|
35
|
-
{ id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
|
|
36
|
-
{ id: '2', metadata: { text: 'Test result 2' }, score: 0.4 },
|
|
37
|
-
{ id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
|
|
38
|
-
];
|
|
39
|
-
|
|
40
|
-
const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), { topK: 2 });
|
|
41
|
-
|
|
42
|
-
expect(rerankedResults).toHaveLength(2);
|
|
43
|
-
expect(rerankedResults[0]).toStrictEqual({
|
|
44
|
-
result: { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
|
|
45
|
-
score: 0.8266666666666667,
|
|
46
|
-
details: {
|
|
47
|
-
semantic: 1,
|
|
48
|
-
vector: 0.9,
|
|
49
|
-
position: 0.33333333333333337,
|
|
50
|
-
},
|
|
51
|
-
});
|
|
52
|
-
expect(rerankedResults[1]).toStrictEqual({
|
|
53
|
-
result: { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
|
|
54
|
-
score: 0.8,
|
|
55
|
-
details: {
|
|
56
|
-
semantic: 1,
|
|
57
|
-
vector: 0.5,
|
|
58
|
-
position: 1,
|
|
59
|
-
},
|
|
60
|
-
});
|
|
61
|
-
|
|
62
|
-
const { scoreSpread1, scoreSpread2 } = getScoreSpreads(results, rerankedResults);
|
|
63
|
-
expect(scoreSpread1).toBe(0.5);
|
|
64
|
-
expect(scoreSpread2).toBe(0.026666666666666616);
|
|
65
|
-
});
|
|
66
|
-
|
|
67
|
-
it('should rerank results with custom weights', async () => {
|
|
68
|
-
const results = [
|
|
69
|
-
{ id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
|
|
70
|
-
{ id: '2', metadata: { text: 'Test result 2' }, score: 0.4 },
|
|
71
|
-
{ id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
|
|
72
|
-
];
|
|
73
|
-
|
|
74
|
-
const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), {
|
|
75
|
-
weights: {
|
|
76
|
-
semantic: 0.5,
|
|
77
|
-
vector: 0.4,
|
|
78
|
-
position: 0.1,
|
|
79
|
-
},
|
|
80
|
-
topK: 2,
|
|
81
|
-
});
|
|
82
|
-
|
|
83
|
-
expect(rerankedResults).toHaveLength(2);
|
|
84
|
-
expect(rerankedResults[0]).toStrictEqual({
|
|
85
|
-
result: { id: '3', metadata: { text: 'Test result 3' }, score: 0.9 },
|
|
86
|
-
score: 0.8933333333333334,
|
|
87
|
-
details: {
|
|
88
|
-
semantic: 1,
|
|
89
|
-
vector: 0.9,
|
|
90
|
-
position: 0.33333333333333337,
|
|
91
|
-
},
|
|
92
|
-
});
|
|
93
|
-
expect(rerankedResults[1]).toStrictEqual({
|
|
94
|
-
result: { id: '1', metadata: { text: 'Test result 1' }, score: 0.5 },
|
|
95
|
-
score: 0.7999999999999999,
|
|
96
|
-
details: {
|
|
97
|
-
semantic: 1,
|
|
98
|
-
vector: 0.5,
|
|
99
|
-
position: 1,
|
|
100
|
-
},
|
|
101
|
-
});
|
|
102
|
-
const { scoreSpread1, scoreSpread2 } = getScoreSpreads(results, rerankedResults);
|
|
103
|
-
expect(scoreSpread1).toBe(0.5);
|
|
104
|
-
expect(scoreSpread2).toBe(0.09333333333333349);
|
|
105
|
-
});
|
|
106
|
-
|
|
107
|
-
it('should handle query embedding when provided', async () => {
|
|
108
|
-
const results = [
|
|
109
|
-
{ id: '1', metadata: { text: 'Test result 1' }, score: 0.8 },
|
|
110
|
-
{ id: '2', metadata: { text: 'Test result 2' }, score: 0.6 },
|
|
111
|
-
];
|
|
112
|
-
|
|
113
|
-
const rerankedResults = await rerank(results, 'test query', cohere('rerank-v3.5'), {
|
|
114
|
-
queryEmbedding: [0.5, 0.3, -0.2, 0.4],
|
|
115
|
-
topK: 2,
|
|
116
|
-
});
|
|
117
|
-
|
|
118
|
-
// Ensure query embedding analysis is being applied (we don't know exact score without knowing internals)
|
|
119
|
-
expect(rerankedResults).toHaveLength(2);
|
|
120
|
-
expect(rerankedResults[0]).toStrictEqual({
|
|
121
|
-
result: { id: '1', metadata: { text: 'Test result 1' }, score: 0.8 },
|
|
122
|
-
score: 0.9200000000000002,
|
|
123
|
-
details: {
|
|
124
|
-
semantic: 1,
|
|
125
|
-
vector: 0.8,
|
|
126
|
-
position: 1,
|
|
127
|
-
queryAnalysis: {
|
|
128
|
-
magnitude: 0.7348469228349535,
|
|
129
|
-
dominantFeatures: [0, 3, 1, 2],
|
|
130
|
-
},
|
|
131
|
-
},
|
|
132
|
-
});
|
|
133
|
-
expect(rerankedResults[1]).toStrictEqual({
|
|
134
|
-
result: { id: '2', metadata: { text: 'Test result 2' }, score: 0.6 },
|
|
135
|
-
score: 0.74,
|
|
136
|
-
details: {
|
|
137
|
-
semantic: 1,
|
|
138
|
-
vector: 0.6,
|
|
139
|
-
position: 0.5,
|
|
140
|
-
queryAnalysis: {
|
|
141
|
-
magnitude: 0.7348469228349535,
|
|
142
|
-
dominantFeatures: [0, 3, 1, 2],
|
|
143
|
-
},
|
|
144
|
-
},
|
|
145
|
-
});
|
|
146
|
-
const { scoreSpread1, scoreSpread2 } = getScoreSpreads(results, rerankedResults);
|
|
147
|
-
expect(scoreSpread1).toBe(0.20000000000000007);
|
|
148
|
-
expect(scoreSpread2).toBe(0.18000000000000016);
|
|
149
|
-
});
|
|
150
|
-
});
|
package/src/rerank/index.ts
DELETED
|
@@ -1,198 +0,0 @@
|
|
|
1
|
-
import type { MastraLanguageModel } from '@mastra/core/agent';
|
|
2
|
-
import type { RelevanceScoreProvider } from '@mastra/core/relevance';
|
|
3
|
-
import type { QueryResult } from '@mastra/core/vector';
|
|
4
|
-
import { Big } from 'big.js';
|
|
5
|
-
import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from './relevance';
|
|
6
|
-
|
|
7
|
-
// Default weights for different scoring components (must add up to 1)
|
|
8
|
-
const DEFAULT_WEIGHTS = {
|
|
9
|
-
semantic: 0.4,
|
|
10
|
-
vector: 0.4,
|
|
11
|
-
position: 0.2,
|
|
12
|
-
} as const;
|
|
13
|
-
|
|
14
|
-
type WeightConfig = {
|
|
15
|
-
semantic?: number;
|
|
16
|
-
vector?: number;
|
|
17
|
-
position?: number;
|
|
18
|
-
};
|
|
19
|
-
|
|
20
|
-
interface ScoringDetails {
|
|
21
|
-
semantic: number;
|
|
22
|
-
vector: number;
|
|
23
|
-
position: number;
|
|
24
|
-
queryAnalysis?: {
|
|
25
|
-
magnitude: number;
|
|
26
|
-
dominantFeatures: number[];
|
|
27
|
-
};
|
|
28
|
-
}
|
|
29
|
-
|
|
30
|
-
export interface RerankResult {
|
|
31
|
-
result: QueryResult;
|
|
32
|
-
score: number;
|
|
33
|
-
details: ScoringDetails;
|
|
34
|
-
}
|
|
35
|
-
|
|
36
|
-
// For use in the vector store tool
|
|
37
|
-
export interface RerankerOptions {
|
|
38
|
-
weights?: WeightConfig;
|
|
39
|
-
topK?: number;
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
// For use in the rerank function
|
|
43
|
-
export interface RerankerFunctionOptions {
|
|
44
|
-
weights?: WeightConfig;
|
|
45
|
-
queryEmbedding?: number[];
|
|
46
|
-
topK?: number;
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
export interface RerankConfig {
|
|
50
|
-
options?: RerankerOptions;
|
|
51
|
-
model: MastraLanguageModel | RelevanceScoreProvider;
|
|
52
|
-
}
|
|
53
|
-
|
|
54
|
-
// Calculate position score based on position in original list
|
|
55
|
-
function calculatePositionScore(position: number, totalChunks: number): number {
|
|
56
|
-
return 1 - position / totalChunks;
|
|
57
|
-
}
|
|
58
|
-
|
|
59
|
-
// Analyze query embedding features if needed
|
|
60
|
-
function analyzeQueryEmbedding(embedding: number[]): {
|
|
61
|
-
magnitude: number;
|
|
62
|
-
dominantFeatures: number[];
|
|
63
|
-
} {
|
|
64
|
-
// Calculate embedding magnitude
|
|
65
|
-
const magnitude = Math.sqrt(embedding.reduce((sum, val) => sum + val * val, 0));
|
|
66
|
-
|
|
67
|
-
// Find dominant features (highest absolute values)
|
|
68
|
-
const dominantFeatures = embedding
|
|
69
|
-
.map((value, index) => ({ value: Math.abs(value), index }))
|
|
70
|
-
.sort((a, b) => b.value - a.value)
|
|
71
|
-
.slice(0, 5)
|
|
72
|
-
.map(item => item.index);
|
|
73
|
-
|
|
74
|
-
return { magnitude, dominantFeatures };
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
// Adjust scores based on query characteristics
|
|
78
|
-
function adjustScores(score: number, queryAnalysis: { magnitude: number; dominantFeatures: number[] }): number {
|
|
79
|
-
const magnitudeAdjustment = queryAnalysis.magnitude > 10 ? 1.1 : 1;
|
|
80
|
-
|
|
81
|
-
const featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
|
|
82
|
-
|
|
83
|
-
return score * magnitudeAdjustment * featureStrengthAdjustment;
|
|
84
|
-
}
|
|
85
|
-
|
|
86
|
-
async function executeRerank({
|
|
87
|
-
results,
|
|
88
|
-
query,
|
|
89
|
-
scorer,
|
|
90
|
-
options,
|
|
91
|
-
}: {
|
|
92
|
-
results: QueryResult[];
|
|
93
|
-
query: string;
|
|
94
|
-
scorer: RelevanceScoreProvider;
|
|
95
|
-
options: RerankerFunctionOptions;
|
|
96
|
-
}) {
|
|
97
|
-
const { queryEmbedding, topK = 3 } = options;
|
|
98
|
-
const weights = {
|
|
99
|
-
...DEFAULT_WEIGHTS,
|
|
100
|
-
...options.weights,
|
|
101
|
-
};
|
|
102
|
-
|
|
103
|
-
//weights must add up to 1
|
|
104
|
-
const sum = Object.values(weights).reduce((acc: Big, w: number) => acc.plus(w.toString()), new Big(0));
|
|
105
|
-
if (!sum.eq(1)) {
|
|
106
|
-
throw new Error(`Weights must add up to 1. Got ${sum} from ${weights}`);
|
|
107
|
-
}
|
|
108
|
-
|
|
109
|
-
const resultLength = results.length;
|
|
110
|
-
|
|
111
|
-
const queryAnalysis = queryEmbedding ? analyzeQueryEmbedding(queryEmbedding) : null;
|
|
112
|
-
|
|
113
|
-
// Get scores for each result
|
|
114
|
-
const scoredResults = await Promise.all(
|
|
115
|
-
results.map(async (result, index) => {
|
|
116
|
-
// Get semantic score from chosen provider
|
|
117
|
-
let semanticScore = 0;
|
|
118
|
-
if (result?.metadata?.text) {
|
|
119
|
-
semanticScore = await scorer.getRelevanceScore(query, result?.metadata?.text);
|
|
120
|
-
}
|
|
121
|
-
|
|
122
|
-
// Get existing vector score from result
|
|
123
|
-
const vectorScore = result.score;
|
|
124
|
-
|
|
125
|
-
// Get score of vector based on position in original list
|
|
126
|
-
const positionScore = calculatePositionScore(index, resultLength);
|
|
127
|
-
|
|
128
|
-
// Combine scores using weights for each component
|
|
129
|
-
let finalScore =
|
|
130
|
-
weights.semantic * semanticScore + weights.vector * vectorScore + weights.position * positionScore;
|
|
131
|
-
|
|
132
|
-
if (queryAnalysis) {
|
|
133
|
-
finalScore = adjustScores(finalScore, queryAnalysis);
|
|
134
|
-
}
|
|
135
|
-
|
|
136
|
-
return {
|
|
137
|
-
result,
|
|
138
|
-
score: finalScore,
|
|
139
|
-
details: {
|
|
140
|
-
semantic: semanticScore,
|
|
141
|
-
vector: vectorScore,
|
|
142
|
-
position: positionScore,
|
|
143
|
-
...(queryAnalysis && {
|
|
144
|
-
queryAnalysis: {
|
|
145
|
-
magnitude: queryAnalysis.magnitude,
|
|
146
|
-
dominantFeatures: queryAnalysis.dominantFeatures,
|
|
147
|
-
},
|
|
148
|
-
}),
|
|
149
|
-
},
|
|
150
|
-
};
|
|
151
|
-
}),
|
|
152
|
-
);
|
|
153
|
-
|
|
154
|
-
// Sort by score and take top K
|
|
155
|
-
return scoredResults.sort((a, b) => b.score - a.score).slice(0, topK);
|
|
156
|
-
}
|
|
157
|
-
|
|
158
|
-
export async function rerankWithScorer({
|
|
159
|
-
results,
|
|
160
|
-
query,
|
|
161
|
-
scorer,
|
|
162
|
-
options,
|
|
163
|
-
}: {
|
|
164
|
-
results: QueryResult[];
|
|
165
|
-
query: string;
|
|
166
|
-
scorer: RelevanceScoreProvider;
|
|
167
|
-
options: RerankerFunctionOptions;
|
|
168
|
-
}): Promise<RerankResult[]> {
|
|
169
|
-
return executeRerank({
|
|
170
|
-
results,
|
|
171
|
-
query,
|
|
172
|
-
scorer,
|
|
173
|
-
options,
|
|
174
|
-
});
|
|
175
|
-
}
|
|
176
|
-
|
|
177
|
-
// Takes in a list of results from a vector store and reranks them based on semantic, vector, and position scores
|
|
178
|
-
export async function rerank(
|
|
179
|
-
results: QueryResult[],
|
|
180
|
-
query: string,
|
|
181
|
-
model: MastraLanguageModel,
|
|
182
|
-
options: RerankerFunctionOptions,
|
|
183
|
-
): Promise<RerankResult[]> {
|
|
184
|
-
let semanticProvider: RelevanceScoreProvider;
|
|
185
|
-
|
|
186
|
-
if (model.modelId === 'rerank-v3.5') {
|
|
187
|
-
semanticProvider = new CohereRelevanceScorer(model.modelId);
|
|
188
|
-
} else {
|
|
189
|
-
semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
|
|
190
|
-
}
|
|
191
|
-
|
|
192
|
-
return executeRerank({
|
|
193
|
-
results,
|
|
194
|
-
query,
|
|
195
|
-
scorer: semanticProvider,
|
|
196
|
-
options,
|
|
197
|
-
});
|
|
198
|
-
}
|