@use-lattice/litmus 0.121.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +19 -0
- package/dist/src/accounts-Bt1oJb1Z.cjs +219 -0
- package/dist/src/accounts-DjOU8Rm3.js +178 -0
- package/dist/src/agentic-utils-D03IiXQc.js +153 -0
- package/dist/src/agentic-utils-Dh7xaMQM.cjs +180 -0
- package/dist/src/agents-C6BIMlZa.js +231 -0
- package/dist/src/agents-DvIpNX1L.cjs +666 -0
- package/dist/src/agents-ZP0RP9vV.cjs +231 -0
- package/dist/src/agents-maJXdjbR.js +665 -0
- package/dist/src/aimlapi-BTbQjG2E.cjs +30 -0
- package/dist/src/aimlapi-CwMxqfXP.js +30 -0
- package/dist/src/audio-BBUdvsde.cjs +97 -0
- package/dist/src/audio-D5DPZ7I-.js +97 -0
- package/dist/src/base-BEysXrkq.cjs +222 -0
- package/dist/src/base-C451JQfq.js +193 -0
- package/dist/src/blobs-BY8MDmpo.js +230 -0
- package/dist/src/blobs-BgcNn97m.cjs +256 -0
- package/dist/src/cache-BBE_lsTA.cjs +4 -0
- package/dist/src/cache-BkrqU5Ba.js +237 -0
- package/dist/src/cache-DsCxFlsZ.cjs +297 -0
- package/dist/src/chat-CPJWDP6a.cjs +289 -0
- package/dist/src/chat-CXX3xzkk.cjs +811 -0
- package/dist/src/chat-CcDgZFJ4.js +787 -0
- package/dist/src/chat-Dz5ZeGO2.js +289 -0
- package/dist/src/chatkit-Dw0mKkML.cjs +1158 -0
- package/dist/src/chatkit-swAIVuea.js +1157 -0
- package/dist/src/chunk-DEq-mXcV.js +15 -0
- package/dist/src/claude-agent-sdk-BXZJtOg6.js +379 -0
- package/dist/src/claude-agent-sdk-CkfyjDoG.cjs +383 -0
- package/dist/src/cloudflare-ai-BzpJcqUH.js +161 -0
- package/dist/src/cloudflare-ai-Cmy_R1y2.cjs +161 -0
- package/dist/src/cloudflare-gateway-B9tVQKok.cjs +272 -0
- package/dist/src/cloudflare-gateway-DrD3ew3H.js +272 -0
- package/dist/src/codex-sdk-Dezj9Nwm.js +1056 -0
- package/dist/src/codex-sdk-Dl9D4k5B.cjs +1060 -0
- package/dist/src/cometapi-C-9YvCHC.js +54 -0
- package/dist/src/cometapi-DHgDKoO2.cjs +54 -0
- package/dist/src/completion-B8Ctyxpr.js +120 -0
- package/dist/src/completion-Cxrt08sj.cjs +131 -0
- package/dist/src/createHash-BwgE13yv.cjs +27 -0
- package/dist/src/createHash-DmPQkvBh.js +15 -0
- package/dist/src/docker-BiqcTwLv.js +80 -0
- package/dist/src/docker-C7tEJnP-.cjs +80 -0
- package/dist/src/esm-C62Zofr1.cjs +409 -0
- package/dist/src/esm-DMVc93eh.js +379 -0
- package/dist/src/evalResult-C3NJPQOo.cjs +301 -0
- package/dist/src/evalResult-C7JJAPBb.js +295 -0
- package/dist/src/evalResult-DoVTZZWI.cjs +2 -0
- package/dist/src/extractor-DnMD3fwt.cjs +391 -0
- package/dist/src/extractor-DtlL28vL.js +374 -0
- package/dist/src/fetch-BTxakTSg.cjs +1133 -0
- package/dist/src/fetch-DQckpUFz.js +928 -0
- package/dist/src/fileExtensions-DnqA1y9x.js +85 -0
- package/dist/src/fileExtensions-bYh77CN8.cjs +114 -0
- package/dist/src/genaiTracer-CyZrmaK0.cjs +268 -0
- package/dist/src/genaiTracer-D3fD9dNV.js +256 -0
- package/dist/src/graders-BNscxFrU.js +13644 -0
- package/dist/src/graders-D2oE9Msq.js +2 -0
- package/dist/src/graders-c0Ez_w-9.cjs +2 -0
- package/dist/src/graders-d0F2M3e9.cjs +14056 -0
- package/dist/src/image-0ZhE0VlR.cjs +280 -0
- package/dist/src/image-CWE1pdNv.js +257 -0
- package/dist/src/image-D9ZK6hwL.js +163 -0
- package/dist/src/image-DKZgZITg.cjs +163 -0
- package/dist/src/index.cjs +11366 -0
- package/dist/src/index.d.cts +19640 -0
- package/dist/src/index.d.ts +19641 -0
- package/dist/src/index.js +11306 -0
- package/dist/src/invariant-Ddh24eXh.js +25 -0
- package/dist/src/invariant-kfQ8Bu82.cjs +30 -0
- package/dist/src/knowledgeBase-BgPyGFUd.cjs +122 -0
- package/dist/src/knowledgeBase-DyHilYaP.js +122 -0
- package/dist/src/litellm-CyMeneHS.js +135 -0
- package/dist/src/litellm-DWDF73yF.cjs +135 -0
- package/dist/src/logger-C40ZGil9.js +717 -0
- package/dist/src/logger-DyfK9PBt.cjs +917 -0
- package/dist/src/luma-ray-BAU9X_ep.cjs +315 -0
- package/dist/src/luma-ray-nwVseBbv.js +313 -0
- package/dist/src/messages-B5ADWTTv.js +245 -0
- package/dist/src/messages-BCnZfqrS.cjs +257 -0
- package/dist/src/meteor-DLZZ3osF.cjs +134 -0
- package/dist/src/meteor-DUiCJRC-.js +134 -0
- package/dist/src/modelslab-00cveB8L.cjs +163 -0
- package/dist/src/modelslab-D9sCU_L7.js +163 -0
- package/dist/src/nova-reel-CTapvqYH.js +276 -0
- package/dist/src/nova-reel-DlWuuroF.cjs +278 -0
- package/dist/src/nova-sonic-5UPWfeMv.cjs +363 -0
- package/dist/src/nova-sonic-BhSwQNym.js +363 -0
- package/dist/src/openai-BWrJK9d8.cjs +52 -0
- package/dist/src/openai-DumO8WQn.js +47 -0
- package/dist/src/openclaw-B8brrjC_.cjs +577 -0
- package/dist/src/openclaw-Bkayww9q.js +571 -0
- package/dist/src/opencode-sdk-7xjoDNiM.cjs +562 -0
- package/dist/src/opencode-sdk-SGwAPxht.js +558 -0
- package/dist/src/otlpReceiver-CoAHfAN9.cjs +15 -0
- package/dist/src/otlpReceiver-oO3EQwI9.js +14 -0
- package/dist/src/providerRegistry-4yjhaEM8.js +45 -0
- package/dist/src/providerRegistry-DhV4rJIc.cjs +50 -0
- package/dist/src/providers-B5RJVG-7.cjs +33609 -0
- package/dist/src/providers-BdmZCLzV.js +33262 -0
- package/dist/src/providers-CxtRxn8e.js +2 -0
- package/dist/src/providers-DnQLNbx1.cjs +3 -0
- package/dist/src/pythonUtils-BD0druiM.cjs +275 -0
- package/dist/src/pythonUtils-IBhn5YGR.js +249 -0
- package/dist/src/quiverai-BDOwZBsM.cjs +213 -0
- package/dist/src/quiverai-D3JTF5lD.js +213 -0
- package/dist/src/responses-B2LCDCXZ.js +667 -0
- package/dist/src/responses-BvNm4Xv9.cjs +685 -0
- package/dist/src/rubyUtils-B0NwnfpY.cjs +245 -0
- package/dist/src/rubyUtils-BroxzZ7c.cjs +2 -0
- package/dist/src/rubyUtils-hqVw5UvJ.js +222 -0
- package/dist/src/sagemaker-Cno2V-Sx.js +689 -0
- package/dist/src/sagemaker-fV_KUgs5.cjs +691 -0
- package/dist/src/server-BOuAXb06.cjs +238 -0
- package/dist/src/server-CtI-EWzm.cjs +2 -0
- package/dist/src/server-Cy3DZymt.js +189 -0
- package/dist/src/slack-CP8xBePa.js +135 -0
- package/dist/src/slack-DSQ1yXVb.cjs +135 -0
- package/dist/src/store-BwDDaBjb.cjs +246 -0
- package/dist/src/store-DcbLC593.cjs +2 -0
- package/dist/src/store-IGpqMIkv.js +240 -0
- package/dist/src/tables-3Q2cL7So.cjs +373 -0
- package/dist/src/tables-Bi2fjr4W.js +288 -0
- package/dist/src/telemetry-Bg2WqF79.js +161 -0
- package/dist/src/telemetry-D0x6u5kX.cjs +166 -0
- package/dist/src/telemetry-DXNimrI0.cjs +2 -0
- package/dist/src/text-B_UCRPp2.js +22 -0
- package/dist/src/text-CW1cyrwj.cjs +33 -0
- package/dist/src/tokenUsageUtils-NYT-WKS6.js +138 -0
- package/dist/src/tokenUsageUtils-bVa1ga6f.cjs +173 -0
- package/dist/src/transcription-Cl_W16Pr.js +122 -0
- package/dist/src/transcription-yt1EecY8.cjs +124 -0
- package/dist/src/transform-BCtGrl_W.cjs +228 -0
- package/dist/src/transform-Bv6gG2MJ.cjs +1688 -0
- package/dist/src/transform-CY1wbpRy.js +1507 -0
- package/dist/src/transform-DU8rUL9P.cjs +2 -0
- package/dist/src/transform-yWaShiKr.js +216 -0
- package/dist/src/transformersAvailability-BGkzavwb.js +35 -0
- package/dist/src/transformersAvailability-DKoRtQLy.cjs +35 -0
- package/dist/src/types-5aqHpBwE.cjs +3769 -0
- package/dist/src/types-Bn6D9c4U.js +3300 -0
- package/dist/src/util-BkKlTkI2.js +293 -0
- package/dist/src/util-CTh0bfOm.cjs +1119 -0
- package/dist/src/util-D17oBwo7.cjs +328 -0
- package/dist/src/util-DsS_-v4p.js +613 -0
- package/dist/src/util-DuntT1Ga.js +951 -0
- package/dist/src/util-aWjdCYMI.cjs +667 -0
- package/dist/src/utils-CisQwpjA.js +94 -0
- package/dist/src/utils-yWamDvmz.cjs +123 -0
- package/dist/tsconfig.tsbuildinfo +1 -0
- package/drizzle/0000_lush_hellion.sql +36 -0
- package/drizzle/0001_wide_calypso.sql +3 -0
- package/drizzle/0002_tidy_juggernaut.sql +1 -0
- package/drizzle/0003_lively_naoko.sql +8 -0
- package/drizzle/0004_minor_peter_quill.sql +19 -0
- package/drizzle/0005_silky_millenium_guard.sql +2 -0
- package/drizzle/0006_harsh_caretaker.sql +42 -0
- package/drizzle/0007_cloudy_wong.sql +1 -0
- package/drizzle/0008_broad_boomer.sql +2 -0
- package/drizzle/0009_strong_marten_broadcloak.sql +19 -0
- package/drizzle/0010_needy_bishop.sql +11 -0
- package/drizzle/0011_moaning_millenium_guard.sql +1 -0
- package/drizzle/0012_late_marten_broadcloak.sql +2 -0
- package/drizzle/0013_previous_dormammu.sql +9 -0
- package/drizzle/0014_lazy_captain_universe.sql +2 -0
- package/drizzle/0015_zippy_wallop.sql +29 -0
- package/drizzle/0016_jazzy_zemo.sql +2 -0
- package/drizzle/0017_reflective_praxagora.sql +4 -0
- package/drizzle/0018_fat_vanisher.sql +22 -0
- package/drizzle/0019_new_clint_barton.sql +8 -0
- package/drizzle/0020_skinny_maverick.sql +1 -0
- package/drizzle/0021_mysterious_madelyne_pryor.sql +13 -0
- package/drizzle/0022_sleepy_ultimo.sql +25 -0
- package/drizzle/0023_wooden_mandrill.sql +2 -0
- package/drizzle/AGENTS.md +68 -0
- package/drizzle/CLAUDE.md +1 -0
- package/drizzle/meta/0000_snapshot.json +221 -0
- package/drizzle/meta/0001_snapshot.json +214 -0
- package/drizzle/meta/0002_snapshot.json +221 -0
- package/drizzle/meta/0005_snapshot.json +369 -0
- package/drizzle/meta/0006_snapshot.json +638 -0
- package/drizzle/meta/0007_snapshot.json +640 -0
- package/drizzle/meta/0008_snapshot.json +649 -0
- package/drizzle/meta/0009_snapshot.json +554 -0
- package/drizzle/meta/0010_snapshot.json +619 -0
- package/drizzle/meta/0011_snapshot.json +627 -0
- package/drizzle/meta/0012_snapshot.json +639 -0
- package/drizzle/meta/0013_snapshot.json +717 -0
- package/drizzle/meta/0014_snapshot.json +717 -0
- package/drizzle/meta/0015_snapshot.json +897 -0
- package/drizzle/meta/0016_snapshot.json +1031 -0
- package/drizzle/meta/0018_snapshot.json +1210 -0
- package/drizzle/meta/0019_snapshot.json +1165 -0
- package/drizzle/meta/0020_snapshot.json +1232 -0
- package/drizzle/meta/0021_snapshot.json +1311 -0
- package/drizzle/meta/0022_snapshot.json +1481 -0
- package/drizzle/meta/0023_snapshot.json +1496 -0
- package/drizzle/meta/_journal.json +174 -0
- package/package.json +240 -0
|
@@ -0,0 +1,691 @@
|
|
|
1
|
+
const require_logger = require("./logger-DyfK9PBt.cjs");
|
|
2
|
+
const require_transform = require("./transform-BCtGrl_W.cjs");
|
|
3
|
+
const require_telemetry = require("./telemetry-D0x6u5kX.cjs");
|
|
4
|
+
let zod = require("zod");
|
|
5
|
+
let crypto = require("crypto");
|
|
6
|
+
crypto = require_logger.__toESM(crypto);
|
|
7
|
+
//#region src/providers/sagemaker.ts
|
|
8
|
+
/**
|
|
9
|
+
* Sleep utility function for implementing delays
|
|
10
|
+
* @param ms Milliseconds to sleep
|
|
11
|
+
* @returns Promise that resolves after the specified delay
|
|
12
|
+
*/
|
|
13
|
+
const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms));
|
|
14
|
+
const SUPPORTED_MODEL_TYPES = [
|
|
15
|
+
"openai",
|
|
16
|
+
"llama",
|
|
17
|
+
"huggingface",
|
|
18
|
+
"jumpstart",
|
|
19
|
+
"custom"
|
|
20
|
+
];
|
|
21
|
+
/**
|
|
22
|
+
* Zod schema for validating SageMaker options
|
|
23
|
+
*/
|
|
24
|
+
const SageMakerConfigSchema = zod.z.strictObject({
|
|
25
|
+
accessKeyId: zod.z.string().optional(),
|
|
26
|
+
profile: zod.z.string().optional(),
|
|
27
|
+
region: zod.z.string().optional(),
|
|
28
|
+
secretAccessKey: zod.z.string().optional(),
|
|
29
|
+
sessionToken: zod.z.string().optional(),
|
|
30
|
+
endpoint: zod.z.string().optional(),
|
|
31
|
+
contentType: zod.z.string().optional(),
|
|
32
|
+
acceptType: zod.z.string().optional(),
|
|
33
|
+
maxTokens: zod.z.number().optional(),
|
|
34
|
+
temperature: zod.z.number().optional(),
|
|
35
|
+
topP: zod.z.number().optional(),
|
|
36
|
+
stopSequences: zod.z.array(zod.z.string()).optional(),
|
|
37
|
+
delay: zod.z.number().optional(),
|
|
38
|
+
transform: zod.z.string().optional(),
|
|
39
|
+
modelType: zod.z.enum(SUPPORTED_MODEL_TYPES).optional(),
|
|
40
|
+
responseFormat: zod.z.strictObject({
|
|
41
|
+
type: zod.z.string().optional(),
|
|
42
|
+
path: zod.z.string().optional()
|
|
43
|
+
}).optional(),
|
|
44
|
+
basePath: zod.z.string().optional()
|
|
45
|
+
});
|
|
46
|
+
/**
|
|
47
|
+
* Base class for SageMaker providers with common functionality
|
|
48
|
+
*/
|
|
49
|
+
var SageMakerGenericProvider = class {
|
|
50
|
+
env;
|
|
51
|
+
sagemakerRuntime;
|
|
52
|
+
config;
|
|
53
|
+
endpointName;
|
|
54
|
+
delay;
|
|
55
|
+
transform;
|
|
56
|
+
providerId;
|
|
57
|
+
constructor(endpointName, options) {
|
|
58
|
+
const { config, id, env, delay, transform } = options;
|
|
59
|
+
this.env = env;
|
|
60
|
+
this.endpointName = endpointName;
|
|
61
|
+
try {
|
|
62
|
+
SageMakerConfigSchema.parse(config);
|
|
63
|
+
} catch (error) {
|
|
64
|
+
require_logger.logger.warn(`Error validating SageMaker config\nConfig: ${JSON.stringify(config)}\n${error instanceof zod.z.ZodError ? zod.z.prettifyError(error) : error}`);
|
|
65
|
+
}
|
|
66
|
+
this.config = config ?? {};
|
|
67
|
+
this.delay = delay || this.config.delay;
|
|
68
|
+
this.transform = transform || this.config.transform;
|
|
69
|
+
this.providerId = id;
|
|
70
|
+
require_telemetry.telemetry.record("feature_used", { feature: "sagemaker" });
|
|
71
|
+
}
|
|
72
|
+
id() {
|
|
73
|
+
return this.providerId || `sagemaker:${this.endpointName}`;
|
|
74
|
+
}
|
|
75
|
+
toString() {
|
|
76
|
+
return `[Amazon SageMaker Provider ${this.endpointName}]`;
|
|
77
|
+
}
|
|
78
|
+
/**
|
|
79
|
+
* Get AWS credentials from config or environment
|
|
80
|
+
*/
|
|
81
|
+
async getCredentials() {
|
|
82
|
+
if (this.config.accessKeyId && this.config.secretAccessKey) {
|
|
83
|
+
require_logger.logger.debug("Using explicit credentials from config");
|
|
84
|
+
return {
|
|
85
|
+
accessKeyId: this.config.accessKeyId,
|
|
86
|
+
secretAccessKey: this.config.secretAccessKey,
|
|
87
|
+
sessionToken: this.config.sessionToken
|
|
88
|
+
};
|
|
89
|
+
}
|
|
90
|
+
if (this.config.profile) {
|
|
91
|
+
require_logger.logger.debug(`Using AWS profile: ${this.config.profile}`);
|
|
92
|
+
try {
|
|
93
|
+
const { fromSSO } = await import("@aws-sdk/credential-provider-sso");
|
|
94
|
+
return fromSSO({ profile: this.config.profile });
|
|
95
|
+
} catch {
|
|
96
|
+
throw new Error(`Failed to load AWS SSO profile. Please install @aws-sdk/credential-provider-sso`);
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
require_logger.logger.debug("Using default AWS credentials from environment");
|
|
100
|
+
}
|
|
101
|
+
/**
|
|
102
|
+
* Initialize and return the SageMaker runtime client
|
|
103
|
+
*/
|
|
104
|
+
async getSageMakerRuntimeInstance() {
|
|
105
|
+
if (!this.sagemakerRuntime) try {
|
|
106
|
+
const { SageMakerRuntimeClient } = await import("@aws-sdk/client-sagemaker-runtime");
|
|
107
|
+
const credentials = await this.getCredentials();
|
|
108
|
+
this.sagemakerRuntime = new SageMakerRuntimeClient({
|
|
109
|
+
region: this.getRegion(),
|
|
110
|
+
maxAttempts: require_logger.getEnvInt("AWS_SAGEMAKER_MAX_RETRIES", 3),
|
|
111
|
+
retryMode: "adaptive",
|
|
112
|
+
...credentials ? { credentials } : {}
|
|
113
|
+
});
|
|
114
|
+
require_logger.logger.debug(`SageMaker client initialized for region ${this.getRegion()}`);
|
|
115
|
+
} catch {
|
|
116
|
+
throw new Error("The @aws-sdk/client-sagemaker-runtime package is required. Please install it with: npm install @aws-sdk/client-sagemaker-runtime");
|
|
117
|
+
}
|
|
118
|
+
return this.sagemakerRuntime;
|
|
119
|
+
}
|
|
120
|
+
/**
|
|
121
|
+
* Get AWS region from config or environment
|
|
122
|
+
*/
|
|
123
|
+
getRegion() {
|
|
124
|
+
return this.config?.region || this.env?.AWS_REGION || require_logger.getEnvString("AWS_REGION") || require_logger.getEnvString("AWS_DEFAULT_REGION") || "us-east-1";
|
|
125
|
+
}
|
|
126
|
+
/**
|
|
127
|
+
* Get SageMaker endpoint name
|
|
128
|
+
*/
|
|
129
|
+
getEndpointName() {
|
|
130
|
+
return this.config?.endpoint || this.endpointName;
|
|
131
|
+
}
|
|
132
|
+
/**
|
|
133
|
+
* Get content type for request
|
|
134
|
+
*/
|
|
135
|
+
getContentType() {
|
|
136
|
+
return this.config?.contentType || "application/json";
|
|
137
|
+
}
|
|
138
|
+
/**
|
|
139
|
+
* Get accept type for response
|
|
140
|
+
*/
|
|
141
|
+
getAcceptType() {
|
|
142
|
+
return this.config?.acceptType || "application/json";
|
|
143
|
+
}
|
|
144
|
+
/**
|
|
145
|
+
* Apply transformation to a prompt if a transform function is specified
|
|
146
|
+
* @param prompt The original prompt to transform
|
|
147
|
+
* @param context Optional context information for the transformation
|
|
148
|
+
* @returns The transformed prompt, or the original if no transformation is applied
|
|
149
|
+
*/
|
|
150
|
+
async applyTransformation(prompt, context) {
|
|
151
|
+
if (!this.transform) return prompt;
|
|
152
|
+
try {
|
|
153
|
+
const transformContext = {
|
|
154
|
+
vars: context?.vars || {},
|
|
155
|
+
prompt: context?.prompt || { raw: prompt },
|
|
156
|
+
uuid: `sagemaker-${this.endpointName}-${Date.now()}`
|
|
157
|
+
};
|
|
158
|
+
const transformFn = this.transform || context?.originalProvider?.transform;
|
|
159
|
+
if (!transformFn) return prompt;
|
|
160
|
+
require_logger.logger.debug(`Applying transform to prompt for SageMaker endpoint ${this.getEndpointName()}`);
|
|
161
|
+
if (typeof transformFn === "string" && !transformFn.startsWith("file://")) try {
|
|
162
|
+
if (transformFn.includes("=>")) {
|
|
163
|
+
const result = new Function("prompt", "context", `try { return (${transformFn})(prompt, context); } catch(e) { throw new Error("Transform function error: " + e.message); }`)(prompt, transformContext);
|
|
164
|
+
if (result === void 0 || result === null) {
|
|
165
|
+
require_logger.logger.debug("Transform function returned null or undefined, using original prompt");
|
|
166
|
+
return prompt;
|
|
167
|
+
}
|
|
168
|
+
if (typeof result === "string") return result;
|
|
169
|
+
else if (typeof result === "object") return JSON.stringify(result);
|
|
170
|
+
else return String(result);
|
|
171
|
+
} else {
|
|
172
|
+
const result = new Function("prompt", "context", `try { ${transformFn} } catch(e) { throw new Error("Transform function error: " + e.message); }`)(prompt, transformContext);
|
|
173
|
+
if (result === void 0 || result === null) {
|
|
174
|
+
require_logger.logger.debug("Transform function returned null or undefined, using original prompt");
|
|
175
|
+
return prompt;
|
|
176
|
+
}
|
|
177
|
+
if (typeof result === "string") return result;
|
|
178
|
+
else if (typeof result === "object") return JSON.stringify(result);
|
|
179
|
+
else return String(result);
|
|
180
|
+
}
|
|
181
|
+
} catch (transformError) {
|
|
182
|
+
require_logger.logger.error(`Error executing inline transform: ${transformError}`);
|
|
183
|
+
}
|
|
184
|
+
else try {
|
|
185
|
+
const { TransformInputType } = await Promise.resolve().then(() => require("./transform-DU8rUL9P.cjs"));
|
|
186
|
+
const transformed = await require_transform.transform(transformFn, prompt, transformContext, false, TransformInputType.OUTPUT);
|
|
187
|
+
if (transformed === void 0 || transformed === null) {
|
|
188
|
+
require_logger.logger.debug("Transform function returned null or undefined, using original prompt");
|
|
189
|
+
return prompt;
|
|
190
|
+
}
|
|
191
|
+
if (typeof transformed === "string") return transformed;
|
|
192
|
+
else if (typeof transformed === "object") return JSON.stringify(transformed);
|
|
193
|
+
else return String(transformed);
|
|
194
|
+
} catch (transformError) {
|
|
195
|
+
require_logger.logger.error(`Error using transform utility: ${transformError}`);
|
|
196
|
+
}
|
|
197
|
+
require_logger.logger.warn(`Transform did not produce a valid result, using original prompt`);
|
|
198
|
+
return prompt;
|
|
199
|
+
} catch (_) {
|
|
200
|
+
require_logger.logger.error(`Error applying transform to prompt: ${_}`);
|
|
201
|
+
return prompt;
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
/**
|
|
205
|
+
* Extracts data from a response using a path expression
|
|
206
|
+
* Supports JavaScript expressions and file-based transforms
|
|
207
|
+
*/
|
|
208
|
+
async extractFromPath(responseJson, pathExpression) {
|
|
209
|
+
if (!pathExpression) return responseJson;
|
|
210
|
+
try {
|
|
211
|
+
if (pathExpression.startsWith("file://")) try {
|
|
212
|
+
const { TransformInputType } = await Promise.resolve().then(() => require("./transform-DU8rUL9P.cjs"));
|
|
213
|
+
const transformedResult = await require_transform.transform(pathExpression, responseJson, { prompt: {} }, false, TransformInputType.OUTPUT);
|
|
214
|
+
return transformedResult !== void 0 && transformedResult !== null ? transformedResult : responseJson;
|
|
215
|
+
} catch (error) {
|
|
216
|
+
require_logger.logger.warn(`Failed to transform response using file: ${error}`);
|
|
217
|
+
return responseJson;
|
|
218
|
+
}
|
|
219
|
+
try {
|
|
220
|
+
const result = new Function("json", `try { return ${pathExpression}; } catch(e) { return undefined; }`)(responseJson);
|
|
221
|
+
if (result === void 0) {
|
|
222
|
+
require_logger.logger.warn(`Path expression "${pathExpression}" did not match any data in the response`);
|
|
223
|
+
require_logger.logger.debug(`Response JSON structure: ${JSON.stringify(responseJson).substring(0, 200)}...`);
|
|
224
|
+
return responseJson;
|
|
225
|
+
}
|
|
226
|
+
return result;
|
|
227
|
+
} catch (error) {
|
|
228
|
+
require_logger.logger.warn(`Failed to evaluate expression "${pathExpression}": ${error}`);
|
|
229
|
+
return responseJson;
|
|
230
|
+
}
|
|
231
|
+
} catch (error) {
|
|
232
|
+
require_logger.logger.warn(`Failed to extract data using path expression "${pathExpression}": ${error}`);
|
|
233
|
+
require_logger.logger.debug(`Response JSON structure: ${JSON.stringify(responseJson).substring(0, 200)}...`);
|
|
234
|
+
return responseJson;
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
};
|
|
238
|
+
/**
|
|
239
|
+
* Provider for text generation with SageMaker endpoints
|
|
240
|
+
*/
|
|
241
|
+
var SageMakerCompletionProvider = class extends SageMakerGenericProvider {
|
|
242
|
+
modelType;
|
|
243
|
+
constructor(endpointName, options) {
|
|
244
|
+
super(endpointName, options);
|
|
245
|
+
this.modelType = this.parseModelType(options.config?.modelType);
|
|
246
|
+
}
|
|
247
|
+
/**
|
|
248
|
+
* Model type must be specified within the id or the `config.modelType` field.
|
|
249
|
+
*/
|
|
250
|
+
parseModelType(modelType) {
|
|
251
|
+
const match = this.id().match(/^sagemaker:(?<modelType>.+):.+$/);
|
|
252
|
+
if (match) {
|
|
253
|
+
const modelTypeFromId = match.groups.modelType;
|
|
254
|
+
if (SUPPORTED_MODEL_TYPES.includes(modelTypeFromId)) return modelTypeFromId;
|
|
255
|
+
else throw new Error(`Invalid model type "${modelTypeFromId}" in provider ID. Valid types are: ${SUPPORTED_MODEL_TYPES.join(", ")}`);
|
|
256
|
+
}
|
|
257
|
+
if (modelType) if (SUPPORTED_MODEL_TYPES.includes(modelType)) return modelType;
|
|
258
|
+
else throw new Error(`Invalid model type "${modelType}" in \`config.modelType\`. Valid types are: ${SUPPORTED_MODEL_TYPES.join(", ")}`);
|
|
259
|
+
throw new Error("Model type must be set either in `config.modelType` or as part of the Provider ID, for example: \"sagemaker:<model_type>:<endpoint>\"");
|
|
260
|
+
}
|
|
261
|
+
/**
|
|
262
|
+
* Format the request payload based on model type
|
|
263
|
+
*/
|
|
264
|
+
formatPayload(prompt) {
|
|
265
|
+
const maxTokens = this.config.maxTokens ?? require_logger.getEnvInt("AWS_SAGEMAKER_MAX_TOKENS") ?? 1024;
|
|
266
|
+
const temperature = typeof this.config.temperature === "number" ? this.config.temperature : require_logger.getEnvFloat("AWS_SAGEMAKER_TEMPERATURE") ?? .7;
|
|
267
|
+
const topP = typeof this.config.topP === "number" ? this.config.topP : require_logger.getEnvFloat("AWS_SAGEMAKER_TOP_P") ?? 1;
|
|
268
|
+
const stopSequences = this.config.stopSequences || [];
|
|
269
|
+
let payload;
|
|
270
|
+
require_logger.logger.debug(`Formatting payload for model type: ${this.modelType}`);
|
|
271
|
+
switch (this.modelType) {
|
|
272
|
+
case "openai":
|
|
273
|
+
try {
|
|
274
|
+
const messages = JSON.parse(prompt);
|
|
275
|
+
if (Array.isArray(messages)) payload = {
|
|
276
|
+
messages,
|
|
277
|
+
max_tokens: maxTokens,
|
|
278
|
+
temperature,
|
|
279
|
+
top_p: topP,
|
|
280
|
+
stop: stopSequences.length > 0 ? stopSequences : void 0
|
|
281
|
+
};
|
|
282
|
+
else throw new Error("Not valid messages format");
|
|
283
|
+
} catch {
|
|
284
|
+
payload = {
|
|
285
|
+
prompt,
|
|
286
|
+
max_tokens: maxTokens,
|
|
287
|
+
temperature,
|
|
288
|
+
top_p: topP,
|
|
289
|
+
stop: stopSequences.length > 0 ? stopSequences : void 0
|
|
290
|
+
};
|
|
291
|
+
}
|
|
292
|
+
break;
|
|
293
|
+
case "llama":
|
|
294
|
+
try {
|
|
295
|
+
const messages = JSON.parse(prompt);
|
|
296
|
+
if (Array.isArray(messages)) payload = {
|
|
297
|
+
inputs: messages,
|
|
298
|
+
parameters: {
|
|
299
|
+
max_new_tokens: maxTokens,
|
|
300
|
+
temperature,
|
|
301
|
+
top_p: topP,
|
|
302
|
+
stop: stopSequences.length > 0 ? stopSequences : void 0
|
|
303
|
+
}
|
|
304
|
+
};
|
|
305
|
+
else throw new Error("Not valid messages format");
|
|
306
|
+
} catch {
|
|
307
|
+
payload = {
|
|
308
|
+
inputs: prompt,
|
|
309
|
+
parameters: {
|
|
310
|
+
max_new_tokens: maxTokens,
|
|
311
|
+
temperature,
|
|
312
|
+
top_p: topP,
|
|
313
|
+
stop: stopSequences.length > 0 ? stopSequences : void 0
|
|
314
|
+
}
|
|
315
|
+
};
|
|
316
|
+
}
|
|
317
|
+
break;
|
|
318
|
+
case "jumpstart":
|
|
319
|
+
payload = {
|
|
320
|
+
inputs: prompt,
|
|
321
|
+
parameters: {
|
|
322
|
+
max_new_tokens: maxTokens,
|
|
323
|
+
temperature,
|
|
324
|
+
top_p: topP,
|
|
325
|
+
do_sample: temperature > 0
|
|
326
|
+
}
|
|
327
|
+
};
|
|
328
|
+
break;
|
|
329
|
+
case "huggingface":
|
|
330
|
+
payload = {
|
|
331
|
+
inputs: prompt,
|
|
332
|
+
parameters: {
|
|
333
|
+
max_new_tokens: maxTokens,
|
|
334
|
+
temperature,
|
|
335
|
+
top_p: topP,
|
|
336
|
+
do_sample: temperature > 0,
|
|
337
|
+
return_full_text: false
|
|
338
|
+
}
|
|
339
|
+
};
|
|
340
|
+
break;
|
|
341
|
+
default:
|
|
342
|
+
try {
|
|
343
|
+
payload = JSON.parse(prompt);
|
|
344
|
+
} catch {
|
|
345
|
+
payload = { prompt };
|
|
346
|
+
}
|
|
347
|
+
break;
|
|
348
|
+
}
|
|
349
|
+
return JSON.stringify(payload);
|
|
350
|
+
}
|
|
351
|
+
/**
|
|
352
|
+
* Parse the response from SageMaker endpoint
|
|
353
|
+
*/
|
|
354
|
+
async parseResponse(responseBody) {
|
|
355
|
+
let responseJson;
|
|
356
|
+
require_logger.logger.debug(`Parsing response for model type: ${this.modelType}`);
|
|
357
|
+
try {
|
|
358
|
+
responseJson = JSON.parse(responseBody);
|
|
359
|
+
} catch {
|
|
360
|
+
require_logger.logger.debug("Response is not JSON, returning as-is");
|
|
361
|
+
return responseBody;
|
|
362
|
+
}
|
|
363
|
+
if (this.config.responseFormat?.path) try {
|
|
364
|
+
const pathExpression = this.config.responseFormat.path;
|
|
365
|
+
return await this.extractFromPath(responseJson, pathExpression);
|
|
366
|
+
} catch (error) {
|
|
367
|
+
require_logger.logger.warn(`Failed to extract from path: ${this.config.responseFormat.path}, Error: ${error}`);
|
|
368
|
+
require_logger.logger.debug(`Response JSON structure: ${JSON.stringify(responseJson).substring(0, 200)}...`);
|
|
369
|
+
return responseJson;
|
|
370
|
+
}
|
|
371
|
+
if (responseJson.generated_text) {
|
|
372
|
+
require_logger.logger.debug("Detected JumpStart model response format with generated_text field");
|
|
373
|
+
return responseJson.generated_text;
|
|
374
|
+
}
|
|
375
|
+
switch (this.modelType) {
|
|
376
|
+
case "openai": return responseJson.choices?.[0]?.message?.content || responseJson.choices?.[0]?.text || responseJson.generation || responseJson;
|
|
377
|
+
case "llama": return responseJson.generation || responseJson.choices?.[0]?.message?.content || responseJson.choices?.[0]?.text || responseJson;
|
|
378
|
+
case "huggingface": return Array.isArray(responseJson) ? responseJson[0]?.generated_text || responseJson[0] : responseJson.generated_text || responseJson;
|
|
379
|
+
case "jumpstart": return responseJson.generated_text || responseJson;
|
|
380
|
+
default: return responseJson.output || responseJson.generation || responseJson.response || responseJson.text || responseJson.generated_text || responseJson.choices?.[0]?.message?.content || responseJson.choices?.[0]?.text || responseJson;
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
/**
|
|
384
|
+
* Generate a consistent cache key for SageMaker requests
|
|
385
|
+
* Uses crypto.createHash to generate a shorter, more efficient key
|
|
386
|
+
*/
|
|
387
|
+
getCacheKey(prompt) {
|
|
388
|
+
const configForKey = {
|
|
389
|
+
endpoint: this.getEndpointName(),
|
|
390
|
+
modelType: this.config.modelType,
|
|
391
|
+
contentType: this.getContentType(),
|
|
392
|
+
acceptType: this.getAcceptType(),
|
|
393
|
+
maxTokens: this.config.maxTokens,
|
|
394
|
+
temperature: this.config.temperature,
|
|
395
|
+
topP: this.config.topP,
|
|
396
|
+
region: this.getRegion()
|
|
397
|
+
};
|
|
398
|
+
const configStr = JSON.stringify(configForKey);
|
|
399
|
+
const promptHash = crypto.default.createHash("sha256").update(prompt).digest("hex").substring(0, 16);
|
|
400
|
+
const configHash = crypto.default.createHash("sha256").update(configStr).digest("hex").substring(0, 8);
|
|
401
|
+
return `sagemaker:v1:${this.getEndpointName()}:${promptHash}:${configHash}`;
|
|
402
|
+
}
|
|
403
|
+
/**
|
|
404
|
+
* Invoke SageMaker endpoint for text generation with caching, delay support, and transformations
|
|
405
|
+
*/
|
|
406
|
+
async callApi(prompt, context, _options) {
|
|
407
|
+
const { isCacheEnabled, getCache } = await Promise.resolve().then(() => require("./cache-BBE_lsTA.cjs"));
|
|
408
|
+
const delayMs = context?.originalProvider?.delay || this.delay;
|
|
409
|
+
const transformedPrompt = await this.applyTransformation(prompt, context);
|
|
410
|
+
const isTransformed = transformedPrompt !== prompt;
|
|
411
|
+
if (isTransformed) {
|
|
412
|
+
require_logger.logger.debug(`Prompt transformed for SageMaker endpoint ${this.getEndpointName()}`);
|
|
413
|
+
require_logger.logger.debug(`Original: ${prompt.substring(0, 100)}${prompt.length > 100 ? "..." : ""}`);
|
|
414
|
+
require_logger.logger.debug(`Transformed: ${transformedPrompt.substring(0, 100)}${transformedPrompt.length > 100 ? "..." : ""}`);
|
|
415
|
+
}
|
|
416
|
+
const bustCache = context?.bustCache ?? context?.debug === true;
|
|
417
|
+
if (isCacheEnabled() && !bustCache) {
|
|
418
|
+
const cacheKey = this.getCacheKey(transformedPrompt);
|
|
419
|
+
const cachedResult = await (getCache ? getCache() : await Promise.resolve().then(() => require("./cache-BBE_lsTA.cjs")).then((m) => m.getCache())).get(cacheKey);
|
|
420
|
+
if (cachedResult) {
|
|
421
|
+
require_logger.logger.debug(`Using cached SageMaker response for ${this.getEndpointName()}`);
|
|
422
|
+
try {
|
|
423
|
+
const parsedResult = JSON.parse(cachedResult);
|
|
424
|
+
if (parsedResult.tokenUsage) parsedResult.tokenUsage.cached = parsedResult.tokenUsage.total || 0;
|
|
425
|
+
if (isTransformed && parsedResult.metadata) {
|
|
426
|
+
parsedResult.metadata.transformed = true;
|
|
427
|
+
parsedResult.metadata.originalPrompt = prompt;
|
|
428
|
+
}
|
|
429
|
+
return {
|
|
430
|
+
...parsedResult,
|
|
431
|
+
cached: true
|
|
432
|
+
};
|
|
433
|
+
} catch (_) {
|
|
434
|
+
require_logger.logger.warn(`Failed to parse cached SageMaker response: ${_}`);
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
if (delayMs && delayMs > 0) {
|
|
439
|
+
require_logger.logger.debug(`Applying delay of ${delayMs}ms before calling SageMaker endpoint ${this.getEndpointName()}`);
|
|
440
|
+
await sleep(delayMs);
|
|
441
|
+
}
|
|
442
|
+
const runtime = await this.getSageMakerRuntimeInstance();
|
|
443
|
+
const payload = this.formatPayload(transformedPrompt);
|
|
444
|
+
require_logger.logger.debug(`Calling SageMaker endpoint ${this.getEndpointName()}`);
|
|
445
|
+
require_logger.logger.debug(`With payload: ${payload.length > 1e3 ? payload.substring(0, 1e3) + "..." : payload}`);
|
|
446
|
+
try {
|
|
447
|
+
const { InvokeEndpointCommand } = await import("@aws-sdk/client-sagemaker-runtime");
|
|
448
|
+
const command = new InvokeEndpointCommand({
|
|
449
|
+
EndpointName: this.getEndpointName(),
|
|
450
|
+
ContentType: this.getContentType(),
|
|
451
|
+
Accept: this.getAcceptType(),
|
|
452
|
+
Body: payload
|
|
453
|
+
});
|
|
454
|
+
const startTime = Date.now();
|
|
455
|
+
const response = await runtime.send(command);
|
|
456
|
+
const _latency = Date.now() - startTime;
|
|
457
|
+
if (!response.Body) {
|
|
458
|
+
require_logger.logger.error("No response body returned from SageMaker endpoint");
|
|
459
|
+
return { error: "No response body returned from SageMaker endpoint" };
|
|
460
|
+
}
|
|
461
|
+
const responseBody = new TextDecoder().decode(response.Body);
|
|
462
|
+
require_logger.logger.debug(`SageMaker response (truncated): ${responseBody.length > 1e3 ? responseBody.substring(0, 1e3) + "..." : responseBody}`);
|
|
463
|
+
const output = await this.parseResponse(responseBody);
|
|
464
|
+
if (typeof output === "object" && output !== null && "code" in output) {
|
|
465
|
+
const code = output.code;
|
|
466
|
+
if (Number.isInteger(code) && code === 424) {
|
|
467
|
+
const errorMessage = `API Error: 424${output?.message ? ` ${output.message}` : ""}\n${JSON.stringify(output)}`;
|
|
468
|
+
require_logger.logger.error(errorMessage);
|
|
469
|
+
return { error: errorMessage };
|
|
470
|
+
}
|
|
471
|
+
}
|
|
472
|
+
const promptTokens = Math.ceil(payload.length / 4);
|
|
473
|
+
const completionTokens = Math.ceil((typeof output === "string" ? output.length : 0) / 4);
|
|
474
|
+
const result = {
|
|
475
|
+
output,
|
|
476
|
+
raw: responseBody,
|
|
477
|
+
tokenUsage: {
|
|
478
|
+
prompt: promptTokens,
|
|
479
|
+
completion: completionTokens,
|
|
480
|
+
total: promptTokens + completionTokens,
|
|
481
|
+
cached: 0,
|
|
482
|
+
numRequests: 1
|
|
483
|
+
},
|
|
484
|
+
metadata: {
|
|
485
|
+
latencyMs: _latency,
|
|
486
|
+
modelType: this.config.modelType || "custom",
|
|
487
|
+
transformed: isTransformed,
|
|
488
|
+
originalPrompt: isTransformed ? prompt : void 0
|
|
489
|
+
}
|
|
490
|
+
};
|
|
491
|
+
if (isCacheEnabled() && !bustCache && result.output && !result.error) {
|
|
492
|
+
const cacheKey = this.getCacheKey(transformedPrompt);
|
|
493
|
+
const cache = getCache ? getCache() : await Promise.resolve().then(() => require("./cache-BBE_lsTA.cjs")).then((m) => m.getCache());
|
|
494
|
+
const resultToCache = JSON.stringify(result);
|
|
495
|
+
try {
|
|
496
|
+
await cache.set(cacheKey, resultToCache);
|
|
497
|
+
require_logger.logger.debug(`Stored SageMaker response in cache with key: ${cacheKey.substring(0, 100)}...`);
|
|
498
|
+
} catch (_) {
|
|
499
|
+
require_logger.logger.warn(`Failed to store SageMaker response in cache: ${_}`);
|
|
500
|
+
}
|
|
501
|
+
}
|
|
502
|
+
return result;
|
|
503
|
+
} catch (error) {
|
|
504
|
+
require_logger.logger.error(`SageMaker API error: ${error}`);
|
|
505
|
+
return { error: `SageMaker API error: ${error.message || String(error)}` };
|
|
506
|
+
}
|
|
507
|
+
}
|
|
508
|
+
};
|
|
509
|
+
/**
|
|
510
|
+
* Provider for embeddings with SageMaker endpoints
|
|
511
|
+
*/
|
|
512
|
+
var SageMakerEmbeddingProvider = class extends SageMakerGenericProvider {
|
|
513
|
+
async callApi() {
|
|
514
|
+
throw new Error("callApi is not implemented for embedding provider. Use callEmbeddingApi instead.");
|
|
515
|
+
}
|
|
516
|
+
/**
|
|
517
|
+
* Generate a consistent cache key for SageMaker embedding requests
|
|
518
|
+
* Uses crypto.createHash to generate a shorter, more efficient key
|
|
519
|
+
*/
|
|
520
|
+
getCacheKey(text) {
|
|
521
|
+
const configForKey = {
|
|
522
|
+
endpoint: this.getEndpointName(),
|
|
523
|
+
modelType: this.config.modelType,
|
|
524
|
+
contentType: this.getContentType(),
|
|
525
|
+
acceptType: this.getAcceptType(),
|
|
526
|
+
region: this.getRegion(),
|
|
527
|
+
responseFormat: this.config.responseFormat
|
|
528
|
+
};
|
|
529
|
+
const configStr = JSON.stringify(configForKey);
|
|
530
|
+
const textHash = crypto.default.createHash("sha256").update(text).digest("hex").substring(0, 16);
|
|
531
|
+
const configHash = crypto.default.createHash("sha256").update(configStr).digest("hex").substring(0, 8);
|
|
532
|
+
return `sagemaker:embedding:v1:${this.getEndpointName()}:${textHash}:${configHash}`;
|
|
533
|
+
}
|
|
534
|
+
/**
|
|
535
|
+
* Invoke SageMaker endpoint for embeddings with caching, delay support, and transformations
|
|
536
|
+
*/
|
|
537
|
+
async callEmbeddingApi(text, context) {
|
|
538
|
+
const { isCacheEnabled, getCache } = await Promise.resolve().then(() => require("./cache-BBE_lsTA.cjs"));
|
|
539
|
+
const delayMs = context?.originalProvider?.delay || this.delay;
|
|
540
|
+
const transformedText = await this.applyTransformation(text, context);
|
|
541
|
+
const isTransformed = transformedText !== text;
|
|
542
|
+
if (isTransformed) {
|
|
543
|
+
require_logger.logger.debug(`Text transformed for SageMaker embedding endpoint ${this.getEndpointName()}`);
|
|
544
|
+
require_logger.logger.debug(`Original: ${text.substring(0, 100)}${text.length > 100 ? "..." : ""}`);
|
|
545
|
+
require_logger.logger.debug(`Transformed: ${transformedText.substring(0, 100)}${transformedText.length > 100 ? "..." : ""}`);
|
|
546
|
+
}
|
|
547
|
+
const bustCache = context?.debug === true;
|
|
548
|
+
if (isCacheEnabled() && !bustCache) {
|
|
549
|
+
const cacheKey = this.getCacheKey(transformedText);
|
|
550
|
+
const cachedResult = await (await getCache ? await getCache() : await Promise.resolve().then(() => require("./cache-BBE_lsTA.cjs")).then((m) => m.getCache())).get(cacheKey);
|
|
551
|
+
if (cachedResult) {
|
|
552
|
+
require_logger.logger.debug(`Using cached SageMaker embedding response for ${this.getEndpointName()}`);
|
|
553
|
+
try {
|
|
554
|
+
const parsedResult = JSON.parse(cachedResult);
|
|
555
|
+
if (parsedResult.tokenUsage) parsedResult.tokenUsage.cached = parsedResult.tokenUsage.prompt || 0;
|
|
556
|
+
return {
|
|
557
|
+
...parsedResult,
|
|
558
|
+
cached: true
|
|
559
|
+
};
|
|
560
|
+
} catch (_) {
|
|
561
|
+
require_logger.logger.warn(`Failed to parse cached SageMaker embedding response: ${_}`);
|
|
562
|
+
}
|
|
563
|
+
}
|
|
564
|
+
}
|
|
565
|
+
if (delayMs && delayMs > 0) {
|
|
566
|
+
require_logger.logger.debug(`Applying delay of ${delayMs}ms before calling SageMaker embedding endpoint ${this.getEndpointName()}`);
|
|
567
|
+
await sleep(delayMs);
|
|
568
|
+
}
|
|
569
|
+
const runtime = await this.getSageMakerRuntimeInstance();
|
|
570
|
+
let payload;
|
|
571
|
+
const modelType = this.config.modelType || "custom";
|
|
572
|
+
require_logger.logger.debug(`Formatting embedding payload for model type: ${modelType}`);
|
|
573
|
+
switch (modelType) {
|
|
574
|
+
case "openai":
|
|
575
|
+
payload = JSON.stringify({
|
|
576
|
+
input: transformedText,
|
|
577
|
+
model: "embedding"
|
|
578
|
+
});
|
|
579
|
+
break;
|
|
580
|
+
case "huggingface":
|
|
581
|
+
payload = JSON.stringify({ inputs: transformedText });
|
|
582
|
+
break;
|
|
583
|
+
default:
|
|
584
|
+
payload = JSON.stringify({
|
|
585
|
+
input: transformedText,
|
|
586
|
+
text: transformedText,
|
|
587
|
+
inputs: transformedText
|
|
588
|
+
});
|
|
589
|
+
break;
|
|
590
|
+
}
|
|
591
|
+
require_logger.logger.debug(`Calling SageMaker embedding endpoint ${this.getEndpointName()}`);
|
|
592
|
+
require_logger.logger.debug(`With payload: ${payload}`);
|
|
593
|
+
try {
|
|
594
|
+
const { InvokeEndpointCommand } = await import("@aws-sdk/client-sagemaker-runtime");
|
|
595
|
+
const command = new InvokeEndpointCommand({
|
|
596
|
+
EndpointName: this.getEndpointName(),
|
|
597
|
+
ContentType: this.getContentType(),
|
|
598
|
+
Accept: this.getAcceptType(),
|
|
599
|
+
Body: payload
|
|
600
|
+
});
|
|
601
|
+
const startTime = Date.now();
|
|
602
|
+
const response = await runtime.send(command);
|
|
603
|
+
Date.now() - startTime;
|
|
604
|
+
if (!response.Body) {
|
|
605
|
+
require_logger.logger.error("No response body returned from SageMaker embedding endpoint");
|
|
606
|
+
return { error: "No response body returned from SageMaker embedding endpoint" };
|
|
607
|
+
}
|
|
608
|
+
const responseBody = new TextDecoder().decode(response.Body);
|
|
609
|
+
require_logger.logger.debug(`SageMaker embedding response: ${responseBody.substring(0, 200)}...`);
|
|
610
|
+
let responseJson;
|
|
611
|
+
try {
|
|
612
|
+
responseJson = JSON.parse(responseBody);
|
|
613
|
+
} catch (_) {
|
|
614
|
+
return { error: `Failed to parse embedding response as JSON: ${_}` };
|
|
615
|
+
}
|
|
616
|
+
const embedding = responseJson.embedding || responseJson.embeddings || responseJson.data?.[0]?.embedding || (Array.isArray(responseJson) ? responseJson[0] : responseJson);
|
|
617
|
+
if (this.config.responseFormat?.path) try {
|
|
618
|
+
const pathExpression = this.config.responseFormat.path;
|
|
619
|
+
const extracted = await this.extractFromPath(responseJson, pathExpression);
|
|
620
|
+
if (Array.isArray(extracted) && extracted.every((val) => typeof val === "number")) {
|
|
621
|
+
const result = {
|
|
622
|
+
embedding: extracted,
|
|
623
|
+
tokenUsage: {
|
|
624
|
+
prompt: Math.ceil(text.length / 4),
|
|
625
|
+
cached: 0,
|
|
626
|
+
numRequests: 1
|
|
627
|
+
},
|
|
628
|
+
metadata: {
|
|
629
|
+
transformed: isTransformed,
|
|
630
|
+
originalText: isTransformed ? text : void 0
|
|
631
|
+
}
|
|
632
|
+
};
|
|
633
|
+
await this.cacheEmbeddingResult(result, transformedText, context, isTransformed, isTransformed ? text : void 0);
|
|
634
|
+
return result;
|
|
635
|
+
} else require_logger.logger.warn("Extracted data is not a valid embedding array, trying other extraction methods");
|
|
636
|
+
} catch (error) {
|
|
637
|
+
require_logger.logger.warn(`Failed to extract embedding from path expression: ${this.config.responseFormat.path}, Error: ${error}`);
|
|
638
|
+
require_logger.logger.debug(`Response JSON structure: ${JSON.stringify(responseJson).substring(0, 200)}...`);
|
|
639
|
+
}
|
|
640
|
+
if (!embedding || !Array.isArray(embedding)) return { error: `Invalid embedding response format. Could not find embedding array in: ${JSON.stringify(responseJson).substring(0, 100)}...` };
|
|
641
|
+
const result = {
|
|
642
|
+
embedding,
|
|
643
|
+
tokenUsage: {
|
|
644
|
+
prompt: Math.ceil(text.length / 4),
|
|
645
|
+
cached: 0,
|
|
646
|
+
numRequests: 1
|
|
647
|
+
},
|
|
648
|
+
metadata: {
|
|
649
|
+
transformed: isTransformed,
|
|
650
|
+
originalText: isTransformed ? text : void 0
|
|
651
|
+
}
|
|
652
|
+
};
|
|
653
|
+
await this.cacheEmbeddingResult(result, transformedText, context, isTransformed, isTransformed ? text : void 0);
|
|
654
|
+
return result;
|
|
655
|
+
} catch (error) {
|
|
656
|
+
require_logger.logger.error(`SageMaker embedding API error: ${error}`);
|
|
657
|
+
return { error: `SageMaker embedding API error: ${error.message || String(error)}` };
|
|
658
|
+
}
|
|
659
|
+
}
|
|
660
|
+
/**
|
|
661
|
+
* Helper method to cache embedding results
|
|
662
|
+
*/
|
|
663
|
+
async cacheEmbeddingResult(result, text, context, isTransformed = false, originalText) {
|
|
664
|
+
const { isCacheEnabled, getCache } = await Promise.resolve().then(() => require("./cache-BBE_lsTA.cjs"));
|
|
665
|
+
const bustCache = context?.debug === true;
|
|
666
|
+
if (isCacheEnabled() && !bustCache && result.embedding && !result.error) {
|
|
667
|
+
const cacheKey = this.getCacheKey(text);
|
|
668
|
+
const cache = await getCache ? await getCache() : await Promise.resolve().then(() => require("./cache-BBE_lsTA.cjs")).then((m) => m.getCache());
|
|
669
|
+
if (isTransformed && originalText && !result.metadata) result.metadata = {
|
|
670
|
+
transformed: true,
|
|
671
|
+
originalText
|
|
672
|
+
};
|
|
673
|
+
else if (isTransformed && originalText && result.metadata) {
|
|
674
|
+
result.metadata.transformed = true;
|
|
675
|
+
result.metadata.originalText = originalText;
|
|
676
|
+
}
|
|
677
|
+
const resultToCache = JSON.stringify(result);
|
|
678
|
+
try {
|
|
679
|
+
await cache.set(cacheKey, resultToCache);
|
|
680
|
+
require_logger.logger.debug(`Stored SageMaker embedding response in cache with key: ${cacheKey.substring(0, 100)}...`);
|
|
681
|
+
} catch (_) {
|
|
682
|
+
require_logger.logger.warn(`Failed to store SageMaker embedding response in cache: ${_}`);
|
|
683
|
+
}
|
|
684
|
+
}
|
|
685
|
+
}
|
|
686
|
+
};
|
|
687
|
+
//#endregion
|
|
688
|
+
exports.SageMakerCompletionProvider = SageMakerCompletionProvider;
|
|
689
|
+
exports.SageMakerEmbeddingProvider = SageMakerEmbeddingProvider;
|
|
690
|
+
|
|
691
|
+
//# sourceMappingURL=sagemaker-fV_KUgs5.cjs.map
|