@mastra/rag 0.0.2-alpha.41 → 0.0.2-alpha.43

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/CHANGELOG.md CHANGED
@@ -1,5 +1,21 @@
1
1
  # @mastra/rag
2
2
 
3
+ ## 0.0.2-alpha.43
4
+
5
+ ### Patch Changes
6
+
7
+ - Updated dependencies [b524c22]
8
+ - @mastra/core@0.1.27-alpha.59
9
+
10
+ ## 0.0.2-alpha.42
11
+
12
+ ### Patch Changes
13
+
14
+ - 1874f40: Added re ranking tool to RAG
15
+ - Updated dependencies [1874f40]
16
+ - Updated dependencies [4b1ce2c]
17
+ - @mastra/core@0.1.27-alpha.58
18
+
3
19
  ## 0.0.2-alpha.41
4
20
 
5
21
  ### Patch Changes
@@ -3396,6 +3396,140 @@ var embed = function embed(chunk, options) {
3396
3396
  return core.embed(value, options);
3397
3397
  };
3398
3398
 
3399
+ // Default weights for different scoring components
3400
+ var DEFAULT_WEIGHTS = {
3401
+ semantic: 0.4,
3402
+ vector: 0.4,
3403
+ position: 0.2
3404
+ };
3405
+ // Takes in a list of results from a vector store and reranks them based on semantic, vector, and position scores
3406
+ var RagReranker = /*#__PURE__*/function () {
3407
+ function RagReranker(options) {
3408
+ this.semanticProvider = void 0;
3409
+ this.weights = void 0;
3410
+ // Set up different weights for scoring components. Uses default weights if not provided
3411
+ this.weights = _extends({}, DEFAULT_WEIGHTS, options.weights);
3412
+ // Initialize semantic provider
3413
+ if (options.semanticProvider === 'cohere') {
3414
+ var _options$cohereModel;
3415
+ if (!options.cohereApiKey) {
3416
+ throw new Error('Cohere API key required when using Cohere provider');
3417
+ }
3418
+ this.semanticProvider = new core.CohereRelevanceScorer(options.cohereApiKey, (_options$cohereModel = options.cohereModel) != null ? _options$cohereModel : '');
3419
+ } else {
3420
+ if (!options.agentProvider) {
3421
+ throw new Error('Agent provider options required when using Agent provider');
3422
+ }
3423
+ this.semanticProvider = new core.MastraAgentRelevanceScorer(options.agentProvider.provider, options.agentProvider.name);
3424
+ }
3425
+ }
3426
+ // Calculate position score based on position in original list
3427
+ var _proto = RagReranker.prototype;
3428
+ _proto.calculatePositionScore = function calculatePositionScore(position, totalChunks) {
3429
+ return 1 - position / totalChunks;
3430
+ }
3431
+ // Analyze query embedding features if needed
3432
+ ;
3433
+ _proto.analyzeQueryEmbedding = function analyzeQueryEmbedding(embedding) {
3434
+ // Calculate embedding magnitude
3435
+ var magnitude = Math.sqrt(embedding.reduce(function (sum, val) {
3436
+ return sum + val * val;
3437
+ }, 0));
3438
+ // Find dominant features (highest absolute values)
3439
+ var dominantFeatures = embedding.map(function (value, index) {
3440
+ return {
3441
+ value: Math.abs(value),
3442
+ index: index
3443
+ };
3444
+ }).sort(function (a, b) {
3445
+ return b.value - a.value;
3446
+ }).slice(0, 5).map(function (item) {
3447
+ return item.index;
3448
+ });
3449
+ return {
3450
+ magnitude: magnitude,
3451
+ dominantFeatures: dominantFeatures
3452
+ };
3453
+ }
3454
+ // Adjust scores based on query characteristics
3455
+ ;
3456
+ _proto.adjustScores = function adjustScores(score, queryAnalysis) {
3457
+ var magnitudeAdjustment = queryAnalysis.magnitude > 10 ? 1.1 : 1;
3458
+ var featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
3459
+ return score * magnitudeAdjustment * featureStrengthAdjustment;
3460
+ };
3461
+ _proto.rerank = /*#__PURE__*/function () {
3462
+ var _rerank = /*#__PURE__*/_asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee2(_ref) {
3463
+ var _this = this;
3464
+ var query, vectorStoreResults, queryEmbedding, _ref$topK, topK, resultLength, queryAnalysis, scoredResults;
3465
+ return _regeneratorRuntime().wrap(function _callee2$(_context2) {
3466
+ while (1) switch (_context2.prev = _context2.next) {
3467
+ case 0:
3468
+ query = _ref.query, vectorStoreResults = _ref.vectorStoreResults, queryEmbedding = _ref.queryEmbedding, _ref$topK = _ref.topK, topK = _ref$topK === void 0 ? 3 : _ref$topK;
3469
+ resultLength = vectorStoreResults.length;
3470
+ queryAnalysis = queryEmbedding ? this.analyzeQueryEmbedding(queryEmbedding) : null; // Get scores for each result
3471
+ _context2.next = 5;
3472
+ return Promise.all(vectorStoreResults.map(/*#__PURE__*/function () {
3473
+ var _ref2 = _asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee(result, index) {
3474
+ var _result$metadata;
3475
+ var semanticScore, vectorScore, positionScore, finalScore;
3476
+ return _regeneratorRuntime().wrap(function _callee$(_context) {
3477
+ while (1) switch (_context.prev = _context.next) {
3478
+ case 0:
3479
+ _context.next = 2;
3480
+ return _this.semanticProvider.getRelevanceScore(query, result == null || (_result$metadata = result.metadata) == null ? void 0 : _result$metadata.text);
3481
+ case 2:
3482
+ semanticScore = _context.sent;
3483
+ // Get existing vector score from result
3484
+ vectorScore = result.score; // Get score of vector based on position in original list
3485
+ positionScore = _this.calculatePositionScore(index, resultLength); // Combine scores using weights for each component
3486
+ finalScore = _this.weights.semantic * semanticScore + _this.weights.vector * vectorScore + _this.weights.position * positionScore;
3487
+ if (queryAnalysis) {
3488
+ finalScore = _this.adjustScores(finalScore, queryAnalysis);
3489
+ }
3490
+ return _context.abrupt("return", {
3491
+ result: result,
3492
+ score: finalScore,
3493
+ details: _extends({
3494
+ semantic: semanticScore,
3495
+ vector: vectorScore,
3496
+ position: positionScore
3497
+ }, queryAnalysis && {
3498
+ queryAnalysis: {
3499
+ magnitude: queryAnalysis.magnitude,
3500
+ dominantFeatures: queryAnalysis.dominantFeatures
3501
+ }
3502
+ })
3503
+ });
3504
+ case 8:
3505
+ case "end":
3506
+ return _context.stop();
3507
+ }
3508
+ }, _callee);
3509
+ }));
3510
+ return function (_x2, _x3) {
3511
+ return _ref2.apply(this, arguments);
3512
+ };
3513
+ }()));
3514
+ case 5:
3515
+ scoredResults = _context2.sent;
3516
+ return _context2.abrupt("return", scoredResults.sort(function (a, b) {
3517
+ return b.score - a.score;
3518
+ }).slice(0, topK));
3519
+ case 7:
3520
+ case "end":
3521
+ return _context2.stop();
3522
+ }
3523
+ }, _callee2, this);
3524
+ }));
3525
+ function rerank(_x) {
3526
+ return _rerank.apply(this, arguments);
3527
+ }
3528
+ return rerank;
3529
+ }();
3530
+ return RagReranker;
3531
+ }();
3532
+
3399
3533
  var createFilter = function createFilter(filter, vectorFilterType) {
3400
3534
  if (['pg', 'astra', 'pinecone'].includes(vectorFilterType)) {
3401
3535
  var _filter$keyword, _ref;
@@ -3418,14 +3552,44 @@ var createFilter = function createFilter(filter, vectorFilterType) {
3418
3552
  };
3419
3553
  }
3420
3554
  };
3421
- var createVectorQueryTool = function createVectorQueryTool(_ref3) {
3422
- var vectorStoreName = _ref3.vectorStoreName,
3423
- indexName = _ref3.indexName,
3424
- _ref3$topK = _ref3.topK,
3425
- topK = _ref3$topK === void 0 ? 10 : _ref3$topK,
3426
- options = _ref3.options,
3427
- _ref3$vectorFilterTyp = _ref3.vectorFilterType,
3428
- vectorFilterType = _ref3$vectorFilterTyp === void 0 ? '' : _ref3$vectorFilterTyp;
3555
+ // Separate function to handle vector query search
3556
+ // Can be imported and used in custom tools
3557
+ var vectorQuerySearch = /*#__PURE__*/function () {
3558
+ var _ref4 = /*#__PURE__*/_asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee(_ref3) {
3559
+ var indexName, vectorStore, queryText, options, _ref3$queryFilter, queryFilter, topK, _yield$embed, embedding, results;
3560
+ return _regeneratorRuntime().wrap(function _callee$(_context) {
3561
+ while (1) switch (_context.prev = _context.next) {
3562
+ case 0:
3563
+ indexName = _ref3.indexName, vectorStore = _ref3.vectorStore, queryText = _ref3.queryText, options = _ref3.options, _ref3$queryFilter = _ref3.queryFilter, queryFilter = _ref3$queryFilter === void 0 ? {} : _ref3$queryFilter, topK = _ref3.topK;
3564
+ _context.next = 3;
3565
+ return embed(queryText, options);
3566
+ case 3:
3567
+ _yield$embed = _context.sent;
3568
+ embedding = _yield$embed.embedding;
3569
+ _context.next = 7;
3570
+ return vectorStore.query(indexName, embedding, topK, queryFilter);
3571
+ case 7:
3572
+ results = _context.sent;
3573
+ return _context.abrupt("return", results);
3574
+ case 9:
3575
+ case "end":
3576
+ return _context.stop();
3577
+ }
3578
+ }, _callee);
3579
+ }));
3580
+ return function vectorQuerySearch(_x) {
3581
+ return _ref4.apply(this, arguments);
3582
+ };
3583
+ }();
3584
+ var createVectorQueryTool = function createVectorQueryTool(_ref5) {
3585
+ var vectorStoreName = _ref5.vectorStoreName,
3586
+ indexName = _ref5.indexName,
3587
+ _ref5$topK = _ref5.topK,
3588
+ topK = _ref5$topK === void 0 ? 10 : _ref5$topK,
3589
+ options = _ref5.options,
3590
+ _ref5$vectorFilterTyp = _ref5.vectorFilterType,
3591
+ vectorFilterType = _ref5$vectorFilterTyp === void 0 ? '' : _ref5$vectorFilterTyp,
3592
+ rerankOptions = _ref5.rerankOptions;
3429
3593
  return core.createTool({
3430
3594
  id: "VectorQuery " + vectorStoreName + " " + indexName + " Tool",
3431
3595
  inputSchema: zod.z.object({
@@ -3441,82 +3605,107 @@ var createVectorQueryTool = function createVectorQueryTool(_ref3) {
3441
3605
  }),
3442
3606
  description: "Fetches and combines the top " + topK + " relevant chunks from the " + vectorStoreName + " vector store using the " + indexName + " index",
3443
3607
  execute: function () {
3444
- var _execute = _asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee(_ref4) {
3608
+ var _execute = _asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee2(_ref6) {
3445
3609
  var _mastra$vectors;
3446
- var _ref4$context, queryText, filter, mastra, relevantContext, vectorStore, _yield$embed, embedding, queryFilter, results, relevantChunks;
3447
- return _regeneratorRuntime().wrap(function _callee$(_context) {
3448
- while (1) switch (_context.prev = _context.next) {
3610
+ var _ref6$context, queryText, filter, mastra, relevantContext, vectorStore, queryFilter, results, reranker, rerankedResults, _relevantChunks, relevantChunks;
3611
+ return _regeneratorRuntime().wrap(function _callee2$(_context2) {
3612
+ while (1) switch (_context2.prev = _context2.next) {
3449
3613
  case 0:
3450
- _ref4$context = _ref4.context, queryText = _ref4$context.queryText, filter = _ref4$context.filter, mastra = _ref4.mastra;
3614
+ _ref6$context = _ref6.context, queryText = _ref6$context.queryText, filter = _ref6$context.filter, mastra = _ref6.mastra;
3451
3615
  relevantContext = '';
3452
- vectorStore = mastra == null || (_mastra$vectors = mastra.vectors) == null ? void 0 : _mastra$vectors[vectorStoreName];
3453
- _context.next = 5;
3454
- return embed(queryText, options);
3455
- case 5:
3456
- _yield$embed = _context.sent;
3457
- embedding = _yield$embed.embedding;
3616
+ vectorStore = mastra == null || (_mastra$vectors = mastra.vectors) == null ? void 0 : _mastra$vectors[vectorStoreName]; // Get relevant chunks from the vector database
3458
3617
  if (!vectorStore) {
3459
- _context.next = 14;
3618
+ _context2.next = 18;
3460
3619
  break;
3461
3620
  }
3462
3621
  queryFilter = vectorFilterType && filter ? createFilter(filter, vectorFilterType) : {};
3463
- _context.next = 11;
3464
- return vectorStore.query(indexName, embedding, topK, queryFilter);
3465
- case 11:
3466
- results = _context.sent;
3467
- relevantChunks = results.map(function (result) {
3622
+ _context2.next = 7;
3623
+ return vectorQuerySearch({
3624
+ indexName: indexName,
3625
+ vectorStore: vectorStore,
3626
+ queryText: queryText,
3627
+ options: options,
3628
+ queryFilter: queryFilter,
3629
+ topK: topK
3630
+ });
3631
+ case 7:
3632
+ results = _context2.sent;
3633
+ if (!rerankOptions) {
3634
+ _context2.next = 16;
3635
+ break;
3636
+ }
3637
+ reranker = new RagReranker(rerankOptions);
3638
+ _context2.next = 12;
3639
+ return reranker.rerank({
3640
+ query: queryText,
3641
+ vectorStoreResults: results,
3642
+ topK: topK
3643
+ });
3644
+ case 12:
3645
+ rerankedResults = _context2.sent;
3646
+ _relevantChunks = rerankedResults.map(function (_ref7) {
3468
3647
  var _result$metadata;
3648
+ var result = _ref7.result;
3469
3649
  return result == null || (_result$metadata = result.metadata) == null ? void 0 : _result$metadata.text;
3650
+ });
3651
+ relevantContext = _relevantChunks.join('\n\n');
3652
+ return _context2.abrupt("return", {
3653
+ relevantContext: relevantContext
3654
+ });
3655
+ case 16:
3656
+ relevantChunks = results.map(function (result) {
3657
+ var _result$metadata2;
3658
+ return result == null || (_result$metadata2 = result.metadata) == null ? void 0 : _result$metadata2.text;
3470
3659
  }); // Combine the chunks into a context string
3471
3660
  relevantContext = relevantChunks.join('\n\n');
3472
- case 14:
3473
- return _context.abrupt("return", {
3661
+ case 18:
3662
+ return _context2.abrupt("return", {
3474
3663
  relevantContext: relevantContext
3475
3664
  });
3476
- case 15:
3665
+ case 19:
3477
3666
  case "end":
3478
- return _context.stop();
3667
+ return _context2.stop();
3479
3668
  }
3480
- }, _callee);
3669
+ }, _callee2);
3481
3670
  }));
3482
- function execute(_x) {
3671
+ function execute(_x2) {
3483
3672
  return _execute.apply(this, arguments);
3484
3673
  }
3485
3674
  return execute;
3486
3675
  }()
3487
3676
  });
3488
3677
  };
3489
- var createDocumentChunker = function createDocumentChunker(_ref5) {
3490
- var doc = _ref5.doc,
3491
- _ref5$params = _ref5.params,
3492
- params = _ref5$params === void 0 ? {
3678
+ var createDocumentChunker = function createDocumentChunker(_ref8) {
3679
+ var doc = _ref8.doc,
3680
+ _ref8$params = _ref8.params,
3681
+ params = _ref8$params === void 0 ? {
3493
3682
  strategy: 'recursive',
3494
3683
  size: 512,
3495
3684
  overlap: 50,
3496
3685
  separator: '\n'
3497
- } : _ref5$params;
3686
+ } : _ref8$params;
3498
3687
  return core.createTool({
3499
3688
  id: "Document Chunker " + params.strategy + " " + params.size,
3500
3689
  inputSchema: zod.z.object({}),
3501
3690
  description: "Chunks document using " + params.strategy + " strategy with size " + params.size + " and " + params.overlap + " overlap",
3502
3691
  execute: function () {
3503
- var _execute2 = _asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee2() {
3692
+ var _execute2 = _asyncToGenerator(/*#__PURE__*/_regeneratorRuntime().mark(function _callee3() {
3504
3693
  var chunks;
3505
- return _regeneratorRuntime().wrap(function _callee2$(_context2) {
3506
- while (1) switch (_context2.prev = _context2.next) {
3694
+ return _regeneratorRuntime().wrap(function _callee3$(_context3) {
3695
+ while (1) switch (_context3.prev = _context3.next) {
3507
3696
  case 0:
3508
- _context2.next = 2;
3697
+ _context3.next = 2;
3509
3698
  return doc.chunk(params);
3510
3699
  case 2:
3511
- chunks = _context2.sent;
3512
- return _context2.abrupt("return", {
3700
+ chunks = _context3.sent;
3701
+ return _context3.abrupt("return", {
3513
3702
  chunks: chunks
3514
3703
  });
3515
3704
  case 4:
3516
3705
  case "end":
3517
- return _context2.stop();
3706
+ return _context3.stop();
3518
3707
  }
3519
- }, _callee2);
3708
+ }, _callee3);
3520
3709
  }));
3521
3710
  function execute() {
3522
3711
  return _execute2.apply(this, arguments);
@@ -3531,8 +3720,10 @@ exports.MDocument = MDocument;
3531
3720
  exports.PgVector = PgVector;
3532
3721
  exports.PineconeVector = PineconeVector;
3533
3722
  exports.QdrantVector = QdrantVector;
3723
+ exports.RagReranker = RagReranker;
3534
3724
  exports.UpstashVector = UpstashVector;
3535
3725
  exports.createDocumentChunker = createDocumentChunker;
3536
3726
  exports.createVectorQueryTool = createVectorQueryTool;
3537
3727
  exports.embed = embed;
3728
+ exports.vectorQuerySearch = vectorQuerySearch;
3538
3729
  //# sourceMappingURL=rag.cjs.development.js.map