@juspay/neurolink 7.29.1 → 7.29.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +12 -0
- package/dist/cli/commands/config.d.ts +86 -86
- package/dist/cli/commands/mcp.js +64 -9
- package/dist/cli/commands/models.js +25 -21
- package/dist/cli/commands/ollama.js +2 -2
- package/dist/cli/factories/commandFactory.d.ts +9 -0
- package/dist/cli/factories/commandFactory.js +177 -83
- package/dist/cli/factories/ollamaCommandFactory.js +3 -1
- package/dist/cli/factories/sagemakerCommandFactory.js +3 -2
- package/dist/cli/index.d.ts +1 -1
- package/dist/cli/index.js +19 -11
- package/dist/cli/utils/envManager.js +5 -5
- package/dist/cli/utils/ollamaUtils.d.ts +12 -0
- package/dist/cli/utils/ollamaUtils.js +58 -42
- package/dist/config/configManager.js +5 -2
- package/dist/core/analytics.d.ts +2 -24
- package/dist/core/analytics.js +12 -17
- package/dist/core/baseProvider.d.ts +30 -1
- package/dist/core/baseProvider.js +180 -198
- package/dist/core/dynamicModels.d.ts +4 -4
- package/dist/core/dynamicModels.js +7 -7
- package/dist/core/evaluation.d.ts +9 -9
- package/dist/core/evaluation.js +117 -65
- package/dist/core/evaluationProviders.d.ts +18 -2
- package/dist/core/evaluationProviders.js +15 -13
- package/dist/core/factory.js +77 -4
- package/dist/core/modelConfiguration.d.ts +63 -0
- package/dist/core/modelConfiguration.js +354 -290
- package/dist/core/streamAnalytics.d.ts +10 -5
- package/dist/core/streamAnalytics.js +10 -10
- package/dist/core/types.d.ts +19 -109
- package/dist/core/types.js +13 -0
- package/dist/factories/providerFactory.js +4 -1
- package/dist/factories/providerRegistry.js +2 -2
- package/dist/index.d.ts +2 -1
- package/dist/lib/config/configManager.js +5 -2
- package/dist/lib/core/analytics.d.ts +2 -24
- package/dist/lib/core/analytics.js +12 -17
- package/dist/lib/core/baseProvider.d.ts +30 -1
- package/dist/lib/core/baseProvider.js +180 -198
- package/dist/lib/core/dynamicModels.js +7 -7
- package/dist/lib/core/evaluation.d.ts +9 -9
- package/dist/lib/core/evaluation.js +117 -65
- package/dist/lib/core/evaluationProviders.d.ts +18 -2
- package/dist/lib/core/evaluationProviders.js +15 -13
- package/dist/lib/core/factory.js +77 -4
- package/dist/lib/core/modelConfiguration.d.ts +63 -0
- package/dist/lib/core/modelConfiguration.js +354 -290
- package/dist/lib/core/streamAnalytics.d.ts +10 -5
- package/dist/lib/core/streamAnalytics.js +10 -10
- package/dist/lib/core/types.d.ts +19 -109
- package/dist/lib/core/types.js +13 -0
- package/dist/lib/factories/providerFactory.js +4 -1
- package/dist/lib/factories/providerRegistry.js +2 -2
- package/dist/lib/index.d.ts +2 -1
- package/dist/lib/mcp/externalServerManager.js +14 -6
- package/dist/lib/mcp/factory.js +1 -1
- package/dist/lib/mcp/flexibleToolValidator.d.ts +50 -0
- package/dist/lib/mcp/flexibleToolValidator.js +161 -0
- package/dist/lib/mcp/index.d.ts +1 -1
- package/dist/lib/mcp/index.js +1 -1
- package/dist/lib/mcp/mcpCircuitBreaker.js +5 -1
- package/dist/lib/mcp/mcpClientFactory.js +3 -0
- package/dist/lib/mcp/registry.d.ts +3 -3
- package/dist/lib/mcp/registry.js +3 -3
- package/dist/lib/mcp/servers/aiProviders/aiAnalysisTools.js +5 -5
- package/dist/lib/mcp/servers/aiProviders/aiWorkflowTools.js +6 -6
- package/dist/lib/mcp/servers/utilities/utilityServer.js +1 -1
- package/dist/lib/mcp/toolDiscoveryService.js +8 -2
- package/dist/lib/mcp/toolRegistry.d.ts +2 -2
- package/dist/lib/mcp/toolRegistry.js +29 -54
- package/dist/lib/middleware/builtin/analytics.js +4 -4
- package/dist/lib/middleware/builtin/guardrails.js +2 -2
- package/dist/lib/middleware/registry.js +11 -2
- package/dist/lib/models/modelRegistry.d.ts +1 -1
- package/dist/lib/models/modelRegistry.js +3 -3
- package/dist/lib/models/modelResolver.d.ts +1 -1
- package/dist/lib/models/modelResolver.js +2 -2
- package/dist/lib/neurolink.d.ts +118 -0
- package/dist/lib/neurolink.js +814 -952
- package/dist/lib/providers/amazonBedrock.d.ts +47 -6
- package/dist/lib/providers/amazonBedrock.js +282 -23
- package/dist/lib/providers/amazonSagemaker.d.ts +1 -1
- package/dist/lib/providers/amazonSagemaker.js +12 -3
- package/dist/lib/providers/anthropic.d.ts +1 -1
- package/dist/lib/providers/anthropic.js +7 -6
- package/dist/lib/providers/anthropicBaseProvider.d.ts +1 -1
- package/dist/lib/providers/anthropicBaseProvider.js +4 -3
- package/dist/lib/providers/aws/credentialProvider.d.ts +58 -0
- package/dist/lib/providers/aws/credentialProvider.js +267 -0
- package/dist/lib/providers/aws/credentialTester.d.ts +49 -0
- package/dist/lib/providers/aws/credentialTester.js +394 -0
- package/dist/lib/providers/azureOpenai.d.ts +1 -1
- package/dist/lib/providers/azureOpenai.js +1 -1
- package/dist/lib/providers/googleAiStudio.d.ts +1 -1
- package/dist/lib/providers/googleAiStudio.js +2 -2
- package/dist/lib/providers/googleVertex.d.ts +40 -0
- package/dist/lib/providers/googleVertex.js +330 -274
- package/dist/lib/providers/huggingFace.js +1 -1
- package/dist/lib/providers/mistral.d.ts +1 -1
- package/dist/lib/providers/mistral.js +2 -2
- package/dist/lib/providers/ollama.d.ts +4 -0
- package/dist/lib/providers/ollama.js +38 -18
- package/dist/lib/providers/openAI.d.ts +1 -1
- package/dist/lib/providers/openAI.js +2 -2
- package/dist/lib/providers/sagemaker/adaptive-semaphore.js +7 -4
- package/dist/lib/providers/sagemaker/client.js +13 -3
- package/dist/lib/providers/sagemaker/config.js +5 -1
- package/dist/lib/providers/sagemaker/detection.js +19 -9
- package/dist/lib/providers/sagemaker/errors.d.ts +8 -1
- package/dist/lib/providers/sagemaker/errors.js +103 -20
- package/dist/lib/providers/sagemaker/language-model.d.ts +3 -3
- package/dist/lib/providers/sagemaker/language-model.js +4 -4
- package/dist/lib/providers/sagemaker/parsers.js +14 -6
- package/dist/lib/providers/sagemaker/streaming.js +14 -3
- package/dist/lib/providers/sagemaker/types.d.ts +1 -1
- package/dist/lib/proxy/awsProxyIntegration.d.ts +23 -0
- package/dist/lib/proxy/awsProxyIntegration.js +285 -0
- package/dist/lib/proxy/proxyFetch.d.ts +9 -5
- package/dist/lib/proxy/proxyFetch.js +232 -98
- package/dist/lib/proxy/utils/noProxyUtils.d.ts +39 -0
- package/dist/lib/proxy/utils/noProxyUtils.js +149 -0
- package/dist/lib/sdk/toolRegistration.d.ts +1 -1
- package/dist/lib/types/cli.d.ts +80 -8
- package/dist/lib/types/contextTypes.js +2 -2
- package/dist/lib/types/generateTypes.d.ts +4 -6
- package/dist/lib/types/providers.d.ts +124 -19
- package/dist/lib/types/providers.js +6 -6
- package/dist/lib/types/streamTypes.d.ts +4 -6
- package/dist/lib/types/typeAliases.d.ts +1 -1
- package/dist/lib/utils/analyticsUtils.d.ts +33 -0
- package/dist/lib/utils/analyticsUtils.js +76 -0
- package/dist/lib/utils/errorHandling.js +4 -1
- package/dist/lib/utils/evaluationUtils.d.ts +27 -0
- package/dist/lib/utils/evaluationUtils.js +131 -0
- package/dist/lib/utils/optionsUtils.js +10 -1
- package/dist/lib/utils/performance.d.ts +1 -1
- package/dist/lib/utils/performance.js +15 -3
- package/dist/lib/utils/providerConfig.d.ts +1 -0
- package/dist/lib/utils/providerConfig.js +2 -1
- package/dist/lib/utils/providerHealth.d.ts +48 -0
- package/dist/lib/utils/providerHealth.js +221 -158
- package/dist/lib/utils/providerUtils.js +2 -2
- package/dist/lib/utils/timeout.js +8 -3
- package/dist/mcp/externalServerManager.js +14 -6
- package/dist/mcp/factory.js +1 -1
- package/dist/mcp/flexibleToolValidator.d.ts +50 -0
- package/dist/mcp/flexibleToolValidator.js +161 -0
- package/dist/mcp/index.d.ts +1 -1
- package/dist/mcp/index.js +1 -1
- package/dist/mcp/mcpCircuitBreaker.js +5 -1
- package/dist/mcp/mcpClientFactory.js +3 -0
- package/dist/mcp/registry.d.ts +3 -3
- package/dist/mcp/registry.js +3 -3
- package/dist/mcp/servers/aiProviders/aiAnalysisTools.js +5 -5
- package/dist/mcp/servers/aiProviders/aiWorkflowTools.js +6 -6
- package/dist/mcp/servers/utilities/utilityServer.js +1 -1
- package/dist/mcp/toolDiscoveryService.js +8 -2
- package/dist/mcp/toolRegistry.d.ts +2 -2
- package/dist/mcp/toolRegistry.js +29 -54
- package/dist/middleware/builtin/analytics.js +4 -4
- package/dist/middleware/builtin/guardrails.js +2 -2
- package/dist/middleware/registry.js +11 -2
- package/dist/models/modelRegistry.d.ts +1 -1
- package/dist/models/modelRegistry.js +3 -3
- package/dist/models/modelResolver.d.ts +1 -1
- package/dist/models/modelResolver.js +2 -2
- package/dist/neurolink.d.ts +118 -0
- package/dist/neurolink.js +814 -952
- package/dist/providers/amazonBedrock.d.ts +47 -6
- package/dist/providers/amazonBedrock.js +282 -23
- package/dist/providers/amazonSagemaker.d.ts +1 -1
- package/dist/providers/amazonSagemaker.js +12 -3
- package/dist/providers/anthropic.d.ts +1 -1
- package/dist/providers/anthropic.js +7 -6
- package/dist/providers/anthropicBaseProvider.d.ts +1 -1
- package/dist/providers/anthropicBaseProvider.js +4 -3
- package/dist/providers/aws/credentialProvider.d.ts +58 -0
- package/dist/providers/aws/credentialProvider.js +267 -0
- package/dist/providers/aws/credentialTester.d.ts +49 -0
- package/dist/providers/aws/credentialTester.js +394 -0
- package/dist/providers/azureOpenai.d.ts +1 -1
- package/dist/providers/azureOpenai.js +1 -1
- package/dist/providers/googleAiStudio.d.ts +1 -1
- package/dist/providers/googleAiStudio.js +2 -2
- package/dist/providers/googleVertex.d.ts +40 -0
- package/dist/providers/googleVertex.js +330 -274
- package/dist/providers/huggingFace.js +1 -1
- package/dist/providers/mistral.d.ts +1 -1
- package/dist/providers/mistral.js +2 -2
- package/dist/providers/ollama.d.ts +4 -0
- package/dist/providers/ollama.js +38 -18
- package/dist/providers/openAI.d.ts +1 -1
- package/dist/providers/openAI.js +2 -2
- package/dist/providers/sagemaker/adaptive-semaphore.js +7 -4
- package/dist/providers/sagemaker/client.js +13 -3
- package/dist/providers/sagemaker/config.js +5 -1
- package/dist/providers/sagemaker/detection.js +19 -9
- package/dist/providers/sagemaker/errors.d.ts +8 -1
- package/dist/providers/sagemaker/errors.js +103 -20
- package/dist/providers/sagemaker/language-model.d.ts +3 -3
- package/dist/providers/sagemaker/language-model.js +4 -4
- package/dist/providers/sagemaker/parsers.js +14 -6
- package/dist/providers/sagemaker/streaming.js +14 -3
- package/dist/providers/sagemaker/types.d.ts +1 -1
- package/dist/proxy/awsProxyIntegration.d.ts +23 -0
- package/dist/proxy/awsProxyIntegration.js +285 -0
- package/dist/proxy/proxyFetch.d.ts +9 -5
- package/dist/proxy/proxyFetch.js +232 -98
- package/dist/proxy/utils/noProxyUtils.d.ts +39 -0
- package/dist/proxy/utils/noProxyUtils.js +149 -0
- package/dist/sdk/toolRegistration.d.ts +1 -1
- package/dist/types/cli.d.ts +80 -8
- package/dist/types/contextTypes.js +2 -2
- package/dist/types/generateTypes.d.ts +4 -6
- package/dist/types/providers.d.ts +124 -19
- package/dist/types/providers.js +6 -6
- package/dist/types/streamTypes.d.ts +4 -6
- package/dist/types/typeAliases.d.ts +1 -1
- package/dist/utils/analyticsUtils.d.ts +33 -0
- package/dist/utils/analyticsUtils.js +76 -0
- package/dist/utils/errorHandling.js +4 -1
- package/dist/utils/evaluationUtils.d.ts +27 -0
- package/dist/utils/evaluationUtils.js +131 -0
- package/dist/utils/optionsUtils.js +10 -1
- package/dist/utils/performance.d.ts +1 -1
- package/dist/utils/performance.js +15 -3
- package/dist/utils/providerConfig.d.ts +1 -0
- package/dist/utils/providerConfig.js +2 -1
- package/dist/utils/providerHealth.d.ts +48 -0
- package/dist/utils/providerHealth.js +221 -158
- package/dist/utils/providerUtils.js +2 -2
- package/dist/utils/timeout.js +8 -3
- package/package.json +5 -1
|
@@ -147,7 +147,7 @@ export class HuggingFaceProvider extends BaseProvider {
|
|
|
147
147
|
* Prepare stream options with HuggingFace-specific enhancements
|
|
148
148
|
* Handles tool calling optimizations and model-specific formatting
|
|
149
149
|
*/
|
|
150
|
-
prepareStreamOptions(options,
|
|
150
|
+
prepareStreamOptions(options, _analysisSchema) {
|
|
151
151
|
const modelSupportsTools = this.supportsTools();
|
|
152
152
|
// If model doesn't support tools, disable them completely
|
|
153
153
|
if (!modelSupportsTools) {
|
|
@@ -10,7 +10,7 @@ import { BaseProvider } from "../core/baseProvider.js";
|
|
|
10
10
|
export declare class MistralProvider extends BaseProvider {
|
|
11
11
|
private model;
|
|
12
12
|
constructor(modelName?: string, sdk?: unknown);
|
|
13
|
-
protected executeStream(options: StreamOptions,
|
|
13
|
+
protected executeStream(options: StreamOptions, _analysisSchema?: ValidationSchema): Promise<StreamResult>;
|
|
14
14
|
protected getProviderName(): AIProviderName;
|
|
15
15
|
protected getDefaultModel(): string;
|
|
16
16
|
/**
|
|
@@ -2,7 +2,7 @@ import { createMistral } from "@ai-sdk/mistral";
|
|
|
2
2
|
import { streamText } from "ai";
|
|
3
3
|
import { BaseProvider } from "../core/baseProvider.js";
|
|
4
4
|
import { logger } from "../utils/logger.js";
|
|
5
|
-
import { createTimeoutController, TimeoutError
|
|
5
|
+
import { createTimeoutController, TimeoutError } from "../utils/timeout.js";
|
|
6
6
|
import { DEFAULT_MAX_TOKENS, DEFAULT_MAX_STEPS } from "../core/constants.js";
|
|
7
7
|
import { validateApiKey, createMistralConfig, getProviderModel, } from "../utils/providerConfig.js";
|
|
8
8
|
import { streamAnalyticsCollector } from "../core/streamAnalytics.js";
|
|
@@ -40,7 +40,7 @@ export class MistralProvider extends BaseProvider {
|
|
|
40
40
|
});
|
|
41
41
|
}
|
|
42
42
|
// generate() method is inherited from BaseProvider; this provider uses the base implementation for generation with tools
|
|
43
|
-
async executeStream(options,
|
|
43
|
+
async executeStream(options, _analysisSchema) {
|
|
44
44
|
this.validateStreamOptions(options);
|
|
45
45
|
const startTime = Date.now();
|
|
46
46
|
const timeout = this.getTimeout(options);
|
|
@@ -62,6 +62,10 @@ export declare class OllamaProvider extends BaseProvider {
|
|
|
62
62
|
* Convert AI SDK tools format to Ollama's function calling format
|
|
63
63
|
*/
|
|
64
64
|
private convertToolsToOllamaFormat;
|
|
65
|
+
/**
|
|
66
|
+
* Process individual stream data chunk from Ollama
|
|
67
|
+
*/
|
|
68
|
+
private processOllamaStreamData;
|
|
65
69
|
/**
|
|
66
70
|
* Create stream generator for Ollama chat API with tool call support
|
|
67
71
|
*/
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import { BaseProvider } from "../core/baseProvider.js";
|
|
2
2
|
import { logger } from "../utils/logger.js";
|
|
3
|
-
import { TimeoutError } from "../utils/timeout.js";
|
|
4
3
|
import { DEFAULT_MAX_TOKENS } from "../core/constants.js";
|
|
5
4
|
import { modelConfig } from "../core/modelConfiguration.js";
|
|
6
5
|
import { createProxyFetch } from "../proxy/proxyFetch.js";
|
|
6
|
+
import { TimeoutError } from "../utils/timeout.js";
|
|
7
7
|
// Model version constants (configurable via environment)
|
|
8
8
|
const DEFAULT_OLLAMA_MODEL = "llama3.1:8b";
|
|
9
9
|
const FALLBACK_OLLAMA_MODEL = "llama3.2:latest"; // Used when primary model fails
|
|
@@ -197,7 +197,9 @@ class OllamaLanguageModel {
|
|
|
197
197
|
}
|
|
198
198
|
}
|
|
199
199
|
catch (error) {
|
|
200
|
-
|
|
200
|
+
logger.error("Error parsing Ollama stream response", {
|
|
201
|
+
error,
|
|
202
|
+
});
|
|
201
203
|
}
|
|
202
204
|
}
|
|
203
205
|
}
|
|
@@ -368,7 +370,7 @@ export class OllamaProvider extends BaseProvider {
|
|
|
368
370
|
* Execute streaming without tools using the generate API
|
|
369
371
|
* Fallback for non-tool scenarios or when chat API is unavailable
|
|
370
372
|
*/
|
|
371
|
-
async executeStreamWithoutTools(options,
|
|
373
|
+
async executeStreamWithoutTools(options, _analysisSchema) {
|
|
372
374
|
const response = await proxyFetch(`${this.baseUrl}/api/generate`, {
|
|
373
375
|
method: "POST",
|
|
374
376
|
headers: { "Content-Type": "application/json" },
|
|
@@ -423,10 +425,32 @@ export class OllamaProvider extends BaseProvider {
|
|
|
423
425
|
},
|
|
424
426
|
}));
|
|
425
427
|
}
|
|
428
|
+
/**
|
|
429
|
+
* Process individual stream data chunk from Ollama
|
|
430
|
+
*/
|
|
431
|
+
processOllamaStreamData(data) {
|
|
432
|
+
const dataRecord = data;
|
|
433
|
+
const choices = dataRecord.choices;
|
|
434
|
+
const delta = choices?.[0]?.delta;
|
|
435
|
+
let content = "";
|
|
436
|
+
if (delta?.content && typeof delta.content === "string") {
|
|
437
|
+
content += delta.content;
|
|
438
|
+
}
|
|
439
|
+
if (delta?.tool_calls) {
|
|
440
|
+
// Handle tool calls - for now, we'll include them as content
|
|
441
|
+
// Future enhancement: Execute tools and return results
|
|
442
|
+
const toolCallDescription = this.formatToolCallForDisplay(delta.tool_calls);
|
|
443
|
+
if (toolCallDescription) {
|
|
444
|
+
content += toolCallDescription;
|
|
445
|
+
}
|
|
446
|
+
}
|
|
447
|
+
const shouldReturn = !!choices?.[0]?.finish_reason;
|
|
448
|
+
return content ? { content, shouldReturn } : { shouldReturn };
|
|
449
|
+
}
|
|
426
450
|
/**
|
|
427
451
|
* Create stream generator for Ollama chat API with tool call support
|
|
428
452
|
*/
|
|
429
|
-
async *createOllamaChatStream(response,
|
|
453
|
+
async *createOllamaChatStream(response, _tools) {
|
|
430
454
|
const reader = response.body?.getReader();
|
|
431
455
|
if (!reader) {
|
|
432
456
|
throw new Error("No response body");
|
|
@@ -450,24 +474,18 @@ export class OllamaProvider extends BaseProvider {
|
|
|
450
474
|
}
|
|
451
475
|
try {
|
|
452
476
|
const data = JSON.parse(dataLine);
|
|
453
|
-
const
|
|
454
|
-
if (
|
|
455
|
-
yield { content:
|
|
456
|
-
}
|
|
457
|
-
if (delta?.tool_calls) {
|
|
458
|
-
// Handle tool calls - for now, we'll include them as content
|
|
459
|
-
// Future enhancement: Execute tools and return results
|
|
460
|
-
const toolCallDescription = this.formatToolCallForDisplay(delta.tool_calls);
|
|
461
|
-
if (toolCallDescription) {
|
|
462
|
-
yield { content: toolCallDescription };
|
|
463
|
-
}
|
|
477
|
+
const result = this.processOllamaStreamData(data);
|
|
478
|
+
if (result?.content) {
|
|
479
|
+
yield { content: result.content };
|
|
464
480
|
}
|
|
465
|
-
if (
|
|
481
|
+
if (result?.shouldReturn) {
|
|
466
482
|
return;
|
|
467
483
|
}
|
|
468
484
|
}
|
|
469
485
|
catch (error) {
|
|
470
|
-
|
|
486
|
+
logger.error("Error parsing Ollama stream response", {
|
|
487
|
+
error,
|
|
488
|
+
});
|
|
471
489
|
}
|
|
472
490
|
}
|
|
473
491
|
}
|
|
@@ -536,7 +554,9 @@ export class OllamaProvider extends BaseProvider {
|
|
|
536
554
|
}
|
|
537
555
|
}
|
|
538
556
|
catch (error) {
|
|
539
|
-
|
|
557
|
+
logger.error("Error parsing Ollama stream response", {
|
|
558
|
+
error,
|
|
559
|
+
});
|
|
540
560
|
}
|
|
541
561
|
}
|
|
542
562
|
}
|
|
@@ -23,6 +23,6 @@ export declare class OpenAIProvider extends BaseProvider {
|
|
|
23
23
|
* For details on the changes and migration steps, refer to the BaseProvider documentation
|
|
24
24
|
* and the migration guide in the project repository.
|
|
25
25
|
*/
|
|
26
|
-
protected executeStream(options: StreamOptions,
|
|
26
|
+
protected executeStream(options: StreamOptions, _analysisSchema?: ValidationSchema): Promise<StreamResult>;
|
|
27
27
|
}
|
|
28
28
|
export default OpenAIProvider;
|
|
@@ -3,7 +3,7 @@ import { streamText } from "ai";
|
|
|
3
3
|
import { AIProviderName } from "../core/types.js";
|
|
4
4
|
import { BaseProvider } from "../core/baseProvider.js";
|
|
5
5
|
import { logger } from "../utils/logger.js";
|
|
6
|
-
import { createTimeoutController, TimeoutError
|
|
6
|
+
import { createTimeoutController, TimeoutError } from "../utils/timeout.js";
|
|
7
7
|
import { AuthenticationError, InvalidModelError, NetworkError, ProviderError, RateLimitError, } from "../types/errors.js";
|
|
8
8
|
import { DEFAULT_MAX_TOKENS, DEFAULT_MAX_STEPS } from "../core/constants.js";
|
|
9
9
|
import { validateApiKey, createOpenAIConfig, getProviderModel, } from "../utils/providerConfig.js";
|
|
@@ -82,7 +82,7 @@ export class OpenAIProvider extends BaseProvider {
|
|
|
82
82
|
* For details on the changes and migration steps, refer to the BaseProvider documentation
|
|
83
83
|
* and the migration guide in the project repository.
|
|
84
84
|
*/
|
|
85
|
-
async executeStream(options,
|
|
85
|
+
async executeStream(options, _analysisSchema) {
|
|
86
86
|
this.validateStreamOptions(options);
|
|
87
87
|
const startTime = Date.now();
|
|
88
88
|
const timeout = this.getTimeout(options);
|
|
@@ -55,7 +55,10 @@ export class AdaptiveSemaphore {
|
|
|
55
55
|
this.activeRequests--;
|
|
56
56
|
if (this.waiters.length > 0) {
|
|
57
57
|
const waiter = this.waiters.shift();
|
|
58
|
-
waiter
|
|
58
|
+
if (waiter) {
|
|
59
|
+
this.count++; // Increment count before calling waiter so waiter can decrement it
|
|
60
|
+
waiter();
|
|
61
|
+
}
|
|
59
62
|
}
|
|
60
63
|
else {
|
|
61
64
|
this.count++;
|
|
@@ -103,9 +106,9 @@ export class AdaptiveSemaphore {
|
|
|
103
106
|
// Wake up waiting requests if we increased concurrency
|
|
104
107
|
while (this.count > 0 && this.waiters.length > 0) {
|
|
105
108
|
const waiter = this.waiters.shift();
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
+
if (waiter) {
|
|
110
|
+
waiter();
|
|
111
|
+
}
|
|
109
112
|
}
|
|
110
113
|
}
|
|
111
114
|
/**
|
|
@@ -142,7 +142,11 @@ export class SageMakerRuntimeClient {
|
|
|
142
142
|
});
|
|
143
143
|
// Return the response with streaming body
|
|
144
144
|
if (!response.Body) {
|
|
145
|
-
throw new SageMakerError("No response body received from streaming endpoint",
|
|
145
|
+
throw new SageMakerError("No response body received from streaming endpoint", {
|
|
146
|
+
code: "MODEL_ERROR",
|
|
147
|
+
statusCode: 500,
|
|
148
|
+
endpoint: params.EndpointName,
|
|
149
|
+
});
|
|
146
150
|
}
|
|
147
151
|
// Convert AWS response stream to async iterable of Uint8Array
|
|
148
152
|
const streamIterable = this.convertAWSStreamToIterable(response.Body);
|
|
@@ -345,7 +349,10 @@ export class SageMakerRuntimeClient {
|
|
|
345
349
|
error: error instanceof Error ? error.message : String(error),
|
|
346
350
|
streamType: typeof awsStream,
|
|
347
351
|
});
|
|
348
|
-
throw new SageMakerError(`Stream conversion failed: ${error instanceof Error ? error.message : String(error)}`,
|
|
352
|
+
throw new SageMakerError(`Stream conversion failed: ${error instanceof Error ? error.message : String(error)}`, {
|
|
353
|
+
code: "NETWORK_ERROR",
|
|
354
|
+
statusCode: 500,
|
|
355
|
+
});
|
|
349
356
|
}
|
|
350
357
|
}
|
|
351
358
|
/**
|
|
@@ -417,7 +424,10 @@ export class SageMakerRuntimeClient {
|
|
|
417
424
|
*/
|
|
418
425
|
ensureNotDisposed() {
|
|
419
426
|
if (this.isDisposed) {
|
|
420
|
-
throw new SageMakerError("Cannot perform operation on disposed SageMaker client",
|
|
427
|
+
throw new SageMakerError("Cannot perform operation on disposed SageMaker client", {
|
|
428
|
+
code: "VALIDATION_ERROR",
|
|
429
|
+
statusCode: 400,
|
|
430
|
+
});
|
|
421
431
|
}
|
|
422
432
|
}
|
|
423
433
|
}
|
|
@@ -99,7 +99,11 @@ export function getSageMakerModelConfig(endpointName) {
|
|
|
99
99
|
const endpoint = endpointName || getDefaultSageMakerEndpoint();
|
|
100
100
|
// Check cache first
|
|
101
101
|
if (modelConfigCache.has(endpoint)) {
|
|
102
|
-
|
|
102
|
+
const cachedConfig = modelConfigCache.get(endpoint);
|
|
103
|
+
if (!cachedConfig) {
|
|
104
|
+
throw new Error(`Model config for endpoint ${endpoint} not found in cache after existence check`);
|
|
105
|
+
}
|
|
106
|
+
return cachedConfig;
|
|
103
107
|
}
|
|
104
108
|
const config = {
|
|
105
109
|
endpointName: endpoint,
|
|
@@ -9,7 +9,6 @@ import { logger } from "../../utils/logger.js";
|
|
|
9
9
|
/**
|
|
10
10
|
* Configurable constants for detection timing and performance
|
|
11
11
|
*/
|
|
12
|
-
const DETECTION_TEST_DELAY_MS = 100; // Base delay between detection tests (ms)
|
|
13
12
|
const DETECTION_STAGGER_DELAY_MS = 25; // Delay between staggered test starts (ms)
|
|
14
13
|
const DETECTION_RATE_LIMIT_BACKOFF_MS = 200; // Initial backoff on rate limit detection (ms)
|
|
15
14
|
/**
|
|
@@ -78,7 +77,7 @@ export class SageMakerDetector {
|
|
|
78
77
|
];
|
|
79
78
|
// Run detection tests in parallel with intelligent rate limiting
|
|
80
79
|
const testNames = ["HuggingFace", "LLaMA", "PyTorch", "TensorFlow"];
|
|
81
|
-
const
|
|
80
|
+
const _results = await this.runDetectionTestsInParallel(detectionTests, testNames, endpointName);
|
|
82
81
|
// Analyze results and determine most likely model type
|
|
83
82
|
const scores = {
|
|
84
83
|
huggingface: 0,
|
|
@@ -218,7 +217,7 @@ export class SageMakerDetector {
|
|
|
218
217
|
}
|
|
219
218
|
}
|
|
220
219
|
catch (error) {
|
|
221
|
-
|
|
220
|
+
logger.debug("HuggingFace signature test failed", { error });
|
|
222
221
|
}
|
|
223
222
|
}
|
|
224
223
|
/**
|
|
@@ -245,7 +244,7 @@ export class SageMakerDetector {
|
|
|
245
244
|
evidence.push("llama: openai text_completion object");
|
|
246
245
|
}
|
|
247
246
|
}
|
|
248
|
-
catch
|
|
247
|
+
catch {
|
|
249
248
|
// Test failed, no evidence
|
|
250
249
|
}
|
|
251
250
|
}
|
|
@@ -266,7 +265,7 @@ export class SageMakerDetector {
|
|
|
266
265
|
evidence.push("pytorch: prediction/output field pattern");
|
|
267
266
|
}
|
|
268
267
|
}
|
|
269
|
-
catch
|
|
268
|
+
catch {
|
|
270
269
|
// Test failed, no evidence
|
|
271
270
|
}
|
|
272
271
|
}
|
|
@@ -290,7 +289,7 @@ export class SageMakerDetector {
|
|
|
290
289
|
evidence.push("tensorflow: serving predictions field");
|
|
291
290
|
}
|
|
292
291
|
}
|
|
293
|
-
catch
|
|
292
|
+
catch {
|
|
294
293
|
// Test failed, no evidence
|
|
295
294
|
}
|
|
296
295
|
}
|
|
@@ -457,7 +456,10 @@ export class SageMakerDetector {
|
|
|
457
456
|
release() {
|
|
458
457
|
if (this.waiters.length > 0) {
|
|
459
458
|
const waiter = this.waiters.shift();
|
|
460
|
-
waiter
|
|
459
|
+
if (waiter) {
|
|
460
|
+
this.count++; // Increment count before calling waiter so waiter can decrement it
|
|
461
|
+
waiter();
|
|
462
|
+
}
|
|
461
463
|
}
|
|
462
464
|
else {
|
|
463
465
|
this.count++;
|
|
@@ -477,7 +479,14 @@ export class SageMakerDetector {
|
|
|
477
479
|
return { status: "fulfilled", value: undefined };
|
|
478
480
|
}
|
|
479
481
|
catch (error) {
|
|
480
|
-
const result = await this.handleDetectionTestError(error,
|
|
482
|
+
const result = await this.handleDetectionTestError(error, {
|
|
483
|
+
test: config.test,
|
|
484
|
+
testName: config.testName,
|
|
485
|
+
endpointName: config.endpointName,
|
|
486
|
+
incrementRateLimit: config.incrementRateLimit,
|
|
487
|
+
maxRateLimitRetries: config.maxRateLimitRetries,
|
|
488
|
+
rateLimitCount: config.rateLimitState.count,
|
|
489
|
+
});
|
|
481
490
|
return result;
|
|
482
491
|
}
|
|
483
492
|
finally {
|
|
@@ -498,8 +507,9 @@ export class SageMakerDetector {
|
|
|
498
507
|
/**
|
|
499
508
|
* Handle detection test errors with rate limiting and retry logic
|
|
500
509
|
*/
|
|
501
|
-
async handleDetectionTestError(error,
|
|
510
|
+
async handleDetectionTestError(error, options) {
|
|
502
511
|
const isRateLimit = this.isRateLimitError(error);
|
|
512
|
+
const { test, testName, endpointName, incrementRateLimit, maxRateLimitRetries, rateLimitCount, } = options;
|
|
503
513
|
if (isRateLimit && rateLimitCount < maxRateLimitRetries) {
|
|
504
514
|
return await this.retryWithBackoff(test, testName, endpointName, incrementRateLimit, rateLimitCount);
|
|
505
515
|
}
|
|
@@ -15,7 +15,14 @@ export declare class SageMakerError extends Error {
|
|
|
15
15
|
readonly endpoint?: string;
|
|
16
16
|
readonly requestId?: string;
|
|
17
17
|
readonly retryable: boolean;
|
|
18
|
-
constructor(message: string,
|
|
18
|
+
constructor(message: string, options?: {
|
|
19
|
+
code?: SageMakerErrorCode;
|
|
20
|
+
statusCode?: number;
|
|
21
|
+
cause?: Error;
|
|
22
|
+
endpoint?: string;
|
|
23
|
+
requestId?: string;
|
|
24
|
+
retryable?: boolean;
|
|
25
|
+
});
|
|
19
26
|
/**
|
|
20
27
|
* Convert error to JSON for logging/serialization
|
|
21
28
|
*/
|
|
@@ -15,15 +15,15 @@ export class SageMakerError extends Error {
|
|
|
15
15
|
endpoint;
|
|
16
16
|
requestId;
|
|
17
17
|
retryable;
|
|
18
|
-
constructor(message,
|
|
18
|
+
constructor(message, options = {}) {
|
|
19
19
|
super(message);
|
|
20
20
|
this.name = "SageMakerError";
|
|
21
|
-
this.code = code;
|
|
22
|
-
this.statusCode = statusCode;
|
|
23
|
-
this.cause = cause;
|
|
24
|
-
this.endpoint = endpoint;
|
|
25
|
-
this.requestId = requestId;
|
|
26
|
-
this.retryable = retryable;
|
|
21
|
+
this.code = options.code ?? "UNKNOWN_ERROR";
|
|
22
|
+
this.statusCode = options.statusCode;
|
|
23
|
+
this.cause = options.cause;
|
|
24
|
+
this.endpoint = options.endpoint;
|
|
25
|
+
this.requestId = options.requestId;
|
|
26
|
+
this.retryable = options.retryable ?? false;
|
|
27
27
|
// Capture stack trace if available
|
|
28
28
|
if (Error.captureStackTrace) {
|
|
29
29
|
Error.captureStackTrace(this, SageMakerError);
|
|
@@ -83,41 +83,111 @@ export function handleSageMakerError(error, endpoint) {
|
|
|
83
83
|
// AWS SDK specific errors using centralized constants
|
|
84
84
|
if (errorName === "ValidationException" ||
|
|
85
85
|
ERROR_KEYWORDS.VALIDATION.some((keyword) => errorMessage.includes(keyword))) {
|
|
86
|
-
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.VALIDATION}: ${error.message}`,
|
|
86
|
+
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.VALIDATION}: ${error.message}`, {
|
|
87
|
+
code: "VALIDATION_ERROR",
|
|
88
|
+
statusCode: 400,
|
|
89
|
+
cause: error,
|
|
90
|
+
endpoint,
|
|
91
|
+
requestId: extractRequestId(error),
|
|
92
|
+
retryable: false,
|
|
93
|
+
});
|
|
87
94
|
}
|
|
88
95
|
if (errorName === "ModelError" ||
|
|
89
96
|
ERROR_KEYWORDS.MODEL.some((keyword) => errorMessage.includes(keyword))) {
|
|
90
|
-
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.MODEL}: ${error.message}`,
|
|
97
|
+
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.MODEL}: ${error.message}`, {
|
|
98
|
+
code: "MODEL_ERROR",
|
|
99
|
+
statusCode: 500,
|
|
100
|
+
cause: error,
|
|
101
|
+
endpoint,
|
|
102
|
+
requestId: extractRequestId(error),
|
|
103
|
+
retryable: false,
|
|
104
|
+
});
|
|
91
105
|
}
|
|
92
106
|
if (errorName === "InternalFailure" ||
|
|
93
107
|
ERROR_KEYWORDS.INTERNAL.some((keyword) => errorMessage.includes(keyword))) {
|
|
94
|
-
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.INTERNAL}: ${error.message}`,
|
|
108
|
+
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.INTERNAL}: ${error.message}`, {
|
|
109
|
+
code: "INTERNAL_ERROR",
|
|
110
|
+
statusCode: 500,
|
|
111
|
+
cause: error,
|
|
112
|
+
endpoint,
|
|
113
|
+
requestId: extractRequestId(error),
|
|
114
|
+
retryable: true,
|
|
115
|
+
});
|
|
95
116
|
}
|
|
96
117
|
if (errorName === "ServiceUnavailable" ||
|
|
97
118
|
ERROR_KEYWORDS.SERVICE_UNAVAILABLE.some((keyword) => errorMessage.includes(keyword))) {
|
|
98
|
-
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.SERVICE_UNAVAILABLE}: ${error.message}`,
|
|
119
|
+
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.SERVICE_UNAVAILABLE}: ${error.message}`, {
|
|
120
|
+
code: "SERVICE_UNAVAILABLE",
|
|
121
|
+
statusCode: 503,
|
|
122
|
+
cause: error,
|
|
123
|
+
endpoint,
|
|
124
|
+
requestId: extractRequestId(error),
|
|
125
|
+
retryable: true,
|
|
126
|
+
});
|
|
99
127
|
}
|
|
100
128
|
if (errorName === "ThrottlingException" ||
|
|
101
129
|
ERROR_KEYWORDS.THROTTLING.some((keyword) => errorMessage.includes(keyword))) {
|
|
102
|
-
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.THROTTLING}: ${error.message}`,
|
|
130
|
+
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.THROTTLING}: ${error.message}`, {
|
|
131
|
+
code: "THROTTLING_ERROR",
|
|
132
|
+
statusCode: 429,
|
|
133
|
+
cause: error,
|
|
134
|
+
endpoint,
|
|
135
|
+
requestId: extractRequestId(error),
|
|
136
|
+
retryable: true,
|
|
137
|
+
});
|
|
103
138
|
}
|
|
104
139
|
if (errorName === "CredentialsError" ||
|
|
105
140
|
ERROR_KEYWORDS.CREDENTIALS.some((keyword) => errorMessage.includes(keyword))) {
|
|
106
|
-
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.CREDENTIALS}: ${error.message}`,
|
|
141
|
+
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.CREDENTIALS}: ${error.message}`, {
|
|
142
|
+
code: "CREDENTIALS_ERROR",
|
|
143
|
+
statusCode: 401,
|
|
144
|
+
cause: error,
|
|
145
|
+
endpoint,
|
|
146
|
+
requestId: undefined,
|
|
147
|
+
retryable: false,
|
|
148
|
+
});
|
|
107
149
|
}
|
|
108
150
|
if (errorName === "NetworkingError" ||
|
|
109
151
|
ERROR_KEYWORDS.NETWORK.some((keyword) => errorMessage.includes(keyword))) {
|
|
110
|
-
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.NETWORK}: ${error.message}`,
|
|
152
|
+
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.NETWORK}: ${error.message}`, {
|
|
153
|
+
code: "NETWORK_ERROR",
|
|
154
|
+
statusCode: 0,
|
|
155
|
+
cause: error,
|
|
156
|
+
endpoint,
|
|
157
|
+
requestId: undefined,
|
|
158
|
+
retryable: true,
|
|
159
|
+
});
|
|
111
160
|
}
|
|
112
161
|
if (ERROR_KEYWORDS.ENDPOINT_NOT_FOUND.every((keyword) => errorMessage.includes(keyword))) {
|
|
113
|
-
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.ENDPOINT_NOT_FOUND}: ${error.message}`,
|
|
162
|
+
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.ENDPOINT_NOT_FOUND}: ${error.message}`, {
|
|
163
|
+
code: "ENDPOINT_NOT_FOUND",
|
|
164
|
+
statusCode: 404,
|
|
165
|
+
cause: error,
|
|
166
|
+
endpoint,
|
|
167
|
+
requestId: extractRequestId(error),
|
|
168
|
+
retryable: false,
|
|
169
|
+
});
|
|
114
170
|
}
|
|
115
171
|
// Generic error handling
|
|
116
|
-
return new SageMakerError(error.message,
|
|
172
|
+
return new SageMakerError(error.message, {
|
|
173
|
+
code: "UNKNOWN_ERROR",
|
|
174
|
+
statusCode: 500,
|
|
175
|
+
cause: error,
|
|
176
|
+
endpoint,
|
|
177
|
+
requestId: extractRequestId(error),
|
|
178
|
+
retryable: false,
|
|
179
|
+
});
|
|
117
180
|
}
|
|
118
181
|
// Handle non-Error objects
|
|
119
182
|
const errorMessage = typeof error === "string" ? error : "Unknown error occurred";
|
|
120
|
-
return new SageMakerError(errorMessage,
|
|
183
|
+
return new SageMakerError(errorMessage, {
|
|
184
|
+
code: "UNKNOWN_ERROR",
|
|
185
|
+
statusCode: 500,
|
|
186
|
+
cause: undefined,
|
|
187
|
+
endpoint,
|
|
188
|
+
requestId: undefined,
|
|
189
|
+
retryable: false,
|
|
190
|
+
});
|
|
121
191
|
}
|
|
122
192
|
/**
|
|
123
193
|
* Extract request ID from AWS SDK error for debugging
|
|
@@ -160,7 +230,11 @@ export function createValidationError(message, field) {
|
|
|
160
230
|
const fullMessage = field
|
|
161
231
|
? `${ERROR_MESSAGE_PREFIXES.VALIDATION_FIELD} '${field}': ${message}`
|
|
162
232
|
: message;
|
|
163
|
-
return new SageMakerError(fullMessage,
|
|
233
|
+
return new SageMakerError(fullMessage, {
|
|
234
|
+
code: "VALIDATION_ERROR",
|
|
235
|
+
statusCode: 400,
|
|
236
|
+
retryable: false,
|
|
237
|
+
});
|
|
164
238
|
}
|
|
165
239
|
/**
|
|
166
240
|
* Create a credentials error with setup guidance
|
|
@@ -169,7 +243,11 @@ export function createValidationError(message, field) {
|
|
|
169
243
|
* @returns SageMakerError with credentials guidance
|
|
170
244
|
*/
|
|
171
245
|
export function createCredentialsError(message) {
|
|
172
|
-
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.CREDENTIALS_SETUP}: ${message}`,
|
|
246
|
+
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.CREDENTIALS_SETUP}: ${message}`, {
|
|
247
|
+
code: "CREDENTIALS_ERROR",
|
|
248
|
+
statusCode: 401,
|
|
249
|
+
retryable: false,
|
|
250
|
+
});
|
|
173
251
|
}
|
|
174
252
|
/**
|
|
175
253
|
* Create a network error with connectivity guidance
|
|
@@ -179,7 +257,12 @@ export function createCredentialsError(message) {
|
|
|
179
257
|
* @returns SageMakerError with network guidance
|
|
180
258
|
*/
|
|
181
259
|
export function createNetworkError(message, endpoint) {
|
|
182
|
-
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.NETWORK_CONNECTION}: ${message}`,
|
|
260
|
+
return new SageMakerError(`${ERROR_MESSAGE_PREFIXES.NETWORK_CONNECTION}: ${message}`, {
|
|
261
|
+
code: "NETWORK_ERROR",
|
|
262
|
+
statusCode: 0,
|
|
263
|
+
endpoint,
|
|
264
|
+
retryable: true,
|
|
265
|
+
});
|
|
183
266
|
}
|
|
184
267
|
/**
|
|
185
268
|
* Check if an error is retryable based on its characteristics
|
|
@@ -126,7 +126,7 @@ export declare class SageMakerLanguageModel implements LanguageModelV1 {
|
|
|
126
126
|
provider: string;
|
|
127
127
|
specificationVersion: string;
|
|
128
128
|
endpointName: string;
|
|
129
|
-
modelType: "
|
|
129
|
+
modelType: "custom" | "huggingface" | "mistral" | "claude" | "llama" | "jumpstart" | undefined;
|
|
130
130
|
region: string;
|
|
131
131
|
};
|
|
132
132
|
/**
|
|
@@ -146,7 +146,7 @@ export declare class SageMakerLanguageModel implements LanguageModelV1 {
|
|
|
146
146
|
usage: {
|
|
147
147
|
promptTokens: number;
|
|
148
148
|
completionTokens: number;
|
|
149
|
-
|
|
149
|
+
total: number;
|
|
150
150
|
};
|
|
151
151
|
finishReason: "stop" | "length" | "content-filter" | "tool-calls" | "error" | "unknown";
|
|
152
152
|
}>>;
|
|
@@ -173,7 +173,7 @@ export declare class SageMakerLanguageModel implements LanguageModelV1 {
|
|
|
173
173
|
provider: string;
|
|
174
174
|
specificationVersion: string;
|
|
175
175
|
endpointName: string;
|
|
176
|
-
modelType: "
|
|
176
|
+
modelType: "custom" | "huggingface" | "mistral" | "claude" | "llama" | "jumpstart" | undefined;
|
|
177
177
|
region: string;
|
|
178
178
|
};
|
|
179
179
|
}
|
|
@@ -8,7 +8,7 @@ import { randomUUID } from "crypto";
|
|
|
8
8
|
import { SageMakerRuntimeClient } from "./client.js";
|
|
9
9
|
import { handleSageMakerError } from "./errors.js";
|
|
10
10
|
import { estimateTokenUsage, createSageMakerStream } from "./streaming.js";
|
|
11
|
-
import { createAdaptiveSemaphore
|
|
11
|
+
import { createAdaptiveSemaphore } from "./adaptive-semaphore.js";
|
|
12
12
|
import { logger } from "../../utils/logger.js";
|
|
13
13
|
/**
|
|
14
14
|
* Base synthetic streaming delay in milliseconds for simulating real-time response
|
|
@@ -147,7 +147,7 @@ export class SageMakerLanguageModel {
|
|
|
147
147
|
usage: {
|
|
148
148
|
promptTokens: usage.promptTokens,
|
|
149
149
|
completionTokens: usage.completionTokens,
|
|
150
|
-
totalTokens: usage.
|
|
150
|
+
totalTokens: usage.total,
|
|
151
151
|
},
|
|
152
152
|
finishReason,
|
|
153
153
|
rawCall: {
|
|
@@ -672,7 +672,7 @@ export class SageMakerLanguageModel {
|
|
|
672
672
|
usage: {
|
|
673
673
|
promptTokens: result.usage.promptTokens,
|
|
674
674
|
completionTokens: result.usage.completionTokens,
|
|
675
|
-
|
|
675
|
+
total: result.usage.totalTokens ??
|
|
676
676
|
result.usage.promptTokens + result.usage.completionTokens,
|
|
677
677
|
},
|
|
678
678
|
finishReason: result.finishReason,
|
|
@@ -692,7 +692,7 @@ export class SageMakerLanguageModel {
|
|
|
692
692
|
// Create error result
|
|
693
693
|
results[index] = {
|
|
694
694
|
text: "",
|
|
695
|
-
usage: { promptTokens: 0, completionTokens: 0,
|
|
695
|
+
usage: { promptTokens: 0, completionTokens: 0, total: 0 },
|
|
696
696
|
finishReason: "error",
|
|
697
697
|
index,
|
|
698
698
|
};
|