@mastra/rag 0.0.2-alpha.41 → 0.0.2-alpha.43
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +16 -0
- package/dist/rag.cjs.development.js +235 -44
- package/dist/rag.cjs.development.js.map +1 -1
- package/dist/rag.cjs.production.min.js +1 -1
- package/dist/rag.cjs.production.min.js.map +1 -1
- package/dist/rag.esm.js +235 -46
- package/dist/rag.esm.js.map +1 -1
- package/dist/utils/index.d.ts +1 -0
- package/dist/utils/index.d.ts.map +1 -1
- package/dist/utils/rag-tools.d.ts +16 -6
- package/dist/utils/rag-tools.d.ts.map +1 -1
- package/dist/utils/re-ranker.d.ts +47 -0
- package/dist/utils/re-ranker.d.ts.map +1 -0
- package/package.json +2 -2
- package/src/astra-db/index.test.ts +1 -1
- package/src/pg/{index_test.ts → index.test.ts} +12 -12
- package/src/utils/index.ts +1 -0
- package/src/utils/rag-tools.ts +49 -3
- package/src/utils/re-ranker.test.ts +222 -0
- package/src/utils/re-ranker.ts +159 -0
- package/dist/pg/index_test.d.ts +0 -2
- package/dist/pg/index_test.d.ts.map +0 -1
package/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,21 @@
|
|
|
1
1
|
# @mastra/rag
|
|
2
2
|
|
|
3
|
+
## 0.0.2-alpha.43
|
|
4
|
+
|
|
5
|
+
### Patch Changes
|
|
6
|
+
|
|
7
|
+
- Updated dependencies [b524c22]
|
|
8
|
+
- @mastra/core@0.1.27-alpha.59
|
|
9
|
+
|
|
10
|
+
## 0.0.2-alpha.42
|
|
11
|
+
|
|
12
|
+
### Patch Changes
|
|
13
|
+
|
|
14
|
+
- 1874f40: Added re ranking tool to RAG
|
|
15
|
+
- Updated dependencies [1874f40]
|
|
16
|
+
- Updated dependencies [4b1ce2c]
|
|
17
|
+
- @mastra/core@0.1.27-alpha.58
|
|
18
|
+
|
|
3
19
|
## 0.0.2-alpha.41
|
|
4
20
|
|
|
5
21
|
### Patch Changes
|
|
@@ -3396,6 +3396,140 @@ var embed = function embed(chunk, options) {
|
|
|
3396
3396
|
return core.embed(value, options);
|
|
3397
3397
|
};
|
|
3398
3398
|
|
|
3399
|
+
// Default weights for different scoring components
|
|
3400
|
+
var DEFAULT_WEIGHTS = {
|
|
3401
|
+
semantic: 0.4,
|
|
3402
|
+
vector: 0.4,
|
|
3403
|
+
position: 0.2
|
|
3404
|
+
};
|
|
3405
|
+
// Takes in a list of results from a vector store and reranks them based on semantic, vector, and position scores
|
|
3406
|
+
var RagReranker = /*#__PURE__*/function () {
|
|
3407
|
+
function RagReranker(options) {
|
|
3408
|
+
this.semanticProvider = void 0;
|
|
3409
|
+
this.weights = void 0;
|
|
3410
|
+
// Set up different weights for scoring components. Uses default weights if not provided
|
|
3411
|
+
this.weights = _extends({}, DEFAULT_WEIGHTS, options.weights);
|
|
3412
|
+
// Initialize semantic provider
|
|
3413
|
+
if (options.semanticProvider === 'cohere') {
|
|
3414
|
+
var _options$cohereModel;
|
|
3415
|
+
if (!options.cohereApiKey) {
|
|
3416
|
+
throw new Error('Cohere API key required when using Cohere provider');
|
|
3417
|
+
}
|
|
3418
|
+
this.semanticProvider = new core.CohereRelevanceScorer(options.cohereApiKey, (_options$cohereModel = options.cohereModel) != null ? _options$cohereModel : '');
|
|
3419
|
+
} else {
|
|
3420
|
+
if (!options.agentProvider) {
|
|
3421
|
+
throw new Error('Agent provider options required when using Agent provider');
|
|
3422
|
+
}
|
|
3423
|
+
this.semanticProvider = new core.MastraAgentRelevanceScorer(options.agentProvider.provider, options.agentProvider.name);
|
|
3424
|
+
}
|
|
3425
|
+
}
|
|
3426
|
+
// Calculate position score based on position in original list
|
|
3427
|
+
var _proto = RagReranker.prototype;
|
|
3428
|
+
_proto.calculatePositionScore = function calculatePositionScore(position, totalChunks) {
|
|
3429
|
+
return 1 - position / totalChunks;
|
|
3430
|
+
}
|
|
3431
|
+
// Analyze query embedding features if needed
|
|
3432
|
+
;
|
|
3433
|
+
_proto.analyzeQueryEmbedding = function analyzeQueryEmbedding(embedding) {
|
|
3434
|
+
// Calculate embedding magnitude
|
|
3435
|
+
var magnitude = Math.sqrt(embedding.reduce(function (sum, val) {
|
|
3436
|
+
return sum + val * val;
|
|
3437
|
+
}, 0));
|
|
3438
|
+
// Find dominant features (highest absolute values)
|
|
3439
|
+
var dominantFeatures = embedding.map(function (value, index) {
|
|
3440
|
+
return {
|
|
3441
|
+
value: Math.abs(value),
|
|
3442
|
+
index: index
|
|
3443
|
+
};
|
|
3444
|
+
}).sort(function (a, b) {
|
|
3445
|
+
return b.value - a.value;
|
|
3446
|
+
}).slice(0, 5).map(function (item) {
|
|
3447
|
+
return item.index;
|
|
3448
|
+
});
|
|
3449
|
+
return {
|
|
3450
|
+
magnitude: magnitude,
|
|
3451
|
+
dominantFeatures: dominantFeatures
|
|
3452
|
+
};
|
|
3453
|
+
}
|
|
3454
|
+
// Adjust scores based on query characteristics
|
|
3455
|
+
;
|
|
3456
|
+
_proto.adjustScores = function adjustScores(score, queryAnalysis) {
|
|
3457
|
+
var magnitudeAdjustment = queryAnalysis.magnitude > 10 ? 1.1 : 1;
|
|
3458
|
+
var featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
|
|
3459
|
+
return score * magnitudeAdjustment * featureStrengthAdjustment;
|
|
3460
|
+
};
|
|
3461
|
+
_proto.rerank = /*#__PURE__*/function () {
|
|
3462
|
+
var _rerank = /*#__PURE__*/_asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee2(_ref) {
|
|
3463
|
+
var _this = this;
|
|
3464
|
+
var query, vectorStoreResults, queryEmbedding, _ref$topK, topK, resultLength, queryAnalysis, scoredResults;
|
|
3465
|
+
return _regeneratorRuntime().wrap(function _callee2$(_context2) {
|
|
3466
|
+
while (1) switch (_context2.prev = _context2.next) {
|
|
3467
|
+
case 0:
|
|
3468
|
+
query = _ref.query, vectorStoreResults = _ref.vectorStoreResults, queryEmbedding = _ref.queryEmbedding, _ref$topK = _ref.topK, topK = _ref$topK === void 0 ? 3 : _ref$topK;
|
|
3469
|
+
resultLength = vectorStoreResults.length;
|
|
3470
|
+
queryAnalysis = queryEmbedding ? this.analyzeQueryEmbedding(queryEmbedding) : null; // Get scores for each result
|
|
3471
|
+
_context2.next = 5;
|
|
3472
|
+
return Promise.all(vectorStoreResults.map(/*#__PURE__*/function () {
|
|
3473
|
+
var _ref2 = _asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee(result, index) {
|
|
3474
|
+
var _result$metadata;
|
|
3475
|
+
var semanticScore, vectorScore, positionScore, finalScore;
|
|
3476
|
+
return _regeneratorRuntime().wrap(function _callee$(_context) {
|
|
3477
|
+
while (1) switch (_context.prev = _context.next) {
|
|
3478
|
+
case 0:
|
|
3479
|
+
_context.next = 2;
|
|
3480
|
+
return _this.semanticProvider.getRelevanceScore(query, result == null || (_result$metadata = result.metadata) == null ? void 0 : _result$metadata.text);
|
|
3481
|
+
case 2:
|
|
3482
|
+
semanticScore = _context.sent;
|
|
3483
|
+
// Get existing vector score from result
|
|
3484
|
+
vectorScore = result.score; // Get score of vector based on position in original list
|
|
3485
|
+
positionScore = _this.calculatePositionScore(index, resultLength); // Combine scores using weights for each component
|
|
3486
|
+
finalScore = _this.weights.semantic * semanticScore + _this.weights.vector * vectorScore + _this.weights.position * positionScore;
|
|
3487
|
+
if (queryAnalysis) {
|
|
3488
|
+
finalScore = _this.adjustScores(finalScore, queryAnalysis);
|
|
3489
|
+
}
|
|
3490
|
+
return _context.abrupt("return", {
|
|
3491
|
+
result: result,
|
|
3492
|
+
score: finalScore,
|
|
3493
|
+
details: _extends({
|
|
3494
|
+
semantic: semanticScore,
|
|
3495
|
+
vector: vectorScore,
|
|
3496
|
+
position: positionScore
|
|
3497
|
+
}, queryAnalysis && {
|
|
3498
|
+
queryAnalysis: {
|
|
3499
|
+
magnitude: queryAnalysis.magnitude,
|
|
3500
|
+
dominantFeatures: queryAnalysis.dominantFeatures
|
|
3501
|
+
}
|
|
3502
|
+
})
|
|
3503
|
+
});
|
|
3504
|
+
case 8:
|
|
3505
|
+
case "end":
|
|
3506
|
+
return _context.stop();
|
|
3507
|
+
}
|
|
3508
|
+
}, _callee);
|
|
3509
|
+
}));
|
|
3510
|
+
return function (_x2, _x3) {
|
|
3511
|
+
return _ref2.apply(this, arguments);
|
|
3512
|
+
};
|
|
3513
|
+
}()));
|
|
3514
|
+
case 5:
|
|
3515
|
+
scoredResults = _context2.sent;
|
|
3516
|
+
return _context2.abrupt("return", scoredResults.sort(function (a, b) {
|
|
3517
|
+
return b.score - a.score;
|
|
3518
|
+
}).slice(0, topK));
|
|
3519
|
+
case 7:
|
|
3520
|
+
case "end":
|
|
3521
|
+
return _context2.stop();
|
|
3522
|
+
}
|
|
3523
|
+
}, _callee2, this);
|
|
3524
|
+
}));
|
|
3525
|
+
function rerank(_x) {
|
|
3526
|
+
return _rerank.apply(this, arguments);
|
|
3527
|
+
}
|
|
3528
|
+
return rerank;
|
|
3529
|
+
}();
|
|
3530
|
+
return RagReranker;
|
|
3531
|
+
}();
|
|
3532
|
+
|
|
3399
3533
|
var createFilter = function createFilter(filter, vectorFilterType) {
|
|
3400
3534
|
if (['pg', 'astra', 'pinecone'].includes(vectorFilterType)) {
|
|
3401
3535
|
var _filter$keyword, _ref;
|
|
@@ -3418,14 +3552,44 @@ var createFilter = function createFilter(filter, vectorFilterType) {
|
|
|
3418
3552
|
};
|
|
3419
3553
|
}
|
|
3420
3554
|
};
|
|
3421
|
-
|
|
3422
|
-
|
|
3423
|
-
|
|
3424
|
-
|
|
3425
|
-
|
|
3426
|
-
|
|
3427
|
-
|
|
3428
|
-
|
|
3555
|
+
// Separate function to handle vector query search
|
|
3556
|
+
// Can be imported and used in custom tools
|
|
3557
|
+
var vectorQuerySearch = /*#__PURE__*/function () {
|
|
3558
|
+
var _ref4 = /*#__PURE__*/_asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee(_ref3) {
|
|
3559
|
+
var indexName, vectorStore, queryText, options, _ref3$queryFilter, queryFilter, topK, _yield$embed, embedding, results;
|
|
3560
|
+
return _regeneratorRuntime().wrap(function _callee$(_context) {
|
|
3561
|
+
while (1) switch (_context.prev = _context.next) {
|
|
3562
|
+
case 0:
|
|
3563
|
+
indexName = _ref3.indexName, vectorStore = _ref3.vectorStore, queryText = _ref3.queryText, options = _ref3.options, _ref3$queryFilter = _ref3.queryFilter, queryFilter = _ref3$queryFilter === void 0 ? {} : _ref3$queryFilter, topK = _ref3.topK;
|
|
3564
|
+
_context.next = 3;
|
|
3565
|
+
return embed(queryText, options);
|
|
3566
|
+
case 3:
|
|
3567
|
+
_yield$embed = _context.sent;
|
|
3568
|
+
embedding = _yield$embed.embedding;
|
|
3569
|
+
_context.next = 7;
|
|
3570
|
+
return vectorStore.query(indexName, embedding, topK, queryFilter);
|
|
3571
|
+
case 7:
|
|
3572
|
+
results = _context.sent;
|
|
3573
|
+
return _context.abrupt("return", results);
|
|
3574
|
+
case 9:
|
|
3575
|
+
case "end":
|
|
3576
|
+
return _context.stop();
|
|
3577
|
+
}
|
|
3578
|
+
}, _callee);
|
|
3579
|
+
}));
|
|
3580
|
+
return function vectorQuerySearch(_x) {
|
|
3581
|
+
return _ref4.apply(this, arguments);
|
|
3582
|
+
};
|
|
3583
|
+
}();
|
|
3584
|
+
var createVectorQueryTool = function createVectorQueryTool(_ref5) {
|
|
3585
|
+
var vectorStoreName = _ref5.vectorStoreName,
|
|
3586
|
+
indexName = _ref5.indexName,
|
|
3587
|
+
_ref5$topK = _ref5.topK,
|
|
3588
|
+
topK = _ref5$topK === void 0 ? 10 : _ref5$topK,
|
|
3589
|
+
options = _ref5.options,
|
|
3590
|
+
_ref5$vectorFilterTyp = _ref5.vectorFilterType,
|
|
3591
|
+
vectorFilterType = _ref5$vectorFilterTyp === void 0 ? '' : _ref5$vectorFilterTyp,
|
|
3592
|
+
rerankOptions = _ref5.rerankOptions;
|
|
3429
3593
|
return core.createTool({
|
|
3430
3594
|
id: "VectorQuery " + vectorStoreName + " " + indexName + " Tool",
|
|
3431
3595
|
inputSchema: zod.z.object({
|
|
@@ -3441,82 +3605,107 @@ var createVectorQueryTool = function createVectorQueryTool(_ref3) {
|
|
|
3441
3605
|
}),
|
|
3442
3606
|
description: "Fetches and combines the top " + topK + " relevant chunks from the " + vectorStoreName + " vector store using the " + indexName + " index",
|
|
3443
3607
|
execute: function () {
|
|
3444
|
-
var _execute = _asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function
|
|
3608
|
+
var _execute = _asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee2(_ref6) {
|
|
3445
3609
|
var _mastra$vectors;
|
|
3446
|
-
var
|
|
3447
|
-
return _regeneratorRuntime().wrap(function
|
|
3448
|
-
while (1) switch (
|
|
3610
|
+
var _ref6$context, queryText, filter, mastra, relevantContext, vectorStore, queryFilter, results, reranker, rerankedResults, _relevantChunks, relevantChunks;
|
|
3611
|
+
return _regeneratorRuntime().wrap(function _callee2$(_context2) {
|
|
3612
|
+
while (1) switch (_context2.prev = _context2.next) {
|
|
3449
3613
|
case 0:
|
|
3450
|
-
|
|
3614
|
+
_ref6$context = _ref6.context, queryText = _ref6$context.queryText, filter = _ref6$context.filter, mastra = _ref6.mastra;
|
|
3451
3615
|
relevantContext = '';
|
|
3452
|
-
vectorStore = mastra == null || (_mastra$vectors = mastra.vectors) == null ? void 0 : _mastra$vectors[vectorStoreName];
|
|
3453
|
-
_context.next = 5;
|
|
3454
|
-
return embed(queryText, options);
|
|
3455
|
-
case 5:
|
|
3456
|
-
_yield$embed = _context.sent;
|
|
3457
|
-
embedding = _yield$embed.embedding;
|
|
3616
|
+
vectorStore = mastra == null || (_mastra$vectors = mastra.vectors) == null ? void 0 : _mastra$vectors[vectorStoreName]; // Get relevant chunks from the vector database
|
|
3458
3617
|
if (!vectorStore) {
|
|
3459
|
-
|
|
3618
|
+
_context2.next = 18;
|
|
3460
3619
|
break;
|
|
3461
3620
|
}
|
|
3462
3621
|
queryFilter = vectorFilterType && filter ? createFilter(filter, vectorFilterType) : {};
|
|
3463
|
-
|
|
3464
|
-
return
|
|
3465
|
-
|
|
3466
|
-
|
|
3467
|
-
|
|
3622
|
+
_context2.next = 7;
|
|
3623
|
+
return vectorQuerySearch({
|
|
3624
|
+
indexName: indexName,
|
|
3625
|
+
vectorStore: vectorStore,
|
|
3626
|
+
queryText: queryText,
|
|
3627
|
+
options: options,
|
|
3628
|
+
queryFilter: queryFilter,
|
|
3629
|
+
topK: topK
|
|
3630
|
+
});
|
|
3631
|
+
case 7:
|
|
3632
|
+
results = _context2.sent;
|
|
3633
|
+
if (!rerankOptions) {
|
|
3634
|
+
_context2.next = 16;
|
|
3635
|
+
break;
|
|
3636
|
+
}
|
|
3637
|
+
reranker = new RagReranker(rerankOptions);
|
|
3638
|
+
_context2.next = 12;
|
|
3639
|
+
return reranker.rerank({
|
|
3640
|
+
query: queryText,
|
|
3641
|
+
vectorStoreResults: results,
|
|
3642
|
+
topK: topK
|
|
3643
|
+
});
|
|
3644
|
+
case 12:
|
|
3645
|
+
rerankedResults = _context2.sent;
|
|
3646
|
+
_relevantChunks = rerankedResults.map(function (_ref7) {
|
|
3468
3647
|
var _result$metadata;
|
|
3648
|
+
var result = _ref7.result;
|
|
3469
3649
|
return result == null || (_result$metadata = result.metadata) == null ? void 0 : _result$metadata.text;
|
|
3650
|
+
});
|
|
3651
|
+
relevantContext = _relevantChunks.join('\n\n');
|
|
3652
|
+
return _context2.abrupt("return", {
|
|
3653
|
+
relevantContext: relevantContext
|
|
3654
|
+
});
|
|
3655
|
+
case 16:
|
|
3656
|
+
relevantChunks = results.map(function (result) {
|
|
3657
|
+
var _result$metadata2;
|
|
3658
|
+
return result == null || (_result$metadata2 = result.metadata) == null ? void 0 : _result$metadata2.text;
|
|
3470
3659
|
}); // Combine the chunks into a context string
|
|
3471
3660
|
relevantContext = relevantChunks.join('\n\n');
|
|
3472
|
-
case
|
|
3473
|
-
return
|
|
3661
|
+
case 18:
|
|
3662
|
+
return _context2.abrupt("return", {
|
|
3474
3663
|
relevantContext: relevantContext
|
|
3475
3664
|
});
|
|
3476
|
-
case
|
|
3665
|
+
case 19:
|
|
3477
3666
|
case "end":
|
|
3478
|
-
return
|
|
3667
|
+
return _context2.stop();
|
|
3479
3668
|
}
|
|
3480
|
-
},
|
|
3669
|
+
}, _callee2);
|
|
3481
3670
|
}));
|
|
3482
|
-
function execute(
|
|
3671
|
+
function execute(_x2) {
|
|
3483
3672
|
return _execute.apply(this, arguments);
|
|
3484
3673
|
}
|
|
3485
3674
|
return execute;
|
|
3486
3675
|
}()
|
|
3487
3676
|
});
|
|
3488
3677
|
};
|
|
3489
|
-
var createDocumentChunker = function createDocumentChunker(
|
|
3490
|
-
var doc =
|
|
3491
|
-
|
|
3492
|
-
params =
|
|
3678
|
+
var createDocumentChunker = function createDocumentChunker(_ref8) {
|
|
3679
|
+
var doc = _ref8.doc,
|
|
3680
|
+
_ref8$params = _ref8.params,
|
|
3681
|
+
params = _ref8$params === void 0 ? {
|
|
3493
3682
|
strategy: 'recursive',
|
|
3494
3683
|
size: 512,
|
|
3495
3684
|
overlap: 50,
|
|
3496
3685
|
separator: '\n'
|
|
3497
|
-
} :
|
|
3686
|
+
} : _ref8$params;
|
|
3498
3687
|
return core.createTool({
|
|
3499
3688
|
id: "Document Chunker " + params.strategy + " " + params.size,
|
|
3500
3689
|
inputSchema: zod.z.object({}),
|
|
3501
3690
|
description: "Chunks document using " + params.strategy + " strategy with size " + params.size + " and " + params.overlap + " overlap",
|
|
3502
3691
|
execute: function () {
|
|
3503
|
-
var _execute2 = _asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function
|
|
3692
|
+
var _execute2 = _asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
|
|
3504
3693
|
var chunks;
|
|
3505
|
-
return _regeneratorRuntime().wrap(function
|
|
3506
|
-
while (1) switch (
|
|
3694
|
+
return _regeneratorRuntime().wrap(function _callee3$(_context3) {
|
|
3695
|
+
while (1) switch (_context3.prev = _context3.next) {
|
|
3507
3696
|
case 0:
|
|
3508
|
-
|
|
3697
|
+
_context3.next = 2;
|
|
3509
3698
|
return doc.chunk(params);
|
|
3510
3699
|
case 2:
|
|
3511
|
-
chunks =
|
|
3512
|
-
return
|
|
3700
|
+
chunks = _context3.sent;
|
|
3701
|
+
return _context3.abrupt("return", {
|
|
3513
3702
|
chunks: chunks
|
|
3514
3703
|
});
|
|
3515
3704
|
case 4:
|
|
3516
3705
|
case "end":
|
|
3517
|
-
return
|
|
3706
|
+
return _context3.stop();
|
|
3518
3707
|
}
|
|
3519
|
-
},
|
|
3708
|
+
}, _callee3);
|
|
3520
3709
|
}));
|
|
3521
3710
|
function execute() {
|
|
3522
3711
|
return _execute2.apply(this, arguments);
|
|
@@ -3531,8 +3720,10 @@ exports.MDocument = MDocument;
|
|
|
3531
3720
|
exports.PgVector = PgVector;
|
|
3532
3721
|
exports.PineconeVector = PineconeVector;
|
|
3533
3722
|
exports.QdrantVector = QdrantVector;
|
|
3723
|
+
exports.RagReranker = RagReranker;
|
|
3534
3724
|
exports.UpstashVector = UpstashVector;
|
|
3535
3725
|
exports.createDocumentChunker = createDocumentChunker;
|
|
3536
3726
|
exports.createVectorQueryTool = createVectorQueryTool;
|
|
3537
3727
|
exports.embed = embed;
|
|
3728
|
+
exports.vectorQuerySearch = vectorQuerySearch;
|
|
3538
3729
|
//# sourceMappingURL=rag.cjs.development.js.map
|