@mastra/rag 1.0.1 → 1.0.2-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +12 -0
- package/dist/_tsup-dts-rollup.d.cts +40 -1
- package/dist/_tsup-dts-rollup.d.ts +40 -1
- package/dist/index.cjs +160 -17
- package/dist/index.d.cts +4 -0
- package/dist/index.d.ts +4 -0
- package/dist/index.js +153 -18
- package/package.json +8 -5
- package/src/index.ts +1 -0
- package/src/rerank/index.test.ts +1 -1
- package/src/rerank/index.ts +56 -16
- package/src/rerank/relevance/cohere/index.ts +26 -0
- package/src/rerank/relevance/index.ts +3 -0
- package/src/rerank/relevance/mastra-agent/index.ts +32 -0
- package/src/rerank/relevance/zeroentropy/index.ts +26 -0
- package/src/tools/vector-query.test.ts +2 -0
- package/src/tools/vector-query.ts +28 -6
package/.turbo/turbo-build.log
CHANGED
|
@@ -1,23 +1,23 @@
|
|
|
1
1
|
|
|
2
|
-
> @mastra/rag@1.0.
|
|
2
|
+
> @mastra/rag@1.0.2-alpha.0 build /home/runner/work/mastra/mastra/packages/rag
|
|
3
3
|
> tsup src/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting
|
|
4
4
|
|
|
5
5
|
[34mCLI[39m Building entry: src/index.ts
|
|
6
6
|
[34mCLI[39m Using tsconfig: tsconfig.json
|
|
7
7
|
[34mCLI[39m tsup v8.5.0
|
|
8
8
|
[34mTSC[39m Build start
|
|
9
|
-
[32mTSC[39m ⚡️ Build success in
|
|
9
|
+
[32mTSC[39m ⚡️ Build success in 14378ms
|
|
10
10
|
[34mDTS[39m Build start
|
|
11
11
|
[34mCLI[39m Target: es2022
|
|
12
12
|
Analysis will use the bundled TypeScript version 5.8.3
|
|
13
13
|
[36mWriting package typings: /home/runner/work/mastra/mastra/packages/rag/dist/_tsup-dts-rollup.d.ts[39m
|
|
14
14
|
Analysis will use the bundled TypeScript version 5.8.3
|
|
15
15
|
[36mWriting package typings: /home/runner/work/mastra/mastra/packages/rag/dist/_tsup-dts-rollup.d.cts[39m
|
|
16
|
-
[32mDTS[39m ⚡️ Build success in
|
|
16
|
+
[32mDTS[39m ⚡️ Build success in 12936ms
|
|
17
17
|
[34mCLI[39m Cleaning output folder
|
|
18
18
|
[34mESM[39m Build start
|
|
19
19
|
[34mCJS[39m Build start
|
|
20
|
-
[
|
|
21
|
-
[
|
|
22
|
-
[
|
|
23
|
-
[
|
|
20
|
+
[32mCJS[39m [1mdist/index.cjs [22m[32m247.58 KB[39m
|
|
21
|
+
[32mCJS[39m ⚡️ Build success in 4578ms
|
|
22
|
+
[32mESM[39m [1mdist/index.js [22m[32m245.51 KB[39m
|
|
23
|
+
[32mESM[39m ⚡️ Build success in 4580ms
|
package/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,17 @@
|
|
|
1
1
|
# @mastra/rag
|
|
2
2
|
|
|
3
|
+
## 1.0.2-alpha.0
|
|
4
|
+
|
|
5
|
+
### Patch Changes
|
|
6
|
+
|
|
7
|
+
- 43da563: Refactor relevance provider
|
|
8
|
+
- Updated dependencies [c7bbf1e]
|
|
9
|
+
- Updated dependencies [8722d53]
|
|
10
|
+
- Updated dependencies [132027f]
|
|
11
|
+
- Updated dependencies [0c85311]
|
|
12
|
+
- Updated dependencies [cb16baf]
|
|
13
|
+
- @mastra/core@0.10.11-alpha.3
|
|
14
|
+
|
|
3
15
|
## 1.0.1
|
|
4
16
|
|
|
5
17
|
### Patch Changes
|
|
@@ -4,6 +4,7 @@ import type { MastraLanguageModel } from '@mastra/core/agent';
|
|
|
4
4
|
import type { MastraVector } from '@mastra/core/vector';
|
|
5
5
|
import type { QueryResult } from '@mastra/core/vector';
|
|
6
6
|
import type { QueryResult as QueryResult_2 } from '@mastra/core';
|
|
7
|
+
import type { RelevanceScoreProvider } from '@mastra/core/relevance';
|
|
7
8
|
import type { TiktokenEncoding } from 'js-tiktoken';
|
|
8
9
|
import type { TiktokenModel } from 'js-tiktoken';
|
|
9
10
|
import type { Tool } from '@mastra/core/tools';
|
|
@@ -162,6 +163,16 @@ declare type ChunkStrategy = 'recursive' | 'character' | 'token' | 'markdown' |
|
|
|
162
163
|
export { ChunkStrategy }
|
|
163
164
|
export { ChunkStrategy as ChunkStrategy_alias_1 }
|
|
164
165
|
|
|
166
|
+
declare class CohereRelevanceScorer implements RelevanceScoreProvider {
|
|
167
|
+
private client;
|
|
168
|
+
private model;
|
|
169
|
+
constructor(model: string, apiKey?: string);
|
|
170
|
+
getRelevanceScore(query: string, text: string): Promise<number>;
|
|
171
|
+
}
|
|
172
|
+
export { CohereRelevanceScorer }
|
|
173
|
+
export { CohereRelevanceScorer as CohereRelevanceScorer_alias_1 }
|
|
174
|
+
export { CohereRelevanceScorer as CohereRelevanceScorer_alias_2 }
|
|
175
|
+
|
|
165
176
|
/**
|
|
166
177
|
* Convert an array of source inputs (QueryResult, RankedNode, or RerankResult) to an array of sources.
|
|
167
178
|
* @param results Array of source inputs to convert.
|
|
@@ -593,6 +604,15 @@ export declare class MarkdownTransformer extends RecursiveCharacterTransformer {
|
|
|
593
604
|
});
|
|
594
605
|
}
|
|
595
606
|
|
|
607
|
+
declare class MastraAgentRelevanceScorer implements RelevanceScoreProvider {
|
|
608
|
+
private agent;
|
|
609
|
+
constructor(name: string, model: MastraLanguageModel);
|
|
610
|
+
getRelevanceScore(query: string, text: string): Promise<number>;
|
|
611
|
+
}
|
|
612
|
+
export { MastraAgentRelevanceScorer }
|
|
613
|
+
export { MastraAgentRelevanceScorer as MastraAgentRelevanceScorer_alias_1 }
|
|
614
|
+
export { MastraAgentRelevanceScorer as MastraAgentRelevanceScorer_alias_2 }
|
|
615
|
+
|
|
596
616
|
declare class MDocument {
|
|
597
617
|
private chunks;
|
|
598
618
|
private type;
|
|
@@ -950,7 +970,7 @@ export { rerank as rerank_alias_1 }
|
|
|
950
970
|
|
|
951
971
|
declare interface RerankConfig {
|
|
952
972
|
options?: RerankerOptions;
|
|
953
|
-
model: MastraLanguageModel;
|
|
973
|
+
model: MastraLanguageModel | RelevanceScoreProvider;
|
|
954
974
|
}
|
|
955
975
|
export { RerankConfig }
|
|
956
976
|
export { RerankConfig as RerankConfig_alias_1 }
|
|
@@ -978,6 +998,15 @@ declare interface RerankResult {
|
|
|
978
998
|
export { RerankResult }
|
|
979
999
|
export { RerankResult as RerankResult_alias_1 }
|
|
980
1000
|
|
|
1001
|
+
declare function rerankWithScorer({ results, query, scorer, options, }: {
|
|
1002
|
+
results: QueryResult[];
|
|
1003
|
+
query: string;
|
|
1004
|
+
scorer: RelevanceScoreProvider;
|
|
1005
|
+
options: RerankerFunctionOptions;
|
|
1006
|
+
}): Promise<RerankResult[]>;
|
|
1007
|
+
export { rerankWithScorer }
|
|
1008
|
+
export { rerankWithScorer as rerankWithScorer_alias_1 }
|
|
1009
|
+
|
|
981
1010
|
declare interface ScoringDetails {
|
|
982
1011
|
semantic: number;
|
|
983
1012
|
vector: number;
|
|
@@ -1256,4 +1285,14 @@ declare type WhereDocumentOperator = '$contains' | '$not_contains' | LogicalOper
|
|
|
1256
1285
|
|
|
1257
1286
|
declare type WhereOperator = '$gt' | '$gte' | '$lt' | '$lte' | '$ne' | '$eq';
|
|
1258
1287
|
|
|
1288
|
+
declare class ZeroEntropyRelevanceScorer implements RelevanceScoreProvider {
|
|
1289
|
+
private client;
|
|
1290
|
+
private model;
|
|
1291
|
+
constructor(model?: string, apiKey?: string);
|
|
1292
|
+
getRelevanceScore(query: string, text: string): Promise<number>;
|
|
1293
|
+
}
|
|
1294
|
+
export { ZeroEntropyRelevanceScorer }
|
|
1295
|
+
export { ZeroEntropyRelevanceScorer as ZeroEntropyRelevanceScorer_alias_1 }
|
|
1296
|
+
export { ZeroEntropyRelevanceScorer as ZeroEntropyRelevanceScorer_alias_2 }
|
|
1297
|
+
|
|
1259
1298
|
export { }
|
|
@@ -4,6 +4,7 @@ import type { MastraLanguageModel } from '@mastra/core/agent';
|
|
|
4
4
|
import type { MastraVector } from '@mastra/core/vector';
|
|
5
5
|
import type { QueryResult } from '@mastra/core/vector';
|
|
6
6
|
import type { QueryResult as QueryResult_2 } from '@mastra/core';
|
|
7
|
+
import type { RelevanceScoreProvider } from '@mastra/core/relevance';
|
|
7
8
|
import type { TiktokenEncoding } from 'js-tiktoken';
|
|
8
9
|
import type { TiktokenModel } from 'js-tiktoken';
|
|
9
10
|
import type { Tool } from '@mastra/core/tools';
|
|
@@ -162,6 +163,16 @@ declare type ChunkStrategy = 'recursive' | 'character' | 'token' | 'markdown' |
|
|
|
162
163
|
export { ChunkStrategy }
|
|
163
164
|
export { ChunkStrategy as ChunkStrategy_alias_1 }
|
|
164
165
|
|
|
166
|
+
declare class CohereRelevanceScorer implements RelevanceScoreProvider {
|
|
167
|
+
private client;
|
|
168
|
+
private model;
|
|
169
|
+
constructor(model: string, apiKey?: string);
|
|
170
|
+
getRelevanceScore(query: string, text: string): Promise<number>;
|
|
171
|
+
}
|
|
172
|
+
export { CohereRelevanceScorer }
|
|
173
|
+
export { CohereRelevanceScorer as CohereRelevanceScorer_alias_1 }
|
|
174
|
+
export { CohereRelevanceScorer as CohereRelevanceScorer_alias_2 }
|
|
175
|
+
|
|
165
176
|
/**
|
|
166
177
|
* Convert an array of source inputs (QueryResult, RankedNode, or RerankResult) to an array of sources.
|
|
167
178
|
* @param results Array of source inputs to convert.
|
|
@@ -593,6 +604,15 @@ export declare class MarkdownTransformer extends RecursiveCharacterTransformer {
|
|
|
593
604
|
});
|
|
594
605
|
}
|
|
595
606
|
|
|
607
|
+
declare class MastraAgentRelevanceScorer implements RelevanceScoreProvider {
|
|
608
|
+
private agent;
|
|
609
|
+
constructor(name: string, model: MastraLanguageModel);
|
|
610
|
+
getRelevanceScore(query: string, text: string): Promise<number>;
|
|
611
|
+
}
|
|
612
|
+
export { MastraAgentRelevanceScorer }
|
|
613
|
+
export { MastraAgentRelevanceScorer as MastraAgentRelevanceScorer_alias_1 }
|
|
614
|
+
export { MastraAgentRelevanceScorer as MastraAgentRelevanceScorer_alias_2 }
|
|
615
|
+
|
|
596
616
|
declare class MDocument {
|
|
597
617
|
private chunks;
|
|
598
618
|
private type;
|
|
@@ -950,7 +970,7 @@ export { rerank as rerank_alias_1 }
|
|
|
950
970
|
|
|
951
971
|
declare interface RerankConfig {
|
|
952
972
|
options?: RerankerOptions;
|
|
953
|
-
model: MastraLanguageModel;
|
|
973
|
+
model: MastraLanguageModel | RelevanceScoreProvider;
|
|
954
974
|
}
|
|
955
975
|
export { RerankConfig }
|
|
956
976
|
export { RerankConfig as RerankConfig_alias_1 }
|
|
@@ -978,6 +998,15 @@ declare interface RerankResult {
|
|
|
978
998
|
export { RerankResult }
|
|
979
999
|
export { RerankResult as RerankResult_alias_1 }
|
|
980
1000
|
|
|
1001
|
+
declare function rerankWithScorer({ results, query, scorer, options, }: {
|
|
1002
|
+
results: QueryResult[];
|
|
1003
|
+
query: string;
|
|
1004
|
+
scorer: RelevanceScoreProvider;
|
|
1005
|
+
options: RerankerFunctionOptions;
|
|
1006
|
+
}): Promise<RerankResult[]>;
|
|
1007
|
+
export { rerankWithScorer }
|
|
1008
|
+
export { rerankWithScorer as rerankWithScorer_alias_1 }
|
|
1009
|
+
|
|
981
1010
|
declare interface ScoringDetails {
|
|
982
1011
|
semantic: number;
|
|
983
1012
|
vector: number;
|
|
@@ -1256,4 +1285,14 @@ declare type WhereDocumentOperator = '$contains' | '$not_contains' | LogicalOper
|
|
|
1256
1285
|
|
|
1257
1286
|
declare type WhereOperator = '$gt' | '$gte' | '$lt' | '$lte' | '$ne' | '$eq';
|
|
1258
1287
|
|
|
1288
|
+
declare class ZeroEntropyRelevanceScorer implements RelevanceScoreProvider {
|
|
1289
|
+
private client;
|
|
1290
|
+
private model;
|
|
1291
|
+
constructor(model?: string, apiKey?: string);
|
|
1292
|
+
getRelevanceScore(query: string, text: string): Promise<number>;
|
|
1293
|
+
}
|
|
1294
|
+
export { ZeroEntropyRelevanceScorer }
|
|
1295
|
+
export { ZeroEntropyRelevanceScorer as ZeroEntropyRelevanceScorer_alias_1 }
|
|
1296
|
+
export { ZeroEntropyRelevanceScorer as ZeroEntropyRelevanceScorer_alias_2 }
|
|
1297
|
+
|
|
1259
1298
|
export { }
|
package/dist/index.cjs
CHANGED
|
@@ -4,11 +4,18 @@ var crypto = require('crypto');
|
|
|
4
4
|
var zod = require('zod');
|
|
5
5
|
var nodeHtmlBetterParser = require('node-html-better-parser');
|
|
6
6
|
var jsTiktoken = require('js-tiktoken');
|
|
7
|
-
var relevance = require('@mastra/core/relevance');
|
|
8
7
|
var big_js = require('big.js');
|
|
8
|
+
var cohereAi = require('cohere-ai');
|
|
9
|
+
var agent = require('@mastra/core/agent');
|
|
10
|
+
var relevance = require('@mastra/core/relevance');
|
|
11
|
+
var ZeroEntropy = require('zeroentropy');
|
|
9
12
|
var tools = require('@mastra/core/tools');
|
|
10
13
|
var ai = require('ai');
|
|
11
14
|
|
|
15
|
+
function _interopDefault (e) { return e && e.__esModule ? e : { default: e }; }
|
|
16
|
+
|
|
17
|
+
var ZeroEntropy__default = /*#__PURE__*/_interopDefault(ZeroEntropy);
|
|
18
|
+
|
|
12
19
|
var __create = Object.create;
|
|
13
20
|
var __defProp = Object.defineProperty;
|
|
14
21
|
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
|
|
@@ -3394,15 +3401,16 @@ var OpenAIResponsesLanguageModel = class {
|
|
|
3394
3401
|
async doGenerate(options) {
|
|
3395
3402
|
var _a15, _b, _c, _d, _e, _f, _g;
|
|
3396
3403
|
const { args: body, warnings } = this.getArgs(options);
|
|
3404
|
+
const url = this.config.url({
|
|
3405
|
+
path: "/responses",
|
|
3406
|
+
modelId: this.modelId
|
|
3407
|
+
});
|
|
3397
3408
|
const {
|
|
3398
3409
|
responseHeaders,
|
|
3399
3410
|
value: response,
|
|
3400
3411
|
rawValue: rawResponse
|
|
3401
3412
|
} = await postJsonToApi({
|
|
3402
|
-
url
|
|
3403
|
-
path: "/responses",
|
|
3404
|
-
modelId: this.modelId
|
|
3405
|
-
}),
|
|
3413
|
+
url,
|
|
3406
3414
|
headers: combineHeaders(this.config.headers(), options.headers),
|
|
3407
3415
|
body,
|
|
3408
3416
|
failedResponseHandler: openaiFailedResponseHandler,
|
|
@@ -3410,6 +3418,10 @@ var OpenAIResponsesLanguageModel = class {
|
|
|
3410
3418
|
zod.z.object({
|
|
3411
3419
|
id: zod.z.string(),
|
|
3412
3420
|
created_at: zod.z.number(),
|
|
3421
|
+
error: zod.z.object({
|
|
3422
|
+
message: zod.z.string(),
|
|
3423
|
+
code: zod.z.string()
|
|
3424
|
+
}).nullish(),
|
|
3413
3425
|
model: zod.z.string(),
|
|
3414
3426
|
output: zod.z.array(
|
|
3415
3427
|
zod.z.discriminatedUnion("type", [
|
|
@@ -3462,6 +3474,17 @@ var OpenAIResponsesLanguageModel = class {
|
|
|
3462
3474
|
abortSignal: options.abortSignal,
|
|
3463
3475
|
fetch: this.config.fetch
|
|
3464
3476
|
});
|
|
3477
|
+
if (response.error) {
|
|
3478
|
+
throw new APICallError({
|
|
3479
|
+
message: response.error.message,
|
|
3480
|
+
url,
|
|
3481
|
+
requestBodyValues: body,
|
|
3482
|
+
statusCode: 400,
|
|
3483
|
+
responseHeaders,
|
|
3484
|
+
responseBody: rawResponse,
|
|
3485
|
+
isRetryable: false
|
|
3486
|
+
});
|
|
3487
|
+
}
|
|
3465
3488
|
const outputTextElements = response.output.filter((output) => output.type === "message").flatMap((output) => output.content).filter((content) => content.type === "output_text");
|
|
3466
3489
|
const toolCalls = response.output.filter((output) => output.type === "function_call").map((output) => ({
|
|
3467
3490
|
toolCallType: "function",
|
|
@@ -3633,6 +3656,8 @@ var OpenAIResponsesLanguageModel = class {
|
|
|
3633
3656
|
title: value.annotation.title
|
|
3634
3657
|
}
|
|
3635
3658
|
});
|
|
3659
|
+
} else if (isErrorChunk(value)) {
|
|
3660
|
+
controller.enqueue({ type: "error", error: value });
|
|
3636
3661
|
}
|
|
3637
3662
|
},
|
|
3638
3663
|
flush(controller) {
|
|
@@ -3742,6 +3767,13 @@ var responseReasoningSummaryTextDeltaSchema = zod.z.object({
|
|
|
3742
3767
|
summary_index: zod.z.number(),
|
|
3743
3768
|
delta: zod.z.string()
|
|
3744
3769
|
});
|
|
3770
|
+
var errorChunkSchema = zod.z.object({
|
|
3771
|
+
type: zod.z.literal("error"),
|
|
3772
|
+
code: zod.z.string(),
|
|
3773
|
+
message: zod.z.string(),
|
|
3774
|
+
param: zod.z.string().nullish(),
|
|
3775
|
+
sequence_number: zod.z.number()
|
|
3776
|
+
});
|
|
3745
3777
|
var openaiResponsesChunkSchema = zod.z.union([
|
|
3746
3778
|
textDeltaChunkSchema,
|
|
3747
3779
|
responseFinishedChunkSchema,
|
|
@@ -3751,6 +3783,7 @@ var openaiResponsesChunkSchema = zod.z.union([
|
|
|
3751
3783
|
responseOutputItemAddedSchema,
|
|
3752
3784
|
responseAnnotationAddedSchema,
|
|
3753
3785
|
responseReasoningSummaryTextDeltaSchema,
|
|
3786
|
+
errorChunkSchema,
|
|
3754
3787
|
zod.z.object({ type: zod.z.string() }).passthrough()
|
|
3755
3788
|
// fallback for unknown chunks
|
|
3756
3789
|
]);
|
|
@@ -3778,6 +3811,9 @@ function isResponseAnnotationAddedChunk(chunk) {
|
|
|
3778
3811
|
function isResponseReasoningSummaryTextDeltaChunk(chunk) {
|
|
3779
3812
|
return chunk.type === "response.reasoning_summary_text.delta";
|
|
3780
3813
|
}
|
|
3814
|
+
function isErrorChunk(chunk) {
|
|
3815
|
+
return chunk.type === "error";
|
|
3816
|
+
}
|
|
3781
3817
|
function getResponsesModelConfig(modelId) {
|
|
3782
3818
|
if (modelId.startsWith("o")) {
|
|
3783
3819
|
if (modelId.startsWith("o1-mini") || modelId.startsWith("o1-preview")) {
|
|
@@ -5940,6 +5976,70 @@ var MDocument = class _MDocument {
|
|
|
5940
5976
|
return this.chunks.map((doc) => doc.metadata);
|
|
5941
5977
|
}
|
|
5942
5978
|
};
|
|
5979
|
+
var CohereRelevanceScorer = class {
|
|
5980
|
+
client;
|
|
5981
|
+
model;
|
|
5982
|
+
constructor(model, apiKey) {
|
|
5983
|
+
this.client = new cohereAi.CohereClient({
|
|
5984
|
+
token: apiKey || process.env.COHERE_API_KEY || ""
|
|
5985
|
+
});
|
|
5986
|
+
this.model = model;
|
|
5987
|
+
}
|
|
5988
|
+
async getRelevanceScore(query, text) {
|
|
5989
|
+
const response = await this.client.rerank({
|
|
5990
|
+
query,
|
|
5991
|
+
documents: [text],
|
|
5992
|
+
model: this.model,
|
|
5993
|
+
topN: 1
|
|
5994
|
+
});
|
|
5995
|
+
return response.results[0].relevanceScore;
|
|
5996
|
+
}
|
|
5997
|
+
};
|
|
5998
|
+
var MastraAgentRelevanceScorer = class {
|
|
5999
|
+
agent;
|
|
6000
|
+
constructor(name14, model) {
|
|
6001
|
+
this.agent = new agent.Agent({
|
|
6002
|
+
name: `Relevance Scorer ${name14}`,
|
|
6003
|
+
instructions: `You are a specialized agent for evaluating the relevance of text to queries.
|
|
6004
|
+
Your task is to rate how well a text passage answers a given query.
|
|
6005
|
+
Output only a number between 0 and 1, where:
|
|
6006
|
+
1.0 = Perfectly relevant, directly answers the query
|
|
6007
|
+
0.0 = Completely irrelevant
|
|
6008
|
+
Consider:
|
|
6009
|
+
- Direct relevance to the question
|
|
6010
|
+
- Completeness of information
|
|
6011
|
+
- Quality and specificity
|
|
6012
|
+
Always return just the number, no explanation.`,
|
|
6013
|
+
model
|
|
6014
|
+
});
|
|
6015
|
+
}
|
|
6016
|
+
async getRelevanceScore(query, text) {
|
|
6017
|
+
const prompt = relevance.createSimilarityPrompt(query, text);
|
|
6018
|
+
const response = await this.agent.generate(prompt);
|
|
6019
|
+
return parseFloat(response.text);
|
|
6020
|
+
}
|
|
6021
|
+
};
|
|
6022
|
+
var ZeroEntropyRelevanceScorer = class {
|
|
6023
|
+
client;
|
|
6024
|
+
model;
|
|
6025
|
+
constructor(model, apiKey) {
|
|
6026
|
+
this.client = new ZeroEntropy__default.default({
|
|
6027
|
+
apiKey: apiKey || process.env.ZEROENTROPY_API_KEY || ""
|
|
6028
|
+
});
|
|
6029
|
+
this.model = model || "zerank-1";
|
|
6030
|
+
}
|
|
6031
|
+
async getRelevanceScore(query, text) {
|
|
6032
|
+
const response = await this.client.models.rerank({
|
|
6033
|
+
query,
|
|
6034
|
+
documents: [text],
|
|
6035
|
+
model: this.model,
|
|
6036
|
+
top_n: 1
|
|
6037
|
+
});
|
|
6038
|
+
return response.results[0]?.relevance_score ?? 0;
|
|
6039
|
+
}
|
|
6040
|
+
};
|
|
6041
|
+
|
|
6042
|
+
// src/rerank/index.ts
|
|
5943
6043
|
var DEFAULT_WEIGHTS = {
|
|
5944
6044
|
semantic: 0.4,
|
|
5945
6045
|
vector: 0.4,
|
|
@@ -5958,13 +6058,12 @@ function adjustScores(score, queryAnalysis) {
|
|
|
5958
6058
|
const featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
|
|
5959
6059
|
return score * magnitudeAdjustment * featureStrengthAdjustment;
|
|
5960
6060
|
}
|
|
5961
|
-
async function
|
|
5962
|
-
|
|
5963
|
-
|
|
5964
|
-
|
|
5965
|
-
|
|
5966
|
-
|
|
5967
|
-
}
|
|
6061
|
+
async function executeRerank({
|
|
6062
|
+
results,
|
|
6063
|
+
query,
|
|
6064
|
+
scorer,
|
|
6065
|
+
options
|
|
6066
|
+
}) {
|
|
5968
6067
|
const { queryEmbedding, topK = 3 } = options;
|
|
5969
6068
|
const weights = {
|
|
5970
6069
|
...DEFAULT_WEIGHTS,
|
|
@@ -5980,7 +6079,7 @@ async function rerank(results, query, model, options) {
|
|
|
5980
6079
|
results.map(async (result, index) => {
|
|
5981
6080
|
let semanticScore = 0;
|
|
5982
6081
|
if (result?.metadata?.text) {
|
|
5983
|
-
semanticScore = await
|
|
6082
|
+
semanticScore = await scorer.getRelevanceScore(query, result?.metadata?.text);
|
|
5984
6083
|
}
|
|
5985
6084
|
const vectorScore = result.score;
|
|
5986
6085
|
const positionScore = calculatePositionScore(index, resultLength);
|
|
@@ -6007,6 +6106,33 @@ async function rerank(results, query, model, options) {
|
|
|
6007
6106
|
);
|
|
6008
6107
|
return scoredResults.sort((a, b) => b.score - a.score).slice(0, topK);
|
|
6009
6108
|
}
|
|
6109
|
+
async function rerankWithScorer({
|
|
6110
|
+
results,
|
|
6111
|
+
query,
|
|
6112
|
+
scorer,
|
|
6113
|
+
options
|
|
6114
|
+
}) {
|
|
6115
|
+
return executeRerank({
|
|
6116
|
+
results,
|
|
6117
|
+
query,
|
|
6118
|
+
scorer,
|
|
6119
|
+
options
|
|
6120
|
+
});
|
|
6121
|
+
}
|
|
6122
|
+
async function rerank(results, query, model, options) {
|
|
6123
|
+
let semanticProvider;
|
|
6124
|
+
if (model.modelId === "rerank-v3.5") {
|
|
6125
|
+
semanticProvider = new CohereRelevanceScorer(model.modelId);
|
|
6126
|
+
} else {
|
|
6127
|
+
semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
|
|
6128
|
+
}
|
|
6129
|
+
return executeRerank({
|
|
6130
|
+
results,
|
|
6131
|
+
query,
|
|
6132
|
+
scorer: semanticProvider,
|
|
6133
|
+
options
|
|
6134
|
+
});
|
|
6135
|
+
}
|
|
6010
6136
|
|
|
6011
6137
|
// src/graph-rag/index.ts
|
|
6012
6138
|
var GraphRAG = class {
|
|
@@ -6605,10 +6731,23 @@ var createVectorQueryTool = (options) => {
|
|
|
6605
6731
|
if (logger) {
|
|
6606
6732
|
logger.debug("Reranking results", { rerankerModel: reranker.model, rerankerOptions: reranker.options });
|
|
6607
6733
|
}
|
|
6608
|
-
|
|
6609
|
-
|
|
6610
|
-
|
|
6611
|
-
|
|
6734
|
+
let rerankedResults = [];
|
|
6735
|
+
if (typeof reranker?.model === "object" && "getRelevanceScore" in reranker?.model) {
|
|
6736
|
+
rerankedResults = await rerankWithScorer({
|
|
6737
|
+
results,
|
|
6738
|
+
query: queryText,
|
|
6739
|
+
scorer: reranker.model,
|
|
6740
|
+
options: {
|
|
6741
|
+
...reranker.options,
|
|
6742
|
+
topK: reranker.options?.topK || topKValue
|
|
6743
|
+
}
|
|
6744
|
+
});
|
|
6745
|
+
} else {
|
|
6746
|
+
rerankedResults = await rerank(results, queryText, reranker.model, {
|
|
6747
|
+
...reranker.options,
|
|
6748
|
+
topK: reranker.options?.topK || topKValue
|
|
6749
|
+
});
|
|
6750
|
+
}
|
|
6612
6751
|
if (logger) {
|
|
6613
6752
|
logger.debug("Reranking complete", { rerankedCount: rerankedResults.length });
|
|
6614
6753
|
}
|
|
@@ -7448,15 +7587,18 @@ Example Complex Query:
|
|
|
7448
7587
|
|
|
7449
7588
|
exports.ASTRA_PROMPT = ASTRA_PROMPT;
|
|
7450
7589
|
exports.CHROMA_PROMPT = CHROMA_PROMPT;
|
|
7590
|
+
exports.CohereRelevanceScorer = CohereRelevanceScorer;
|
|
7451
7591
|
exports.GraphRAG = GraphRAG;
|
|
7452
7592
|
exports.LIBSQL_PROMPT = LIBSQL_PROMPT;
|
|
7453
7593
|
exports.MDocument = MDocument;
|
|
7454
7594
|
exports.MONGODB_PROMPT = MONGODB_PROMPT;
|
|
7595
|
+
exports.MastraAgentRelevanceScorer = MastraAgentRelevanceScorer;
|
|
7455
7596
|
exports.PGVECTOR_PROMPT = PGVECTOR_PROMPT;
|
|
7456
7597
|
exports.PINECONE_PROMPT = PINECONE_PROMPT;
|
|
7457
7598
|
exports.QDRANT_PROMPT = QDRANT_PROMPT;
|
|
7458
7599
|
exports.UPSTASH_PROMPT = UPSTASH_PROMPT;
|
|
7459
7600
|
exports.VECTORIZE_PROMPT = VECTORIZE_PROMPT;
|
|
7601
|
+
exports.ZeroEntropyRelevanceScorer = ZeroEntropyRelevanceScorer;
|
|
7460
7602
|
exports.createDocumentChunkerTool = createDocumentChunkerTool;
|
|
7461
7603
|
exports.createGraphRAGTool = createGraphRAGTool;
|
|
7462
7604
|
exports.createVectorQueryTool = createVectorQueryTool;
|
|
@@ -7465,4 +7607,5 @@ exports.defaultVectorQueryDescription = defaultVectorQueryDescription;
|
|
|
7465
7607
|
exports.filterDescription = filterDescription;
|
|
7466
7608
|
exports.queryTextDescription = queryTextDescription;
|
|
7467
7609
|
exports.rerank = rerank;
|
|
7610
|
+
exports.rerankWithScorer = rerankWithScorer;
|
|
7468
7611
|
exports.topKDescription = topKDescription;
|
package/dist/index.d.cts
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
export { GraphRAG } from './_tsup-dts-rollup.cjs';
|
|
2
2
|
export { MDocument } from './_tsup-dts-rollup.cjs';
|
|
3
|
+
export { rerankWithScorer } from './_tsup-dts-rollup.cjs';
|
|
3
4
|
export { rerank } from './_tsup-dts-rollup.cjs';
|
|
4
5
|
export { RerankResult } from './_tsup-dts-rollup.cjs';
|
|
5
6
|
export { RerankerOptions } from './_tsup-dts-rollup.cjs';
|
|
6
7
|
export { RerankerFunctionOptions } from './_tsup-dts-rollup.cjs';
|
|
7
8
|
export { RerankConfig } from './_tsup-dts-rollup.cjs';
|
|
9
|
+
export { CohereRelevanceScorer } from './_tsup-dts-rollup.cjs';
|
|
10
|
+
export { MastraAgentRelevanceScorer } from './_tsup-dts-rollup.cjs';
|
|
11
|
+
export { ZeroEntropyRelevanceScorer } from './_tsup-dts-rollup.cjs';
|
|
8
12
|
export { createDocumentChunkerTool } from './_tsup-dts-rollup.cjs';
|
|
9
13
|
export { createGraphRAGTool } from './_tsup-dts-rollup.cjs';
|
|
10
14
|
export { createVectorQueryTool } from './_tsup-dts-rollup.cjs';
|
package/dist/index.d.ts
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
export { GraphRAG } from './_tsup-dts-rollup.js';
|
|
2
2
|
export { MDocument } from './_tsup-dts-rollup.js';
|
|
3
|
+
export { rerankWithScorer } from './_tsup-dts-rollup.js';
|
|
3
4
|
export { rerank } from './_tsup-dts-rollup.js';
|
|
4
5
|
export { RerankResult } from './_tsup-dts-rollup.js';
|
|
5
6
|
export { RerankerOptions } from './_tsup-dts-rollup.js';
|
|
6
7
|
export { RerankerFunctionOptions } from './_tsup-dts-rollup.js';
|
|
7
8
|
export { RerankConfig } from './_tsup-dts-rollup.js';
|
|
9
|
+
export { CohereRelevanceScorer } from './_tsup-dts-rollup.js';
|
|
10
|
+
export { MastraAgentRelevanceScorer } from './_tsup-dts-rollup.js';
|
|
11
|
+
export { ZeroEntropyRelevanceScorer } from './_tsup-dts-rollup.js';
|
|
8
12
|
export { createDocumentChunkerTool } from './_tsup-dts-rollup.js';
|
|
9
13
|
export { createGraphRAGTool } from './_tsup-dts-rollup.js';
|
|
10
14
|
export { createVectorQueryTool } from './_tsup-dts-rollup.js';
|
package/dist/index.js
CHANGED
|
@@ -2,8 +2,11 @@ import { randomUUID, createHash } from 'crypto';
|
|
|
2
2
|
import { z } from 'zod';
|
|
3
3
|
import { parse } from 'node-html-better-parser';
|
|
4
4
|
import { encodingForModel, getEncoding } from 'js-tiktoken';
|
|
5
|
-
import { CohereRelevanceScorer, MastraAgentRelevanceScorer } from '@mastra/core/relevance';
|
|
6
5
|
import { Big } from 'big.js';
|
|
6
|
+
import { CohereClient } from 'cohere-ai';
|
|
7
|
+
import { Agent } from '@mastra/core/agent';
|
|
8
|
+
import { createSimilarityPrompt } from '@mastra/core/relevance';
|
|
9
|
+
import ZeroEntropy from 'zeroentropy';
|
|
7
10
|
import { createTool } from '@mastra/core/tools';
|
|
8
11
|
import { embed } from 'ai';
|
|
9
12
|
|
|
@@ -3392,15 +3395,16 @@ var OpenAIResponsesLanguageModel = class {
|
|
|
3392
3395
|
async doGenerate(options) {
|
|
3393
3396
|
var _a15, _b, _c, _d, _e, _f, _g;
|
|
3394
3397
|
const { args: body, warnings } = this.getArgs(options);
|
|
3398
|
+
const url = this.config.url({
|
|
3399
|
+
path: "/responses",
|
|
3400
|
+
modelId: this.modelId
|
|
3401
|
+
});
|
|
3395
3402
|
const {
|
|
3396
3403
|
responseHeaders,
|
|
3397
3404
|
value: response,
|
|
3398
3405
|
rawValue: rawResponse
|
|
3399
3406
|
} = await postJsonToApi({
|
|
3400
|
-
url
|
|
3401
|
-
path: "/responses",
|
|
3402
|
-
modelId: this.modelId
|
|
3403
|
-
}),
|
|
3407
|
+
url,
|
|
3404
3408
|
headers: combineHeaders(this.config.headers(), options.headers),
|
|
3405
3409
|
body,
|
|
3406
3410
|
failedResponseHandler: openaiFailedResponseHandler,
|
|
@@ -3408,6 +3412,10 @@ var OpenAIResponsesLanguageModel = class {
|
|
|
3408
3412
|
z.object({
|
|
3409
3413
|
id: z.string(),
|
|
3410
3414
|
created_at: z.number(),
|
|
3415
|
+
error: z.object({
|
|
3416
|
+
message: z.string(),
|
|
3417
|
+
code: z.string()
|
|
3418
|
+
}).nullish(),
|
|
3411
3419
|
model: z.string(),
|
|
3412
3420
|
output: z.array(
|
|
3413
3421
|
z.discriminatedUnion("type", [
|
|
@@ -3460,6 +3468,17 @@ var OpenAIResponsesLanguageModel = class {
|
|
|
3460
3468
|
abortSignal: options.abortSignal,
|
|
3461
3469
|
fetch: this.config.fetch
|
|
3462
3470
|
});
|
|
3471
|
+
if (response.error) {
|
|
3472
|
+
throw new APICallError({
|
|
3473
|
+
message: response.error.message,
|
|
3474
|
+
url,
|
|
3475
|
+
requestBodyValues: body,
|
|
3476
|
+
statusCode: 400,
|
|
3477
|
+
responseHeaders,
|
|
3478
|
+
responseBody: rawResponse,
|
|
3479
|
+
isRetryable: false
|
|
3480
|
+
});
|
|
3481
|
+
}
|
|
3463
3482
|
const outputTextElements = response.output.filter((output) => output.type === "message").flatMap((output) => output.content).filter((content) => content.type === "output_text");
|
|
3464
3483
|
const toolCalls = response.output.filter((output) => output.type === "function_call").map((output) => ({
|
|
3465
3484
|
toolCallType: "function",
|
|
@@ -3631,6 +3650,8 @@ var OpenAIResponsesLanguageModel = class {
|
|
|
3631
3650
|
title: value.annotation.title
|
|
3632
3651
|
}
|
|
3633
3652
|
});
|
|
3653
|
+
} else if (isErrorChunk(value)) {
|
|
3654
|
+
controller.enqueue({ type: "error", error: value });
|
|
3634
3655
|
}
|
|
3635
3656
|
},
|
|
3636
3657
|
flush(controller) {
|
|
@@ -3740,6 +3761,13 @@ var responseReasoningSummaryTextDeltaSchema = z.object({
|
|
|
3740
3761
|
summary_index: z.number(),
|
|
3741
3762
|
delta: z.string()
|
|
3742
3763
|
});
|
|
3764
|
+
var errorChunkSchema = z.object({
|
|
3765
|
+
type: z.literal("error"),
|
|
3766
|
+
code: z.string(),
|
|
3767
|
+
message: z.string(),
|
|
3768
|
+
param: z.string().nullish(),
|
|
3769
|
+
sequence_number: z.number()
|
|
3770
|
+
});
|
|
3743
3771
|
var openaiResponsesChunkSchema = z.union([
|
|
3744
3772
|
textDeltaChunkSchema,
|
|
3745
3773
|
responseFinishedChunkSchema,
|
|
@@ -3749,6 +3777,7 @@ var openaiResponsesChunkSchema = z.union([
|
|
|
3749
3777
|
responseOutputItemAddedSchema,
|
|
3750
3778
|
responseAnnotationAddedSchema,
|
|
3751
3779
|
responseReasoningSummaryTextDeltaSchema,
|
|
3780
|
+
errorChunkSchema,
|
|
3752
3781
|
z.object({ type: z.string() }).passthrough()
|
|
3753
3782
|
// fallback for unknown chunks
|
|
3754
3783
|
]);
|
|
@@ -3776,6 +3805,9 @@ function isResponseAnnotationAddedChunk(chunk) {
|
|
|
3776
3805
|
function isResponseReasoningSummaryTextDeltaChunk(chunk) {
|
|
3777
3806
|
return chunk.type === "response.reasoning_summary_text.delta";
|
|
3778
3807
|
}
|
|
3808
|
+
function isErrorChunk(chunk) {
|
|
3809
|
+
return chunk.type === "error";
|
|
3810
|
+
}
|
|
3779
3811
|
function getResponsesModelConfig(modelId) {
|
|
3780
3812
|
if (modelId.startsWith("o")) {
|
|
3781
3813
|
if (modelId.startsWith("o1-mini") || modelId.startsWith("o1-preview")) {
|
|
@@ -5938,6 +5970,70 @@ var MDocument = class _MDocument {
|
|
|
5938
5970
|
return this.chunks.map((doc) => doc.metadata);
|
|
5939
5971
|
}
|
|
5940
5972
|
};
|
|
5973
|
+
var CohereRelevanceScorer = class {
|
|
5974
|
+
client;
|
|
5975
|
+
model;
|
|
5976
|
+
constructor(model, apiKey) {
|
|
5977
|
+
this.client = new CohereClient({
|
|
5978
|
+
token: apiKey || process.env.COHERE_API_KEY || ""
|
|
5979
|
+
});
|
|
5980
|
+
this.model = model;
|
|
5981
|
+
}
|
|
5982
|
+
async getRelevanceScore(query, text) {
|
|
5983
|
+
const response = await this.client.rerank({
|
|
5984
|
+
query,
|
|
5985
|
+
documents: [text],
|
|
5986
|
+
model: this.model,
|
|
5987
|
+
topN: 1
|
|
5988
|
+
});
|
|
5989
|
+
return response.results[0].relevanceScore;
|
|
5990
|
+
}
|
|
5991
|
+
};
|
|
5992
|
+
var MastraAgentRelevanceScorer = class {
|
|
5993
|
+
agent;
|
|
5994
|
+
constructor(name14, model) {
|
|
5995
|
+
this.agent = new Agent({
|
|
5996
|
+
name: `Relevance Scorer ${name14}`,
|
|
5997
|
+
instructions: `You are a specialized agent for evaluating the relevance of text to queries.
|
|
5998
|
+
Your task is to rate how well a text passage answers a given query.
|
|
5999
|
+
Output only a number between 0 and 1, where:
|
|
6000
|
+
1.0 = Perfectly relevant, directly answers the query
|
|
6001
|
+
0.0 = Completely irrelevant
|
|
6002
|
+
Consider:
|
|
6003
|
+
- Direct relevance to the question
|
|
6004
|
+
- Completeness of information
|
|
6005
|
+
- Quality and specificity
|
|
6006
|
+
Always return just the number, no explanation.`,
|
|
6007
|
+
model
|
|
6008
|
+
});
|
|
6009
|
+
}
|
|
6010
|
+
async getRelevanceScore(query, text) {
|
|
6011
|
+
const prompt = createSimilarityPrompt(query, text);
|
|
6012
|
+
const response = await this.agent.generate(prompt);
|
|
6013
|
+
return parseFloat(response.text);
|
|
6014
|
+
}
|
|
6015
|
+
};
|
|
6016
|
+
var ZeroEntropyRelevanceScorer = class {
|
|
6017
|
+
client;
|
|
6018
|
+
model;
|
|
6019
|
+
constructor(model, apiKey) {
|
|
6020
|
+
this.client = new ZeroEntropy({
|
|
6021
|
+
apiKey: apiKey || process.env.ZEROENTROPY_API_KEY || ""
|
|
6022
|
+
});
|
|
6023
|
+
this.model = model || "zerank-1";
|
|
6024
|
+
}
|
|
6025
|
+
async getRelevanceScore(query, text) {
|
|
6026
|
+
const response = await this.client.models.rerank({
|
|
6027
|
+
query,
|
|
6028
|
+
documents: [text],
|
|
6029
|
+
model: this.model,
|
|
6030
|
+
top_n: 1
|
|
6031
|
+
});
|
|
6032
|
+
return response.results[0]?.relevance_score ?? 0;
|
|
6033
|
+
}
|
|
6034
|
+
};
|
|
6035
|
+
|
|
6036
|
+
// src/rerank/index.ts
|
|
5941
6037
|
var DEFAULT_WEIGHTS = {
|
|
5942
6038
|
semantic: 0.4,
|
|
5943
6039
|
vector: 0.4,
|
|
@@ -5956,13 +6052,12 @@ function adjustScores(score, queryAnalysis) {
|
|
|
5956
6052
|
const featureStrengthAdjustment = queryAnalysis.magnitude > 5 ? 1.05 : 1;
|
|
5957
6053
|
return score * magnitudeAdjustment * featureStrengthAdjustment;
|
|
5958
6054
|
}
|
|
5959
|
-
async function
|
|
5960
|
-
|
|
5961
|
-
|
|
5962
|
-
|
|
5963
|
-
|
|
5964
|
-
|
|
5965
|
-
}
|
|
6055
|
+
async function executeRerank({
|
|
6056
|
+
results,
|
|
6057
|
+
query,
|
|
6058
|
+
scorer,
|
|
6059
|
+
options
|
|
6060
|
+
}) {
|
|
5966
6061
|
const { queryEmbedding, topK = 3 } = options;
|
|
5967
6062
|
const weights = {
|
|
5968
6063
|
...DEFAULT_WEIGHTS,
|
|
@@ -5978,7 +6073,7 @@ async function rerank(results, query, model, options) {
|
|
|
5978
6073
|
results.map(async (result, index) => {
|
|
5979
6074
|
let semanticScore = 0;
|
|
5980
6075
|
if (result?.metadata?.text) {
|
|
5981
|
-
semanticScore = await
|
|
6076
|
+
semanticScore = await scorer.getRelevanceScore(query, result?.metadata?.text);
|
|
5982
6077
|
}
|
|
5983
6078
|
const vectorScore = result.score;
|
|
5984
6079
|
const positionScore = calculatePositionScore(index, resultLength);
|
|
@@ -6005,6 +6100,33 @@ async function rerank(results, query, model, options) {
|
|
|
6005
6100
|
);
|
|
6006
6101
|
return scoredResults.sort((a, b) => b.score - a.score).slice(0, topK);
|
|
6007
6102
|
}
|
|
6103
|
+
async function rerankWithScorer({
|
|
6104
|
+
results,
|
|
6105
|
+
query,
|
|
6106
|
+
scorer,
|
|
6107
|
+
options
|
|
6108
|
+
}) {
|
|
6109
|
+
return executeRerank({
|
|
6110
|
+
results,
|
|
6111
|
+
query,
|
|
6112
|
+
scorer,
|
|
6113
|
+
options
|
|
6114
|
+
});
|
|
6115
|
+
}
|
|
6116
|
+
async function rerank(results, query, model, options) {
|
|
6117
|
+
let semanticProvider;
|
|
6118
|
+
if (model.modelId === "rerank-v3.5") {
|
|
6119
|
+
semanticProvider = new CohereRelevanceScorer(model.modelId);
|
|
6120
|
+
} else {
|
|
6121
|
+
semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
|
|
6122
|
+
}
|
|
6123
|
+
return executeRerank({
|
|
6124
|
+
results,
|
|
6125
|
+
query,
|
|
6126
|
+
scorer: semanticProvider,
|
|
6127
|
+
options
|
|
6128
|
+
});
|
|
6129
|
+
}
|
|
6008
6130
|
|
|
6009
6131
|
// src/graph-rag/index.ts
|
|
6010
6132
|
var GraphRAG = class {
|
|
@@ -6603,10 +6725,23 @@ var createVectorQueryTool = (options) => {
|
|
|
6603
6725
|
if (logger) {
|
|
6604
6726
|
logger.debug("Reranking results", { rerankerModel: reranker.model, rerankerOptions: reranker.options });
|
|
6605
6727
|
}
|
|
6606
|
-
|
|
6607
|
-
|
|
6608
|
-
|
|
6609
|
-
|
|
6728
|
+
let rerankedResults = [];
|
|
6729
|
+
if (typeof reranker?.model === "object" && "getRelevanceScore" in reranker?.model) {
|
|
6730
|
+
rerankedResults = await rerankWithScorer({
|
|
6731
|
+
results,
|
|
6732
|
+
query: queryText,
|
|
6733
|
+
scorer: reranker.model,
|
|
6734
|
+
options: {
|
|
6735
|
+
...reranker.options,
|
|
6736
|
+
topK: reranker.options?.topK || topKValue
|
|
6737
|
+
}
|
|
6738
|
+
});
|
|
6739
|
+
} else {
|
|
6740
|
+
rerankedResults = await rerank(results, queryText, reranker.model, {
|
|
6741
|
+
...reranker.options,
|
|
6742
|
+
topK: reranker.options?.topK || topKValue
|
|
6743
|
+
});
|
|
6744
|
+
}
|
|
6610
6745
|
if (logger) {
|
|
6611
6746
|
logger.debug("Reranking complete", { rerankedCount: rerankedResults.length });
|
|
6612
6747
|
}
|
|
@@ -7444,4 +7579,4 @@ Example Complex Query:
|
|
|
7444
7579
|
}
|
|
7445
7580
|
`;
|
|
7446
7581
|
|
|
7447
|
-
export { ASTRA_PROMPT, CHROMA_PROMPT, GraphRAG, LIBSQL_PROMPT, MDocument, MONGODB_PROMPT, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, UPSTASH_PROMPT, VECTORIZE_PROMPT, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultGraphRagDescription, defaultVectorQueryDescription, filterDescription, queryTextDescription, rerank, topKDescription };
|
|
7582
|
+
export { ASTRA_PROMPT, CHROMA_PROMPT, CohereRelevanceScorer, GraphRAG, LIBSQL_PROMPT, MDocument, MONGODB_PROMPT, MastraAgentRelevanceScorer, PGVECTOR_PROMPT, PINECONE_PROMPT, QDRANT_PROMPT, UPSTASH_PROMPT, VECTORIZE_PROMPT, ZeroEntropyRelevanceScorer, createDocumentChunkerTool, createGraphRAGTool, createVectorQueryTool, defaultGraphRagDescription, defaultVectorQueryDescription, filterDescription, queryTextDescription, rerank, rerankWithScorer, topKDescription };
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@mastra/rag",
|
|
3
|
-
"version": "1.0.
|
|
3
|
+
"version": "1.0.2-alpha.0",
|
|
4
4
|
"description": "",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "dist/index.js",
|
|
@@ -23,9 +23,11 @@
|
|
|
23
23
|
"dependencies": {
|
|
24
24
|
"@paralleldrive/cuid2": "^2.2.2",
|
|
25
25
|
"big.js": "^7.0.1",
|
|
26
|
+
"cohere-ai": "^7.17.1",
|
|
26
27
|
"js-tiktoken": "^1.0.20",
|
|
27
28
|
"node-html-better-parser": "^1.4.11",
|
|
28
29
|
"pathe": "^2.0.3",
|
|
30
|
+
"zeroentropy": "0.1.0-alpha.6",
|
|
29
31
|
"zod": "^3.25.67"
|
|
30
32
|
},
|
|
31
33
|
"peerDependencies": {
|
|
@@ -39,13 +41,13 @@
|
|
|
39
41
|
"@types/big.js": "^6.2.2",
|
|
40
42
|
"@types/node": "^20.19.0",
|
|
41
43
|
"ai": "^4.3.16",
|
|
42
|
-
"dotenv": "^
|
|
44
|
+
"dotenv": "^17.0.0",
|
|
43
45
|
"eslint": "^9.29.0",
|
|
44
46
|
"tsup": "^8.5.0",
|
|
45
47
|
"typescript": "^5.8.3",
|
|
46
|
-
"vitest": "^3.2.
|
|
47
|
-
"@
|
|
48
|
-
"@
|
|
48
|
+
"vitest": "^3.2.4",
|
|
49
|
+
"@internal/lint": "0.0.17",
|
|
50
|
+
"@mastra/core": "0.10.11-alpha.3"
|
|
49
51
|
},
|
|
50
52
|
"keywords": [
|
|
51
53
|
"rag",
|
|
@@ -66,6 +68,7 @@
|
|
|
66
68
|
"scripts": {
|
|
67
69
|
"build": "tsup src/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting",
|
|
68
70
|
"buld:watch": "pnpm build --watch",
|
|
71
|
+
"vitest": "vitest",
|
|
69
72
|
"test": "vitest run",
|
|
70
73
|
"lint": "eslint ."
|
|
71
74
|
}
|
package/src/index.ts
CHANGED
package/src/rerank/index.test.ts
CHANGED
package/src/rerank/index.ts
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import type { MastraLanguageModel } from '@mastra/core/agent';
|
|
2
|
-
import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from '@mastra/core/relevance';
|
|
3
2
|
import type { RelevanceScoreProvider } from '@mastra/core/relevance';
|
|
4
3
|
import type { QueryResult } from '@mastra/core/vector';
|
|
5
4
|
import { Big } from 'big.js';
|
|
5
|
+
import { MastraAgentRelevanceScorer, CohereRelevanceScorer } from './relevance';
|
|
6
6
|
|
|
7
7
|
// Default weights for different scoring components (must add up to 1)
|
|
8
8
|
const DEFAULT_WEIGHTS = {
|
|
@@ -48,7 +48,7 @@ export interface RerankerFunctionOptions {
|
|
|
48
48
|
|
|
49
49
|
export interface RerankConfig {
|
|
50
50
|
options?: RerankerOptions;
|
|
51
|
-
model: MastraLanguageModel;
|
|
51
|
+
model: MastraLanguageModel | RelevanceScoreProvider;
|
|
52
52
|
}
|
|
53
53
|
|
|
54
54
|
// Calculate position score based on position in original list
|
|
@@ -83,19 +83,17 @@ function adjustScores(score: number, queryAnalysis: { magnitude: number; dominan
|
|
|
83
83
|
return score * magnitudeAdjustment * featureStrengthAdjustment;
|
|
84
84
|
}
|
|
85
85
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
|
|
98
|
-
}
|
|
86
|
+
async function executeRerank({
|
|
87
|
+
results,
|
|
88
|
+
query,
|
|
89
|
+
scorer,
|
|
90
|
+
options,
|
|
91
|
+
}: {
|
|
92
|
+
results: QueryResult[];
|
|
93
|
+
query: string;
|
|
94
|
+
scorer: RelevanceScoreProvider;
|
|
95
|
+
options: RerankerFunctionOptions;
|
|
96
|
+
}) {
|
|
99
97
|
const { queryEmbedding, topK = 3 } = options;
|
|
100
98
|
const weights = {
|
|
101
99
|
...DEFAULT_WEIGHTS,
|
|
@@ -118,7 +116,7 @@ export async function rerank(
|
|
|
118
116
|
// Get semantic score from chosen provider
|
|
119
117
|
let semanticScore = 0;
|
|
120
118
|
if (result?.metadata?.text) {
|
|
121
|
-
semanticScore = await
|
|
119
|
+
semanticScore = await scorer.getRelevanceScore(query, result?.metadata?.text);
|
|
122
120
|
}
|
|
123
121
|
|
|
124
122
|
// Get existing vector score from result
|
|
@@ -156,3 +154,45 @@ export async function rerank(
|
|
|
156
154
|
// Sort by score and take top K
|
|
157
155
|
return scoredResults.sort((a, b) => b.score - a.score).slice(0, topK);
|
|
158
156
|
}
|
|
157
|
+
|
|
158
|
+
export async function rerankWithScorer({
|
|
159
|
+
results,
|
|
160
|
+
query,
|
|
161
|
+
scorer,
|
|
162
|
+
options,
|
|
163
|
+
}: {
|
|
164
|
+
results: QueryResult[];
|
|
165
|
+
query: string;
|
|
166
|
+
scorer: RelevanceScoreProvider;
|
|
167
|
+
options: RerankerFunctionOptions;
|
|
168
|
+
}): Promise<RerankResult[]> {
|
|
169
|
+
return executeRerank({
|
|
170
|
+
results,
|
|
171
|
+
query,
|
|
172
|
+
scorer,
|
|
173
|
+
options,
|
|
174
|
+
});
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
// Takes in a list of results from a vector store and reranks them based on semantic, vector, and position scores
|
|
178
|
+
export async function rerank(
|
|
179
|
+
results: QueryResult[],
|
|
180
|
+
query: string,
|
|
181
|
+
model: MastraLanguageModel,
|
|
182
|
+
options: RerankerFunctionOptions,
|
|
183
|
+
): Promise<RerankResult[]> {
|
|
184
|
+
let semanticProvider: RelevanceScoreProvider;
|
|
185
|
+
|
|
186
|
+
if (model.modelId === 'rerank-v3.5') {
|
|
187
|
+
semanticProvider = new CohereRelevanceScorer(model.modelId);
|
|
188
|
+
} else {
|
|
189
|
+
semanticProvider = new MastraAgentRelevanceScorer(model.provider, model);
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
return executeRerank({
|
|
193
|
+
results,
|
|
194
|
+
query,
|
|
195
|
+
scorer: semanticProvider,
|
|
196
|
+
options,
|
|
197
|
+
});
|
|
198
|
+
}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import type { RelevanceScoreProvider } from '@mastra/core/relevance';
|
|
2
|
+
import { CohereClient } from 'cohere-ai';
|
|
3
|
+
|
|
4
|
+
// Cohere implementation
|
|
5
|
+
export class CohereRelevanceScorer implements RelevanceScoreProvider {
|
|
6
|
+
private client: any;
|
|
7
|
+
private model: string;
|
|
8
|
+
constructor(model: string, apiKey?: string) {
|
|
9
|
+
this.client = new CohereClient({
|
|
10
|
+
token: apiKey || process.env.COHERE_API_KEY || '',
|
|
11
|
+
});
|
|
12
|
+
|
|
13
|
+
this.model = model;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
async getRelevanceScore(query: string, text: string): Promise<number> {
|
|
17
|
+
const response = await this.client.rerank({
|
|
18
|
+
query,
|
|
19
|
+
documents: [text],
|
|
20
|
+
model: this.model,
|
|
21
|
+
topN: 1,
|
|
22
|
+
});
|
|
23
|
+
|
|
24
|
+
return response.results[0].relevanceScore;
|
|
25
|
+
}
|
|
26
|
+
}
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import { Agent } from '@mastra/core/agent';
|
|
2
|
+
import type { MastraLanguageModel } from '@mastra/core/agent';
|
|
3
|
+
import { createSimilarityPrompt } from '@mastra/core/relevance';
|
|
4
|
+
import type { RelevanceScoreProvider } from '@mastra/core/relevance';
|
|
5
|
+
|
|
6
|
+
// Mastra Agent implementation
|
|
7
|
+
export class MastraAgentRelevanceScorer implements RelevanceScoreProvider {
|
|
8
|
+
private agent: Agent;
|
|
9
|
+
|
|
10
|
+
constructor(name: string, model: MastraLanguageModel) {
|
|
11
|
+
this.agent = new Agent({
|
|
12
|
+
name: `Relevance Scorer ${name}`,
|
|
13
|
+
instructions: `You are a specialized agent for evaluating the relevance of text to queries.
|
|
14
|
+
Your task is to rate how well a text passage answers a given query.
|
|
15
|
+
Output only a number between 0 and 1, where:
|
|
16
|
+
1.0 = Perfectly relevant, directly answers the query
|
|
17
|
+
0.0 = Completely irrelevant
|
|
18
|
+
Consider:
|
|
19
|
+
- Direct relevance to the question
|
|
20
|
+
- Completeness of information
|
|
21
|
+
- Quality and specificity
|
|
22
|
+
Always return just the number, no explanation.`,
|
|
23
|
+
model,
|
|
24
|
+
});
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
async getRelevanceScore(query: string, text: string): Promise<number> {
|
|
28
|
+
const prompt = createSimilarityPrompt(query, text);
|
|
29
|
+
const response = await this.agent.generate(prompt);
|
|
30
|
+
return parseFloat(response.text);
|
|
31
|
+
}
|
|
32
|
+
}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import type { RelevanceScoreProvider } from '@mastra/core/relevance';
|
|
2
|
+
import ZeroEntropy from 'zeroentropy';
|
|
3
|
+
|
|
4
|
+
// ZeroEntropy implementation
|
|
5
|
+
export class ZeroEntropyRelevanceScorer implements RelevanceScoreProvider {
|
|
6
|
+
private client: ZeroEntropy;
|
|
7
|
+
private model: string;
|
|
8
|
+
|
|
9
|
+
constructor(model?: string, apiKey?: string) {
|
|
10
|
+
this.client = new ZeroEntropy({
|
|
11
|
+
apiKey: apiKey || process.env.ZEROENTROPY_API_KEY || '',
|
|
12
|
+
});
|
|
13
|
+
this.model = model || 'zerank-1';
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
async getRelevanceScore(query: string, text: string): Promise<number> {
|
|
17
|
+
const response = await this.client.models.rerank({
|
|
18
|
+
query,
|
|
19
|
+
documents: [text],
|
|
20
|
+
model: this.model,
|
|
21
|
+
top_n: 1,
|
|
22
|
+
});
|
|
23
|
+
|
|
24
|
+
return response.results[0]?.relevance_score ?? 0;
|
|
25
|
+
}
|
|
26
|
+
}
|
|
@@ -2,8 +2,8 @@ import { createTool } from '@mastra/core/tools';
|
|
|
2
2
|
import type { EmbeddingModel } from 'ai';
|
|
3
3
|
import { z } from 'zod';
|
|
4
4
|
|
|
5
|
-
import { rerank } from '../rerank';
|
|
6
|
-
import type { RerankConfig } from '../rerank';
|
|
5
|
+
import { rerank, rerankWithScorer } from '../rerank';
|
|
6
|
+
import type { RerankConfig, RerankResult } from '../rerank';
|
|
7
7
|
import { vectorQuerySearch, defaultVectorQueryDescription, filterSchema, outputSchema, baseSchema } from '../utils';
|
|
8
8
|
import type { RagTool } from '../utils';
|
|
9
9
|
import { convertToSources } from '../utils/convert-sources';
|
|
@@ -94,26 +94,48 @@ export const createVectorQueryTool = (options: VectorQueryToolOptions) => {
|
|
|
94
94
|
if (logger) {
|
|
95
95
|
logger.debug('vectorQuerySearch returned results', { count: results.length });
|
|
96
96
|
}
|
|
97
|
+
|
|
97
98
|
if (reranker) {
|
|
98
99
|
if (logger) {
|
|
99
100
|
logger.debug('Reranking results', { rerankerModel: reranker.model, rerankerOptions: reranker.options });
|
|
100
101
|
}
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
102
|
+
|
|
103
|
+
let rerankedResults: RerankResult[] = [];
|
|
104
|
+
|
|
105
|
+
if (typeof reranker?.model === 'object' && 'getRelevanceScore' in reranker?.model) {
|
|
106
|
+
rerankedResults = await rerankWithScorer({
|
|
107
|
+
results,
|
|
108
|
+
query: queryText,
|
|
109
|
+
scorer: reranker.model,
|
|
110
|
+
options: {
|
|
111
|
+
...reranker.options,
|
|
112
|
+
topK: reranker.options?.topK || topKValue,
|
|
113
|
+
},
|
|
114
|
+
});
|
|
115
|
+
} else {
|
|
116
|
+
rerankedResults = await rerank(results, queryText, reranker.model, {
|
|
117
|
+
...reranker.options,
|
|
118
|
+
topK: reranker.options?.topK || topKValue,
|
|
119
|
+
});
|
|
120
|
+
}
|
|
121
|
+
|
|
105
122
|
if (logger) {
|
|
106
123
|
logger.debug('Reranking complete', { rerankedCount: rerankedResults.length });
|
|
107
124
|
}
|
|
125
|
+
|
|
108
126
|
const relevantChunks = rerankedResults.map(({ result }) => result?.metadata);
|
|
127
|
+
|
|
109
128
|
if (logger) {
|
|
110
129
|
logger.debug('Returning reranked relevant context chunks', { count: relevantChunks.length });
|
|
111
130
|
}
|
|
131
|
+
|
|
112
132
|
const sources = includeSources ? convertToSources(rerankedResults) : [];
|
|
133
|
+
|
|
113
134
|
return { relevantContext: relevantChunks, sources };
|
|
114
135
|
}
|
|
115
136
|
|
|
116
137
|
const relevantChunks = results.map(result => result?.metadata);
|
|
138
|
+
|
|
117
139
|
if (logger) {
|
|
118
140
|
logger.debug('Returning relevant context chunks', { count: relevantChunks.length });
|
|
119
141
|
}
|