@aws/ml-container-creator 0.10.0 → 0.12.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +9304 -0
- package/bin/cli.js +2 -0
- package/config/bootstrap-e2e-stack.json +341 -0
- package/config/bootstrap-stack.json +40 -3
- package/config/parameter-schema-v2.json +33 -22
- package/config/tune-catalog.json +1781 -0
- package/infra/ci-harness/buildspec.yml +1 -0
- package/infra/ci-harness/lambda/path-prover/brain.ts +306 -0
- package/infra/ci-harness/lambda/path-prover/write-results.ts +152 -0
- package/infra/ci-harness/lib/ci-harness-stack.ts +851 -7
- package/infra/ci-harness/state-machines/path-prover.asl.json +496 -0
- package/package.json +53 -67
- package/servers/base-image-picker/index.js +121 -121
- package/servers/e2e-status/index.js +297 -0
- package/servers/e2e-status/manifest.json +14 -0
- package/servers/e2e-status/package.json +15 -0
- package/servers/endpoint-picker/LICENSE +202 -0
- package/servers/endpoint-picker/index.js +536 -0
- package/servers/endpoint-picker/manifest.json +14 -0
- package/servers/endpoint-picker/package.json +18 -0
- package/servers/hyperpod-cluster-picker/index.js +125 -125
- package/servers/instance-sizer/index.js +166 -153
- package/servers/instance-sizer/lib/instance-ranker.js +120 -76
- package/servers/instance-sizer/lib/model-resolver.js +61 -61
- package/servers/instance-sizer/lib/quota-resolver.js +113 -113
- package/servers/instance-sizer/lib/vram-estimator.js +31 -31
- package/servers/lib/bedrock-client.js +38 -38
- package/servers/lib/catalogs/instances.json +27 -0
- package/servers/lib/catalogs/model-servers.json +201 -3
- package/servers/lib/custom-validators.js +13 -13
- package/servers/lib/dynamic-resolver.js +4 -4
- package/servers/marketplace-picker/index.js +342 -0
- package/servers/marketplace-picker/manifest.json +14 -0
- package/servers/marketplace-picker/package.json +18 -0
- package/servers/model-picker/index.js +382 -382
- package/servers/region-picker/index.js +56 -56
- package/servers/workload-picker/LICENSE +202 -0
- package/servers/workload-picker/catalogs/workload-profiles.json +67 -0
- package/servers/workload-picker/index.js +171 -0
- package/servers/workload-picker/manifest.json +16 -0
- package/servers/workload-picker/package.json +16 -0
- package/src/app.js +12 -3
- package/src/lib/bootstrap-command-handler.js +609 -15
- package/src/lib/bootstrap-config.js +36 -0
- package/src/lib/bootstrap-profile-manager.js +48 -41
- package/src/lib/ci-register-helpers.js +74 -0
- package/src/lib/config-loader.js +3 -0
- package/src/lib/config-manager.js +7 -0
- package/src/lib/config-validator.js +1 -1
- package/src/lib/cuda-resolver.js +17 -8
- package/src/lib/generated/cli-options.js +319 -314
- package/src/lib/generated/parameter-matrix.js +672 -661
- package/src/lib/generated/validation-rules.js +76 -72
- package/src/lib/path-prover-brain.js +664 -0
- package/src/lib/prompts/infrastructure-prompts.js +2 -2
- package/src/lib/prompts/model-prompts.js +6 -0
- package/src/lib/prompts/project-prompts.js +12 -0
- package/src/lib/secrets-prompt-runner.js +4 -0
- package/src/lib/template-manager.js +1 -1
- package/src/lib/template-variable-resolver.js +87 -1
- package/src/lib/tune-catalog-validator.js +37 -4
- package/templates/Dockerfile +9 -0
- package/templates/code/adapter_sidecar.py +444 -0
- package/templates/code/serve +6 -0
- package/templates/code/serve.d/vllm.ejs +1 -1
- package/templates/do/.benchmark_writer.py +1476 -0
- package/templates/do/.tune_helper.py +982 -57
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/adapter +154 -0
- package/templates/do/benchmark +639 -85
- package/templates/do/build +5 -0
- package/templates/do/clean.d/async-inference.ejs +5 -0
- package/templates/do/clean.d/batch-transform.ejs +5 -0
- package/templates/do/clean.d/hyperpod-eks.ejs +5 -0
- package/templates/do/clean.d/managed-inference.ejs +5 -0
- package/templates/do/config +115 -45
- package/templates/do/deploy.d/async-inference.ejs +30 -3
- package/templates/do/deploy.d/batch-transform.ejs +29 -3
- package/templates/do/deploy.d/hyperpod-eks.ejs +4 -0
- package/templates/do/deploy.d/managed-inference.ejs +216 -14
- package/templates/do/lib/endpoint-config.sh +1 -1
- package/templates/do/lib/profile.sh +44 -0
- package/templates/do/optimize +106 -37
- package/templates/do/push +5 -0
- package/templates/do/register +94 -0
- package/templates/do/stage +567 -0
- package/templates/do/submit +7 -0
- package/templates/do/test +14 -0
- package/templates/do/tune +382 -59
- package/templates/do/validate +44 -4
|
@@ -11,20 +11,20 @@
|
|
|
11
11
|
* All methods degrade gracefully — API failures return null and log to stderr.
|
|
12
12
|
*/
|
|
13
13
|
|
|
14
|
-
import { ServiceQuotasClient, ListServiceQuotasCommand } from '@aws-sdk/client-service-quotas'
|
|
15
|
-
import { SageMakerClient, ListEndpointsCommand, ListTrainingPlansCommand } from '@aws-sdk/client-sagemaker'
|
|
14
|
+
import { ServiceQuotasClient, ListServiceQuotasCommand } from '@aws-sdk/client-service-quotas';
|
|
15
|
+
import { SageMakerClient, ListEndpointsCommand, ListTrainingPlansCommand } from '@aws-sdk/client-sagemaker';
|
|
16
16
|
|
|
17
17
|
// ── Constants ────────────────────────────────────────────────────────────────
|
|
18
18
|
|
|
19
|
-
const SAGEMAKER_SERVICE_CODE = 'sagemaker'
|
|
20
|
-
const DEFAULT_TIMEOUT_MS = 5000
|
|
21
|
-
const DEFAULT_CACHE_TTL_MS = 300000 // 5 minutes
|
|
22
|
-
const QUOTA_NAME_PATTERN = /^(ml\.[a-z0-9]+\.[a-z0-9]+) for endpoint usage
|
|
19
|
+
const SAGEMAKER_SERVICE_CODE = 'sagemaker';
|
|
20
|
+
const DEFAULT_TIMEOUT_MS = 5000;
|
|
21
|
+
const DEFAULT_CACHE_TTL_MS = 300000; // 5 minutes
|
|
22
|
+
const QUOTA_NAME_PATTERN = /^(ml\.[a-z0-9]+\.[a-z0-9]+) for endpoint usage$/;
|
|
23
23
|
|
|
24
24
|
// ── Logging ──────────────────────────────────────────────────────────────────
|
|
25
25
|
|
|
26
26
|
function log(message) {
|
|
27
|
-
process.stderr.write(`[quota-resolver] ${message}\n`)
|
|
27
|
+
process.stderr.write(`[quota-resolver] ${message}\n`);
|
|
28
28
|
}
|
|
29
29
|
|
|
30
30
|
// ── QuotaResolver Class ──────────────────────────────────────────────────────
|
|
@@ -37,20 +37,20 @@ class QuotaResolver {
|
|
|
37
37
|
* @param {number} [options.cacheTtl=300000] - Cache TTL in ms (default 5 min)
|
|
38
38
|
*/
|
|
39
39
|
constructor(region, options = {}) {
|
|
40
|
-
this.region = region
|
|
41
|
-
this.timeout = options.timeout || DEFAULT_TIMEOUT_MS
|
|
42
|
-
this.cacheTtl = options.cacheTtl || DEFAULT_CACHE_TTL_MS
|
|
43
|
-
this.cache = new Map()
|
|
40
|
+
this.region = region;
|
|
41
|
+
this.timeout = options.timeout || DEFAULT_TIMEOUT_MS;
|
|
42
|
+
this.cacheTtl = options.cacheTtl || DEFAULT_CACHE_TTL_MS;
|
|
43
|
+
this.cache = new Map();
|
|
44
44
|
|
|
45
45
|
const clientConfig = {
|
|
46
46
|
region: this.region,
|
|
47
47
|
requestHandler: {
|
|
48
48
|
requestTimeout: this.timeout
|
|
49
49
|
}
|
|
50
|
-
}
|
|
50
|
+
};
|
|
51
51
|
|
|
52
|
-
this.quotasClient = new ServiceQuotasClient(clientConfig)
|
|
53
|
-
this.sagemakerClient = new SageMakerClient(clientConfig)
|
|
52
|
+
this.quotasClient = new ServiceQuotasClient(clientConfig);
|
|
53
|
+
this.sagemakerClient = new SageMakerClient(clientConfig);
|
|
54
54
|
}
|
|
55
55
|
|
|
56
56
|
/**
|
|
@@ -59,13 +59,13 @@ class QuotaResolver {
|
|
|
59
59
|
* @returns {*|null} Cached value or null
|
|
60
60
|
*/
|
|
61
61
|
_getCached(key) {
|
|
62
|
-
const entry = this.cache.get(key)
|
|
63
|
-
if (!entry) return null
|
|
62
|
+
const entry = this.cache.get(key);
|
|
63
|
+
if (!entry) return null;
|
|
64
64
|
if (Date.now() - entry.timestamp > this.cacheTtl) {
|
|
65
|
-
this.cache.delete(key)
|
|
66
|
-
return null
|
|
65
|
+
this.cache.delete(key);
|
|
66
|
+
return null;
|
|
67
67
|
}
|
|
68
|
-
return entry.value
|
|
68
|
+
return entry.value;
|
|
69
69
|
}
|
|
70
70
|
|
|
71
71
|
/**
|
|
@@ -74,7 +74,7 @@ class QuotaResolver {
|
|
|
74
74
|
* @param {*} value - Value to cache
|
|
75
75
|
*/
|
|
76
76
|
_setCache(key, value) {
|
|
77
|
-
this.cache.set(key, { value, timestamp: Date.now() })
|
|
77
|
+
this.cache.set(key, { value, timestamp: Date.now() });
|
|
78
78
|
}
|
|
79
79
|
|
|
80
80
|
/**
|
|
@@ -85,8 +85,8 @@ class QuotaResolver {
|
|
|
85
85
|
* @returns {string|null} Instance type or null if pattern doesn't match
|
|
86
86
|
*/
|
|
87
87
|
_parseQuotaName(quotaName) {
|
|
88
|
-
const match = quotaName.match(QUOTA_NAME_PATTERN)
|
|
89
|
-
return match ? match[1] : null
|
|
88
|
+
const match = quotaName.match(QUOTA_NAME_PATTERN);
|
|
89
|
+
return match ? match[1] : null;
|
|
90
90
|
}
|
|
91
91
|
|
|
92
92
|
/**
|
|
@@ -100,50 +100,50 @@ class QuotaResolver {
|
|
|
100
100
|
* @returns {Promise<Map|null>} Map: instanceType → { quota, deployed, headroom }, or null on failure
|
|
101
101
|
*/
|
|
102
102
|
async getQuotaHeadroom(instanceTypes) {
|
|
103
|
-
const cacheKey = 'quotaHeadroom'
|
|
104
|
-
const cached = this._getCached(cacheKey)
|
|
105
|
-
if (cached) return cached
|
|
103
|
+
const cacheKey = 'quotaHeadroom';
|
|
104
|
+
const cached = this._getCached(cacheKey);
|
|
105
|
+
if (cached) return cached;
|
|
106
106
|
|
|
107
107
|
try {
|
|
108
108
|
const [quotaMap, deployedMap] = await Promise.allSettled([
|
|
109
109
|
this._fetchServiceQuotas(),
|
|
110
110
|
this._fetchDeployedCounts()
|
|
111
|
-
])
|
|
111
|
+
]);
|
|
112
112
|
|
|
113
|
-
const quotas = quotaMap.status === 'fulfilled' ? quotaMap.value : null
|
|
114
|
-
const deployed = deployedMap.status === 'fulfilled' ? deployedMap.value : null
|
|
113
|
+
const quotas = quotaMap.status === 'fulfilled' ? quotaMap.value : null;
|
|
114
|
+
const deployed = deployedMap.status === 'fulfilled' ? deployedMap.value : null;
|
|
115
115
|
|
|
116
116
|
if (!quotas) {
|
|
117
|
-
return null
|
|
117
|
+
return null;
|
|
118
118
|
}
|
|
119
119
|
|
|
120
|
-
const result = new Map()
|
|
121
|
-
const deployedCounts = deployed || new Map()
|
|
120
|
+
const result = new Map();
|
|
121
|
+
const deployedCounts = deployed || new Map();
|
|
122
122
|
|
|
123
123
|
for (const instanceType of instanceTypes) {
|
|
124
|
-
const quota = quotas.get(instanceType)
|
|
125
|
-
if (quota
|
|
126
|
-
const deployedCount = deployedCounts.get(instanceType) || 0
|
|
127
|
-
const headroom = quota - deployedCount
|
|
124
|
+
const quota = quotas.get(instanceType);
|
|
125
|
+
if (quota !== null && quota !== undefined) {
|
|
126
|
+
const deployedCount = deployedCounts.get(instanceType) || 0;
|
|
127
|
+
const headroom = quota - deployedCount;
|
|
128
128
|
result.set(instanceType, {
|
|
129
129
|
quota,
|
|
130
130
|
deployed: deployedCount,
|
|
131
131
|
headroom
|
|
132
|
-
})
|
|
132
|
+
});
|
|
133
133
|
}
|
|
134
134
|
}
|
|
135
135
|
|
|
136
|
-
this._setCache(cacheKey, result)
|
|
137
|
-
return result
|
|
136
|
+
this._setCache(cacheKey, result);
|
|
137
|
+
return result;
|
|
138
138
|
} catch (err) {
|
|
139
139
|
if (err.name === 'AccessDeniedException' || err.Code === 'AccessDeniedException') {
|
|
140
|
-
log(
|
|
140
|
+
log('AccessDenied: insufficient permissions for quota queries — skipping');
|
|
141
141
|
} else if (err.name === 'ThrottlingException' || err.Code === 'ThrottlingException') {
|
|
142
|
-
log(
|
|
142
|
+
log('Throttled: Service Quotas API rate limit hit — skipping');
|
|
143
143
|
} else {
|
|
144
|
-
log(`Failed to get quota headroom: ${err.message}`)
|
|
144
|
+
log(`Failed to get quota headroom: ${err.message}`);
|
|
145
145
|
}
|
|
146
|
-
return null
|
|
146
|
+
return null;
|
|
147
147
|
}
|
|
148
148
|
}
|
|
149
149
|
|
|
@@ -154,28 +154,28 @@ class QuotaResolver {
|
|
|
154
154
|
* @returns {Promise<Map>} Map: instanceType → quota limit (number)
|
|
155
155
|
*/
|
|
156
156
|
async _fetchServiceQuotas() {
|
|
157
|
-
const quotaMap = new Map()
|
|
158
|
-
let nextToken = undefined
|
|
157
|
+
const quotaMap = new Map();
|
|
158
|
+
let nextToken = undefined;
|
|
159
159
|
|
|
160
160
|
do {
|
|
161
161
|
const command = new ListServiceQuotasCommand({
|
|
162
162
|
ServiceCode: SAGEMAKER_SERVICE_CODE,
|
|
163
163
|
...(nextToken && { NextToken: nextToken })
|
|
164
|
-
})
|
|
164
|
+
});
|
|
165
165
|
|
|
166
|
-
const response = await this.quotasClient.send(command)
|
|
166
|
+
const response = await this.quotasClient.send(command);
|
|
167
167
|
|
|
168
168
|
for (const quota of (response.Quotas || [])) {
|
|
169
|
-
const instanceType = this._parseQuotaName(quota.QuotaName || '')
|
|
170
|
-
if (instanceType && quota.Value
|
|
171
|
-
quotaMap.set(instanceType, quota.Value)
|
|
169
|
+
const instanceType = this._parseQuotaName(quota.QuotaName || '');
|
|
170
|
+
if (instanceType && quota.Value !== null && quota.Value !== undefined) {
|
|
171
|
+
quotaMap.set(instanceType, quota.Value);
|
|
172
172
|
}
|
|
173
173
|
}
|
|
174
174
|
|
|
175
|
-
nextToken = response.NextToken
|
|
176
|
-
} while (nextToken)
|
|
175
|
+
nextToken = response.NextToken;
|
|
176
|
+
} while (nextToken);
|
|
177
177
|
|
|
178
|
-
return quotaMap
|
|
178
|
+
return quotaMap;
|
|
179
179
|
}
|
|
180
180
|
|
|
181
181
|
/**
|
|
@@ -185,16 +185,16 @@ class QuotaResolver {
|
|
|
185
185
|
* @returns {Promise<Map>} Map: instanceType → deployed count
|
|
186
186
|
*/
|
|
187
187
|
async _fetchDeployedCounts() {
|
|
188
|
-
const deployedMap = new Map()
|
|
189
|
-
let nextToken = undefined
|
|
188
|
+
const deployedMap = new Map();
|
|
189
|
+
let nextToken = undefined;
|
|
190
190
|
|
|
191
191
|
do {
|
|
192
192
|
const command = new ListEndpointsCommand({
|
|
193
193
|
StatusEquals: 'InService',
|
|
194
194
|
...(nextToken && { NextToken: nextToken })
|
|
195
|
-
})
|
|
195
|
+
});
|
|
196
196
|
|
|
197
|
-
const response = await this.sagemakerClient.send(command)
|
|
197
|
+
const response = await this.sagemakerClient.send(command);
|
|
198
198
|
|
|
199
199
|
for (const endpoint of (response.Endpoints || [])) {
|
|
200
200
|
// ListEndpoints returns endpoint summaries; instance type info
|
|
@@ -208,18 +208,18 @@ class QuotaResolver {
|
|
|
208
208
|
if (endpoint.ProductionVariants) {
|
|
209
209
|
for (const variant of endpoint.ProductionVariants) {
|
|
210
210
|
if (variant.InstanceType) {
|
|
211
|
-
const current = deployedMap.get(variant.InstanceType) || 0
|
|
212
|
-
const count = variant.CurrentInstanceCount || 1
|
|
213
|
-
deployedMap.set(variant.InstanceType, current + count)
|
|
211
|
+
const current = deployedMap.get(variant.InstanceType) || 0;
|
|
212
|
+
const count = variant.CurrentInstanceCount || 1;
|
|
213
|
+
deployedMap.set(variant.InstanceType, current + count);
|
|
214
214
|
}
|
|
215
215
|
}
|
|
216
216
|
}
|
|
217
217
|
}
|
|
218
218
|
|
|
219
|
-
nextToken = response.NextToken
|
|
220
|
-
} while (nextToken)
|
|
219
|
+
nextToken = response.NextToken;
|
|
220
|
+
} while (nextToken);
|
|
221
221
|
|
|
222
|
-
return deployedMap
|
|
222
|
+
return deployedMap;
|
|
223
223
|
}
|
|
224
224
|
|
|
225
225
|
/**
|
|
@@ -234,46 +234,46 @@ class QuotaResolver {
|
|
|
234
234
|
* @returns {Promise<Map|null>} Map: instanceType → { planName, planArn, remainingCapacity, startDate, endDate }, or null on failure
|
|
235
235
|
*/
|
|
236
236
|
async getCapacityReservations() {
|
|
237
|
-
const cacheKey = 'capacityReservations'
|
|
238
|
-
const cached = this._getCached(cacheKey)
|
|
239
|
-
if (cached) return cached
|
|
237
|
+
const cacheKey = 'capacityReservations';
|
|
238
|
+
const cached = this._getCached(cacheKey);
|
|
239
|
+
if (cached) return cached;
|
|
240
240
|
|
|
241
241
|
try {
|
|
242
|
-
const result = new Map()
|
|
243
|
-
let nextToken = undefined
|
|
242
|
+
const result = new Map();
|
|
243
|
+
let nextToken = undefined;
|
|
244
244
|
|
|
245
245
|
do {
|
|
246
246
|
const command = new ListTrainingPlansCommand({
|
|
247
247
|
StatusEquals: 'Active',
|
|
248
248
|
...(nextToken && { NextToken: nextToken })
|
|
249
|
-
})
|
|
249
|
+
});
|
|
250
250
|
|
|
251
|
-
const response = await this.sagemakerClient.send(command)
|
|
252
|
-
const now = new Date()
|
|
251
|
+
const response = await this.sagemakerClient.send(command);
|
|
252
|
+
const now = new Date();
|
|
253
253
|
|
|
254
254
|
for (const plan of (response.TrainingPlanSummaries || [])) {
|
|
255
255
|
// Only include plans targeting inference endpoints
|
|
256
|
-
const targetResources = plan.TargetResources || []
|
|
257
|
-
if (!targetResources.includes('endpoint')) continue
|
|
256
|
+
const targetResources = plan.TargetResources || [];
|
|
257
|
+
if (!targetResources.includes('endpoint')) continue;
|
|
258
258
|
|
|
259
|
-
const instanceType = plan.InstanceType || plan.ReservedCapacityInstanceType
|
|
260
|
-
if (!instanceType) continue
|
|
259
|
+
const instanceType = plan.InstanceType || plan.ReservedCapacityInstanceType;
|
|
260
|
+
if (!instanceType) continue;
|
|
261
261
|
|
|
262
|
-
const planArn = plan.TrainingPlanArn
|
|
263
|
-
const planName = plan.TrainingPlanName || 'unknown'
|
|
262
|
+
const planArn = plan.TrainingPlanArn;
|
|
263
|
+
const planName = plan.TrainingPlanName || 'unknown';
|
|
264
264
|
const remainingCapacity = plan.AvailableInstanceCount
|
|
265
265
|
?? plan.RemainingCapacity
|
|
266
266
|
?? plan.TotalInstanceCount
|
|
267
|
-
?? 0
|
|
268
|
-
const startDate = plan.StartTime || null
|
|
269
|
-
const endDate = plan.EndTime || plan.ExpirationTime || null
|
|
267
|
+
?? 0;
|
|
268
|
+
const startDate = plan.StartTime || null;
|
|
269
|
+
const endDate = plan.EndTime || plan.ExpirationTime || null;
|
|
270
270
|
|
|
271
271
|
// Skip plans outside their time window
|
|
272
|
-
if (startDate && new Date(startDate) > now) continue
|
|
273
|
-
if (endDate && new Date(endDate) < now) continue
|
|
272
|
+
if (startDate && new Date(startDate) > now) continue;
|
|
273
|
+
if (endDate && new Date(endDate) < now) continue;
|
|
274
274
|
|
|
275
275
|
// Only include if there's remaining capacity
|
|
276
|
-
if (remainingCapacity <= 0) continue
|
|
276
|
+
if (remainingCapacity <= 0) continue;
|
|
277
277
|
|
|
278
278
|
result.set(instanceType, {
|
|
279
279
|
planName,
|
|
@@ -282,25 +282,25 @@ class QuotaResolver {
|
|
|
282
282
|
count: remainingCapacity,
|
|
283
283
|
startDate: startDate ? (startDate instanceof Date ? startDate.toISOString() : startDate) : null,
|
|
284
284
|
endDate: endDate ? (endDate instanceof Date ? endDate.toISOString() : endDate) : null
|
|
285
|
-
})
|
|
285
|
+
});
|
|
286
286
|
}
|
|
287
287
|
|
|
288
|
-
nextToken = response.NextToken
|
|
289
|
-
} while (nextToken)
|
|
288
|
+
nextToken = response.NextToken;
|
|
289
|
+
} while (nextToken);
|
|
290
290
|
|
|
291
|
-
this._setCache(cacheKey, result)
|
|
292
|
-
return result
|
|
291
|
+
this._setCache(cacheKey, result);
|
|
292
|
+
return result;
|
|
293
293
|
} catch (err) {
|
|
294
294
|
if (err.name === 'AccessDeniedException' || err.Code === 'AccessDeniedException') {
|
|
295
|
-
log(
|
|
295
|
+
log('AccessDenied: insufficient permissions for training plan queries — skipping');
|
|
296
296
|
} else if (err.name === 'ValidationException') {
|
|
297
|
-
log(`ListTrainingPlans not available in region ${this.region} — skipping`)
|
|
297
|
+
log(`ListTrainingPlans not available in region ${this.region} — skipping`);
|
|
298
298
|
} else if (err.name === 'ThrottlingException' || err.Code === 'ThrottlingException') {
|
|
299
|
-
log(
|
|
299
|
+
log('Throttled: ListTrainingPlans rate limit hit — skipping');
|
|
300
300
|
} else {
|
|
301
|
-
log(`Failed to get capacity reservations: ${err.message}`)
|
|
301
|
+
log(`Failed to get capacity reservations: ${err.message}`);
|
|
302
302
|
}
|
|
303
|
-
return null
|
|
303
|
+
return null;
|
|
304
304
|
}
|
|
305
305
|
}
|
|
306
306
|
|
|
@@ -313,56 +313,56 @@ class QuotaResolver {
|
|
|
313
313
|
* @returns {Promise<Map|null>} Map: instanceType → { planName, remainingCapacity, expiresAt }, or null on failure
|
|
314
314
|
*/
|
|
315
315
|
async getTrainingPlans() {
|
|
316
|
-
const cacheKey = 'trainingPlans'
|
|
317
|
-
const cached = this._getCached(cacheKey)
|
|
318
|
-
if (cached) return cached
|
|
316
|
+
const cacheKey = 'trainingPlans';
|
|
317
|
+
const cached = this._getCached(cacheKey);
|
|
318
|
+
if (cached) return cached;
|
|
319
319
|
|
|
320
320
|
try {
|
|
321
|
-
const result = new Map()
|
|
322
|
-
let nextToken = undefined
|
|
321
|
+
const result = new Map();
|
|
322
|
+
let nextToken = undefined;
|
|
323
323
|
|
|
324
324
|
do {
|
|
325
325
|
const command = new ListTrainingPlansCommand({
|
|
326
326
|
StatusEquals: 'Active',
|
|
327
327
|
...(nextToken && { NextToken: nextToken })
|
|
328
|
-
})
|
|
328
|
+
});
|
|
329
329
|
|
|
330
|
-
const response = await this.sagemakerClient.send(command)
|
|
330
|
+
const response = await this.sagemakerClient.send(command);
|
|
331
331
|
|
|
332
332
|
for (const plan of (response.TrainingPlanSummaries || [])) {
|
|
333
|
-
const instanceType = plan.InstanceType || plan.ReservedCapacityInstanceType
|
|
334
|
-
const planName = plan.TrainingPlanName || plan.TrainingPlanArn || 'unknown'
|
|
333
|
+
const instanceType = plan.InstanceType || plan.ReservedCapacityInstanceType;
|
|
334
|
+
const planName = plan.TrainingPlanName || plan.TrainingPlanArn || 'unknown';
|
|
335
335
|
const remainingCapacity = plan.AvailableInstanceCount
|
|
336
336
|
?? plan.RemainingCapacity
|
|
337
337
|
?? plan.TotalInstanceCount
|
|
338
|
-
?? 0
|
|
339
|
-
const expiresAt = plan.EndTime || plan.ExpirationTime || null
|
|
338
|
+
?? 0;
|
|
339
|
+
const expiresAt = plan.EndTime || plan.ExpirationTime || null;
|
|
340
340
|
|
|
341
341
|
if (instanceType && remainingCapacity > 0) {
|
|
342
342
|
result.set(instanceType, {
|
|
343
343
|
planName,
|
|
344
344
|
remainingCapacity,
|
|
345
345
|
expiresAt
|
|
346
|
-
})
|
|
346
|
+
});
|
|
347
347
|
}
|
|
348
348
|
}
|
|
349
349
|
|
|
350
|
-
nextToken = response.NextToken
|
|
351
|
-
} while (nextToken)
|
|
350
|
+
nextToken = response.NextToken;
|
|
351
|
+
} while (nextToken);
|
|
352
352
|
|
|
353
|
-
this._setCache(cacheKey, result)
|
|
354
|
-
return result
|
|
353
|
+
this._setCache(cacheKey, result);
|
|
354
|
+
return result;
|
|
355
355
|
} catch (err) {
|
|
356
356
|
if (err.name === 'AccessDeniedException' || err.Code === 'AccessDeniedException') {
|
|
357
|
-
log(
|
|
357
|
+
log('AccessDenied: insufficient permissions for training plan queries — skipping');
|
|
358
358
|
} else if (err.name === 'ValidationException') {
|
|
359
|
-
log(`ListTrainingPlans not available in region ${this.region} — skipping`)
|
|
359
|
+
log(`ListTrainingPlans not available in region ${this.region} — skipping`);
|
|
360
360
|
} else {
|
|
361
|
-
log(`Failed to get training plans: ${err.message}`)
|
|
361
|
+
log(`Failed to get training plans: ${err.message}`);
|
|
362
362
|
}
|
|
363
|
-
return null
|
|
363
|
+
return null;
|
|
364
364
|
}
|
|
365
365
|
}
|
|
366
366
|
}
|
|
367
367
|
|
|
368
|
-
export { QuotaResolver, QUOTA_NAME_PATTERN, SAGEMAKER_SERVICE_CODE, DEFAULT_TIMEOUT_MS, DEFAULT_CACHE_TTL_MS }
|
|
368
|
+
export { QuotaResolver, QUOTA_NAME_PATTERN, SAGEMAKER_SERVICE_CODE, DEFAULT_TIMEOUT_MS, DEFAULT_CACHE_TTL_MS };
|
|
@@ -17,20 +17,20 @@ const BYTES_PER_PARAM = {
|
|
|
17
17
|
bfloat16: 2.0,
|
|
18
18
|
int8: 1.0,
|
|
19
19
|
int4: 0.5
|
|
20
|
-
}
|
|
20
|
+
};
|
|
21
21
|
|
|
22
22
|
const QUANTIZATION_BYTES = {
|
|
23
23
|
'awq': 0.5,
|
|
24
24
|
'gptq': 0.5,
|
|
25
25
|
'bnb-4bit': 0.5,
|
|
26
26
|
'bnb-8bit': 1.0
|
|
27
|
-
}
|
|
27
|
+
};
|
|
28
28
|
|
|
29
|
-
const BYTES_IN_GB = 1024 ** 3
|
|
29
|
+
const BYTES_IN_GB = 1024 ** 3;
|
|
30
30
|
|
|
31
|
-
const DEFAULT_MAX_SEQUENCE_LENGTH = 4096
|
|
32
|
-
const DEFAULT_BATCH_SIZE = 1
|
|
33
|
-
const OVERHEAD_FACTOR = 0.1
|
|
31
|
+
const DEFAULT_MAX_SEQUENCE_LENGTH = 4096;
|
|
32
|
+
const DEFAULT_BATCH_SIZE = 1;
|
|
33
|
+
const OVERHEAD_FACTOR = 0.1;
|
|
34
34
|
|
|
35
35
|
// ── Helper Functions ─────────────────────────────────────────────────────────
|
|
36
36
|
|
|
@@ -44,10 +44,10 @@ const OVERHEAD_FACTOR = 0.1
|
|
|
44
44
|
*/
|
|
45
45
|
const bytesPerParam = (dtype, quantization) => {
|
|
46
46
|
if (quantization && QUANTIZATION_BYTES[quantization] !== undefined) {
|
|
47
|
-
return QUANTIZATION_BYTES[quantization]
|
|
47
|
+
return QUANTIZATION_BYTES[quantization];
|
|
48
48
|
}
|
|
49
|
-
return BYTES_PER_PARAM[dtype] ?? BYTES_PER_PARAM.float16
|
|
50
|
-
}
|
|
49
|
+
return BYTES_PER_PARAM[dtype] ?? BYTES_PER_PARAM.float16;
|
|
50
|
+
};
|
|
51
51
|
|
|
52
52
|
/**
|
|
53
53
|
* Estimate KV cache memory usage.
|
|
@@ -66,16 +66,16 @@ const bytesPerParam = (dtype, quantization) => {
|
|
|
66
66
|
* @returns {number} Estimated KV cache size in bytes
|
|
67
67
|
*/
|
|
68
68
|
const estimateKvCache = (parameterCount, maxSequenceLength, batchSize) => {
|
|
69
|
-
const seqLength = maxSequenceLength ?? DEFAULT_MAX_SEQUENCE_LENGTH
|
|
70
|
-
const batch = batchSize ?? DEFAULT_BATCH_SIZE
|
|
69
|
+
const seqLength = maxSequenceLength ?? DEFAULT_MAX_SEQUENCE_LENGTH;
|
|
70
|
+
const batch = batchSize ?? DEFAULT_BATCH_SIZE;
|
|
71
71
|
|
|
72
72
|
// Heuristic: KV cache ≈ parameterCount × (seqLength / 4096) × batch × 0.05 bytes
|
|
73
73
|
// This gives ~5% of raw param count in bytes at default seq length and batch=1
|
|
74
74
|
// For 7B params: 7e9 × 0.05 = 350MB at seq=4096, batch=1
|
|
75
75
|
// Scales linearly with sequence length and batch size
|
|
76
|
-
const kvBytes = parameterCount * (seqLength / DEFAULT_MAX_SEQUENCE_LENGTH) * batch * 0.05
|
|
77
|
-
return kvBytes
|
|
78
|
-
}
|
|
76
|
+
const kvBytes = parameterCount * (seqLength / DEFAULT_MAX_SEQUENCE_LENGTH) * batch * 0.05;
|
|
77
|
+
return kvBytes;
|
|
78
|
+
};
|
|
79
79
|
|
|
80
80
|
// ── Main Estimation Function ─────────────────────────────────────────────────
|
|
81
81
|
|
|
@@ -97,28 +97,28 @@ const estimateVram = (modelInfo) => {
|
|
|
97
97
|
quantization,
|
|
98
98
|
maxSequenceLength,
|
|
99
99
|
batchSize
|
|
100
|
-
} = modelInfo
|
|
100
|
+
} = modelInfo;
|
|
101
101
|
|
|
102
102
|
// Determine confidence based on what was explicitly provided
|
|
103
|
-
const confidence = determineConfidence(modelInfo)
|
|
103
|
+
const confidence = determineConfidence(modelInfo);
|
|
104
104
|
|
|
105
105
|
// Calculate base weight bytes
|
|
106
|
-
const bpp = bytesPerParam(dtype, quantization)
|
|
107
|
-
const baseWeightBytes = parameterCount * bpp
|
|
106
|
+
const bpp = bytesPerParam(dtype, quantization);
|
|
107
|
+
const baseWeightBytes = parameterCount * bpp;
|
|
108
108
|
|
|
109
109
|
// Calculate KV cache
|
|
110
110
|
const kvCacheBytes = estimateKvCache(
|
|
111
111
|
parameterCount,
|
|
112
112
|
maxSequenceLength ?? DEFAULT_MAX_SEQUENCE_LENGTH,
|
|
113
113
|
batchSize ?? DEFAULT_BATCH_SIZE
|
|
114
|
-
)
|
|
114
|
+
);
|
|
115
115
|
|
|
116
116
|
// Calculate overhead (framework/CUDA)
|
|
117
|
-
const overheadBytes = baseWeightBytes * OVERHEAD_FACTOR
|
|
117
|
+
const overheadBytes = baseWeightBytes * OVERHEAD_FACTOR;
|
|
118
118
|
|
|
119
119
|
// Total VRAM
|
|
120
|
-
const totalVramBytes = baseWeightBytes + kvCacheBytes + overheadBytes
|
|
121
|
-
const vramGb = totalVramBytes / BYTES_IN_GB
|
|
120
|
+
const totalVramBytes = baseWeightBytes + kvCacheBytes + overheadBytes;
|
|
121
|
+
const vramGb = totalVramBytes / BYTES_IN_GB;
|
|
122
122
|
|
|
123
123
|
return {
|
|
124
124
|
vramGb,
|
|
@@ -129,8 +129,8 @@ const estimateVram = (modelInfo) => {
|
|
|
129
129
|
},
|
|
130
130
|
confidence,
|
|
131
131
|
source: 'estimate'
|
|
132
|
-
}
|
|
133
|
-
}
|
|
132
|
+
};
|
|
133
|
+
};
|
|
134
134
|
|
|
135
135
|
/**
|
|
136
136
|
* Determine confidence level based on which parameters were explicitly provided.
|
|
@@ -143,25 +143,25 @@ const estimateVram = (modelInfo) => {
|
|
|
143
143
|
* @returns {'high' | 'medium' | 'low'}
|
|
144
144
|
*/
|
|
145
145
|
const determineConfidence = (modelInfo) => {
|
|
146
|
-
const { parameterCount, dtype, maxSequenceLength, batchSize } = modelInfo
|
|
146
|
+
const { parameterCount, dtype, maxSequenceLength, batchSize } = modelInfo;
|
|
147
147
|
|
|
148
148
|
if (!parameterCount || !dtype) {
|
|
149
|
-
return 'low'
|
|
149
|
+
return 'low';
|
|
150
150
|
}
|
|
151
151
|
|
|
152
152
|
// If dtype is not in our known list, confidence drops
|
|
153
153
|
if (!BYTES_PER_PARAM[dtype]) {
|
|
154
|
-
return 'low'
|
|
154
|
+
return 'low';
|
|
155
155
|
}
|
|
156
156
|
|
|
157
157
|
// All key params explicitly provided
|
|
158
158
|
if (maxSequenceLength !== undefined && batchSize !== undefined) {
|
|
159
|
-
return 'high'
|
|
159
|
+
return 'high';
|
|
160
160
|
}
|
|
161
161
|
|
|
162
162
|
// Core params present but some optional ones use defaults
|
|
163
|
-
return 'medium'
|
|
164
|
-
}
|
|
163
|
+
return 'medium';
|
|
164
|
+
};
|
|
165
165
|
|
|
166
166
|
export {
|
|
167
167
|
estimateVram,
|
|
@@ -174,4 +174,4 @@ export {
|
|
|
174
174
|
DEFAULT_BATCH_SIZE,
|
|
175
175
|
OVERHEAD_FACTOR,
|
|
176
176
|
BYTES_IN_GB
|
|
177
|
-
}
|
|
177
|
+
};
|