@mastra/rag 1.3.5-alpha.0 → 1.3.6-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +18 -0
- package/dist/graph-rag/index.d.ts +13 -2
- package/dist/graph-rag/index.d.ts.map +1 -1
- package/dist/index.cjs +23 -5
- package/dist/index.cjs.map +1 -1
- package/dist/index.js +23 -5
- package/dist/index.js.map +1 -1
- package/dist/utils/convert-sources.d.ts +3 -1
- package/dist/utils/convert-sources.d.ts.map +1 -1
- package/package.json +4 -4
package/dist/index.js
CHANGED
|
@@ -6808,7 +6808,7 @@ var GraphRAG = class {
|
|
|
6808
6808
|
return neighbors[neighbors.length - 1]?.id;
|
|
6809
6809
|
}
|
|
6810
6810
|
// Perform random walk with restart
|
|
6811
|
-
randomWalkWithRestart(startNodeId, steps, restartProb) {
|
|
6811
|
+
randomWalkWithRestart(startNodeId, steps, restartProb, allowedNodeIds) {
|
|
6812
6812
|
const visits = /* @__PURE__ */ new Map();
|
|
6813
6813
|
let currentNodeId = startNodeId;
|
|
6814
6814
|
for (let step = 0; step < steps; step++) {
|
|
@@ -6817,7 +6817,10 @@ var GraphRAG = class {
|
|
|
6817
6817
|
currentNodeId = startNodeId;
|
|
6818
6818
|
continue;
|
|
6819
6819
|
}
|
|
6820
|
-
|
|
6820
|
+
let neighbors = this.getNeighbors(currentNodeId);
|
|
6821
|
+
if (allowedNodeIds) {
|
|
6822
|
+
neighbors = neighbors.filter((n) => allowedNodeIds.has(n.id));
|
|
6823
|
+
}
|
|
6821
6824
|
if (neighbors.length === 0) {
|
|
6822
6825
|
currentNodeId = startNodeId;
|
|
6823
6826
|
continue;
|
|
@@ -6831,12 +6834,22 @@ var GraphRAG = class {
|
|
|
6831
6834
|
}
|
|
6832
6835
|
return normalizedVisits;
|
|
6833
6836
|
}
|
|
6837
|
+
/**
|
|
6838
|
+
* Query the graph with a dense embedding and optional metadata filter.
|
|
6839
|
+
*
|
|
6840
|
+
* @param query - The embedding vector to query.
|
|
6841
|
+
* @param topK - Number of top results to return.
|
|
6842
|
+
* @param randomWalkSteps - Steps for random walk reranking.
|
|
6843
|
+
* @param restartProb - Restart probability for random walk.
|
|
6844
|
+
* @param filter - Optional strict metadata filter. All key-value pairs must match exactly.
|
|
6845
|
+
*/
|
|
6834
6846
|
// Retrieve relevant nodes using hybrid approach
|
|
6835
6847
|
query({
|
|
6836
6848
|
query,
|
|
6837
6849
|
topK = 10,
|
|
6838
6850
|
randomWalkSteps = 100,
|
|
6839
|
-
restartProb = 0.15
|
|
6851
|
+
restartProb = 0.15,
|
|
6852
|
+
filter
|
|
6840
6853
|
}) {
|
|
6841
6854
|
if (!query || query.length !== this.dimension) {
|
|
6842
6855
|
throw new Error(`Query embedding must have dimension ${this.dimension}`);
|
|
@@ -6850,15 +6863,20 @@ var GraphRAG = class {
|
|
|
6850
6863
|
if (restartProb <= 0 || restartProb >= 1) {
|
|
6851
6864
|
throw new Error("Restart probability must be between 0 and 1");
|
|
6852
6865
|
}
|
|
6853
|
-
const
|
|
6866
|
+
const filterEntries = Object.entries(filter ?? {});
|
|
6867
|
+
const matchesFilter = (node) => filterEntries.length === 0 ? true : filterEntries.every(([key, value]) => node.metadata?.[key] === value);
|
|
6868
|
+
const nodesToSearch = Array.from(this.nodes.values()).filter(matchesFilter);
|
|
6869
|
+
const similarities = nodesToSearch.map((node) => ({
|
|
6854
6870
|
node,
|
|
6855
6871
|
similarity: this.cosineSimilarity(query, node.embedding)
|
|
6856
6872
|
}));
|
|
6857
6873
|
similarities.sort((a, b) => b.similarity - a.similarity);
|
|
6858
6874
|
const topNodes = similarities.slice(0, topK);
|
|
6875
|
+
const useFilter = filterEntries.length > 0;
|
|
6876
|
+
const allowedNodeIds = useFilter ? new Set(nodesToSearch.map((n) => n.id)) : void 0;
|
|
6859
6877
|
const rerankedNodes = /* @__PURE__ */ new Map();
|
|
6860
6878
|
for (const { node, similarity } of topNodes) {
|
|
6861
|
-
const walkScores = this.randomWalkWithRestart(node.id, randomWalkSteps, restartProb);
|
|
6879
|
+
const walkScores = this.randomWalkWithRestart(node.id, randomWalkSteps, restartProb, allowedNodeIds);
|
|
6862
6880
|
for (const [nodeId, walkScore] of walkScores) {
|
|
6863
6881
|
const node2 = this.nodes.get(nodeId);
|
|
6864
6882
|
const existingScore = rerankedNodes.get(nodeId)?.score || 0;
|