@juspay/neurolink 7.7.1 → 7.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +25 -2
- package/README.md +34 -2
- package/dist/cli/commands/config.d.ts +3 -3
- package/dist/cli/commands/sagemaker.d.ts +11 -0
- package/dist/cli/commands/sagemaker.js +778 -0
- package/dist/cli/factories/commandFactory.js +7 -2
- package/dist/cli/index.js +3 -0
- package/dist/cli/utils/interactiveSetup.js +28 -0
- package/dist/core/baseProvider.d.ts +2 -2
- package/dist/core/types.d.ts +16 -4
- package/dist/core/types.js +24 -3
- package/dist/factories/providerFactory.js +10 -1
- package/dist/factories/providerRegistry.js +6 -1
- package/dist/lib/core/baseProvider.d.ts +2 -2
- package/dist/lib/core/types.d.ts +16 -4
- package/dist/lib/core/types.js +24 -3
- package/dist/lib/factories/providerFactory.js +10 -1
- package/dist/lib/factories/providerRegistry.js +6 -1
- package/dist/lib/neurolink.d.ts +15 -0
- package/dist/lib/neurolink.js +73 -1
- package/dist/lib/providers/amazonSagemaker.d.ts +67 -0
- package/dist/lib/providers/amazonSagemaker.js +149 -0
- package/dist/lib/providers/googleVertex.d.ts +4 -0
- package/dist/lib/providers/googleVertex.js +44 -3
- package/dist/lib/providers/index.d.ts +4 -0
- package/dist/lib/providers/index.js +4 -0
- package/dist/lib/providers/sagemaker/adaptive-semaphore.d.ts +86 -0
- package/dist/lib/providers/sagemaker/adaptive-semaphore.js +212 -0
- package/dist/lib/providers/sagemaker/client.d.ts +156 -0
- package/dist/lib/providers/sagemaker/client.js +462 -0
- package/dist/lib/providers/sagemaker/config.d.ts +73 -0
- package/dist/lib/providers/sagemaker/config.js +308 -0
- package/dist/lib/providers/sagemaker/detection.d.ts +176 -0
- package/dist/lib/providers/sagemaker/detection.js +596 -0
- package/dist/lib/providers/sagemaker/diagnostics.d.ts +37 -0
- package/dist/lib/providers/sagemaker/diagnostics.js +137 -0
- package/dist/lib/providers/sagemaker/error-constants.d.ts +78 -0
- package/dist/lib/providers/sagemaker/error-constants.js +227 -0
- package/dist/lib/providers/sagemaker/errors.d.ts +83 -0
- package/dist/lib/providers/sagemaker/errors.js +216 -0
- package/dist/lib/providers/sagemaker/index.d.ts +35 -0
- package/dist/lib/providers/sagemaker/index.js +67 -0
- package/dist/lib/providers/sagemaker/language-model.d.ts +182 -0
- package/dist/lib/providers/sagemaker/language-model.js +755 -0
- package/dist/lib/providers/sagemaker/parsers.d.ts +136 -0
- package/dist/lib/providers/sagemaker/parsers.js +625 -0
- package/dist/lib/providers/sagemaker/streaming.d.ts +39 -0
- package/dist/lib/providers/sagemaker/streaming.js +320 -0
- package/dist/lib/providers/sagemaker/structured-parser.d.ts +117 -0
- package/dist/lib/providers/sagemaker/structured-parser.js +625 -0
- package/dist/lib/providers/sagemaker/types.d.ts +456 -0
- package/dist/lib/providers/sagemaker/types.js +7 -0
- package/dist/lib/sdk/toolRegistration.d.ts +1 -1
- package/dist/lib/sdk/toolRegistration.js +13 -5
- package/dist/lib/types/cli.d.ts +36 -1
- package/dist/lib/utils/providerHealth.js +19 -4
- package/dist/neurolink.d.ts +15 -0
- package/dist/neurolink.js +73 -1
- package/dist/providers/amazonSagemaker.d.ts +67 -0
- package/dist/providers/amazonSagemaker.js +149 -0
- package/dist/providers/googleVertex.d.ts +4 -0
- package/dist/providers/googleVertex.js +44 -3
- package/dist/providers/index.d.ts +4 -0
- package/dist/providers/index.js +4 -0
- package/dist/providers/sagemaker/adaptive-semaphore.d.ts +86 -0
- package/dist/providers/sagemaker/adaptive-semaphore.js +212 -0
- package/dist/providers/sagemaker/client.d.ts +156 -0
- package/dist/providers/sagemaker/client.js +462 -0
- package/dist/providers/sagemaker/config.d.ts +73 -0
- package/dist/providers/sagemaker/config.js +308 -0
- package/dist/providers/sagemaker/detection.d.ts +176 -0
- package/dist/providers/sagemaker/detection.js +596 -0
- package/dist/providers/sagemaker/diagnostics.d.ts +37 -0
- package/dist/providers/sagemaker/diagnostics.js +137 -0
- package/dist/providers/sagemaker/error-constants.d.ts +78 -0
- package/dist/providers/sagemaker/error-constants.js +227 -0
- package/dist/providers/sagemaker/errors.d.ts +83 -0
- package/dist/providers/sagemaker/errors.js +216 -0
- package/dist/providers/sagemaker/index.d.ts +35 -0
- package/dist/providers/sagemaker/index.js +67 -0
- package/dist/providers/sagemaker/language-model.d.ts +182 -0
- package/dist/providers/sagemaker/language-model.js +755 -0
- package/dist/providers/sagemaker/parsers.d.ts +136 -0
- package/dist/providers/sagemaker/parsers.js +625 -0
- package/dist/providers/sagemaker/streaming.d.ts +39 -0
- package/dist/providers/sagemaker/streaming.js +320 -0
- package/dist/providers/sagemaker/structured-parser.d.ts +117 -0
- package/dist/providers/sagemaker/structured-parser.js +625 -0
- package/dist/providers/sagemaker/types.d.ts +456 -0
- package/dist/providers/sagemaker/types.js +7 -0
- package/dist/sdk/toolRegistration.d.ts +1 -1
- package/dist/sdk/toolRegistration.js +13 -5
- package/dist/types/cli.d.ts +36 -1
- package/dist/utils/providerHealth.js +19 -4
- package/package.json +8 -2
|
@@ -0,0 +1,755 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* SageMaker Language Model Implementation
|
|
3
|
+
*
|
|
4
|
+
* This module implements the LanguageModelV1 interface for Amazon SageMaker
|
|
5
|
+
* integration with the Vercel AI SDK.
|
|
6
|
+
*/
|
|
7
|
+
import { randomUUID } from "crypto";
|
|
8
|
+
import { SageMakerRuntimeClient } from "./client.js";
|
|
9
|
+
import { handleSageMakerError } from "./errors.js";
|
|
10
|
+
import { estimateTokenUsage, createSageMakerStream } from "./streaming.js";
|
|
11
|
+
import { AdaptiveSemaphore, createAdaptiveSemaphore, } from "./adaptive-semaphore.js";
|
|
12
|
+
import { logger } from "../../utils/logger.js";
|
|
13
|
+
/**
|
|
14
|
+
* Base synthetic streaming delay in milliseconds for simulating real-time response
|
|
15
|
+
* Can be configured via SAGEMAKER_BASE_STREAMING_DELAY_MS environment variable
|
|
16
|
+
*/
|
|
17
|
+
const BASE_SYNTHETIC_STREAMING_DELAY_MS = process.env
|
|
18
|
+
.SAGEMAKER_BASE_STREAMING_DELAY_MS
|
|
19
|
+
? parseInt(process.env.SAGEMAKER_BASE_STREAMING_DELAY_MS, 10)
|
|
20
|
+
: 50;
|
|
21
|
+
/**
|
|
22
|
+
* Maximum synthetic streaming delay in milliseconds to prevent excessively slow streaming
|
|
23
|
+
* Can be configured via SAGEMAKER_MAX_STREAMING_DELAY_MS environment variable
|
|
24
|
+
*/
|
|
25
|
+
const MAX_SYNTHETIC_STREAMING_DELAY_MS = process.env
|
|
26
|
+
.SAGEMAKER_MAX_STREAMING_DELAY_MS
|
|
27
|
+
? parseInt(process.env.SAGEMAKER_MAX_STREAMING_DELAY_MS, 10)
|
|
28
|
+
: 200;
|
|
29
|
+
/**
|
|
30
|
+
* Calculate adaptive delay based on text size to avoid slow streaming for large texts
|
|
31
|
+
* Smaller texts get longer delays for realistic feel, larger texts get shorter delays for performance
|
|
32
|
+
*/
|
|
33
|
+
function calculateAdaptiveDelay(textLength, chunkCount) {
|
|
34
|
+
// Base calculation: smaller delay for larger texts
|
|
35
|
+
const adaptiveDelay = Math.max(10, // Minimum 10ms delay
|
|
36
|
+
Math.min(MAX_SYNTHETIC_STREAMING_DELAY_MS, BASE_SYNTHETIC_STREAMING_DELAY_MS * (1000 / Math.max(textLength, 100))));
|
|
37
|
+
// Further reduce delay if there are many chunks to process
|
|
38
|
+
if (chunkCount > 20) {
|
|
39
|
+
return Math.max(10, adaptiveDelay * 0.5); // Half delay for many chunks
|
|
40
|
+
}
|
|
41
|
+
else if (chunkCount > 10) {
|
|
42
|
+
return Math.max(15, adaptiveDelay * 0.7); // Reduced delay for moderate chunks
|
|
43
|
+
}
|
|
44
|
+
return adaptiveDelay;
|
|
45
|
+
}
|
|
46
|
+
/**
|
|
47
|
+
* Create an async iterator for text chunks with adaptive delay between chunks
|
|
48
|
+
* Used for synthetic streaming simulation with performance optimization for large texts
|
|
49
|
+
*/
|
|
50
|
+
async function* createTextChunkIterator(text) {
|
|
51
|
+
if (!text) {
|
|
52
|
+
return; // No text to emit
|
|
53
|
+
}
|
|
54
|
+
const words = text.split(/\s+/);
|
|
55
|
+
const chunkSize = Math.max(1, Math.floor(words.length / 10));
|
|
56
|
+
const totalChunks = Math.ceil(words.length / chunkSize);
|
|
57
|
+
// Calculate adaptive delay based on text size and chunk count
|
|
58
|
+
const adaptiveDelay = calculateAdaptiveDelay(text.length, totalChunks);
|
|
59
|
+
for (let i = 0; i < words.length; i += chunkSize) {
|
|
60
|
+
const chunk = words.slice(i, i + chunkSize).join(" ");
|
|
61
|
+
const deltaText = i === 0 ? chunk : " " + chunk;
|
|
62
|
+
// Add adaptive delay between chunks for realistic streaming simulation
|
|
63
|
+
// Delay is shorter for larger texts to improve performance
|
|
64
|
+
if (i > 0) {
|
|
65
|
+
await new Promise((resolve) => setTimeout(resolve, adaptiveDelay));
|
|
66
|
+
}
|
|
67
|
+
yield deltaText;
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
/**
|
|
71
|
+
* Batch processing concurrency constants
|
|
72
|
+
*/
|
|
73
|
+
const DEFAULT_INITIAL_CONCURRENCY = 5;
|
|
74
|
+
const DEFAULT_MAX_CONCURRENCY = 10;
|
|
75
|
+
const DEFAULT_MIN_CONCURRENCY = 1;
|
|
76
|
+
/**
|
|
77
|
+
* SageMaker Language Model implementing LanguageModelV1 interface
|
|
78
|
+
*/
|
|
79
|
+
export class SageMakerLanguageModel {
|
|
80
|
+
specificationVersion = "v1";
|
|
81
|
+
provider = "sagemaker";
|
|
82
|
+
modelId;
|
|
83
|
+
supportsStreaming = true;
|
|
84
|
+
defaultObjectGenerationMode = "json";
|
|
85
|
+
client;
|
|
86
|
+
config;
|
|
87
|
+
modelConfig;
|
|
88
|
+
constructor(modelId, config, modelConfig) {
|
|
89
|
+
this.modelId = modelId;
|
|
90
|
+
this.config = config;
|
|
91
|
+
this.modelConfig = modelConfig;
|
|
92
|
+
this.client = new SageMakerRuntimeClient(config);
|
|
93
|
+
logger.debug("SageMaker Language Model initialized", {
|
|
94
|
+
modelId: this.modelId,
|
|
95
|
+
endpointName: this.modelConfig.endpointName,
|
|
96
|
+
provider: this.provider,
|
|
97
|
+
specificationVersion: this.specificationVersion,
|
|
98
|
+
});
|
|
99
|
+
}
|
|
100
|
+
/**
|
|
101
|
+
* Generate text synchronously using SageMaker endpoint
|
|
102
|
+
*/
|
|
103
|
+
async doGenerate(options) {
|
|
104
|
+
const startTime = Date.now();
|
|
105
|
+
try {
|
|
106
|
+
const promptText = this.extractPromptText(options);
|
|
107
|
+
logger.debug("SageMaker doGenerate called", {
|
|
108
|
+
endpointName: this.modelConfig.endpointName,
|
|
109
|
+
promptLength: promptText.length,
|
|
110
|
+
maxTokens: options.maxTokens,
|
|
111
|
+
temperature: options.temperature,
|
|
112
|
+
});
|
|
113
|
+
// Convert AI SDK options to SageMaker request format
|
|
114
|
+
const sagemakerRequest = this.convertToSageMakerRequest(options);
|
|
115
|
+
// Invoke SageMaker endpoint
|
|
116
|
+
const response = await this.client.invokeEndpoint({
|
|
117
|
+
EndpointName: this.modelConfig.endpointName,
|
|
118
|
+
Body: JSON.stringify(sagemakerRequest),
|
|
119
|
+
ContentType: "application/json",
|
|
120
|
+
Accept: "application/json",
|
|
121
|
+
});
|
|
122
|
+
// Parse SageMaker response
|
|
123
|
+
const responseBody = JSON.parse(new TextDecoder().decode(response.Body));
|
|
124
|
+
const generatedText = this.extractTextFromResponse(responseBody);
|
|
125
|
+
// Extract tool calls if present (Phase 4 enhancement)
|
|
126
|
+
const toolCalls = this.extractToolCallsFromResponse(responseBody);
|
|
127
|
+
// Calculate token usage
|
|
128
|
+
const usage = estimateTokenUsage(promptText, generatedText);
|
|
129
|
+
// Determine finish reason based on response content
|
|
130
|
+
let finishReason = "stop";
|
|
131
|
+
if (toolCalls && toolCalls.length > 0) {
|
|
132
|
+
finishReason = "tool-calls";
|
|
133
|
+
}
|
|
134
|
+
else if (responseBody.finish_reason) {
|
|
135
|
+
finishReason = this.mapSageMakerFinishReason(responseBody.finish_reason);
|
|
136
|
+
}
|
|
137
|
+
const duration = Date.now() - startTime;
|
|
138
|
+
logger.debug("SageMaker doGenerate completed", {
|
|
139
|
+
duration,
|
|
140
|
+
outputLength: generatedText.length,
|
|
141
|
+
usage,
|
|
142
|
+
toolCallsCount: toolCalls?.length || 0,
|
|
143
|
+
finishReason,
|
|
144
|
+
});
|
|
145
|
+
const result = {
|
|
146
|
+
text: generatedText,
|
|
147
|
+
usage: {
|
|
148
|
+
promptTokens: usage.promptTokens,
|
|
149
|
+
completionTokens: usage.completionTokens,
|
|
150
|
+
totalTokens: usage.totalTokens,
|
|
151
|
+
},
|
|
152
|
+
finishReason,
|
|
153
|
+
rawCall: {
|
|
154
|
+
rawPrompt: options.prompt,
|
|
155
|
+
rawSettings: {
|
|
156
|
+
maxTokens: options.maxTokens,
|
|
157
|
+
temperature: options.temperature,
|
|
158
|
+
topP: options.topP,
|
|
159
|
+
endpointName: this.modelConfig.endpointName,
|
|
160
|
+
},
|
|
161
|
+
},
|
|
162
|
+
rawResponse: {
|
|
163
|
+
headers: {
|
|
164
|
+
"content-type": response.ContentType || "application/json",
|
|
165
|
+
"invoked-variant": response.InvokedProductionVariant || "",
|
|
166
|
+
},
|
|
167
|
+
},
|
|
168
|
+
request: {
|
|
169
|
+
body: JSON.stringify(sagemakerRequest),
|
|
170
|
+
},
|
|
171
|
+
};
|
|
172
|
+
// Add tool calls to result if present
|
|
173
|
+
if (toolCalls && toolCalls.length > 0) {
|
|
174
|
+
result.toolCalls = toolCalls;
|
|
175
|
+
}
|
|
176
|
+
// Add structured data if response format was specified (Phase 4)
|
|
177
|
+
const responseFormat = sagemakerRequest
|
|
178
|
+
.response_format;
|
|
179
|
+
if (responseFormat &&
|
|
180
|
+
(responseFormat.type === "json_object" ||
|
|
181
|
+
responseFormat.type === "json_schema")) {
|
|
182
|
+
try {
|
|
183
|
+
const parsedData = JSON.parse(generatedText);
|
|
184
|
+
result.object = parsedData;
|
|
185
|
+
logger.debug("Extracted structured data from response", {
|
|
186
|
+
responseFormat: responseFormat.type,
|
|
187
|
+
hasObject: !!result.object,
|
|
188
|
+
});
|
|
189
|
+
}
|
|
190
|
+
catch (parseError) {
|
|
191
|
+
logger.warn("Failed to parse structured response as JSON", {
|
|
192
|
+
error: parseError instanceof Error
|
|
193
|
+
? parseError.message
|
|
194
|
+
: String(parseError),
|
|
195
|
+
responseText: generatedText.substring(0, 200),
|
|
196
|
+
});
|
|
197
|
+
// Keep the text response as fallback
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
return result;
|
|
201
|
+
}
|
|
202
|
+
catch (error) {
|
|
203
|
+
const duration = Date.now() - startTime;
|
|
204
|
+
logger.error("SageMaker doGenerate failed", {
|
|
205
|
+
duration,
|
|
206
|
+
error: error instanceof Error ? error.message : String(error),
|
|
207
|
+
});
|
|
208
|
+
throw handleSageMakerError(error, this.modelConfig.endpointName);
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
/**
|
|
212
|
+
* Generate text with streaming using SageMaker endpoint
|
|
213
|
+
*/
|
|
214
|
+
async doStream(options) {
|
|
215
|
+
try {
|
|
216
|
+
const promptText = this.extractPromptText(options);
|
|
217
|
+
logger.debug("SageMaker doStream called", {
|
|
218
|
+
endpointName: this.modelConfig.endpointName,
|
|
219
|
+
promptLength: promptText.length,
|
|
220
|
+
});
|
|
221
|
+
// Phase 2: Full streaming implementation with automatic detection
|
|
222
|
+
const sagemakerRequest = this.convertToSageMakerRequest(options);
|
|
223
|
+
// Add streaming parameter if model supports it
|
|
224
|
+
const requestWithStreaming = {
|
|
225
|
+
...sagemakerRequest,
|
|
226
|
+
parameters: {
|
|
227
|
+
...(typeof sagemakerRequest.parameters === "object" &&
|
|
228
|
+
sagemakerRequest.parameters !== null
|
|
229
|
+
? sagemakerRequest.parameters
|
|
230
|
+
: {}),
|
|
231
|
+
stream: true, // Will be validated by detection system
|
|
232
|
+
},
|
|
233
|
+
};
|
|
234
|
+
logger.debug("Attempting streaming generation", {
|
|
235
|
+
endpointName: this.modelConfig.endpointName,
|
|
236
|
+
hasStreamingFlag: true,
|
|
237
|
+
});
|
|
238
|
+
try {
|
|
239
|
+
// First, try to invoke with streaming
|
|
240
|
+
const response = await this.client.invokeEndpointWithStreaming({
|
|
241
|
+
EndpointName: this.modelConfig.endpointName,
|
|
242
|
+
Body: JSON.stringify(requestWithStreaming),
|
|
243
|
+
ContentType: this.modelConfig.contentType || "application/json",
|
|
244
|
+
Accept: this.modelConfig.accept || "application/json",
|
|
245
|
+
});
|
|
246
|
+
// Create intelligent streaming response
|
|
247
|
+
const stream = await createSageMakerStream(response.Body, this.modelConfig.endpointName, this.config, {
|
|
248
|
+
prompt: promptText,
|
|
249
|
+
onChunk: (chunk) => {
|
|
250
|
+
logger.debug("Streaming chunk received", {
|
|
251
|
+
contentLength: chunk.content?.length || 0,
|
|
252
|
+
done: chunk.done,
|
|
253
|
+
});
|
|
254
|
+
},
|
|
255
|
+
onComplete: (usage) => {
|
|
256
|
+
logger.debug("Streaming completed", {
|
|
257
|
+
usage,
|
|
258
|
+
endpointName: this.modelConfig.endpointName,
|
|
259
|
+
});
|
|
260
|
+
},
|
|
261
|
+
onError: (error) => {
|
|
262
|
+
logger.error("Streaming error", {
|
|
263
|
+
error: error.message,
|
|
264
|
+
endpointName: this.modelConfig.endpointName,
|
|
265
|
+
});
|
|
266
|
+
},
|
|
267
|
+
});
|
|
268
|
+
return {
|
|
269
|
+
stream: stream,
|
|
270
|
+
rawCall: {
|
|
271
|
+
rawPrompt: sagemakerRequest,
|
|
272
|
+
rawSettings: this.modelConfig,
|
|
273
|
+
},
|
|
274
|
+
rawResponse: {
|
|
275
|
+
headers: {
|
|
276
|
+
"Content-Type": response.ContentType || "application/json",
|
|
277
|
+
"X-Invoked-Production-Variant": response.InvokedProductionVariant || "unknown",
|
|
278
|
+
},
|
|
279
|
+
},
|
|
280
|
+
};
|
|
281
|
+
}
|
|
282
|
+
catch (streamingError) {
|
|
283
|
+
logger.warn("Streaming failed, falling back to non-streaming", {
|
|
284
|
+
endpointName: this.modelConfig.endpointName,
|
|
285
|
+
error: streamingError instanceof Error
|
|
286
|
+
? streamingError.message
|
|
287
|
+
: String(streamingError),
|
|
288
|
+
});
|
|
289
|
+
// Fallback: Generate normally and create synthetic stream
|
|
290
|
+
const result = await this.doGenerate(options);
|
|
291
|
+
// Create synthetic stream from complete result using async iterator pattern
|
|
292
|
+
const syntheticStream = new ReadableStream({
|
|
293
|
+
async start(controller) {
|
|
294
|
+
try {
|
|
295
|
+
// Create async iterator for text chunks
|
|
296
|
+
const textChunks = createTextChunkIterator(result.text);
|
|
297
|
+
// Process chunks with async iterator pattern
|
|
298
|
+
for await (const deltaText of textChunks) {
|
|
299
|
+
controller.enqueue({
|
|
300
|
+
type: "text-delta",
|
|
301
|
+
textDelta: deltaText,
|
|
302
|
+
});
|
|
303
|
+
}
|
|
304
|
+
// Emit completion
|
|
305
|
+
controller.enqueue({
|
|
306
|
+
type: "finish",
|
|
307
|
+
finishReason: result.finishReason,
|
|
308
|
+
usage: result.usage,
|
|
309
|
+
});
|
|
310
|
+
controller.close();
|
|
311
|
+
}
|
|
312
|
+
catch (error) {
|
|
313
|
+
controller.error(error);
|
|
314
|
+
}
|
|
315
|
+
},
|
|
316
|
+
});
|
|
317
|
+
return {
|
|
318
|
+
stream: syntheticStream,
|
|
319
|
+
rawCall: result.rawCall,
|
|
320
|
+
rawResponse: result.rawResponse,
|
|
321
|
+
request: result.request,
|
|
322
|
+
warnings: [
|
|
323
|
+
...(result.warnings || []),
|
|
324
|
+
{
|
|
325
|
+
type: "other",
|
|
326
|
+
message: "Streaming not supported, using synthetic stream",
|
|
327
|
+
},
|
|
328
|
+
],
|
|
329
|
+
};
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
catch (error) {
|
|
333
|
+
logger.error("SageMaker doStream failed", {
|
|
334
|
+
error: error instanceof Error ? error.message : String(error),
|
|
335
|
+
});
|
|
336
|
+
throw handleSageMakerError(error, this.modelConfig.endpointName);
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
/**
|
|
340
|
+
* Convert AI SDK options to SageMaker request format
|
|
341
|
+
*/
|
|
342
|
+
convertToSageMakerRequest(options) {
|
|
343
|
+
const promptText = this.extractPromptText(options);
|
|
344
|
+
// Enhanced SageMaker request format with tool support (Phase 4)
|
|
345
|
+
const request = {
|
|
346
|
+
inputs: promptText,
|
|
347
|
+
parameters: {
|
|
348
|
+
max_new_tokens: options.maxTokens || 512,
|
|
349
|
+
temperature: options.temperature || 0.7,
|
|
350
|
+
top_p: options.topP || 0.9,
|
|
351
|
+
stop: options.stopSequences || [],
|
|
352
|
+
},
|
|
353
|
+
};
|
|
354
|
+
// Add tool support if tools are present
|
|
355
|
+
const tools = options.tools;
|
|
356
|
+
if (tools && Array.isArray(tools) && tools.length > 0) {
|
|
357
|
+
request.tools = this.convertToolsToSageMakerFormat(tools);
|
|
358
|
+
// Add tool choice if specified
|
|
359
|
+
const toolChoice = options.toolChoice;
|
|
360
|
+
if (toolChoice) {
|
|
361
|
+
request.tool_choice =
|
|
362
|
+
this.convertToolChoiceToSageMakerFormat(toolChoice);
|
|
363
|
+
}
|
|
364
|
+
logger.debug("Added tool support to SageMaker request", {
|
|
365
|
+
toolCount: tools.length,
|
|
366
|
+
toolChoice: toolChoice,
|
|
367
|
+
});
|
|
368
|
+
}
|
|
369
|
+
// Add structured output support (Phase 4)
|
|
370
|
+
const responseFormat = options
|
|
371
|
+
.responseFormat;
|
|
372
|
+
if (responseFormat) {
|
|
373
|
+
request.response_format =
|
|
374
|
+
this.convertResponseFormatToSageMakerFormat(responseFormat);
|
|
375
|
+
logger.debug("Added structured output support to SageMaker request", {
|
|
376
|
+
responseFormat: responseFormat.type,
|
|
377
|
+
});
|
|
378
|
+
}
|
|
379
|
+
logger.debug("Converted to SageMaker request format", {
|
|
380
|
+
inputLength: promptText.length,
|
|
381
|
+
parameters: request.parameters,
|
|
382
|
+
hasTools: !!request.tools,
|
|
383
|
+
});
|
|
384
|
+
return request;
|
|
385
|
+
}
|
|
386
|
+
/**
|
|
387
|
+
* Convert Vercel AI SDK tools to SageMaker format
|
|
388
|
+
*/
|
|
389
|
+
convertToolsToSageMakerFormat(tools) {
|
|
390
|
+
return tools.map((tool) => {
|
|
391
|
+
if (tool.type === "function") {
|
|
392
|
+
return {
|
|
393
|
+
type: "function",
|
|
394
|
+
function: {
|
|
395
|
+
name: tool.function.name,
|
|
396
|
+
description: tool.function.description || "",
|
|
397
|
+
parameters: tool.function.parameters || {},
|
|
398
|
+
},
|
|
399
|
+
};
|
|
400
|
+
}
|
|
401
|
+
return tool; // Pass through other tool types
|
|
402
|
+
});
|
|
403
|
+
}
|
|
404
|
+
/**
|
|
405
|
+
* Convert Vercel AI SDK tool choice to SageMaker format
|
|
406
|
+
*/
|
|
407
|
+
convertToolChoiceToSageMakerFormat(toolChoice) {
|
|
408
|
+
if (typeof toolChoice === "string") {
|
|
409
|
+
return toolChoice; // 'auto', 'none', etc.
|
|
410
|
+
}
|
|
411
|
+
if (toolChoice?.type === "function") {
|
|
412
|
+
return {
|
|
413
|
+
type: "function",
|
|
414
|
+
function: {
|
|
415
|
+
name: toolChoice.function.name,
|
|
416
|
+
},
|
|
417
|
+
};
|
|
418
|
+
}
|
|
419
|
+
return toolChoice;
|
|
420
|
+
}
|
|
421
|
+
/**
|
|
422
|
+
* Convert Vercel AI SDK response format to SageMaker format (Phase 4)
|
|
423
|
+
*/
|
|
424
|
+
convertResponseFormatToSageMakerFormat(responseFormat) {
|
|
425
|
+
if (responseFormat.type === "json_object") {
|
|
426
|
+
return {
|
|
427
|
+
type: "json_object",
|
|
428
|
+
schema: responseFormat.schema || undefined,
|
|
429
|
+
};
|
|
430
|
+
}
|
|
431
|
+
if (responseFormat.type === "json_schema") {
|
|
432
|
+
return {
|
|
433
|
+
type: "json_schema",
|
|
434
|
+
json_schema: {
|
|
435
|
+
name: responseFormat.json_schema?.name || "response",
|
|
436
|
+
description: responseFormat.json_schema?.description ||
|
|
437
|
+
"Generated response",
|
|
438
|
+
schema: responseFormat.json_schema?.schema || {},
|
|
439
|
+
},
|
|
440
|
+
};
|
|
441
|
+
}
|
|
442
|
+
// Default to text
|
|
443
|
+
return {
|
|
444
|
+
type: "text",
|
|
445
|
+
};
|
|
446
|
+
}
|
|
447
|
+
/**
|
|
448
|
+
* Extract text content from AI SDK prompt format
|
|
449
|
+
*/
|
|
450
|
+
extractPromptText(options) {
|
|
451
|
+
// Check for messages first (like Ollama)
|
|
452
|
+
const messages = options.messages;
|
|
453
|
+
if (messages && Array.isArray(messages)) {
|
|
454
|
+
return messages
|
|
455
|
+
.filter((msg) => msg.role && msg.content)
|
|
456
|
+
.map((msg) => {
|
|
457
|
+
if (typeof msg.content === "string") {
|
|
458
|
+
return `${msg.role}: ${msg.content}`;
|
|
459
|
+
}
|
|
460
|
+
return `${msg.role}: ${JSON.stringify(msg.content)}`;
|
|
461
|
+
})
|
|
462
|
+
.join("\n");
|
|
463
|
+
}
|
|
464
|
+
// Fallback to prompt property
|
|
465
|
+
const prompt = options.prompt;
|
|
466
|
+
if (typeof prompt === "string") {
|
|
467
|
+
return prompt;
|
|
468
|
+
}
|
|
469
|
+
if (Array.isArray(prompt)) {
|
|
470
|
+
return prompt
|
|
471
|
+
.filter((msg) => msg.role && msg.content)
|
|
472
|
+
.map((msg) => {
|
|
473
|
+
if (typeof msg.content === "string") {
|
|
474
|
+
return `${msg.role}: ${msg.content}`;
|
|
475
|
+
}
|
|
476
|
+
return `${msg.role}: ${JSON.stringify(msg.content)}`;
|
|
477
|
+
})
|
|
478
|
+
.join("\n");
|
|
479
|
+
}
|
|
480
|
+
return String(prompt);
|
|
481
|
+
}
|
|
482
|
+
/**
|
|
483
|
+
* Extract generated text from SageMaker response
|
|
484
|
+
*/
|
|
485
|
+
extractTextFromResponse(responseBody) {
|
|
486
|
+
// Handle common SageMaker response formats
|
|
487
|
+
if (typeof responseBody === "string") {
|
|
488
|
+
return responseBody;
|
|
489
|
+
}
|
|
490
|
+
if (responseBody.generated_text) {
|
|
491
|
+
return responseBody.generated_text;
|
|
492
|
+
}
|
|
493
|
+
if (responseBody.outputs) {
|
|
494
|
+
return responseBody.outputs;
|
|
495
|
+
}
|
|
496
|
+
if (responseBody.text) {
|
|
497
|
+
return responseBody.text;
|
|
498
|
+
}
|
|
499
|
+
if (Array.isArray(responseBody) && responseBody[0]?.generated_text) {
|
|
500
|
+
return responseBody[0].generated_text;
|
|
501
|
+
}
|
|
502
|
+
// Handle response with tool calls
|
|
503
|
+
if (responseBody.choices && Array.isArray(responseBody.choices)) {
|
|
504
|
+
const choice = responseBody.choices[0];
|
|
505
|
+
if (choice?.message?.content) {
|
|
506
|
+
return choice.message.content;
|
|
507
|
+
}
|
|
508
|
+
}
|
|
509
|
+
// Fallback: stringify the entire response
|
|
510
|
+
return JSON.stringify(responseBody);
|
|
511
|
+
}
|
|
512
|
+
/**
|
|
513
|
+
* Extract tool calls from SageMaker response (Phase 4)
|
|
514
|
+
*/
|
|
515
|
+
extractToolCallsFromResponse(responseBody) {
|
|
516
|
+
// Handle OpenAI-compatible format (common for many SageMaker models)
|
|
517
|
+
if (responseBody.choices && Array.isArray(responseBody.choices)) {
|
|
518
|
+
const choice = responseBody.choices[0];
|
|
519
|
+
if (choice?.message?.tool_calls) {
|
|
520
|
+
return choice.message.tool_calls.map((toolCall) => ({
|
|
521
|
+
type: "function",
|
|
522
|
+
id: String(toolCall.id || `call_${randomUUID()}`),
|
|
523
|
+
function: {
|
|
524
|
+
name: String(toolCall.function.name),
|
|
525
|
+
arguments: String(toolCall.function.arguments),
|
|
526
|
+
},
|
|
527
|
+
}));
|
|
528
|
+
}
|
|
529
|
+
}
|
|
530
|
+
// Handle custom SageMaker tool call format
|
|
531
|
+
if (responseBody.tool_calls && Array.isArray(responseBody.tool_calls)) {
|
|
532
|
+
return responseBody.tool_calls;
|
|
533
|
+
}
|
|
534
|
+
// Handle Anthropic-style tool use
|
|
535
|
+
if (responseBody.content && Array.isArray(responseBody.content)) {
|
|
536
|
+
const toolUses = responseBody.content.filter((item) => item.type === "tool_use");
|
|
537
|
+
if (toolUses.length > 0) {
|
|
538
|
+
return toolUses.map((toolUse) => ({
|
|
539
|
+
type: "function",
|
|
540
|
+
id: String(toolUse.id || `call_${randomUUID()}`),
|
|
541
|
+
function: {
|
|
542
|
+
name: String(toolUse.name),
|
|
543
|
+
arguments: JSON.stringify(toolUse.input || {}),
|
|
544
|
+
},
|
|
545
|
+
}));
|
|
546
|
+
}
|
|
547
|
+
}
|
|
548
|
+
return undefined;
|
|
549
|
+
}
|
|
550
|
+
/**
|
|
551
|
+
* Map SageMaker finish reason to standardized format
|
|
552
|
+
*/
|
|
553
|
+
mapSageMakerFinishReason(sagemakerReason) {
|
|
554
|
+
switch (sagemakerReason?.toLowerCase()) {
|
|
555
|
+
case "stop":
|
|
556
|
+
case "end_turn":
|
|
557
|
+
case "stop_sequence":
|
|
558
|
+
return "stop";
|
|
559
|
+
case "length":
|
|
560
|
+
case "max_tokens":
|
|
561
|
+
case "max_length":
|
|
562
|
+
return "length";
|
|
563
|
+
case "content_filter":
|
|
564
|
+
case "content_filtered":
|
|
565
|
+
return "content-filter";
|
|
566
|
+
case "tool_calls":
|
|
567
|
+
case "function_call":
|
|
568
|
+
return "tool-calls";
|
|
569
|
+
case "error":
|
|
570
|
+
return "error";
|
|
571
|
+
default:
|
|
572
|
+
return "unknown";
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
/**
|
|
576
|
+
* Get model configuration summary for debugging
|
|
577
|
+
*/
|
|
578
|
+
getModelInfo() {
|
|
579
|
+
return {
|
|
580
|
+
modelId: this.modelId,
|
|
581
|
+
provider: this.provider,
|
|
582
|
+
specificationVersion: this.specificationVersion,
|
|
583
|
+
endpointName: this.modelConfig.endpointName,
|
|
584
|
+
modelType: this.modelConfig.modelType,
|
|
585
|
+
region: this.config.region,
|
|
586
|
+
};
|
|
587
|
+
}
|
|
588
|
+
/**
|
|
589
|
+
* Test basic connectivity to the SageMaker endpoint
|
|
590
|
+
*/
|
|
591
|
+
async testConnectivity() {
|
|
592
|
+
try {
|
|
593
|
+
// Use the same pattern as Ollama - pass messages directly
|
|
594
|
+
const result = await this.doGenerate({
|
|
595
|
+
inputFormat: "messages",
|
|
596
|
+
mode: { type: "regular" },
|
|
597
|
+
prompt: [
|
|
598
|
+
{ role: "user", content: [{ type: "text", text: "Hello" }] },
|
|
599
|
+
],
|
|
600
|
+
maxTokens: 10,
|
|
601
|
+
});
|
|
602
|
+
return {
|
|
603
|
+
success: !!result.text,
|
|
604
|
+
};
|
|
605
|
+
}
|
|
606
|
+
catch (error) {
|
|
607
|
+
return {
|
|
608
|
+
success: false,
|
|
609
|
+
error: error instanceof Error ? error.message : String(error),
|
|
610
|
+
};
|
|
611
|
+
}
|
|
612
|
+
}
|
|
613
|
+
/**
|
|
614
|
+
* Batch inference support (Phase 4)
|
|
615
|
+
* Process multiple prompts in a single request for efficiency
|
|
616
|
+
*/
|
|
617
|
+
async doBatchGenerate(prompts, options) {
|
|
618
|
+
try {
|
|
619
|
+
logger.debug("SageMaker batch generate called", {
|
|
620
|
+
batchSize: prompts.length,
|
|
621
|
+
endpointName: this.modelConfig.endpointName,
|
|
622
|
+
});
|
|
623
|
+
// Advanced parallel processing with dynamic concurrency and error handling
|
|
624
|
+
const results = await this.processPromptsInParallel(prompts, options);
|
|
625
|
+
logger.debug("SageMaker batch generate completed", {
|
|
626
|
+
batchSize: prompts.length,
|
|
627
|
+
successCount: results.length,
|
|
628
|
+
});
|
|
629
|
+
return results;
|
|
630
|
+
}
|
|
631
|
+
catch (error) {
|
|
632
|
+
logger.error("SageMaker batch generate failed", {
|
|
633
|
+
error: error instanceof Error ? error.message : String(error),
|
|
634
|
+
batchSize: prompts.length,
|
|
635
|
+
});
|
|
636
|
+
throw handleSageMakerError(error, this.modelConfig.endpointName);
|
|
637
|
+
}
|
|
638
|
+
}
|
|
639
|
+
/**
|
|
640
|
+
* Process prompts in parallel with advanced concurrency control and error handling
|
|
641
|
+
*/
|
|
642
|
+
async processPromptsInParallel(prompts, options) {
|
|
643
|
+
// Dynamic concurrency based on batch size and endpoint capacity
|
|
644
|
+
const INITIAL_CONCURRENCY = Math.min(this.modelConfig.initialConcurrency ?? DEFAULT_INITIAL_CONCURRENCY, prompts.length);
|
|
645
|
+
const MAX_CONCURRENCY = this.modelConfig.maxConcurrency ?? DEFAULT_MAX_CONCURRENCY;
|
|
646
|
+
const MIN_CONCURRENCY = this.modelConfig.minConcurrency ?? DEFAULT_MIN_CONCURRENCY;
|
|
647
|
+
const results = new Array(prompts.length);
|
|
648
|
+
const errors = [];
|
|
649
|
+
// Use adaptive semaphore utility for concurrency control
|
|
650
|
+
const semaphore = createAdaptiveSemaphore(INITIAL_CONCURRENCY, MAX_CONCURRENCY, MIN_CONCURRENCY);
|
|
651
|
+
// Process each prompt with adaptive concurrency
|
|
652
|
+
const processPrompt = async (prompt, index) => {
|
|
653
|
+
await semaphore.acquire();
|
|
654
|
+
const startTime = Date.now();
|
|
655
|
+
try {
|
|
656
|
+
const result = await this.doGenerate({
|
|
657
|
+
inputFormat: "messages",
|
|
658
|
+
mode: { type: "regular" },
|
|
659
|
+
prompt: [
|
|
660
|
+
{
|
|
661
|
+
role: "user",
|
|
662
|
+
content: [{ type: "text", text: prompt }],
|
|
663
|
+
},
|
|
664
|
+
],
|
|
665
|
+
maxTokens: options?.maxTokens,
|
|
666
|
+
temperature: options?.temperature,
|
|
667
|
+
topP: options?.topP,
|
|
668
|
+
});
|
|
669
|
+
const duration = Date.now() - startTime;
|
|
670
|
+
results[index] = {
|
|
671
|
+
text: result.text || "",
|
|
672
|
+
usage: {
|
|
673
|
+
promptTokens: result.usage.promptTokens,
|
|
674
|
+
completionTokens: result.usage.completionTokens,
|
|
675
|
+
totalTokens: result.usage.totalTokens ||
|
|
676
|
+
result.usage.promptTokens + result.usage.completionTokens,
|
|
677
|
+
},
|
|
678
|
+
finishReason: result.finishReason,
|
|
679
|
+
index,
|
|
680
|
+
};
|
|
681
|
+
// Record successful completion for adaptive concurrency adjustment
|
|
682
|
+
semaphore.recordSuccess(duration);
|
|
683
|
+
}
|
|
684
|
+
catch (error) {
|
|
685
|
+
errors.push({
|
|
686
|
+
index,
|
|
687
|
+
error: error instanceof Error ? error : new Error(String(error)),
|
|
688
|
+
});
|
|
689
|
+
// Record error for adaptive concurrency adjustment
|
|
690
|
+
const duration = Date.now() - startTime;
|
|
691
|
+
semaphore.recordError(duration);
|
|
692
|
+
// Create error result
|
|
693
|
+
results[index] = {
|
|
694
|
+
text: "",
|
|
695
|
+
usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 },
|
|
696
|
+
finishReason: "error",
|
|
697
|
+
index,
|
|
698
|
+
};
|
|
699
|
+
}
|
|
700
|
+
finally {
|
|
701
|
+
semaphore.release();
|
|
702
|
+
}
|
|
703
|
+
};
|
|
704
|
+
// Start all requests with concurrency control
|
|
705
|
+
const allPromises = prompts.map((prompt, index) => processPrompt(prompt, index));
|
|
706
|
+
// Wait for all requests to complete
|
|
707
|
+
await Promise.all(allPromises);
|
|
708
|
+
// Log final statistics using semaphore metrics
|
|
709
|
+
const metrics = semaphore.getMetrics();
|
|
710
|
+
logger.debug("Parallel batch processing completed", {
|
|
711
|
+
totalPrompts: prompts.length,
|
|
712
|
+
successCount: metrics.completedCount,
|
|
713
|
+
errorCount: metrics.errorCount,
|
|
714
|
+
finalConcurrency: metrics.currentConcurrency,
|
|
715
|
+
errorRate: metrics.errorCount / prompts.length,
|
|
716
|
+
averageResponseTime: metrics.averageResponseTime,
|
|
717
|
+
});
|
|
718
|
+
// If we have too many errors, log them for debugging
|
|
719
|
+
if (errors.length > 0) {
|
|
720
|
+
logger.warn("Batch processing encountered errors", {
|
|
721
|
+
errorCount: errors.length,
|
|
722
|
+
sampleErrors: errors.slice(0, 3).map((e) => ({
|
|
723
|
+
index: e.index,
|
|
724
|
+
message: e.error.message,
|
|
725
|
+
})),
|
|
726
|
+
});
|
|
727
|
+
}
|
|
728
|
+
// Return results in original order (already sorted by index)
|
|
729
|
+
return results.map(({ text, usage, finishReason }) => ({
|
|
730
|
+
text,
|
|
731
|
+
usage,
|
|
732
|
+
finishReason,
|
|
733
|
+
}));
|
|
734
|
+
}
|
|
735
|
+
/**
|
|
736
|
+
* Enhanced model information with batch capabilities
|
|
737
|
+
*/
|
|
738
|
+
getModelCapabilities() {
|
|
739
|
+
return {
|
|
740
|
+
...this.getModelInfo(),
|
|
741
|
+
capabilities: {
|
|
742
|
+
streaming: true,
|
|
743
|
+
toolCalling: true,
|
|
744
|
+
structuredOutput: true,
|
|
745
|
+
batchInference: true,
|
|
746
|
+
supportedResponseFormats: ["text", "json_object", "json_schema"],
|
|
747
|
+
supportedToolTypes: ["function"],
|
|
748
|
+
maxBatchSize: 100, // Increased limit with parallel processing
|
|
749
|
+
adaptiveConcurrency: true,
|
|
750
|
+
errorRecovery: true,
|
|
751
|
+
},
|
|
752
|
+
};
|
|
753
|
+
}
|
|
754
|
+
}
|
|
755
|
+
export default SageMakerLanguageModel;
|