@juspay/neurolink 7.7.1 → 7.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/CHANGELOG.md +12 -0
  2. package/README.md +33 -2
  3. package/dist/cli/commands/config.d.ts +3 -3
  4. package/dist/cli/commands/sagemaker.d.ts +11 -0
  5. package/dist/cli/commands/sagemaker.js +778 -0
  6. package/dist/cli/factories/commandFactory.js +1 -0
  7. package/dist/cli/index.js +3 -0
  8. package/dist/cli/utils/interactiveSetup.js +28 -0
  9. package/dist/core/baseProvider.d.ts +2 -2
  10. package/dist/core/types.d.ts +1 -0
  11. package/dist/core/types.js +1 -0
  12. package/dist/factories/providerRegistry.js +5 -0
  13. package/dist/lib/core/baseProvider.d.ts +2 -2
  14. package/dist/lib/core/types.d.ts +1 -0
  15. package/dist/lib/core/types.js +1 -0
  16. package/dist/lib/factories/providerRegistry.js +5 -0
  17. package/dist/lib/providers/amazonSagemaker.d.ts +67 -0
  18. package/dist/lib/providers/amazonSagemaker.js +149 -0
  19. package/dist/lib/providers/index.d.ts +4 -0
  20. package/dist/lib/providers/index.js +4 -0
  21. package/dist/lib/providers/sagemaker/adaptive-semaphore.d.ts +86 -0
  22. package/dist/lib/providers/sagemaker/adaptive-semaphore.js +212 -0
  23. package/dist/lib/providers/sagemaker/client.d.ts +156 -0
  24. package/dist/lib/providers/sagemaker/client.js +462 -0
  25. package/dist/lib/providers/sagemaker/config.d.ts +73 -0
  26. package/dist/lib/providers/sagemaker/config.js +308 -0
  27. package/dist/lib/providers/sagemaker/detection.d.ts +176 -0
  28. package/dist/lib/providers/sagemaker/detection.js +596 -0
  29. package/dist/lib/providers/sagemaker/diagnostics.d.ts +37 -0
  30. package/dist/lib/providers/sagemaker/diagnostics.js +137 -0
  31. package/dist/lib/providers/sagemaker/error-constants.d.ts +78 -0
  32. package/dist/lib/providers/sagemaker/error-constants.js +227 -0
  33. package/dist/lib/providers/sagemaker/errors.d.ts +83 -0
  34. package/dist/lib/providers/sagemaker/errors.js +216 -0
  35. package/dist/lib/providers/sagemaker/index.d.ts +35 -0
  36. package/dist/lib/providers/sagemaker/index.js +67 -0
  37. package/dist/lib/providers/sagemaker/language-model.d.ts +182 -0
  38. package/dist/lib/providers/sagemaker/language-model.js +755 -0
  39. package/dist/lib/providers/sagemaker/parsers.d.ts +136 -0
  40. package/dist/lib/providers/sagemaker/parsers.js +625 -0
  41. package/dist/lib/providers/sagemaker/streaming.d.ts +39 -0
  42. package/dist/lib/providers/sagemaker/streaming.js +320 -0
  43. package/dist/lib/providers/sagemaker/structured-parser.d.ts +117 -0
  44. package/dist/lib/providers/sagemaker/structured-parser.js +625 -0
  45. package/dist/lib/providers/sagemaker/types.d.ts +456 -0
  46. package/dist/lib/providers/sagemaker/types.js +7 -0
  47. package/dist/lib/types/cli.d.ts +36 -1
  48. package/dist/providers/amazonSagemaker.d.ts +67 -0
  49. package/dist/providers/amazonSagemaker.js +149 -0
  50. package/dist/providers/index.d.ts +4 -0
  51. package/dist/providers/index.js +4 -0
  52. package/dist/providers/sagemaker/adaptive-semaphore.d.ts +86 -0
  53. package/dist/providers/sagemaker/adaptive-semaphore.js +212 -0
  54. package/dist/providers/sagemaker/client.d.ts +156 -0
  55. package/dist/providers/sagemaker/client.js +462 -0
  56. package/dist/providers/sagemaker/config.d.ts +73 -0
  57. package/dist/providers/sagemaker/config.js +308 -0
  58. package/dist/providers/sagemaker/detection.d.ts +176 -0
  59. package/dist/providers/sagemaker/detection.js +596 -0
  60. package/dist/providers/sagemaker/diagnostics.d.ts +37 -0
  61. package/dist/providers/sagemaker/diagnostics.js +137 -0
  62. package/dist/providers/sagemaker/error-constants.d.ts +78 -0
  63. package/dist/providers/sagemaker/error-constants.js +227 -0
  64. package/dist/providers/sagemaker/errors.d.ts +83 -0
  65. package/dist/providers/sagemaker/errors.js +216 -0
  66. package/dist/providers/sagemaker/index.d.ts +35 -0
  67. package/dist/providers/sagemaker/index.js +67 -0
  68. package/dist/providers/sagemaker/language-model.d.ts +182 -0
  69. package/dist/providers/sagemaker/language-model.js +755 -0
  70. package/dist/providers/sagemaker/parsers.d.ts +136 -0
  71. package/dist/providers/sagemaker/parsers.js +625 -0
  72. package/dist/providers/sagemaker/streaming.d.ts +39 -0
  73. package/dist/providers/sagemaker/streaming.js +320 -0
  74. package/dist/providers/sagemaker/structured-parser.d.ts +117 -0
  75. package/dist/providers/sagemaker/structured-parser.js +625 -0
  76. package/dist/providers/sagemaker/types.d.ts +456 -0
  77. package/dist/providers/sagemaker/types.js +7 -0
  78. package/dist/types/cli.d.ts +36 -1
  79. package/package.json +4 -1
@@ -0,0 +1,755 @@
1
+ /**
2
+ * SageMaker Language Model Implementation
3
+ *
4
+ * This module implements the LanguageModelV1 interface for Amazon SageMaker
5
+ * integration with the Vercel AI SDK.
6
+ */
7
+ import { randomUUID } from "crypto";
8
+ import { SageMakerRuntimeClient } from "./client.js";
9
+ import { handleSageMakerError } from "./errors.js";
10
+ import { estimateTokenUsage, createSageMakerStream } from "./streaming.js";
11
+ import { createAdaptiveSemaphore, } from "./adaptive-semaphore.js";
12
+ import { logger } from "../../utils/logger.js";
13
+ /**
14
+ * Base synthetic streaming delay in milliseconds for simulating real-time response
15
+ * Can be configured via SAGEMAKER_BASE_STREAMING_DELAY_MS environment variable
16
+ */
17
+ const BASE_SYNTHETIC_STREAMING_DELAY_MS = process.env
18
+ .SAGEMAKER_BASE_STREAMING_DELAY_MS
19
+ ? parseInt(process.env.SAGEMAKER_BASE_STREAMING_DELAY_MS, 10)
20
+ : 50;
21
+ /**
22
+ * Maximum synthetic streaming delay in milliseconds to prevent excessively slow streaming
23
+ * Can be configured via SAGEMAKER_MAX_STREAMING_DELAY_MS environment variable
24
+ */
25
+ const MAX_SYNTHETIC_STREAMING_DELAY_MS = process.env
26
+ .SAGEMAKER_MAX_STREAMING_DELAY_MS
27
+ ? parseInt(process.env.SAGEMAKER_MAX_STREAMING_DELAY_MS, 10)
28
+ : 200;
29
+ /**
30
+ * Calculate adaptive delay based on text size to avoid slow streaming for large texts
31
+ * Smaller texts get longer delays for realistic feel, larger texts get shorter delays for performance
32
+ */
33
+ function calculateAdaptiveDelay(textLength, chunkCount) {
34
+ // Base calculation: smaller delay for larger texts
35
+ const adaptiveDelay = Math.max(10, // Minimum 10ms delay
36
+ Math.min(MAX_SYNTHETIC_STREAMING_DELAY_MS, BASE_SYNTHETIC_STREAMING_DELAY_MS * (1000 / Math.max(textLength, 100))));
37
+ // Further reduce delay if there are many chunks to process
38
+ if (chunkCount > 20) {
39
+ return Math.max(10, adaptiveDelay * 0.5); // Half delay for many chunks
40
+ }
41
+ else if (chunkCount > 10) {
42
+ return Math.max(15, adaptiveDelay * 0.7); // Reduced delay for moderate chunks
43
+ }
44
+ return adaptiveDelay;
45
+ }
46
+ /**
47
+ * Create an async iterator for text chunks with adaptive delay between chunks
48
+ * Used for synthetic streaming simulation with performance optimization for large texts
49
+ */
50
+ async function* createTextChunkIterator(text) {
51
+ if (!text) {
52
+ return; // No text to emit
53
+ }
54
+ const words = text.split(/\s+/);
55
+ const chunkSize = Math.max(1, Math.floor(words.length / 10));
56
+ const totalChunks = Math.ceil(words.length / chunkSize);
57
+ // Calculate adaptive delay based on text size and chunk count
58
+ const adaptiveDelay = calculateAdaptiveDelay(text.length, totalChunks);
59
+ for (let i = 0; i < words.length; i += chunkSize) {
60
+ const chunk = words.slice(i, i + chunkSize).join(" ");
61
+ const deltaText = i === 0 ? chunk : " " + chunk;
62
+ // Add adaptive delay between chunks for realistic streaming simulation
63
+ // Delay is shorter for larger texts to improve performance
64
+ if (i > 0) {
65
+ await new Promise((resolve) => setTimeout(resolve, adaptiveDelay));
66
+ }
67
+ yield deltaText;
68
+ }
69
+ }
70
+ /**
71
+ * Batch processing concurrency constants
72
+ */
73
+ const DEFAULT_INITIAL_CONCURRENCY = 5;
74
+ const DEFAULT_MAX_CONCURRENCY = 10;
75
+ const DEFAULT_MIN_CONCURRENCY = 1;
76
+ /**
77
+ * SageMaker Language Model implementing LanguageModelV1 interface
78
+ */
79
+ export class SageMakerLanguageModel {
80
+ specificationVersion = "v1";
81
+ provider = "sagemaker";
82
+ modelId;
83
+ supportsStreaming = true;
84
+ defaultObjectGenerationMode = "json";
85
+ client;
86
+ config;
87
+ modelConfig;
88
+ constructor(modelId, config, modelConfig) {
89
+ this.modelId = modelId;
90
+ this.config = config;
91
+ this.modelConfig = modelConfig;
92
+ this.client = new SageMakerRuntimeClient(config);
93
+ logger.debug("SageMaker Language Model initialized", {
94
+ modelId: this.modelId,
95
+ endpointName: this.modelConfig.endpointName,
96
+ provider: this.provider,
97
+ specificationVersion: this.specificationVersion,
98
+ });
99
+ }
100
+ /**
101
+ * Generate text synchronously using SageMaker endpoint
102
+ */
103
+ async doGenerate(options) {
104
+ const startTime = Date.now();
105
+ try {
106
+ const promptText = this.extractPromptText(options);
107
+ logger.debug("SageMaker doGenerate called", {
108
+ endpointName: this.modelConfig.endpointName,
109
+ promptLength: promptText.length,
110
+ maxTokens: options.maxTokens,
111
+ temperature: options.temperature,
112
+ });
113
+ // Convert AI SDK options to SageMaker request format
114
+ const sagemakerRequest = this.convertToSageMakerRequest(options);
115
+ // Invoke SageMaker endpoint
116
+ const response = await this.client.invokeEndpoint({
117
+ EndpointName: this.modelConfig.endpointName,
118
+ Body: JSON.stringify(sagemakerRequest),
119
+ ContentType: "application/json",
120
+ Accept: "application/json",
121
+ });
122
+ // Parse SageMaker response
123
+ const responseBody = JSON.parse(new TextDecoder().decode(response.Body));
124
+ const generatedText = this.extractTextFromResponse(responseBody);
125
+ // Extract tool calls if present (Phase 4 enhancement)
126
+ const toolCalls = this.extractToolCallsFromResponse(responseBody);
127
+ // Calculate token usage
128
+ const usage = estimateTokenUsage(promptText, generatedText);
129
+ // Determine finish reason based on response content
130
+ let finishReason = "stop";
131
+ if (toolCalls && toolCalls.length > 0) {
132
+ finishReason = "tool-calls";
133
+ }
134
+ else if (responseBody.finish_reason) {
135
+ finishReason = this.mapSageMakerFinishReason(responseBody.finish_reason);
136
+ }
137
+ const duration = Date.now() - startTime;
138
+ logger.debug("SageMaker doGenerate completed", {
139
+ duration,
140
+ outputLength: generatedText.length,
141
+ usage,
142
+ toolCallsCount: toolCalls?.length || 0,
143
+ finishReason,
144
+ });
145
+ const result = {
146
+ text: generatedText,
147
+ usage: {
148
+ promptTokens: usage.promptTokens,
149
+ completionTokens: usage.completionTokens,
150
+ totalTokens: usage.totalTokens,
151
+ },
152
+ finishReason,
153
+ rawCall: {
154
+ rawPrompt: options.prompt,
155
+ rawSettings: {
156
+ maxTokens: options.maxTokens,
157
+ temperature: options.temperature,
158
+ topP: options.topP,
159
+ endpointName: this.modelConfig.endpointName,
160
+ },
161
+ },
162
+ rawResponse: {
163
+ headers: {
164
+ "content-type": response.ContentType || "application/json",
165
+ "invoked-variant": response.InvokedProductionVariant || "",
166
+ },
167
+ },
168
+ request: {
169
+ body: JSON.stringify(sagemakerRequest),
170
+ },
171
+ };
172
+ // Add tool calls to result if present
173
+ if (toolCalls && toolCalls.length > 0) {
174
+ result.toolCalls = toolCalls;
175
+ }
176
+ // Add structured data if response format was specified (Phase 4)
177
+ const responseFormat = sagemakerRequest
178
+ .response_format;
179
+ if (responseFormat &&
180
+ (responseFormat.type === "json_object" ||
181
+ responseFormat.type === "json_schema")) {
182
+ try {
183
+ const parsedData = JSON.parse(generatedText);
184
+ result.object = parsedData;
185
+ logger.debug("Extracted structured data from response", {
186
+ responseFormat: responseFormat.type,
187
+ hasObject: !!result.object,
188
+ });
189
+ }
190
+ catch (parseError) {
191
+ logger.warn("Failed to parse structured response as JSON", {
192
+ error: parseError instanceof Error
193
+ ? parseError.message
194
+ : String(parseError),
195
+ responseText: generatedText.substring(0, 200),
196
+ });
197
+ // Keep the text response as fallback
198
+ }
199
+ }
200
+ return result;
201
+ }
202
+ catch (error) {
203
+ const duration = Date.now() - startTime;
204
+ logger.error("SageMaker doGenerate failed", {
205
+ duration,
206
+ error: error instanceof Error ? error.message : String(error),
207
+ });
208
+ throw handleSageMakerError(error, this.modelConfig.endpointName);
209
+ }
210
+ }
211
+ /**
212
+ * Generate text with streaming using SageMaker endpoint
213
+ */
214
+ async doStream(options) {
215
+ try {
216
+ const promptText = this.extractPromptText(options);
217
+ logger.debug("SageMaker doStream called", {
218
+ endpointName: this.modelConfig.endpointName,
219
+ promptLength: promptText.length,
220
+ });
221
+ // Phase 2: Full streaming implementation with automatic detection
222
+ const sagemakerRequest = this.convertToSageMakerRequest(options);
223
+ // Add streaming parameter if model supports it
224
+ const requestWithStreaming = {
225
+ ...sagemakerRequest,
226
+ parameters: {
227
+ ...(typeof sagemakerRequest.parameters === "object" &&
228
+ sagemakerRequest.parameters !== null
229
+ ? sagemakerRequest.parameters
230
+ : {}),
231
+ stream: true, // Will be validated by detection system
232
+ },
233
+ };
234
+ logger.debug("Attempting streaming generation", {
235
+ endpointName: this.modelConfig.endpointName,
236
+ hasStreamingFlag: true,
237
+ });
238
+ try {
239
+ // First, try to invoke with streaming
240
+ const response = await this.client.invokeEndpointWithStreaming({
241
+ EndpointName: this.modelConfig.endpointName,
242
+ Body: JSON.stringify(requestWithStreaming),
243
+ ContentType: this.modelConfig.contentType || "application/json",
244
+ Accept: this.modelConfig.accept || "application/json",
245
+ });
246
+ // Create intelligent streaming response
247
+ const stream = await createSageMakerStream(response.Body, this.modelConfig.endpointName, this.config, {
248
+ prompt: promptText,
249
+ onChunk: (chunk) => {
250
+ logger.debug("Streaming chunk received", {
251
+ contentLength: chunk.content?.length || 0,
252
+ done: chunk.done,
253
+ });
254
+ },
255
+ onComplete: (usage) => {
256
+ logger.debug("Streaming completed", {
257
+ usage,
258
+ endpointName: this.modelConfig.endpointName,
259
+ });
260
+ },
261
+ onError: (error) => {
262
+ logger.error("Streaming error", {
263
+ error: error.message,
264
+ endpointName: this.modelConfig.endpointName,
265
+ });
266
+ },
267
+ });
268
+ return {
269
+ stream: stream,
270
+ rawCall: {
271
+ rawPrompt: sagemakerRequest,
272
+ rawSettings: this.modelConfig,
273
+ },
274
+ rawResponse: {
275
+ headers: {
276
+ "Content-Type": response.ContentType || "application/json",
277
+ "X-Invoked-Production-Variant": response.InvokedProductionVariant || "unknown",
278
+ },
279
+ },
280
+ };
281
+ }
282
+ catch (streamingError) {
283
+ logger.warn("Streaming failed, falling back to non-streaming", {
284
+ endpointName: this.modelConfig.endpointName,
285
+ error: streamingError instanceof Error
286
+ ? streamingError.message
287
+ : String(streamingError),
288
+ });
289
+ // Fallback: Generate normally and create synthetic stream
290
+ const result = await this.doGenerate(options);
291
+ // Create synthetic stream from complete result using async iterator pattern
292
+ const syntheticStream = new ReadableStream({
293
+ async start(controller) {
294
+ try {
295
+ // Create async iterator for text chunks
296
+ const textChunks = createTextChunkIterator(result.text);
297
+ // Process chunks with async iterator pattern
298
+ for await (const deltaText of textChunks) {
299
+ controller.enqueue({
300
+ type: "text-delta",
301
+ textDelta: deltaText,
302
+ });
303
+ }
304
+ // Emit completion
305
+ controller.enqueue({
306
+ type: "finish",
307
+ finishReason: result.finishReason,
308
+ usage: result.usage,
309
+ });
310
+ controller.close();
311
+ }
312
+ catch (error) {
313
+ controller.error(error);
314
+ }
315
+ },
316
+ });
317
+ return {
318
+ stream: syntheticStream,
319
+ rawCall: result.rawCall,
320
+ rawResponse: result.rawResponse,
321
+ request: result.request,
322
+ warnings: [
323
+ ...(result.warnings || []),
324
+ {
325
+ type: "other",
326
+ message: "Streaming not supported, using synthetic stream",
327
+ },
328
+ ],
329
+ };
330
+ }
331
+ }
332
+ catch (error) {
333
+ logger.error("SageMaker doStream failed", {
334
+ error: error instanceof Error ? error.message : String(error),
335
+ });
336
+ throw handleSageMakerError(error, this.modelConfig.endpointName);
337
+ }
338
+ }
339
+ /**
340
+ * Convert AI SDK options to SageMaker request format
341
+ */
342
+ convertToSageMakerRequest(options) {
343
+ const promptText = this.extractPromptText(options);
344
+ // Enhanced SageMaker request format with tool support (Phase 4)
345
+ const request = {
346
+ inputs: promptText,
347
+ parameters: {
348
+ max_new_tokens: options.maxTokens || 512,
349
+ temperature: options.temperature || 0.7,
350
+ top_p: options.topP || 0.9,
351
+ stop: options.stopSequences || [],
352
+ },
353
+ };
354
+ // Add tool support if tools are present
355
+ const tools = options.tools;
356
+ if (tools && Array.isArray(tools) && tools.length > 0) {
357
+ request.tools = this.convertToolsToSageMakerFormat(tools);
358
+ // Add tool choice if specified
359
+ const toolChoice = options.toolChoice;
360
+ if (toolChoice) {
361
+ request.tool_choice =
362
+ this.convertToolChoiceToSageMakerFormat(toolChoice);
363
+ }
364
+ logger.debug("Added tool support to SageMaker request", {
365
+ toolCount: tools.length,
366
+ toolChoice: toolChoice,
367
+ });
368
+ }
369
+ // Add structured output support (Phase 4)
370
+ const responseFormat = options
371
+ .responseFormat;
372
+ if (responseFormat) {
373
+ request.response_format =
374
+ this.convertResponseFormatToSageMakerFormat(responseFormat);
375
+ logger.debug("Added structured output support to SageMaker request", {
376
+ responseFormat: responseFormat.type,
377
+ });
378
+ }
379
+ logger.debug("Converted to SageMaker request format", {
380
+ inputLength: promptText.length,
381
+ parameters: request.parameters,
382
+ hasTools: !!request.tools,
383
+ });
384
+ return request;
385
+ }
386
+ /**
387
+ * Convert Vercel AI SDK tools to SageMaker format
388
+ */
389
+ convertToolsToSageMakerFormat(tools) {
390
+ return tools.map((tool) => {
391
+ if (tool.type === "function") {
392
+ return {
393
+ type: "function",
394
+ function: {
395
+ name: tool.function.name,
396
+ description: tool.function.description || "",
397
+ parameters: tool.function.parameters || {},
398
+ },
399
+ };
400
+ }
401
+ return tool; // Pass through other tool types
402
+ });
403
+ }
404
+ /**
405
+ * Convert Vercel AI SDK tool choice to SageMaker format
406
+ */
407
+ convertToolChoiceToSageMakerFormat(toolChoice) {
408
+ if (typeof toolChoice === "string") {
409
+ return toolChoice; // 'auto', 'none', etc.
410
+ }
411
+ if (toolChoice?.type === "function") {
412
+ return {
413
+ type: "function",
414
+ function: {
415
+ name: toolChoice.function.name,
416
+ },
417
+ };
418
+ }
419
+ return toolChoice;
420
+ }
421
+ /**
422
+ * Convert Vercel AI SDK response format to SageMaker format (Phase 4)
423
+ */
424
+ convertResponseFormatToSageMakerFormat(responseFormat) {
425
+ if (responseFormat.type === "json_object") {
426
+ return {
427
+ type: "json_object",
428
+ schema: responseFormat.schema || undefined,
429
+ };
430
+ }
431
+ if (responseFormat.type === "json_schema") {
432
+ return {
433
+ type: "json_schema",
434
+ json_schema: {
435
+ name: responseFormat.json_schema?.name || "response",
436
+ description: responseFormat.json_schema?.description ||
437
+ "Generated response",
438
+ schema: responseFormat.json_schema?.schema || {},
439
+ },
440
+ };
441
+ }
442
+ // Default to text
443
+ return {
444
+ type: "text",
445
+ };
446
+ }
447
+ /**
448
+ * Extract text content from AI SDK prompt format
449
+ */
450
+ extractPromptText(options) {
451
+ // Check for messages first (like Ollama)
452
+ const messages = options.messages;
453
+ if (messages && Array.isArray(messages)) {
454
+ return messages
455
+ .filter((msg) => msg.role && msg.content)
456
+ .map((msg) => {
457
+ if (typeof msg.content === "string") {
458
+ return `${msg.role}: ${msg.content}`;
459
+ }
460
+ return `${msg.role}: ${JSON.stringify(msg.content)}`;
461
+ })
462
+ .join("\n");
463
+ }
464
+ // Fallback to prompt property
465
+ const prompt = options.prompt;
466
+ if (typeof prompt === "string") {
467
+ return prompt;
468
+ }
469
+ if (Array.isArray(prompt)) {
470
+ return prompt
471
+ .filter((msg) => msg.role && msg.content)
472
+ .map((msg) => {
473
+ if (typeof msg.content === "string") {
474
+ return `${msg.role}: ${msg.content}`;
475
+ }
476
+ return `${msg.role}: ${JSON.stringify(msg.content)}`;
477
+ })
478
+ .join("\n");
479
+ }
480
+ return String(prompt);
481
+ }
482
+ /**
483
+ * Extract generated text from SageMaker response
484
+ */
485
+ extractTextFromResponse(responseBody) {
486
+ // Handle common SageMaker response formats
487
+ if (typeof responseBody === "string") {
488
+ return responseBody;
489
+ }
490
+ if (responseBody.generated_text) {
491
+ return responseBody.generated_text;
492
+ }
493
+ if (responseBody.outputs) {
494
+ return responseBody.outputs;
495
+ }
496
+ if (responseBody.text) {
497
+ return responseBody.text;
498
+ }
499
+ if (Array.isArray(responseBody) && responseBody[0]?.generated_text) {
500
+ return responseBody[0].generated_text;
501
+ }
502
+ // Handle response with tool calls
503
+ if (responseBody.choices && Array.isArray(responseBody.choices)) {
504
+ const choice = responseBody.choices[0];
505
+ if (choice?.message?.content) {
506
+ return choice.message.content;
507
+ }
508
+ }
509
+ // Fallback: stringify the entire response
510
+ return JSON.stringify(responseBody);
511
+ }
512
+ /**
513
+ * Extract tool calls from SageMaker response (Phase 4)
514
+ */
515
+ extractToolCallsFromResponse(responseBody) {
516
+ // Handle OpenAI-compatible format (common for many SageMaker models)
517
+ if (responseBody.choices && Array.isArray(responseBody.choices)) {
518
+ const choice = responseBody.choices[0];
519
+ if (choice?.message?.tool_calls) {
520
+ return choice.message.tool_calls.map((toolCall) => ({
521
+ type: "function",
522
+ id: String(toolCall.id || `call_${randomUUID()}`),
523
+ function: {
524
+ name: String(toolCall.function.name),
525
+ arguments: String(toolCall.function.arguments),
526
+ },
527
+ }));
528
+ }
529
+ }
530
+ // Handle custom SageMaker tool call format
531
+ if (responseBody.tool_calls && Array.isArray(responseBody.tool_calls)) {
532
+ return responseBody.tool_calls;
533
+ }
534
+ // Handle Anthropic-style tool use
535
+ if (responseBody.content && Array.isArray(responseBody.content)) {
536
+ const toolUses = responseBody.content.filter((item) => item.type === "tool_use");
537
+ if (toolUses.length > 0) {
538
+ return toolUses.map((toolUse) => ({
539
+ type: "function",
540
+ id: String(toolUse.id || `call_${randomUUID()}`),
541
+ function: {
542
+ name: String(toolUse.name),
543
+ arguments: JSON.stringify(toolUse.input || {}),
544
+ },
545
+ }));
546
+ }
547
+ }
548
+ return undefined;
549
+ }
550
+ /**
551
+ * Map SageMaker finish reason to standardized format
552
+ */
553
+ mapSageMakerFinishReason(sagemakerReason) {
554
+ switch (sagemakerReason?.toLowerCase()) {
555
+ case "stop":
556
+ case "end_turn":
557
+ case "stop_sequence":
558
+ return "stop";
559
+ case "length":
560
+ case "max_tokens":
561
+ case "max_length":
562
+ return "length";
563
+ case "content_filter":
564
+ case "content_filtered":
565
+ return "content-filter";
566
+ case "tool_calls":
567
+ case "function_call":
568
+ return "tool-calls";
569
+ case "error":
570
+ return "error";
571
+ default:
572
+ return "unknown";
573
+ }
574
+ }
575
+ /**
576
+ * Get model configuration summary for debugging
577
+ */
578
+ getModelInfo() {
579
+ return {
580
+ modelId: this.modelId,
581
+ provider: this.provider,
582
+ specificationVersion: this.specificationVersion,
583
+ endpointName: this.modelConfig.endpointName,
584
+ modelType: this.modelConfig.modelType,
585
+ region: this.config.region,
586
+ };
587
+ }
588
+ /**
589
+ * Test basic connectivity to the SageMaker endpoint
590
+ */
591
+ async testConnectivity() {
592
+ try {
593
+ // Use the same pattern as Ollama - pass messages directly
594
+ const result = await this.doGenerate({
595
+ inputFormat: "messages",
596
+ mode: { type: "regular" },
597
+ prompt: [
598
+ { role: "user", content: [{ type: "text", text: "Hello" }] },
599
+ ],
600
+ maxTokens: 10,
601
+ });
602
+ return {
603
+ success: !!result.text,
604
+ };
605
+ }
606
+ catch (error) {
607
+ return {
608
+ success: false,
609
+ error: error instanceof Error ? error.message : String(error),
610
+ };
611
+ }
612
+ }
613
+ /**
614
+ * Batch inference support (Phase 4)
615
+ * Process multiple prompts in a single request for efficiency
616
+ */
617
+ async doBatchGenerate(prompts, options) {
618
+ try {
619
+ logger.debug("SageMaker batch generate called", {
620
+ batchSize: prompts.length,
621
+ endpointName: this.modelConfig.endpointName,
622
+ });
623
+ // Advanced parallel processing with dynamic concurrency and error handling
624
+ const results = await this.processPromptsInParallel(prompts, options);
625
+ logger.debug("SageMaker batch generate completed", {
626
+ batchSize: prompts.length,
627
+ successCount: results.length,
628
+ });
629
+ return results;
630
+ }
631
+ catch (error) {
632
+ logger.error("SageMaker batch generate failed", {
633
+ error: error instanceof Error ? error.message : String(error),
634
+ batchSize: prompts.length,
635
+ });
636
+ throw handleSageMakerError(error, this.modelConfig.endpointName);
637
+ }
638
+ }
639
+ /**
640
+ * Process prompts in parallel with advanced concurrency control and error handling
641
+ */
642
+ async processPromptsInParallel(prompts, options) {
643
+ // Dynamic concurrency based on batch size and endpoint capacity
644
+ const INITIAL_CONCURRENCY = Math.min(this.modelConfig.initialConcurrency ?? DEFAULT_INITIAL_CONCURRENCY, prompts.length);
645
+ const MAX_CONCURRENCY = this.modelConfig.maxConcurrency ?? DEFAULT_MAX_CONCURRENCY;
646
+ const MIN_CONCURRENCY = this.modelConfig.minConcurrency ?? DEFAULT_MIN_CONCURRENCY;
647
+ const results = new Array(prompts.length);
648
+ const errors = [];
649
+ // Use adaptive semaphore utility for concurrency control
650
+ const semaphore = createAdaptiveSemaphore(INITIAL_CONCURRENCY, MAX_CONCURRENCY, MIN_CONCURRENCY);
651
+ // Process each prompt with adaptive concurrency
652
+ const processPrompt = async (prompt, index) => {
653
+ await semaphore.acquire();
654
+ const startTime = Date.now();
655
+ try {
656
+ const result = await this.doGenerate({
657
+ inputFormat: "messages",
658
+ mode: { type: "regular" },
659
+ prompt: [
660
+ {
661
+ role: "user",
662
+ content: [{ type: "text", text: prompt }],
663
+ },
664
+ ],
665
+ maxTokens: options?.maxTokens,
666
+ temperature: options?.temperature,
667
+ topP: options?.topP,
668
+ });
669
+ const duration = Date.now() - startTime;
670
+ results[index] = {
671
+ text: result.text || "",
672
+ usage: {
673
+ promptTokens: result.usage.promptTokens,
674
+ completionTokens: result.usage.completionTokens,
675
+ totalTokens: result.usage.totalTokens ||
676
+ result.usage.promptTokens + result.usage.completionTokens,
677
+ },
678
+ finishReason: result.finishReason,
679
+ index,
680
+ };
681
+ // Record successful completion for adaptive concurrency adjustment
682
+ semaphore.recordSuccess(duration);
683
+ }
684
+ catch (error) {
685
+ errors.push({
686
+ index,
687
+ error: error instanceof Error ? error : new Error(String(error)),
688
+ });
689
+ // Record error for adaptive concurrency adjustment
690
+ const duration = Date.now() - startTime;
691
+ semaphore.recordError(duration);
692
+ // Create error result
693
+ results[index] = {
694
+ text: "",
695
+ usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 },
696
+ finishReason: "error",
697
+ index,
698
+ };
699
+ }
700
+ finally {
701
+ semaphore.release();
702
+ }
703
+ };
704
+ // Start all requests with concurrency control
705
+ const allPromises = prompts.map((prompt, index) => processPrompt(prompt, index));
706
+ // Wait for all requests to complete
707
+ await Promise.all(allPromises);
708
+ // Log final statistics using semaphore metrics
709
+ const metrics = semaphore.getMetrics();
710
+ logger.debug("Parallel batch processing completed", {
711
+ totalPrompts: prompts.length,
712
+ successCount: metrics.completedCount,
713
+ errorCount: metrics.errorCount,
714
+ finalConcurrency: metrics.currentConcurrency,
715
+ errorRate: metrics.errorCount / prompts.length,
716
+ averageResponseTime: metrics.averageResponseTime,
717
+ });
718
+ // If we have too many errors, log them for debugging
719
+ if (errors.length > 0) {
720
+ logger.warn("Batch processing encountered errors", {
721
+ errorCount: errors.length,
722
+ sampleErrors: errors.slice(0, 3).map((e) => ({
723
+ index: e.index,
724
+ message: e.error.message,
725
+ })),
726
+ });
727
+ }
728
+ // Return results in original order (already sorted by index)
729
+ return results.map(({ text, usage, finishReason }) => ({
730
+ text,
731
+ usage,
732
+ finishReason,
733
+ }));
734
+ }
735
+ /**
736
+ * Enhanced model information with batch capabilities
737
+ */
738
+ getModelCapabilities() {
739
+ return {
740
+ ...this.getModelInfo(),
741
+ capabilities: {
742
+ streaming: true,
743
+ toolCalling: true,
744
+ structuredOutput: true,
745
+ batchInference: true,
746
+ supportedResponseFormats: ["text", "json_object", "json_schema"],
747
+ supportedToolTypes: ["function"],
748
+ maxBatchSize: 100, // Increased limit with parallel processing
749
+ adaptiveConcurrency: true,
750
+ errorRecovery: true,
751
+ },
752
+ };
753
+ }
754
+ }
755
+ export default SageMakerLanguageModel;