@juspay/neurolink 7.7.1 → 7.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +25 -2
- package/README.md +34 -2
- package/dist/cli/commands/config.d.ts +3 -3
- package/dist/cli/commands/sagemaker.d.ts +11 -0
- package/dist/cli/commands/sagemaker.js +778 -0
- package/dist/cli/factories/commandFactory.js +7 -2
- package/dist/cli/index.js +3 -0
- package/dist/cli/utils/interactiveSetup.js +28 -0
- package/dist/core/baseProvider.d.ts +2 -2
- package/dist/core/types.d.ts +16 -4
- package/dist/core/types.js +24 -3
- package/dist/factories/providerFactory.js +10 -1
- package/dist/factories/providerRegistry.js +6 -1
- package/dist/lib/core/baseProvider.d.ts +2 -2
- package/dist/lib/core/types.d.ts +16 -4
- package/dist/lib/core/types.js +24 -3
- package/dist/lib/factories/providerFactory.js +10 -1
- package/dist/lib/factories/providerRegistry.js +6 -1
- package/dist/lib/neurolink.d.ts +15 -0
- package/dist/lib/neurolink.js +73 -1
- package/dist/lib/providers/amazonSagemaker.d.ts +67 -0
- package/dist/lib/providers/amazonSagemaker.js +149 -0
- package/dist/lib/providers/googleVertex.d.ts +4 -0
- package/dist/lib/providers/googleVertex.js +44 -3
- package/dist/lib/providers/index.d.ts +4 -0
- package/dist/lib/providers/index.js +4 -0
- package/dist/lib/providers/sagemaker/adaptive-semaphore.d.ts +86 -0
- package/dist/lib/providers/sagemaker/adaptive-semaphore.js +212 -0
- package/dist/lib/providers/sagemaker/client.d.ts +156 -0
- package/dist/lib/providers/sagemaker/client.js +462 -0
- package/dist/lib/providers/sagemaker/config.d.ts +73 -0
- package/dist/lib/providers/sagemaker/config.js +308 -0
- package/dist/lib/providers/sagemaker/detection.d.ts +176 -0
- package/dist/lib/providers/sagemaker/detection.js +596 -0
- package/dist/lib/providers/sagemaker/diagnostics.d.ts +37 -0
- package/dist/lib/providers/sagemaker/diagnostics.js +137 -0
- package/dist/lib/providers/sagemaker/error-constants.d.ts +78 -0
- package/dist/lib/providers/sagemaker/error-constants.js +227 -0
- package/dist/lib/providers/sagemaker/errors.d.ts +83 -0
- package/dist/lib/providers/sagemaker/errors.js +216 -0
- package/dist/lib/providers/sagemaker/index.d.ts +35 -0
- package/dist/lib/providers/sagemaker/index.js +67 -0
- package/dist/lib/providers/sagemaker/language-model.d.ts +182 -0
- package/dist/lib/providers/sagemaker/language-model.js +755 -0
- package/dist/lib/providers/sagemaker/parsers.d.ts +136 -0
- package/dist/lib/providers/sagemaker/parsers.js +625 -0
- package/dist/lib/providers/sagemaker/streaming.d.ts +39 -0
- package/dist/lib/providers/sagemaker/streaming.js +320 -0
- package/dist/lib/providers/sagemaker/structured-parser.d.ts +117 -0
- package/dist/lib/providers/sagemaker/structured-parser.js +625 -0
- package/dist/lib/providers/sagemaker/types.d.ts +456 -0
- package/dist/lib/providers/sagemaker/types.js +7 -0
- package/dist/lib/sdk/toolRegistration.d.ts +1 -1
- package/dist/lib/sdk/toolRegistration.js +13 -5
- package/dist/lib/types/cli.d.ts +36 -1
- package/dist/lib/utils/providerHealth.js +19 -4
- package/dist/neurolink.d.ts +15 -0
- package/dist/neurolink.js +73 -1
- package/dist/providers/amazonSagemaker.d.ts +67 -0
- package/dist/providers/amazonSagemaker.js +149 -0
- package/dist/providers/googleVertex.d.ts +4 -0
- package/dist/providers/googleVertex.js +44 -3
- package/dist/providers/index.d.ts +4 -0
- package/dist/providers/index.js +4 -0
- package/dist/providers/sagemaker/adaptive-semaphore.d.ts +86 -0
- package/dist/providers/sagemaker/adaptive-semaphore.js +212 -0
- package/dist/providers/sagemaker/client.d.ts +156 -0
- package/dist/providers/sagemaker/client.js +462 -0
- package/dist/providers/sagemaker/config.d.ts +73 -0
- package/dist/providers/sagemaker/config.js +308 -0
- package/dist/providers/sagemaker/detection.d.ts +176 -0
- package/dist/providers/sagemaker/detection.js +596 -0
- package/dist/providers/sagemaker/diagnostics.d.ts +37 -0
- package/dist/providers/sagemaker/diagnostics.js +137 -0
- package/dist/providers/sagemaker/error-constants.d.ts +78 -0
- package/dist/providers/sagemaker/error-constants.js +227 -0
- package/dist/providers/sagemaker/errors.d.ts +83 -0
- package/dist/providers/sagemaker/errors.js +216 -0
- package/dist/providers/sagemaker/index.d.ts +35 -0
- package/dist/providers/sagemaker/index.js +67 -0
- package/dist/providers/sagemaker/language-model.d.ts +182 -0
- package/dist/providers/sagemaker/language-model.js +755 -0
- package/dist/providers/sagemaker/parsers.d.ts +136 -0
- package/dist/providers/sagemaker/parsers.js +625 -0
- package/dist/providers/sagemaker/streaming.d.ts +39 -0
- package/dist/providers/sagemaker/streaming.js +320 -0
- package/dist/providers/sagemaker/structured-parser.d.ts +117 -0
- package/dist/providers/sagemaker/structured-parser.js +625 -0
- package/dist/providers/sagemaker/types.d.ts +456 -0
- package/dist/providers/sagemaker/types.js +7 -0
- package/dist/sdk/toolRegistration.d.ts +1 -1
- package/dist/sdk/toolRegistration.js +13 -5
- package/dist/types/cli.d.ts +36 -1
- package/dist/utils/providerHealth.js +19 -4
- package/package.json +8 -2
|
@@ -0,0 +1,596 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* SageMaker Model Detection and Streaming Capability Discovery
|
|
3
|
+
*
|
|
4
|
+
* This module provides intelligent detection of SageMaker endpoint capabilities
|
|
5
|
+
* including model type identification and streaming protocol support.
|
|
6
|
+
*/
|
|
7
|
+
import { SageMakerRuntimeClient } from "./client.js";
|
|
8
|
+
import { logger } from "../../utils/logger.js";
|
|
9
|
+
/**
|
|
10
|
+
* Configurable constants for detection timing and performance
|
|
11
|
+
*/
|
|
12
|
+
const DETECTION_TEST_DELAY_MS = 100; // Base delay between detection tests (ms)
|
|
13
|
+
const DETECTION_STAGGER_DELAY_MS = 25; // Delay between staggered test starts (ms)
|
|
14
|
+
const DETECTION_RATE_LIMIT_BACKOFF_MS = 200; // Initial backoff on rate limit detection (ms)
|
|
15
|
+
/**
|
|
16
|
+
* SageMaker Model Detection and Capability Discovery Service
|
|
17
|
+
*/
|
|
18
|
+
export class SageMakerDetector {
|
|
19
|
+
client;
|
|
20
|
+
config;
|
|
21
|
+
constructor(config) {
|
|
22
|
+
this.client = new SageMakerRuntimeClient(config);
|
|
23
|
+
this.config = config;
|
|
24
|
+
}
|
|
25
|
+
/**
|
|
26
|
+
* Detect streaming capabilities for a given endpoint
|
|
27
|
+
*/
|
|
28
|
+
async detectStreamingCapability(endpointName) {
|
|
29
|
+
logger.debug("Starting streaming capability detection", { endpointName });
|
|
30
|
+
try {
|
|
31
|
+
// Step 1: Check endpoint health and gather metadata
|
|
32
|
+
const health = await this.checkEndpointHealth(endpointName);
|
|
33
|
+
if (health.status !== "healthy") {
|
|
34
|
+
return this.createNoStreamingCapability("custom", "Endpoint not healthy");
|
|
35
|
+
}
|
|
36
|
+
// Step 2: Detect model type
|
|
37
|
+
const modelDetection = await this.detectModelType(endpointName);
|
|
38
|
+
logger.debug("Model type detection result", {
|
|
39
|
+
endpointName,
|
|
40
|
+
type: modelDetection.type,
|
|
41
|
+
confidence: modelDetection.confidence,
|
|
42
|
+
});
|
|
43
|
+
// Step 3: Test streaming support based on model type
|
|
44
|
+
const streamingSupport = await this.testStreamingSupport(endpointName, modelDetection.type);
|
|
45
|
+
// Step 4: Determine streaming protocol
|
|
46
|
+
const protocol = await this.detectStreamingProtocol(modelDetection.type);
|
|
47
|
+
return {
|
|
48
|
+
supported: streamingSupport.supported,
|
|
49
|
+
protocol,
|
|
50
|
+
modelType: modelDetection.type,
|
|
51
|
+
confidence: Math.min(modelDetection.confidence, streamingSupport.confidence),
|
|
52
|
+
parameters: streamingSupport.parameters,
|
|
53
|
+
metadata: {
|
|
54
|
+
modelName: health.modelInfo?.name,
|
|
55
|
+
framework: health.modelInfo?.framework,
|
|
56
|
+
version: health.modelInfo?.version,
|
|
57
|
+
},
|
|
58
|
+
};
|
|
59
|
+
}
|
|
60
|
+
catch (error) {
|
|
61
|
+
logger.warn("Streaming capability detection failed", {
|
|
62
|
+
endpointName,
|
|
63
|
+
error: error instanceof Error ? error.message : String(error),
|
|
64
|
+
});
|
|
65
|
+
return this.createNoStreamingCapability("custom", "Detection failed, assuming custom model");
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
/**
|
|
69
|
+
* Detect the model type/framework for an endpoint
|
|
70
|
+
*/
|
|
71
|
+
async detectModelType(endpointName) {
|
|
72
|
+
const evidence = [];
|
|
73
|
+
const detectionTests = [
|
|
74
|
+
() => this.testHuggingFaceSignature(endpointName, evidence),
|
|
75
|
+
() => this.testLlamaSignature(endpointName, evidence),
|
|
76
|
+
() => this.testPyTorchSignature(endpointName, evidence),
|
|
77
|
+
() => this.testTensorFlowSignature(endpointName, evidence),
|
|
78
|
+
];
|
|
79
|
+
// Run detection tests in parallel with intelligent rate limiting
|
|
80
|
+
const testNames = ["HuggingFace", "LLaMA", "PyTorch", "TensorFlow"];
|
|
81
|
+
const results = await this.runDetectionTestsInParallel(detectionTests, testNames, endpointName);
|
|
82
|
+
// Analyze results and determine most likely model type
|
|
83
|
+
const scores = {
|
|
84
|
+
huggingface: 0,
|
|
85
|
+
llama: 0,
|
|
86
|
+
pytorch: 0,
|
|
87
|
+
tensorflow: 0,
|
|
88
|
+
custom: 0.1, // Base score for custom models
|
|
89
|
+
};
|
|
90
|
+
// Process evidence and calculate scores
|
|
91
|
+
evidence.forEach((item) => {
|
|
92
|
+
if (item.includes("huggingface") || item.includes("transformers")) {
|
|
93
|
+
scores.huggingface += 0.3;
|
|
94
|
+
}
|
|
95
|
+
if (item.includes("llama") || item.includes("openai-compatible")) {
|
|
96
|
+
scores.llama += 0.3;
|
|
97
|
+
}
|
|
98
|
+
if (item.includes("pytorch") || item.includes("torch")) {
|
|
99
|
+
scores.pytorch += 0.2;
|
|
100
|
+
}
|
|
101
|
+
if (item.includes("tensorflow") || item.includes("serving")) {
|
|
102
|
+
scores.tensorflow += 0.2;
|
|
103
|
+
}
|
|
104
|
+
});
|
|
105
|
+
// Find highest scoring model type
|
|
106
|
+
const maxScore = Math.max(...Object.values(scores));
|
|
107
|
+
const detectedType = Object.entries(scores).find(([, score]) => score === maxScore)?.[0] || "custom";
|
|
108
|
+
return {
|
|
109
|
+
type: detectedType,
|
|
110
|
+
confidence: maxScore,
|
|
111
|
+
evidence,
|
|
112
|
+
suggestedConfig: this.getSuggestedConfig(detectedType),
|
|
113
|
+
};
|
|
114
|
+
}
|
|
115
|
+
/**
|
|
116
|
+
* Check endpoint health and gather metadata
|
|
117
|
+
*/
|
|
118
|
+
async checkEndpointHealth(endpointName) {
|
|
119
|
+
const startTime = Date.now();
|
|
120
|
+
try {
|
|
121
|
+
// Simple health check with minimal payload
|
|
122
|
+
const testPayload = JSON.stringify({ inputs: "test" });
|
|
123
|
+
const response = await this.client.invokeEndpoint({
|
|
124
|
+
EndpointName: endpointName,
|
|
125
|
+
Body: testPayload,
|
|
126
|
+
ContentType: "application/json",
|
|
127
|
+
});
|
|
128
|
+
const responseTime = Date.now() - startTime;
|
|
129
|
+
return {
|
|
130
|
+
status: "healthy",
|
|
131
|
+
responseTime,
|
|
132
|
+
metadata: response.CustomAttributes
|
|
133
|
+
? JSON.parse(response.CustomAttributes)
|
|
134
|
+
: undefined,
|
|
135
|
+
modelInfo: this.extractModelInfo(response),
|
|
136
|
+
};
|
|
137
|
+
}
|
|
138
|
+
catch (error) {
|
|
139
|
+
const responseTime = Date.now() - startTime;
|
|
140
|
+
logger.warn("Endpoint health check failed", {
|
|
141
|
+
endpointName,
|
|
142
|
+
responseTime,
|
|
143
|
+
error: error instanceof Error ? error.message : String(error),
|
|
144
|
+
});
|
|
145
|
+
return {
|
|
146
|
+
status: "unhealthy",
|
|
147
|
+
responseTime,
|
|
148
|
+
};
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
/**
|
|
152
|
+
* Test if endpoint supports streaming for given model type
|
|
153
|
+
*/
|
|
154
|
+
async testStreamingSupport(endpointName, modelType) {
|
|
155
|
+
const testCases = this.getStreamingTestCases(modelType);
|
|
156
|
+
for (const testCase of testCases) {
|
|
157
|
+
try {
|
|
158
|
+
const response = await this.client.invokeEndpoint({
|
|
159
|
+
EndpointName: endpointName,
|
|
160
|
+
Body: JSON.stringify(testCase.payload),
|
|
161
|
+
ContentType: "application/json",
|
|
162
|
+
Accept: testCase.acceptHeader,
|
|
163
|
+
});
|
|
164
|
+
// Check response headers for streaming indicators
|
|
165
|
+
if (this.indicatesStreamingSupport(response)) {
|
|
166
|
+
return {
|
|
167
|
+
supported: true,
|
|
168
|
+
confidence: testCase.confidence,
|
|
169
|
+
parameters: testCase.parameters,
|
|
170
|
+
};
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
catch (error) {
|
|
174
|
+
// Streaming test failed, continue to next test case
|
|
175
|
+
logger.debug("Streaming test failed", {
|
|
176
|
+
endpointName,
|
|
177
|
+
error: error instanceof Error ? error.message : String(error),
|
|
178
|
+
});
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
return { supported: false, confidence: 0.9 };
|
|
182
|
+
}
|
|
183
|
+
/**
|
|
184
|
+
* Detect streaming protocol used by endpoint
|
|
185
|
+
*/
|
|
186
|
+
async detectStreamingProtocol(modelType) {
|
|
187
|
+
// Protocol mapping based on model type
|
|
188
|
+
const protocolMap = {
|
|
189
|
+
huggingface: "sse", // Server-Sent Events
|
|
190
|
+
llama: "jsonl", // JSON Lines
|
|
191
|
+
pytorch: "none", // Usually no streaming
|
|
192
|
+
tensorflow: "none", // Usually no streaming
|
|
193
|
+
custom: "chunked", // Generic chunked transfer
|
|
194
|
+
};
|
|
195
|
+
return protocolMap[modelType] || "none";
|
|
196
|
+
}
|
|
197
|
+
/**
|
|
198
|
+
* Test for HuggingFace Transformers signature
|
|
199
|
+
*/
|
|
200
|
+
async testHuggingFaceSignature(endpointName, evidence) {
|
|
201
|
+
try {
|
|
202
|
+
const testPayload = {
|
|
203
|
+
inputs: "test",
|
|
204
|
+
parameters: { return_full_text: false, max_new_tokens: 1 },
|
|
205
|
+
};
|
|
206
|
+
const response = await this.client.invokeEndpoint({
|
|
207
|
+
EndpointName: endpointName,
|
|
208
|
+
Body: JSON.stringify(testPayload),
|
|
209
|
+
ContentType: "application/json",
|
|
210
|
+
});
|
|
211
|
+
const responseText = new TextDecoder().decode(response.Body);
|
|
212
|
+
const parsedResponse = JSON.parse(responseText);
|
|
213
|
+
if (parsedResponse[0]?.generated_text !== undefined) {
|
|
214
|
+
evidence.push("huggingface: generated_text field found");
|
|
215
|
+
}
|
|
216
|
+
if (parsedResponse.error?.includes("transformers")) {
|
|
217
|
+
evidence.push("huggingface: transformers error message");
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
catch (error) {
|
|
221
|
+
// Test failed, no evidence
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
/**
|
|
225
|
+
* Test for LLaMA model signature
|
|
226
|
+
*/
|
|
227
|
+
async testLlamaSignature(endpointName, evidence) {
|
|
228
|
+
try {
|
|
229
|
+
const testPayload = {
|
|
230
|
+
prompt: "test",
|
|
231
|
+
max_tokens: 1,
|
|
232
|
+
temperature: 0,
|
|
233
|
+
};
|
|
234
|
+
const response = await this.client.invokeEndpoint({
|
|
235
|
+
EndpointName: endpointName,
|
|
236
|
+
Body: JSON.stringify(testPayload),
|
|
237
|
+
ContentType: "application/json",
|
|
238
|
+
});
|
|
239
|
+
const responseText = new TextDecoder().decode(response.Body);
|
|
240
|
+
const parsedResponse = JSON.parse(responseText);
|
|
241
|
+
if (parsedResponse.choices) {
|
|
242
|
+
evidence.push("llama: openai-compatible choices field");
|
|
243
|
+
}
|
|
244
|
+
if (parsedResponse.object === "text_completion") {
|
|
245
|
+
evidence.push("llama: openai text_completion object");
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
catch (error) {
|
|
249
|
+
// Test failed, no evidence
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
/**
|
|
253
|
+
* Test for PyTorch model signature
|
|
254
|
+
*/
|
|
255
|
+
async testPyTorchSignature(endpointName, evidence) {
|
|
256
|
+
try {
|
|
257
|
+
const testPayload = { input: "test" };
|
|
258
|
+
const response = await this.client.invokeEndpoint({
|
|
259
|
+
EndpointName: endpointName,
|
|
260
|
+
Body: JSON.stringify(testPayload),
|
|
261
|
+
ContentType: "application/json",
|
|
262
|
+
});
|
|
263
|
+
const responseText = new TextDecoder().decode(response.Body);
|
|
264
|
+
if (responseText.includes("prediction") ||
|
|
265
|
+
responseText.includes("output")) {
|
|
266
|
+
evidence.push("pytorch: prediction/output field pattern");
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
catch (error) {
|
|
270
|
+
// Test failed, no evidence
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
/**
|
|
274
|
+
* Test for TensorFlow Serving signature
|
|
275
|
+
*/
|
|
276
|
+
async testTensorFlowSignature(endpointName, evidence) {
|
|
277
|
+
try {
|
|
278
|
+
const testPayload = {
|
|
279
|
+
instances: [{ input: "test" }],
|
|
280
|
+
signature_name: "serving_default",
|
|
281
|
+
};
|
|
282
|
+
const response = await this.client.invokeEndpoint({
|
|
283
|
+
EndpointName: endpointName,
|
|
284
|
+
Body: JSON.stringify(testPayload),
|
|
285
|
+
ContentType: "application/json",
|
|
286
|
+
});
|
|
287
|
+
const responseText = new TextDecoder().decode(response.Body);
|
|
288
|
+
const parsedResponse = JSON.parse(responseText);
|
|
289
|
+
if (parsedResponse.predictions) {
|
|
290
|
+
evidence.push("tensorflow: serving predictions field");
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
catch (error) {
|
|
294
|
+
// Test failed, no evidence
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
/**
|
|
298
|
+
* Get streaming test cases for a model type
|
|
299
|
+
*/
|
|
300
|
+
getStreamingTestCases(modelType) {
|
|
301
|
+
const testCases = {
|
|
302
|
+
huggingface: [
|
|
303
|
+
{
|
|
304
|
+
name: "HF streaming test",
|
|
305
|
+
payload: {
|
|
306
|
+
inputs: "test",
|
|
307
|
+
parameters: { stream: true, max_new_tokens: 5 },
|
|
308
|
+
},
|
|
309
|
+
acceptHeader: "text/event-stream",
|
|
310
|
+
confidence: 0.8,
|
|
311
|
+
parameters: { stream: true },
|
|
312
|
+
},
|
|
313
|
+
],
|
|
314
|
+
llama: [
|
|
315
|
+
{
|
|
316
|
+
name: "LLaMA streaming test",
|
|
317
|
+
payload: { prompt: "test", stream: true, max_tokens: 5 },
|
|
318
|
+
acceptHeader: "application/x-ndjson",
|
|
319
|
+
confidence: 0.8,
|
|
320
|
+
parameters: { stream: true },
|
|
321
|
+
},
|
|
322
|
+
],
|
|
323
|
+
pytorch: [],
|
|
324
|
+
tensorflow: [],
|
|
325
|
+
custom: [
|
|
326
|
+
{
|
|
327
|
+
name: "Generic streaming test",
|
|
328
|
+
payload: { input: "test", stream: true },
|
|
329
|
+
acceptHeader: "application/json",
|
|
330
|
+
confidence: 0.3,
|
|
331
|
+
parameters: { stream: true },
|
|
332
|
+
},
|
|
333
|
+
],
|
|
334
|
+
};
|
|
335
|
+
return testCases[modelType] || [];
|
|
336
|
+
}
|
|
337
|
+
/**
|
|
338
|
+
* Check if response indicates streaming support
|
|
339
|
+
*/
|
|
340
|
+
indicatesStreamingSupport(response) {
|
|
341
|
+
// Check content type for streaming indicators
|
|
342
|
+
const contentType = response.ContentType || "";
|
|
343
|
+
if (contentType.includes("event-stream") ||
|
|
344
|
+
contentType.includes("x-ndjson") ||
|
|
345
|
+
contentType.includes("chunked")) {
|
|
346
|
+
return true;
|
|
347
|
+
}
|
|
348
|
+
// Note: InvokeEndpointResponse doesn't include headers
|
|
349
|
+
// Streaming detection is based on ContentType only
|
|
350
|
+
logger.debug("Testing streaming support", {
|
|
351
|
+
contentType,
|
|
352
|
+
});
|
|
353
|
+
return false;
|
|
354
|
+
}
|
|
355
|
+
/**
|
|
356
|
+
* Extract model information from response
|
|
357
|
+
*/
|
|
358
|
+
extractModelInfo(response) {
|
|
359
|
+
try {
|
|
360
|
+
const customAttributes = response.CustomAttributes
|
|
361
|
+
? JSON.parse(response.CustomAttributes)
|
|
362
|
+
: {};
|
|
363
|
+
return {
|
|
364
|
+
name: customAttributes.model_name,
|
|
365
|
+
version: customAttributes.model_version,
|
|
366
|
+
framework: customAttributes.framework,
|
|
367
|
+
architecture: customAttributes.architecture,
|
|
368
|
+
};
|
|
369
|
+
}
|
|
370
|
+
catch {
|
|
371
|
+
return undefined;
|
|
372
|
+
}
|
|
373
|
+
}
|
|
374
|
+
/**
|
|
375
|
+
* Get suggested configuration for detected model type
|
|
376
|
+
*/
|
|
377
|
+
getSuggestedConfig(modelType) {
|
|
378
|
+
const configs = {
|
|
379
|
+
huggingface: {
|
|
380
|
+
modelType: "huggingface",
|
|
381
|
+
inputFormat: "huggingface",
|
|
382
|
+
outputFormat: "huggingface",
|
|
383
|
+
contentType: "application/json",
|
|
384
|
+
accept: "text/event-stream",
|
|
385
|
+
},
|
|
386
|
+
llama: {
|
|
387
|
+
modelType: "llama",
|
|
388
|
+
contentType: "application/json",
|
|
389
|
+
accept: "application/x-ndjson",
|
|
390
|
+
},
|
|
391
|
+
pytorch: {
|
|
392
|
+
modelType: "custom",
|
|
393
|
+
contentType: "application/json",
|
|
394
|
+
accept: "application/json",
|
|
395
|
+
},
|
|
396
|
+
tensorflow: {
|
|
397
|
+
modelType: "custom",
|
|
398
|
+
contentType: "application/json",
|
|
399
|
+
accept: "application/json",
|
|
400
|
+
},
|
|
401
|
+
custom: {
|
|
402
|
+
modelType: "custom",
|
|
403
|
+
contentType: "application/json",
|
|
404
|
+
accept: "application/json",
|
|
405
|
+
},
|
|
406
|
+
};
|
|
407
|
+
return configs[modelType] || configs.custom;
|
|
408
|
+
}
|
|
409
|
+
/**
|
|
410
|
+
* Run detection tests in parallel with intelligent rate limiting and circuit breaker
|
|
411
|
+
* Now uses configuration object for better parameter management
|
|
412
|
+
*/
|
|
413
|
+
async runDetectionTestsInParallel(detectionTests, testNames, endpointName, config = {
|
|
414
|
+
maxConcurrentTests: 2,
|
|
415
|
+
maxRateLimitRetries: 2,
|
|
416
|
+
initialRateLimitCount: 0,
|
|
417
|
+
}) {
|
|
418
|
+
// Use configurable concurrency limit from config
|
|
419
|
+
const semaphore = this.createDetectionSemaphore(config.maxConcurrentTests);
|
|
420
|
+
// Use mutable object to prevent closure stale state issues
|
|
421
|
+
const rateLimitState = { count: config.initialRateLimitCount };
|
|
422
|
+
const wrappedTests = detectionTests.map((test, index) => this.wrapDetectionTest({
|
|
423
|
+
test,
|
|
424
|
+
index,
|
|
425
|
+
testName: testNames[index],
|
|
426
|
+
endpointName,
|
|
427
|
+
semaphore,
|
|
428
|
+
incrementRateLimit: () => rateLimitState.count++,
|
|
429
|
+
maxRateLimitRetries: config.maxRateLimitRetries,
|
|
430
|
+
rateLimitState,
|
|
431
|
+
}));
|
|
432
|
+
const results = await this.executeTestsWithConcurrencyControl(wrappedTests);
|
|
433
|
+
this.logDetectionResults(endpointName, testNames, results, rateLimitState.count > 0);
|
|
434
|
+
return results;
|
|
435
|
+
}
|
|
436
|
+
/**
|
|
437
|
+
* Create a semaphore for detection test concurrency control
|
|
438
|
+
*/
|
|
439
|
+
createDetectionSemaphore(maxConcurrent) {
|
|
440
|
+
return {
|
|
441
|
+
count: maxConcurrent,
|
|
442
|
+
waiters: [],
|
|
443
|
+
async acquire() {
|
|
444
|
+
return new Promise((resolve) => {
|
|
445
|
+
if (this.count > 0) {
|
|
446
|
+
this.count--;
|
|
447
|
+
resolve();
|
|
448
|
+
}
|
|
449
|
+
else {
|
|
450
|
+
this.waiters.push(() => {
|
|
451
|
+
this.count--;
|
|
452
|
+
resolve();
|
|
453
|
+
});
|
|
454
|
+
}
|
|
455
|
+
});
|
|
456
|
+
},
|
|
457
|
+
release() {
|
|
458
|
+
if (this.waiters.length > 0) {
|
|
459
|
+
const waiter = this.waiters.shift();
|
|
460
|
+
waiter();
|
|
461
|
+
}
|
|
462
|
+
else {
|
|
463
|
+
this.count++;
|
|
464
|
+
}
|
|
465
|
+
},
|
|
466
|
+
};
|
|
467
|
+
}
|
|
468
|
+
/**
|
|
469
|
+
* Wrap a detection test with error handling, rate limiting, and retry logic
|
|
470
|
+
* Now uses configuration object instead of multiple parameters
|
|
471
|
+
*/
|
|
472
|
+
wrapDetectionTest(config) {
|
|
473
|
+
return async () => {
|
|
474
|
+
await config.semaphore.acquire();
|
|
475
|
+
try {
|
|
476
|
+
await this.executeWithStaggeredStart(config.test, config.index);
|
|
477
|
+
return { status: "fulfilled", value: undefined };
|
|
478
|
+
}
|
|
479
|
+
catch (error) {
|
|
480
|
+
const result = await this.handleDetectionTestError(error, config.test, config.testName, config.endpointName, config.incrementRateLimit, config.maxRateLimitRetries, config.rateLimitState.count);
|
|
481
|
+
return result;
|
|
482
|
+
}
|
|
483
|
+
finally {
|
|
484
|
+
config.semaphore.release();
|
|
485
|
+
}
|
|
486
|
+
};
|
|
487
|
+
}
|
|
488
|
+
/**
|
|
489
|
+
* Execute a test with staggered start to spread load
|
|
490
|
+
*/
|
|
491
|
+
async executeWithStaggeredStart(test, index) {
|
|
492
|
+
const staggerDelay = index * DETECTION_STAGGER_DELAY_MS;
|
|
493
|
+
if (staggerDelay > 0) {
|
|
494
|
+
await new Promise((resolve) => setTimeout(resolve, staggerDelay));
|
|
495
|
+
}
|
|
496
|
+
await test();
|
|
497
|
+
}
|
|
498
|
+
/**
|
|
499
|
+
* Handle detection test errors with rate limiting and retry logic
|
|
500
|
+
*/
|
|
501
|
+
async handleDetectionTestError(error, test, testName, endpointName, incrementRateLimit, maxRateLimitRetries, rateLimitCount) {
|
|
502
|
+
const isRateLimit = this.isRateLimitError(error);
|
|
503
|
+
if (isRateLimit && rateLimitCount < maxRateLimitRetries) {
|
|
504
|
+
return await this.retryWithBackoff(test, testName, endpointName, incrementRateLimit, rateLimitCount);
|
|
505
|
+
}
|
|
506
|
+
this.logDetectionTestFailure(testName, endpointName, error);
|
|
507
|
+
return { status: "rejected", reason: error };
|
|
508
|
+
}
|
|
509
|
+
/**
|
|
510
|
+
* Check if an error indicates rate limiting
|
|
511
|
+
*/
|
|
512
|
+
isRateLimitError(error) {
|
|
513
|
+
return (error instanceof Error &&
|
|
514
|
+
(error.message.toLowerCase().includes("throttl") ||
|
|
515
|
+
error.message.toLowerCase().includes("rate limit") ||
|
|
516
|
+
error.message.toLowerCase().includes("too many requests")));
|
|
517
|
+
}
|
|
518
|
+
/**
|
|
519
|
+
* Retry a test with exponential backoff
|
|
520
|
+
*/
|
|
521
|
+
async retryWithBackoff(test, testName, endpointName, incrementRateLimit, rateLimitCount) {
|
|
522
|
+
incrementRateLimit();
|
|
523
|
+
logger.debug(`Rate limit detected for ${testName}, applying backoff`, {
|
|
524
|
+
endpointName,
|
|
525
|
+
attempt: rateLimitCount + 1,
|
|
526
|
+
});
|
|
527
|
+
await new Promise((resolve) => setTimeout(resolve, DETECTION_RATE_LIMIT_BACKOFF_MS * Math.pow(2, rateLimitCount)));
|
|
528
|
+
try {
|
|
529
|
+
await test();
|
|
530
|
+
return { status: "fulfilled", value: undefined };
|
|
531
|
+
}
|
|
532
|
+
catch (retryError) {
|
|
533
|
+
this.logDetectionTestRetryFailure(testName, endpointName, retryError);
|
|
534
|
+
return { status: "rejected", reason: retryError };
|
|
535
|
+
}
|
|
536
|
+
}
|
|
537
|
+
/**
|
|
538
|
+
* Execute wrapped tests with concurrency control
|
|
539
|
+
*/
|
|
540
|
+
async executeTestsWithConcurrencyControl(wrappedTests) {
|
|
541
|
+
const testPromises = wrappedTests.map((wrappedTest) => wrappedTest());
|
|
542
|
+
return await Promise.all(testPromises);
|
|
543
|
+
}
|
|
544
|
+
/**
|
|
545
|
+
* Log detection test failure
|
|
546
|
+
*/
|
|
547
|
+
logDetectionTestFailure(testName, endpointName, error) {
|
|
548
|
+
logger.debug(`${testName} detection test failed`, {
|
|
549
|
+
endpointName,
|
|
550
|
+
error: error instanceof Error ? error.message : String(error),
|
|
551
|
+
});
|
|
552
|
+
}
|
|
553
|
+
/**
|
|
554
|
+
* Log detection test retry failure
|
|
555
|
+
*/
|
|
556
|
+
logDetectionTestRetryFailure(testName, endpointName, error) {
|
|
557
|
+
logger.debug(`${testName} detection test retry failed`, {
|
|
558
|
+
endpointName,
|
|
559
|
+
error: error instanceof Error ? error.message : String(error),
|
|
560
|
+
});
|
|
561
|
+
}
|
|
562
|
+
/**
|
|
563
|
+
* Log final detection results
|
|
564
|
+
*/
|
|
565
|
+
logDetectionResults(endpointName, testNames, results, rateLimitEncountered) {
|
|
566
|
+
logger.debug("Parallel detection tests completed", {
|
|
567
|
+
endpointName,
|
|
568
|
+
totalTests: testNames.length,
|
|
569
|
+
successCount: results.filter((r) => r.status === "fulfilled").length,
|
|
570
|
+
rateLimitEncountered,
|
|
571
|
+
});
|
|
572
|
+
}
|
|
573
|
+
/**
|
|
574
|
+
* Create a no-streaming capability result
|
|
575
|
+
*/
|
|
576
|
+
createNoStreamingCapability(modelType, reason) {
|
|
577
|
+
logger.debug("No streaming capability detected", { modelType, reason });
|
|
578
|
+
return {
|
|
579
|
+
supported: false,
|
|
580
|
+
protocol: "none",
|
|
581
|
+
modelType,
|
|
582
|
+
confidence: 0.9,
|
|
583
|
+
metadata: {
|
|
584
|
+
// reason property not supported in interface
|
|
585
|
+
// Store reason in framework field for debugging
|
|
586
|
+
framework: reason,
|
|
587
|
+
},
|
|
588
|
+
};
|
|
589
|
+
}
|
|
590
|
+
}
|
|
591
|
+
/**
|
|
592
|
+
* Create a detector instance with configuration
|
|
593
|
+
*/
|
|
594
|
+
export function createSageMakerDetector(config) {
|
|
595
|
+
return new SageMakerDetector(config);
|
|
596
|
+
}
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* SageMaker Simple Diagnostics Module
|
|
3
|
+
*
|
|
4
|
+
* Provides basic diagnostic functions for SageMaker configuration and connectivity.
|
|
5
|
+
*/
|
|
6
|
+
/**
|
|
7
|
+
* Simple diagnostic result interface
|
|
8
|
+
*/
|
|
9
|
+
export interface DiagnosticResult {
|
|
10
|
+
name: string;
|
|
11
|
+
category: "configuration" | "connectivity" | "streaming";
|
|
12
|
+
status: "pass" | "fail" | "warning";
|
|
13
|
+
message: string;
|
|
14
|
+
details?: string;
|
|
15
|
+
recommendation?: string;
|
|
16
|
+
}
|
|
17
|
+
/**
|
|
18
|
+
* Diagnostic report interface
|
|
19
|
+
*/
|
|
20
|
+
export interface DiagnosticReport {
|
|
21
|
+
overallStatus: "healthy" | "issues" | "critical";
|
|
22
|
+
results: DiagnosticResult[];
|
|
23
|
+
summary: {
|
|
24
|
+
total: number;
|
|
25
|
+
passed: number;
|
|
26
|
+
failed: number;
|
|
27
|
+
warnings: number;
|
|
28
|
+
};
|
|
29
|
+
}
|
|
30
|
+
/**
|
|
31
|
+
* Run quick diagnostics for SageMaker configuration
|
|
32
|
+
*/
|
|
33
|
+
export declare function runQuickDiagnostics(endpoint?: string): Promise<DiagnosticReport>;
|
|
34
|
+
/**
|
|
35
|
+
* Format diagnostic report for console output
|
|
36
|
+
*/
|
|
37
|
+
export declare function formatDiagnosticReport(report: DiagnosticReport): string;
|