@mastra/rag 1.0.1 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,23 +1,23 @@
1
1
 
2
- > @mastra/rag@1.0.1-alpha.0 build /home/runner/work/mastra/mastra/packages/rag
2
+ > @mastra/rag@1.0.2-alpha.0 build /home/runner/work/mastra/mastra/packages/rag
3
3
  > tsup src/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting
4
4
 
5
5
  CLI Building entry: src/index.ts
6
6
  CLI Using tsconfig: tsconfig.json
7
7
  CLI tsup v8.5.0
8
8
  TSC Build start
9
- TSC ⚡️ Build success in 14155ms
9
+ TSC ⚡️ Build success in 17466ms
10
10
  DTS Build start
11
11
  CLI Target: es2022
12
12
  Analysis will use the bundled TypeScript version 5.8.3
13
13
  Writing package typings: /home/runner/work/mastra/mastra/packages/rag/dist/_tsup-dts-rollup.d.ts
14
14
  Analysis will use the bundled TypeScript version 5.8.3
15
15
  Writing package typings: /home/runner/work/mastra/mastra/packages/rag/dist/_tsup-dts-rollup.d.cts
16
- DTS ⚡️ Build success in 13672ms
16
+ DTS ⚡️ Build success in 13205ms
17
17
  CLI Cleaning output folder
18
18
  ESM Build start
19
19
  CJS Build start
20
- ESM dist/index.js 242.08 KB
21
- ESM ⚡️ Build success in 4294ms
22
- CJS dist/index.cjs 243.79 KB
23
- CJS ⚡️ Build success in 4295ms
20
+ ESM dist/index.js 245.51 KB
21
+ ESM ⚡️ Build success in 4603ms
22
+ CJS dist/index.cjs 247.58 KB
23
+ CJS ⚡️ Build success in 4609ms
package/CHANGELOG.md CHANGED
@@ -1,5 +1,38 @@
1
1
  # @mastra/rag
2
2
 
3
+ ## 1.0.2
4
+
5
+ ### Patch Changes
6
+
7
+ - 43da563: Refactor relevance provider
8
+ - Updated dependencies [2873c7f]
9
+ - Updated dependencies [1c1c6a1]
10
+ - Updated dependencies [f8ce2cc]
11
+ - Updated dependencies [8c846b6]
12
+ - Updated dependencies [c7bbf1e]
13
+ - Updated dependencies [8722d53]
14
+ - Updated dependencies [565cc0c]
15
+ - Updated dependencies [b790fd1]
16
+ - Updated dependencies [132027f]
17
+ - Updated dependencies [0c85311]
18
+ - Updated dependencies [d7ed04d]
19
+ - Updated dependencies [cb16baf]
20
+ - Updated dependencies [f36e4f1]
21
+ - Updated dependencies [7f6e403]
22
+ - @mastra/core@0.10.11
23
+
24
+ ## 1.0.2-alpha.0
25
+
26
+ ### Patch Changes
27
+
28
+ - 43da563: Refactor relevance provider
29
+ - Updated dependencies [c7bbf1e]
30
+ - Updated dependencies [8722d53]
31
+ - Updated dependencies [132027f]
32
+ - Updated dependencies [0c85311]
33
+ - Updated dependencies [cb16baf]
34
+ - @mastra/core@0.10.11-alpha.3
35
+
3
36
  ## 1.0.1
4
37
 
5
38
  ### Patch Changes
package/LICENSE.md CHANGED
@@ -1,46 +1,15 @@
1
- # Elastic License 2.0 (ELv2)
1
+ # Apache License 2.0
2
2
 
3
- Copyright (c) 2025 Mastra AI, Inc.
3
+ Copyright (c) 2025 Kepler Software, Inc.
4
4
 
5
- **Acceptance**
6
- By using the software, you agree to all of the terms and conditions below.
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
7
8
 
8
- **Copyright License**
9
- The licensor grants you a non-exclusive, royalty-free, worldwide, non-sublicensable, non-transferable license to use, copy, distribute, make available, and prepare derivative works of the software, in each case subject to the limitations and conditions below
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
10
 
11
- **Limitations**
12
- You may not provide the software to third parties as a hosted or managed service, where the service provides users with access to any substantial set of the features or functionality of the software.
13
-
14
- You may not move, change, disable, or circumvent the license key functionality in the software, and you may not remove or obscure any functionality in the software that is protected by the license key.
15
-
16
- You may not alter, remove, or obscure any licensing, copyright, or other notices of the licensor in the software. Any use of the licensor’s trademarks is subject to applicable law.
17
-
18
- **Patents**
19
- The licensor grants you a license, under any patent claims the licensor can license, or becomes able to license, to make, have made, use, sell, offer for sale, import and have imported the software, in each case subject to the limitations and conditions in this license. This license does not cover any patent claims that you cause to be infringed by modifications or additions to the software. If you or your company make any written claim that the software infringes or contributes to infringement of any patent, your patent license for the software granted under these terms ends immediately. If your company makes such a claim, your patent license ends immediately for work on behalf of your company.
20
-
21
- **Notices**
22
- You must ensure that anyone who gets a copy of any part of the software from you also gets a copy of these terms.
23
-
24
- If you modify the software, you must include in any modified copies of the software prominent notices stating that you have modified the software.
25
-
26
- **No Other Rights**
27
- These terms do not imply any licenses other than those expressly granted in these terms.
28
-
29
- **Termination**
30
- If you use the software in violation of these terms, such use is not licensed, and your licenses will automatically terminate. If the licensor provides you with a notice of your violation, and you cease all violation of this license no later than 30 days after you receive that notice, your licenses will be reinstated retroactively. However, if you violate these terms after such reinstatement, any additional violation of these terms will cause your licenses to terminate automatically and permanently.
31
-
32
- **No Liability**
33
- As far as the law allows, the software comes as is, without any warranty or condition, and the licensor will not be liable to you for any damages arising out of these terms or the use or nature of the software, under any kind of legal claim.
34
-
35
- **Definitions**
36
- The _licensor_ is the entity offering these terms, and the _software_ is the software the licensor makes available under these terms, including any portion of it.
37
-
38
- _you_ refers to the individual or entity agreeing to these terms.
39
-
40
- _your company_ is any legal entity, sole proprietorship, or other kind of organization that you work for, plus all organizations that have control over, are under the control of, or are under common control with that organization. _control_ means ownership of substantially all the assets of an entity, or the power to direct its management and policies by vote, contract, or otherwise. Control can be direct or indirect.
41
-
42
- _your licenses_ are all the licenses granted to you for the software under these terms.
43
-
44
- _use_ means anything you do with the software requiring one of your licenses.
45
-
46
- _trademark_ means trademarks, service marks, and similar rights.
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
@@ -4,6 +4,7 @@ import type { MastraLanguageModel } from '@mastra/core/agent';
4
4
  import type { MastraVector } from '@mastra/core/vector';
5
5
  import type { QueryResult } from '@mastra/core/vector';
6
6
  import type { QueryResult as QueryResult_2 } from '@mastra/core';
7
+ import type { RelevanceScoreProvider } from '@mastra/core/relevance';
7
8
  import type { TiktokenEncoding } from 'js-tiktoken';
8
9
  import type { TiktokenModel } from 'js-tiktoken';
9
10
  import type { Tool } from '@mastra/core/tools';
@@ -162,6 +163,16 @@ declare type ChunkStrategy = 'recursive' | 'character' | 'token' | 'markdown' |
162
163
  export { ChunkStrategy }
163
164
  export { ChunkStrategy as ChunkStrategy_alias_1 }
164
165
 
166
+ declare class CohereRelevanceScorer implements RelevanceScoreProvider {
167
+ private client;
168
+ private model;
169
+ constructor(model: string, apiKey?: string);
170
+ getRelevanceScore(query: string, text: string): Promise<number>;
171
+ }
172
+ export { CohereRelevanceScorer }
173
+ export { CohereRelevanceScorer as CohereRelevanceScorer_alias_1 }
174
+ export { CohereRelevanceScorer as CohereRelevanceScorer_alias_2 }
175
+
165
176
  /**
166
177
  * Convert an array of source inputs (QueryResult, RankedNode, or RerankResult) to an array of sources.
167
178
  * @param results Array of source inputs to convert.
@@ -593,6 +604,15 @@ export declare class MarkdownTransformer extends RecursiveCharacterTransformer {
593
604
  });
594
605
  }
595
606
 
607
+ declare class MastraAgentRelevanceScorer implements RelevanceScoreProvider {
608
+ private agent;
609
+ constructor(name: string, model: MastraLanguageModel);
610
+ getRelevanceScore(query: string, text: string): Promise<number>;
611
+ }
612
+ export { MastraAgentRelevanceScorer }
613
+ export { MastraAgentRelevanceScorer as MastraAgentRelevanceScorer_alias_1 }
614
+ export { MastraAgentRelevanceScorer as MastraAgentRelevanceScorer_alias_2 }
615
+
596
616
  declare class MDocument {
597
617
  private chunks;
598
618
  private type;
@@ -950,7 +970,7 @@ export { rerank as rerank_alias_1 }
950
970
 
951
971
  declare interface RerankConfig {
952
972
  options?: RerankerOptions;
953
- model: MastraLanguageModel;
973
+ model: MastraLanguageModel | RelevanceScoreProvider;
954
974
  }
955
975
  export { RerankConfig }
956
976
  export { RerankConfig as RerankConfig_alias_1 }
@@ -978,6 +998,15 @@ declare interface RerankResult {
978
998
  export { RerankResult }
979
999
  export { RerankResult as RerankResult_alias_1 }
980
1000
 
1001
+ declare function rerankWithScorer({ results, query, scorer, options, }: {
1002
+ results: QueryResult[];
1003
+ query: string;
1004
+ scorer: RelevanceScoreProvider;
1005
+ options: RerankerFunctionOptions;
1006
+ }): Promise<RerankResult[]>;
1007
+ export { rerankWithScorer }
1008
+ export { rerankWithScorer as rerankWithScorer_alias_1 }
1009
+
981
1010
  declare interface ScoringDetails {
982
1011
  semantic: number;
983
1012
  vector: number;
@@ -1256,4 +1285,14 @@ declare type WhereDocumentOperator = '$contains' | '$not_contains' | LogicalOper
1256
1285
 
1257
1286
  declare type WhereOperator = '$gt' | '$gte' | '$lt' | '$lte' | '$ne' | '$eq';
1258
1287
 
1288
+ declare class ZeroEntropyRelevanceScorer implements RelevanceScoreProvider {
1289
+ private client;
1290
+ private model;
1291
+ constructor(model?: string, apiKey?: string);
1292
+ getRelevanceScore(query: string, text: string): Promise<number>;
1293
+ }
1294
+ export { ZeroEntropyRelevanceScorer }
1295
+ export { ZeroEntropyRelevanceScorer as ZeroEntropyRelevanceScorer_alias_1 }
1296
+ export { ZeroEntropyRelevanceScorer as ZeroEntropyRelevanceScorer_alias_2 }
1297
+
1259
1298
  export { }
@@ -4,6 +4,7 @@ import type { MastraLanguageModel } from '@mastra/core/agent';
4
4
  import type { MastraVector } from '@mastra/core/vector';
5
5
  import type { QueryResult } from '@mastra/core/vector';
6
6
  import type { QueryResult as QueryResult_2 } from '@mastra/core';
7
+ import type { RelevanceScoreProvider } from '@mastra/core/relevance';
7
8
  import type { TiktokenEncoding } from 'js-tiktoken';
8
9
  import type { TiktokenModel } from 'js-tiktoken';
9
10
  import type { Tool } from '@mastra/core/tools';
@@ -162,6 +163,16 @@ declare type ChunkStrategy = 'recursive' | 'character' | 'token' | 'markdown' |
162
163
  export { ChunkStrategy }
163
164
  export { ChunkStrategy as ChunkStrategy_alias_1 }
164
165
 
166
+ declare class CohereRelevanceScorer implements RelevanceScoreProvider {
167
+ private client;
168
+ private model;
169
+ constructor(model: string, apiKey?: string);
170
+ getRelevanceScore(query: string, text: string): Promise<number>;
171
+ }
172
+ export { CohereRelevanceScorer }
173
+ export { CohereRelevanceScorer as CohereRelevanceScorer_alias_1 }
174
+ export { CohereRelevanceScorer as CohereRelevanceScorer_alias_2 }
175
+
165
176
  /**
166
177
  * Convert an array of source inputs (QueryResult, RankedNode, or RerankResult) to an array of sources.
167
178
  * @param results Array of source inputs to convert.
@@ -593,6 +604,15 @@ export declare class MarkdownTransformer extends RecursiveCharacterTransformer {
593
604
  });
594
605
  }
595
606
 
607
+ declare class MastraAgentRelevanceScorer implements RelevanceScoreProvider {
608
+ private agent;
609
+ constructor(name: string, model: MastraLanguageModel);
610
+ getRelevanceScore(query: string, text: string): Promise<number>;
611
+ }
612
+ export { MastraAgentRelevanceScorer }
613
+ export { MastraAgentRelevanceScorer as MastraAgentRelevanceScorer_alias_1 }
614
+ export { MastraAgentRelevanceScorer as MastraAgentRelevanceScorer_alias_2 }
615
+
596
616
  declare class MDocument {
597
617
  private chunks;
598
618
  private type;
@@ -950,7 +970,7 @@ export { rerank as rerank_alias_1 }
950
970
 
951
971
  declare interface RerankConfig {
952
972
  options?: RerankerOptions;
953
- model: MastraLanguageModel;
973
+ model: MastraLanguageModel | RelevanceScoreProvider;
954
974
  }
955
975
  export { RerankConfig }
956
976
  export { RerankConfig as RerankConfig_alias_1 }
@@ -978,6 +998,15 @@ declare interface RerankResult {
978
998
  export { RerankResult }
979
999
  export { RerankResult as RerankResult_alias_1 }
980
1000
 
1001
+ declare function rerankWithScorer({ results, query, scorer, options, }: {
1002
+ results: QueryResult[];
1003
+ query: string;
1004
+ scorer: RelevanceScoreProvider;
1005
+ options: RerankerFunctionOptions;
1006
+ }): Promise<RerankResult[]>;
1007
+ export { rerankWithScorer }
1008
+ export { rerankWithScorer as rerankWithScorer_alias_1 }
1009
+
981
1010
  declare interface ScoringDetails {
982
1011
  semantic: number;
983
1012
  vector: number;
@@ -1256,4 +1285,14 @@ declare type WhereDocumentOperator = '$contains' | '$not_contains' | LogicalOper
1256
1285
 
1257
1286
  declare type WhereOperator = '$gt' | '$gte' | '$lt' | '$lte' | '$ne' | '$eq';
1258
1287
 
1288
+ declare class ZeroEntropyRelevanceScorer implements RelevanceScoreProvider {
1289
+ private client;
1290
+ private model;
1291
+ constructor(model?: string, apiKey?: string);
1292
+ getRelevanceScore(query: string, text: string): Promise<number>;
1293
+ }
1294
+ export { ZeroEntropyRelevanceScorer }
1295
+ export { ZeroEntropyRelevanceScorer as ZeroEntropyRelevanceScorer_alias_1 }
1296
+ export { ZeroEntropyRelevanceScorer as ZeroEntropyRelevanceScorer_alias_2 }
1297
+
1259
1298
  export { }
package/dist/index.cjs CHANGED
@@ -4,11 +4,18 @@ var crypto = require('crypto');
4
4
  var zod = require('zod');
5
5
  var nodeHtmlBetterParser = require('node-html-better-parser');
6
6
  var jsTiktoken = require('js-tiktoken');
7
- var relevance = require('@mastra/core/relevance');
8
7
  var big_js = require('big.js');
8
+ var cohereAi = require('cohere-ai');
9
+ var agent = require('@mastra/core/agent');
10
+ var relevance = require('@mastra/core/relevance');
11
+ var ZeroEntropy = require('zeroentropy');
9
12
  var tools = require('@mastra/core/tools');
10
13
  var ai = require('ai');
11
14
 
15
+ function _interopDefault (e) { return e && e.__esModule ? e : { default: e }; }
16
+
17
+ var ZeroEntropy__default = /*#__PURE__*/_interopDefault(ZeroEntropy);
18
+
12
19
  var __create = Object.create;
13
20
  var __defProp = Object.defineProperty;
14
21
  var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
@@ -3394,15 +3401,16 @@ var OpenAIResponsesLanguageModel = class {
3394
3401
  async doGenerate(options) {
3395
3402
  var _a15, _b, _c, _d, _e, _f, _g;
3396
3403
  const { args: body, warnings } = this.getArgs(options);
3404
+ const url = this.config.url({
3405
+ path: "/responses",
3406
+ modelId: this.modelId
3407
+ });
3397
3408
  const {
3398
3409
  responseHeaders,
3399
3410
  value: response,
3400
3411
  rawValue: rawResponse
3401
3412
  } = await postJsonToApi({
3402
- url: this.config.url({
3403
- path: "/responses",
3404
- modelId: this.modelId
3405
- }),
3413
+ url,
3406
3414
  headers: combineHeaders(this.config.headers(), options.headers),
3407
3415
  body,
3408
3416
  failedResponseHandler: openaiFailedResponseHandler,
@@ -3410,6 +3418,10 @@ var OpenAIResponsesLanguageModel = class {
3410
3418
  zod.z.object({
3411
3419
  id: zod.z.string(),
3412
3420
  created_at: zod.z.number(),
3421
+ error: zod.z.object({
3422
+ message: zod.z.string(),
3423
+ code: zod.z.string()
3424
+ }).nullish(),
3413
3425
  model: zod.z.string(),
3414
3426
  output: zod.z.array(
3415
3427
  zod.z.discriminatedUnion("type", [
@@ -3462,6 +3474,17 @@ var OpenAIResponsesLanguageModel = class {
3462
3474
  abortSignal: options.abortSignal,
3463
3475
  fetch: this.config.fetch
3464
3476
  });
3477
+ if (response.error) {
3478
+ throw new APICallError({
3479
+ message: response.error.message,
3480
+ url,
3481
+ requestBodyValues: body,
3482
+ statusCode: 400,
3483
+ responseHeaders,
3484
+ responseBody: rawResponse,
3485
+ isRetryable: false
3486
+ });
3487
+ }
3465
3488
  const outputTextElements = response.output.filter((output) => output.type === "message").flatMap((output) => output.content).filter((content) => content.type === "output_text");
3466
3489
  const toolCalls = response.output.filter((output) => output.type === "function_call").map((output) => ({
3467
3490
  toolCallType: "function",
@@ -3633,6 +3656,8 @@ var OpenAIResponsesLanguageModel = class {
3633
3656
  title: value.annotation.title
3634
3657
  }
3635
3658
  });
3659
+ } else if (isErrorChunk(value)) {
3660
+ controller.enqueue({ type: "error", error: value });
3636
3661
  }
3637
3662
  },
3638
3663
  flush(controller) {
@@ -3742,6 +3767,13 @@ var responseReasoningSummaryTextDeltaSchema = zod.z.object({
3742
3767
  summary_index: zod.z.number(),
3743
3768
  delta: zod.z.string()
3744
3769
  });
3770
+ var errorChunkSchema = zod.z.object({
3771
+ type: zod.z.literal("error"),
3772
+ code: zod.z.string(),
3773
+ message: zod.z.string(),
3774
+ param: zod.z.string().nullish(),
3775
+ sequence_number: zod.z.number()
3776
+ });
3745
3777
  var openaiResponsesChunkSchema = zod.z.union([
3746
3778
  textDeltaChunkSchema,
3747
3779
  responseFinishedChunkSchema,
@@ -3751,6 +3783,7 @@ var openaiResponsesChunkSchema = zod.z.union([
3751
3783
  responseOutputItemAddedSchema,
3752
3784
  responseAnnotationAddedSchema,
3753
3785
  responseReasoningSummaryTextDeltaSchema,
3786
+ errorChunkSchema,
3754
3787
  zod.z.object({ type: zod.z.string() }).passthrough()
3755
3788
  // fallback for unknown chunks
3756
3789
  ]);
@@ -3778,6 +3811,9 @@ function isResponseAnnotationAddedChunk(chunk) {
3778
3811
  function isResponseReasoningSummaryTextDeltaChunk(chunk) {
3779
3812
  return chunk.type === "response.reasoning_summary_text.delta";
3780
3813
  }
3814
+ function isErrorChunk(chunk) {
3815
+ return chunk.type === "error";
3816
+ }
3781
3817
  function getResponsesModelConfig(modelId) {
3782
3818
  if (modelId.startsWith("o")) {
3783
3819
  if (modelId.startsWith("o1-mini") || modelId.startsWith("o1-preview")) {
@@ -5940,6 +5976,70 @@ var MDocument = class _MDocument {
5940
5976
  return this.chunks.map((doc) => doc.metadata);
5941
5977
  }
5942
5978
  };
5979
+ var CohereRelevanceScorer = class {
5980
+ client;
5981
+ model;
5982
+ constructor(model, apiKey) {
5983
+ this.client = new cohereAi.CohereClient({
5984
+ token: apiKey || process.env.COHERE_API_KEY || ""
5985
+ });
5986
+ this.model = model;
5987
+ }
5988
+ async getRelevanceScore(query, text) {
5989
+ const response = await this.client.rerank({
5990
+ query,
5991
+ documents: [text],
5992
+ model: this.model,
5993
+ topN: 1
5994
+ });
5995
+ return response.results[0].relevanceScore;
5996
+ }
5997
+ };
5998
+ var MastraAgentRelevanceScorer = class {
5999
+ agent;
6000
+ constructor(name14, model) {
6001
+ this.agent = new agent.Agent({
6002
+ name: `Relevance Scorer ${name14}`,
6003
+ instructions: `You are a specialized agent for evaluating the relevance of text to queries.
6004
+ Your task is to rate how well a text passage answers a given query.
6005
+ Output only a number between 0 and 1, where:
6006
+ 1.0 = Perfectly relevant, directly answers the query
6007
+ 0.0 = Completely irrelevant
6008
+ Consider:
6009
+ - Direct relevance to the question
6010
+ - Completeness of information
6011
+ - Quality and specificity
6012
+ Always return just the number, no explanation.`,
6013
+ model
6014
+ });
6015
+ }
6016
+ async getRelevanceScore(query, text) {
6017
+ const prompt = relevance.createSimilarityPrompt(query, text);
6018
+ const response = await this.agent.generate(prompt);
6019
+ return parseFloat(response.text);
6020
+ }
6021
+ };
6022
+ var ZeroEntropyRelevanceScorer = class {
6023
+ client;
6024
+ model;
6025
+ constructor(model, apiKey) {
6026
+ this.client = new ZeroEntropy__default.default({
6027
+ apiKey: apiKey || process.env.ZEROENTROPY_API_KEY || ""
6028
+ });
6029
+ this.model = model || "zerank-1";
6030
+ }
6031
+ async getRelevanceScore(query, text) {
6032
+ const response = await this.client.models.rerank({
6033
+ query,
6034
+ documents: [text],
6035
+ model: this.model,
6036
+ top_n: 1
6037
+ });
6038
+ return response.results[0]?.relevance_score ?? 0;
6039
+ }
6040
+ };
6041
+
6042
+ // src/rerank/index.ts
5943
6043
  var DEFAULT_WEIGHTS = {
5944
6044
  semantic: 0.4,
5945
6045
  vector: 0.4,
@@ -5958,13 +6058,12 @@ function adjustScores(score, queryAnalysis) {
5958
6058
  const featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
5959
6059
  return score * magnitudeAdjustment * featureStrengthAdjustment;
5960
6060
  }
5961
- async function rerank(results, query, model, options) {
5962
- let semanticProvider;
5963
- if (model.modelId === "rerank-v3.5") {
5964
- semanticProvider = new relevance.CohereRelevanceScorer(model.modelId);
5965
- } else {
5966
- semanticProvider = new relevance.MastraAgentRelevanceScorer(model.provider, model);
5967
- }
6061
+ async function executeRerank({
6062
+ results,
6063
+ query,
6064
+ scorer,
6065
+ options
6066
+ }) {
5968
6067
  const { queryEmbedding, topK = 3 } = options;
5969
6068
  const weights = {
5970
6069
  ...DEFAULT_WEIGHTS,
@@ -5980,7 +6079,7 @@ async function rerank(results, query, model, options) {
5980
6079
  results.map(async (result, index) => {
5981
6080
  let semanticScore = 0;
5982
6081
  if (result?.metadata?.text) {
5983
- semanticScore = await semanticProvider.getRelevanceScore(query, result?.metadata?.text);
6082
+ semanticScore = await scorer.getRelevanceScore(query, result?.metadata?.text);
5984
6083
  }
5985
6084
  const vectorScore = result.score;
5986
6085
  const positionScore = calculatePositionScore(index, resultLength);
@@ -6007,6 +6106,33 @@ async function rerank(results, query, model, options) {
6007
6106
  );
6008
6107
  return scoredResults.sort((a, b) => b.score - a.score).slice(0, topK);
6009
6108
  }
6109
+ async function rerankWithScorer({
6110
+ results,
6111
+ query,
6112
+ scorer,
6113
+ options
6114
+ }) {
6115
+ return executeRerank({
6116
+ results,
6117
+ query,
6118
+ scorer,
6119
+ options
6120
+ });
6121
+ }
6122
+ async function rerank(results, query, model, options) {
6123
+ let semanticProvider;
6124
+ if (model.modelId === "rerank-v3.5") {
6125
+ semanticProvider = new CohereRelevanceScorer(model.modelId);
6126
+ } else {
6127
+ semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
6128
+ }
6129
+ return executeRerank({
6130
+ results,
6131
+ query,
6132
+ scorer: semanticProvider,
6133
+ options
6134
+ });
6135
+ }
6010
6136
 
6011
6137
  // src/graph-rag/index.ts
6012
6138
  var GraphRAG = class {
@@ -6605,10 +6731,23 @@ var createVectorQueryTool = (options) => {
6605
6731
  if (logger) {
6606
6732
  logger.debug("Reranking results", { rerankerModel: reranker.model, rerankerOptions: reranker.options });
6607
6733
  }
6608
- const rerankedResults = await rerank(results, queryText, reranker.model, {
6609
- ...reranker.options,
6610
- topK: reranker.options?.topK || topKValue
6611
- });
6734
+ let rerankedResults = [];
6735
+ if (typeof reranker?.model === "object" && "getRelevanceScore" in reranker?.model) {
6736
+ rerankedResults = await rerankWithScorer({
6737
+ results,
6738
+ query: queryText,
6739
+ scorer: reranker.model,
6740
+ options: {
6741
+ ...reranker.options,
6742
+ topK: reranker.options?.topK || topKValue
6743
+ }
6744
+ });
6745
+ } else {
6746
+ rerankedResults = await rerank(results, queryText, reranker.model, {
6747
+ ...reranker.options,
6748
+ topK: reranker.options?.topK || topKValue
6749
+ });
6750
+ }
6612
6751
  if (logger) {
6613
6752
  logger.debug("Reranking complete", { rerankedCount: rerankedResults.length });
6614
6753
  }
@@ -7448,15 +7587,18 @@ Example Complex Query:
7448
7587
 
7449
7588
  exports.ASTRA_PROMPT = ASTRA_PROMPT;
7450
7589
  exports.CHROMA_PROMPT = CHROMA_PROMPT;
7590
+ exports.CohereRelevanceScorer = CohereRelevanceScorer;
7451
7591
  exports.GraphRAG = GraphRAG;
7452
7592
  exports.LIBSQL_PROMPT = LIBSQL_PROMPT;
7453
7593
  exports.MDocument = MDocument;
7454
7594
  exports.MONGODB_PROMPT = MONGODB_PROMPT;
7595
+ exports.MastraAgentRelevanceScorer = MastraAgentRelevanceScorer;
7455
7596
  exports.PGVECTOR_PROMPT = PGVECTOR_PROMPT;
7456
7597
  exports.PINECONE_PROMPT = PINECONE_PROMPT;
7457
7598
  exports.QDRANT_PROMPT = QDRANT_PROMPT;
7458
7599
  exports.UPSTASH_PROMPT = UPSTASH_PROMPT;
7459
7600
  exports.VECTORIZE_PROMPT = VECTORIZE_PROMPT;
7601
+ exports.ZeroEntropyRelevanceScorer = ZeroEntropyRelevanceScorer;
7460
7602
  exports.createDocumentChunkerTool = createDocumentChunkerTool;
7461
7603
  exports.createGraphRAGTool = createGraphRAGTool;
7462
7604
  exports.createVectorQueryTool = createVectorQueryTool;
@@ -7465,4 +7607,5 @@ exports.defaultVectorQueryDescription = defaultVectorQueryDescription;
7465
7607
  exports.filterDescription = filterDescription;
7466
7608
  exports.queryTextDescription = queryTextDescription;
7467
7609
  exports.rerank = rerank;
7610
+ exports.rerankWithScorer = rerankWithScorer;
7468
7611
  exports.topKDescription = topKDescription;
package/dist/index.d.cts CHANGED
@@ -1,10 +1,14 @@
1
1
  export { GraphRAG } from './_tsup-dts-rollup.cjs';
2
2
  export { MDocument } from './_tsup-dts-rollup.cjs';
3
+ export { rerankWithScorer } from './_tsup-dts-rollup.cjs';
3
4
  export { rerank } from './_tsup-dts-rollup.cjs';
4
5
  export { RerankResult } from './_tsup-dts-rollup.cjs';
5
6
  export { RerankerOptions } from './_tsup-dts-rollup.cjs';
6
7
  export { RerankerFunctionOptions } from './_tsup-dts-rollup.cjs';
7
8
  export { RerankConfig } from './_tsup-dts-rollup.cjs';
9
+ export { CohereRelevanceScorer } from './_tsup-dts-rollup.cjs';
10
+ export { MastraAgentRelevanceScorer } from './_tsup-dts-rollup.cjs';
11
+ export { ZeroEntropyRelevanceScorer } from './_tsup-dts-rollup.cjs';
8
12
  export { createDocumentChunkerTool } from './_tsup-dts-rollup.cjs';
9
13
  export { createGraphRAGTool } from './_tsup-dts-rollup.cjs';
10
14
  export { createVectorQueryTool } from './_tsup-dts-rollup.cjs';
package/dist/index.d.ts CHANGED
@@ -1,10 +1,14 @@
1
1
  export { GraphRAG } from './_tsup-dts-rollup.js';
2
2
  export { MDocument } from './_tsup-dts-rollup.js';
3
+ export { rerankWithScorer } from './_tsup-dts-rollup.js';
3
4
  export { rerank } from './_tsup-dts-rollup.js';
4
5
  export { RerankResult } from './_tsup-dts-rollup.js';
5
6
  export { RerankerOptions } from './_tsup-dts-rollup.js';
6
7
  export { RerankerFunctionOptions } from './_tsup-dts-rollup.js';
7
8
  export { RerankConfig } from './_tsup-dts-rollup.js';
9
+ export { CohereRelevanceScorer } from './_tsup-dts-rollup.js';
10
+ export { MastraAgentRelevanceScorer } from './_tsup-dts-rollup.js';
11
+ export { ZeroEntropyRelevanceScorer } from './_tsup-dts-rollup.js';
8
12
  export { createDocumentChunkerTool } from './_tsup-dts-rollup.js';
9
13
  export { createGraphRAGTool } from './_tsup-dts-rollup.js';
10
14
  export { createVectorQueryTool } from './_tsup-dts-rollup.js';
package/dist/index.js CHANGED
@@ -2,8 +2,11 @@ import { randomUUID, createHash } from 'crypto';
2
2
  import { z } from 'zod';
3
3
  import { parse } from 'node-html-better-parser';
4
4
  import { encodingForModel, getEncoding } from 'js-tiktoken';
5
- import { CohereRelevanceScorer, MastraAgentRelevanceScorer } from '@mastra/core/relevance';
6
5
  import { Big } from 'big.js';
6
+ import { CohereClient } from 'cohere-ai';
7
+ import { Agent } from '@mastra/core/agent';
8
+ import { createSimilarityPrompt } from '@mastra/core/relevance';
9
+ import ZeroEntropy from 'zeroentropy';
7
10
  import { createTool } from '@mastra/core/tools';
8
11
  import { embed } from 'ai';
9
12
 
@@ -3392,15 +3395,16 @@ var OpenAIResponsesLanguageModel = class {
3392
3395
  async doGenerate(options) {
3393
3396
  var _a15, _b, _c, _d, _e, _f, _g;
3394
3397
  const { args: body, warnings } = this.getArgs(options);
3398
+ const url = this.config.url({
3399
+ path: "/responses",
3400
+ modelId: this.modelId
3401
+ });
3395
3402
  const {
3396
3403
  responseHeaders,
3397
3404
  value: response,
3398
3405
  rawValue: rawResponse
3399
3406
  } = await postJsonToApi({
3400
- url: this.config.url({
3401
- path: "/responses",
3402
- modelId: this.modelId
3403
- }),
3407
+ url,
3404
3408
  headers: combineHeaders(this.config.headers(), options.headers),
3405
3409
  body,
3406
3410
  failedResponseHandler: openaiFailedResponseHandler,
@@ -3408,6 +3412,10 @@ var OpenAIResponsesLanguageModel = class {
3408
3412
  z.object({
3409
3413
  id: z.string(),
3410
3414
  created_at: z.number(),
3415
+ error: z.object({
3416
+ message: z.string(),
3417
+ code: z.string()
3418
+ }).nullish(),
3411
3419
  model: z.string(),
3412
3420
  output: z.array(
3413
3421
  z.discriminatedUnion("type", [
@@ -3460,6 +3468,17 @@ var OpenAIResponsesLanguageModel = class {
3460
3468
  abortSignal: options.abortSignal,
3461
3469
  fetch: this.config.fetch
3462
3470
  });
3471
+ if (response.error) {
3472
+ throw new APICallError({
3473
+ message: response.error.message,
3474
+ url,
3475
+ requestBodyValues: body,
3476
+ statusCode: 400,
3477
+ responseHeaders,
3478
+ responseBody: rawResponse,
3479
+ isRetryable: false
3480
+ });
3481
+ }
3463
3482
  const outputTextElements = response.output.filter((output) => output.type === "message").flatMap((output) => output.content).filter((content) => content.type === "output_text");
3464
3483
  const toolCalls = response.output.filter((output) => output.type === "function_call").map((output) => ({
3465
3484
  toolCallType: "function",
@@ -3631,6 +3650,8 @@ var OpenAIResponsesLanguageModel = class {
3631
3650
  title: value.annotation.title
3632
3651
  }
3633
3652
  });
3653
+ } else if (isErrorChunk(value)) {
3654
+ controller.enqueue({ type: "error", error: value });
3634
3655
  }
3635
3656
  },
3636
3657
  flush(controller) {
@@ -3740,6 +3761,13 @@ var responseReasoningSummaryTextDeltaSchema = z.object({
3740
3761
  summary_index: z.number(),
3741
3762
  delta: z.string()
3742
3763
  });
3764
+ var errorChunkSchema = z.object({
3765
+ type: z.literal("error"),
3766
+ code: z.string(),
3767
+ message: z.string(),
3768
+ param: z.string().nullish(),
3769
+ sequence_number: z.number()
3770
+ });
3743
3771
  var openaiResponsesChunkSchema = z.union([
3744
3772
  textDeltaChunkSchema,
3745
3773
  responseFinishedChunkSchema,
@@ -3749,6 +3777,7 @@ var openaiResponsesChunkSchema = z.union([
3749
3777
  responseOutputItemAddedSchema,
3750
3778
  responseAnnotationAddedSchema,
3751
3779
  responseReasoningSummaryTextDeltaSchema,
3780
+ errorChunkSchema,
3752
3781
  z.object({ type: z.string() }).passthrough()
3753
3782
  // fallback for unknown chunks
3754
3783
  ]);
@@ -3776,6 +3805,9 @@ function isResponseAnnotationAddedChunk(chunk) {
3776
3805
  function isResponseReasoningSummaryTextDeltaChunk(chunk) {
3777
3806
  return chunk.type === "response.reasoning_summary_text.delta";
3778
3807
  }
3808
+ function isErrorChunk(chunk) {
3809
+ return chunk.type === "error";
3810
+ }
3779
3811
  function getResponsesModelConfig(modelId) {
3780
3812
  if (modelId.startsWith("o")) {
3781
3813
  if (modelId.startsWith("o1-mini") || modelId.startsWith("o1-preview")) {
@@ -5938,6 +5970,70 @@ var MDocument = class _MDocument {
5938
5970
  return this.chunks.map((doc) => doc.metadata);
5939
5971
  }
5940
5972
  };
5973
+ var CohereRelevanceScorer = class {
5974
+ client;
5975
+ model;
5976
+ constructor(model, apiKey) {
5977
+ this.client = new CohereClient({
5978
+ token: apiKey || process.env.COHERE_API_KEY || ""
5979
+ });
5980
+ this.model = model;
5981
+ }
5982
+ async getRelevanceScore(query, text) {
5983
+ const response = await this.client.rerank({
5984
+ query,
5985
+ documents: [text],
5986
+ model: this.model,
5987
+ topN: 1
5988
+ });
5989
+ return response.results[0].relevanceScore;
5990
+ }
5991
+ };
5992
+ var MastraAgentRelevanceScorer = class {
5993
+ agent;
5994
+ constructor(name14, model) {
5995
+ this.agent = new Agent({
5996
+ name: `Relevance Scorer ${name14}`,
5997
+ instructions: `You are a specialized agent for evaluating the relevance of text to queries.
5998
+ Your task is to rate how well a text passage answers a given query.
5999
+ Output only a number between 0 and 1, where:
6000
+ 1.0 = Perfectly relevant, directly answers the query
6001
+ 0.0 = Completely irrelevant
6002
+ Consider:
6003
+ - Direct relevance to the question
6004
+ - Completeness of information
6005
+ - Quality and specificity
6006
+ Always return just the number, no explanation.`,
6007
+ model
6008
+ });
6009
+ }
6010
+ async getRelevanceScore(query, text) {
6011
+ const prompt = createSimilarityPrompt(query, text);
6012
+ const response = await this.agent.generate(prompt);
6013
+ return parseFloat(response.text);
6014
+ }
6015
+ };
6016
+ var ZeroEntropyRelevanceScorer = class {
6017
+ client;
6018
+ model;
6019
+ constructor(model, apiKey) {
6020
+ this.client = new ZeroEntropy({
6021
+ apiKey: apiKey || process.env.ZEROENTROPY_API_KEY || ""
6022
+ });
6023
+ this.model = model || "zerank-1";
6024
+ }
6025
+ async getRelevanceScore(query, text) {
6026
+ const response = await this.client.models.rerank({
6027
+ query,
6028
+ documents: [text],
6029
+ model: this.model,
6030
+ top_n: 1
6031
+ });
6032
+ return response.results[0]?.relevance_score ?? 0;
6033
+ }
6034
+ };
6035
+
6036
+ // src/rerank/index.ts
5941
6037
  var DEFAULT_WEIGHTS = {
5942
6038
  semantic: 0.4,
5943
6039
  vector: 0.4,
@@ -5956,13 +6052,12 @@ function adjustScores(score, queryAnalysis) {
5956
6052
  const featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
5957
6053
  return score * magnitudeAdjustment * featureStrengthAdjustment;
5958
6054
  }
5959
- async function rerank(results, query, model, options) {
5960
- let semanticProvider;
5961
- if (model.modelId === "rerank-v3.5") {
5962
- semanticProvider = new CohereRelevanceScorer(model.modelId);
5963
- } else {
5964
- semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
5965
- }
6055
+ async function executeRerank({
6056
+ results,
6057
+ query,
6058
+ scorer,
6059
+ options
6060
+ }) {
5966
6061
  const { queryEmbedding, topK = 3 } = options;
5967
6062
  const weights = {
5968
6063
  ...DEFAULT_WEIGHTS,
@@ -5978,7 +6073,7 @@ async function rerank(results, query, model, options) {
5978
6073
  results.map(async (result, index) => {
5979
6074
  let semanticScore = 0;
5980
6075
  if (result?.metadata?.text) {
5981
- semanticScore = await semanticProvider.getRelevanceScore(query, result?.metadata?.text);
6076
+ semanticScore = await scorer.getRelevanceScore(query, result?.metadata?.text);
5982
6077
  }
5983
6078
  const vectorScore = result.score;
5984
6079
  const positionScore = calculatePositionScore(index, resultLength);
@@ -6005,6 +6100,33 @@ async function rerank(results, query, model, options) {
6005
6100
  );
6006
6101
  return scoredResults.sort((a, b) => b.score - a.score).slice(0, topK);
6007
6102
  }
6103
+ async function rerankWithScorer({
6104
+ results,
6105
+ query,
6106
+ scorer,
6107
+ options
6108
+ }) {
6109
+ return executeRerank({
6110
+ results,
6111
+ query,
6112
+ scorer,
6113
+ options
6114
+ });
6115
+ }
6116
+ async function rerank(results, query, model, options) {
6117
+ let semanticProvider;
6118
+ if (model.modelId === "rerank-v3.5") {
6119
+ semanticProvider = new CohereRelevanceScorer(model.modelId);
6120
+ } else {
6121
+ semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
6122
+ }
6123
+ return executeRerank({
6124
+ results,
6125
+ query,
6126
+ scorer: semanticProvider,
6127
+ options
6128
+ });
6129
+ }
6008
6130
 
6009
6131
  // src/graph-rag/index.ts
6010
6132
  var GraphRAG = class {
@@ -6603,10 +6725,23 @@ var createVectorQueryTool = (options) => {
6603
6725
  if (logger) {
6604
6726
  logger.debug("Reranking results", { rerankerModel: reranker.model, rerankerOptions: reranker.options });
6605
6727
  }
6606
- const rerankedResults = await rerank(results, queryText, reranker.model, {
6607
- ...reranker.options,
6608
- topK: reranker.options?.topK || topKValue
6609
- });
6728
+ let rerankedResults = [];
6729
+ if (typeof reranker?.model === "object" && "getRelevanceScore" in reranker?.model) {
6730
+ rerankedResults = await rerankWithScorer({
6731
+ results,
6732
+ query: queryText,
6733
+ scorer: reranker.model,
6734
+ options: {
6735
+ ...reranker.options,
6736
+ topK: reranker.options?.topK || topKValue
6737
+ }
6738
+ });
6739
+ } else {
6740
+ rerankedResults = await rerank(results, queryText, reranker.model, {
6741
+ ...reranker.options,
6742
+ topK: reranker.options?.topK || topKValue
6743
+ });
6744
+ }
6610
6745
  if (logger) {
6611
6746
  logger.debug("Reranking complete", { rerankedCount: rerankedResults.length });
6612
6747
  }
@@ -7444,4 +7579,4 @@ Example Complex Query:
7444
7579
  }
7445
7580
  `;
7446
7581
 
7447
- export { ASTRA_PROMPT, CHROMA_PROMPT, GraphRAG, LIBSQL_PROMPT, MDocument, MONGODB_PROMPT, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, UPSTASH_PROMPT, VECTORIZE_PROMPT, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultGraphRagDescription, defaultVectorQueryDescription, filterDescription, queryTextDescription, rerank, topKDescription };
7582
+ export { ASTRA_PROMPT, CHROMA_PROMPT, CohereRelevanceScorer, GraphRAG, LIBSQL_PROMPT, MDocument, MONGODB_PROMPT, MastraAgentRelevanceScorer, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, UPSTASH_PROMPT, VECTORIZE_PROMPT, ZeroEntropyRelevanceScorer, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultGraphRagDescription, defaultVectorQueryDescription, filterDescription, queryTextDescription, rerank, rerankWithScorer, topKDescription };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@mastra/rag",
3
- "version": "1.0.1",
3
+ "version": "1.0.2",
4
4
  "description": "",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
@@ -19,13 +19,15 @@
19
19
  "./package.json": "./package.json"
20
20
  },
21
21
  "author": "",
22
- "license": "Elastic-2.0",
22
+ "license": "Apache-2.0",
23
23
  "dependencies": {
24
24
  "@paralleldrive/cuid2": "^2.2.2",
25
25
  "big.js": "^7.0.1",
26
+ "cohere-ai": "^7.17.1",
26
27
  "js-tiktoken": "^1.0.20",
27
28
  "node-html-better-parser": "^1.4.11",
28
29
  "pathe": "^2.0.3",
30
+ "zeroentropy": "0.1.0-alpha.6",
29
31
  "zod": "^3.25.67"
30
32
  },
31
33
  "peerDependencies": {
@@ -39,13 +41,13 @@
39
41
  "@types/big.js": "^6.2.2",
40
42
  "@types/node": "^20.19.0",
41
43
  "ai": "^4.3.16",
42
- "dotenv": "^16.5.0",
44
+ "dotenv": "^17.0.0",
43
45
  "eslint": "^9.29.0",
44
46
  "tsup": "^8.5.0",
45
47
  "typescript": "^5.8.3",
46
- "vitest": "^3.2.3",
47
- "@mastra/core": "0.10.7",
48
- "@internal/lint": "0.0.14"
48
+ "vitest": "^3.2.4",
49
+ "@internal/lint": "0.0.18",
50
+ "@mastra/core": "0.10.11"
49
51
  },
50
52
  "keywords": [
51
53
  "rag",
@@ -66,6 +68,7 @@
66
68
  "scripts": {
67
69
  "build": "tsup src/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting",
68
70
  "buld:watch": "pnpm build --watch",
71
+ "vitest": "vitest",
69
72
  "test": "vitest run",
70
73
  "lint": "eslint ."
71
74
  }
package/src/index.ts CHANGED
@@ -1,5 +1,6 @@
1
1
  export * from './document/document';
2
2
  export * from './rerank';
3
+ export * from './rerank/relevance';
3
4
  export { GraphRAG } from './graph-rag';
4
5
  export * from './tools';
5
6
  export * from './utils/vector-prompts';
@@ -1,6 +1,6 @@
1
1
  import { cohere } from '@ai-sdk/cohere';
2
- import { CohereRelevanceScorer } from '@mastra/core/relevance';
3
2
  import { describe, it, expect, vi, beforeEach } from 'vitest';
3
+ import { CohereRelevanceScorer } from './relevance';
4
4
 
5
5
  import { rerank } from '.';
6
6
 
@@ -1,8 +1,8 @@
1
1
  import type { MastraLanguageModel } from '@mastra/core/agent';
2
- import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from '@mastra/core/relevance';
3
2
  import type { RelevanceScoreProvider } from '@mastra/core/relevance';
4
3
  import type { QueryResult } from '@mastra/core/vector';
5
4
  import { Big } from 'big.js';
5
+ import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from './relevance';
6
6
 
7
7
  // Default weights for different scoring components (must add up to 1)
8
8
  const DEFAULT_WEIGHTS = {
@@ -48,7 +48,7 @@ export interface RerankerFunctionOptions {
48
48
 
49
49
  export interface RerankConfig {
50
50
  options?: RerankerOptions;
51
- model: MastraLanguageModel;
51
+ model: MastraLanguageModel | RelevanceScoreProvider;
52
52
  }
53
53
 
54
54
  // Calculate position score based on position in original list
@@ -83,19 +83,17 @@ function adjustScores(score: number, queryAnalysis: { magnitude: number; dominan
83
83
  return score * magnitudeAdjustment * featureStrengthAdjustment;
84
84
  }
85
85
 
86
- // Takes in a list of results from a vector store and reranks them based on semantic, vector, and position scores
87
- export async function rerank(
88
- results: QueryResult[],
89
- query: string,
90
- model: MastraLanguageModel,
91
- options: RerankerFunctionOptions,
92
- ): Promise<RerankResult[]> {
93
- let semanticProvider: RelevanceScoreProvider;
94
- if (model.modelId === 'rerank-v3.5') {
95
- semanticProvider = new CohereRelevanceScorer(model.modelId);
96
- } else {
97
- semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
98
- }
86
+ async function executeRerank({
87
+ results,
88
+ query,
89
+ scorer,
90
+ options,
91
+ }: {
92
+ results: QueryResult[];
93
+ query: string;
94
+ scorer: RelevanceScoreProvider;
95
+ options: RerankerFunctionOptions;
96
+ }) {
99
97
  const { queryEmbedding, topK = 3 } = options;
100
98
  const weights = {
101
99
  ...DEFAULT_WEIGHTS,
@@ -118,7 +116,7 @@ export async function rerank(
118
116
  // Get semantic score from chosen provider
119
117
  let semanticScore = 0;
120
118
  if (result?.metadata?.text) {
121
- semanticScore = await semanticProvider.getRelevanceScore(query, result?.metadata?.text);
119
+ semanticScore = await scorer.getRelevanceScore(query, result?.metadata?.text);
122
120
  }
123
121
 
124
122
  // Get existing vector score from result
@@ -156,3 +154,45 @@ export async function rerank(
156
154
  // Sort by score and take top K
157
155
  return scoredResults.sort((a, b) => b.score - a.score).slice(0, topK);
158
156
  }
157
+
158
+ export async function rerankWithScorer({
159
+ results,
160
+ query,
161
+ scorer,
162
+ options,
163
+ }: {
164
+ results: QueryResult[];
165
+ query: string;
166
+ scorer: RelevanceScoreProvider;
167
+ options: RerankerFunctionOptions;
168
+ }): Promise<RerankResult[]> {
169
+ return executeRerank({
170
+ results,
171
+ query,
172
+ scorer,
173
+ options,
174
+ });
175
+ }
176
+
177
+ // Takes in a list of results from a vector store and reranks them based on semantic, vector, and position scores
178
+ export async function rerank(
179
+ results: QueryResult[],
180
+ query: string,
181
+ model: MastraLanguageModel,
182
+ options: RerankerFunctionOptions,
183
+ ): Promise<RerankResult[]> {
184
+ let semanticProvider: RelevanceScoreProvider;
185
+
186
+ if (model.modelId === 'rerank-v3.5') {
187
+ semanticProvider = new CohereRelevanceScorer(model.modelId);
188
+ } else {
189
+ semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
190
+ }
191
+
192
+ return executeRerank({
193
+ results,
194
+ query,
195
+ scorer: semanticProvider,
196
+ options,
197
+ });
198
+ }
@@ -0,0 +1,26 @@
1
+ import type { RelevanceScoreProvider } from '@mastra/core/relevance';
2
+ import { CohereClient } from 'cohere-ai';
3
+
4
+ // Cohere implementation
5
+ export class CohereRelevanceScorer implements RelevanceScoreProvider {
6
+ private client: any;
7
+ private model: string;
8
+ constructor(model: string, apiKey?: string) {
9
+ this.client = new CohereClient({
10
+ token: apiKey || process.env.COHERE_API_KEY || '',
11
+ });
12
+
13
+ this.model = model;
14
+ }
15
+
16
+ async getRelevanceScore(query: string, text: string): Promise<number> {
17
+ const response = await this.client.rerank({
18
+ query,
19
+ documents: [text],
20
+ model: this.model,
21
+ topN: 1,
22
+ });
23
+
24
+ return response.results[0].relevanceScore;
25
+ }
26
+ }
@@ -0,0 +1,3 @@
1
+ export * from './cohere';
2
+ export * from './mastra-agent';
3
+ export * from './zeroentropy';
@@ -0,0 +1,32 @@
1
+ import { Agent } from '@mastra/core/agent';
2
+ import type { MastraLanguageModel } from '@mastra/core/agent';
3
+ import { createSimilarityPrompt } from '@mastra/core/relevance';
4
+ import type { RelevanceScoreProvider } from '@mastra/core/relevance';
5
+
6
+ // Mastra Agent implementation
7
+ export class MastraAgentRelevanceScorer implements RelevanceScoreProvider {
8
+ private agent: Agent;
9
+
10
+ constructor(name: string, model: MastraLanguageModel) {
11
+ this.agent = new Agent({
12
+ name: `Relevance Scorer ${name}`,
13
+ instructions: `You are a specialized agent for evaluating the relevance of text to queries.
14
+ Your task is to rate how well a text passage answers a given query.
15
+ Output only a number between 0 and 1, where:
16
+ 1.0 = Perfectly relevant, directly answers the query
17
+ 0.0 = Completely irrelevant
18
+ Consider:
19
+ - Direct relevance to the question
20
+ - Completeness of information
21
+ - Quality and specificity
22
+ Always return just the number, no explanation.`,
23
+ model,
24
+ });
25
+ }
26
+
27
+ async getRelevanceScore(query: string, text: string): Promise<number> {
28
+ const prompt = createSimilarityPrompt(query, text);
29
+ const response = await this.agent.generate(prompt);
30
+ return parseFloat(response.text);
31
+ }
32
+ }
@@ -0,0 +1,26 @@
1
+ import type { RelevanceScoreProvider } from '@mastra/core/relevance';
2
+ import ZeroEntropy from 'zeroentropy';
3
+
4
+ // ZeroEntropy implementation
5
+ export class ZeroEntropyRelevanceScorer implements RelevanceScoreProvider {
6
+ private client: ZeroEntropy;
7
+ private model: string;
8
+
9
+ constructor(model?: string, apiKey?: string) {
10
+ this.client = new ZeroEntropy({
11
+ apiKey: apiKey || process.env.ZEROENTROPY_API_KEY || '',
12
+ });
13
+ this.model = model || 'zerank-1';
14
+ }
15
+
16
+ async getRelevanceScore(query: string, text: string): Promise<number> {
17
+ const response = await this.client.models.rerank({
18
+ query,
19
+ documents: [text],
20
+ model: this.model,
21
+ top_n: 1,
22
+ });
23
+
24
+ return response.results[0]?.relevance_score ?? 0;
25
+ }
26
+ }
@@ -44,11 +44,13 @@ describe('createVectorQueryTool', () => {
44
44
  debug: vi.fn(),
45
45
  warn: vi.fn(),
46
46
  info: vi.fn(),
47
+ error: vi.fn(),
47
48
  },
48
49
  getLogger: vi.fn(() => ({
49
50
  debug: vi.fn(),
50
51
  warn: vi.fn(),
51
52
  info: vi.fn(),
53
+ error: vi.fn(),
52
54
  })),
53
55
  };
54
56
 
@@ -2,8 +2,8 @@ import { createTool } from '@mastra/core/tools';
2
2
  import type { EmbeddingModel } from 'ai';
3
3
  import { z } from 'zod';
4
4
 
5
- import { rerank } from '../rerank';
6
- import type { RerankConfig } from '../rerank';
5
+ import { rerank, rerankWithScorer } from '../rerank';
6
+ import type { RerankConfig, RerankResult } from '../rerank';
7
7
  import { vectorQuerySearch, defaultVectorQueryDescription, filterSchema, outputSchema, baseSchema } from '../utils';
8
8
  import type { RagTool } from '../utils';
9
9
  import { convertToSources } from '../utils/convert-sources';
@@ -94,26 +94,48 @@ export const createVectorQueryTool = (options: VectorQueryToolOptions) => {
94
94
  if (logger) {
95
95
  logger.debug('vectorQuerySearch returned results', { count: results.length });
96
96
  }
97
+
97
98
  if (reranker) {
98
99
  if (logger) {
99
100
  logger.debug('Reranking results', { rerankerModel: reranker.model, rerankerOptions: reranker.options });
100
101
  }
101
- const rerankedResults = await rerank(results, queryText, reranker.model, {
102
- ...reranker.options,
103
- topK: reranker.options?.topK || topKValue,
104
- });
102
+
103
+ let rerankedResults: RerankResult[] = [];
104
+
105
+ if (typeof reranker?.model === 'object' && 'getRelevanceScore' in reranker?.model) {
106
+ rerankedResults = await rerankWithScorer({
107
+ results,
108
+ query: queryText,
109
+ scorer: reranker.model,
110
+ options: {
111
+ ...reranker.options,
112
+ topK: reranker.options?.topK || topKValue,
113
+ },
114
+ });
115
+ } else {
116
+ rerankedResults = await rerank(results, queryText, reranker.model, {
117
+ ...reranker.options,
118
+ topK: reranker.options?.topK || topKValue,
119
+ });
120
+ }
121
+
105
122
  if (logger) {
106
123
  logger.debug('Reranking complete', { rerankedCount: rerankedResults.length });
107
124
  }
125
+
108
126
  const relevantChunks = rerankedResults.map(({ result }) => result?.metadata);
127
+
109
128
  if (logger) {
110
129
  logger.debug('Returning reranked relevant context chunks', { count: relevantChunks.length });
111
130
  }
131
+
112
132
  const sources = includeSources ? convertToSources(rerankedResults) : [];
133
+
113
134
  return { relevantContext: relevantChunks, sources };
114
135
  }
115
136
 
116
137
  const relevantChunks = results.map(result => result?.metadata);
138
+
117
139
  if (logger) {
118
140
  logger.debug('Returning relevant context chunks', { count: relevantChunks.length });
119
141
  }