@aws/ml-container-creator 0.9.1 → 0.10.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. package/LICENSE-THIRD-PARTY +9304 -0
  2. package/bin/cli.js +2 -0
  3. package/config/bootstrap-e2e-stack.json +341 -0
  4. package/config/bootstrap-stack.json +40 -3
  5. package/config/parameter-schema-v2.json +2049 -0
  6. package/config/tune-catalog.json +1781 -0
  7. package/infra/ci-harness/buildspec.yml +1 -0
  8. package/infra/ci-harness/lambda/path-prover/brain.ts +306 -0
  9. package/infra/ci-harness/lambda/path-prover/write-results.ts +152 -0
  10. package/infra/ci-harness/lib/ci-harness-stack.ts +837 -7
  11. package/infra/ci-harness/state-machines/path-prover.asl.json +496 -0
  12. package/package.json +53 -68
  13. package/servers/base-image-picker/index.js +121 -121
  14. package/servers/e2e-status/index.js +297 -0
  15. package/servers/e2e-status/manifest.json +14 -0
  16. package/servers/e2e-status/package.json +15 -0
  17. package/servers/endpoint-picker/LICENSE +202 -0
  18. package/servers/endpoint-picker/index.js +536 -0
  19. package/servers/endpoint-picker/manifest.json +14 -0
  20. package/servers/endpoint-picker/package.json +18 -0
  21. package/servers/hyperpod-cluster-picker/index.js +125 -125
  22. package/servers/instance-sizer/index.js +138 -138
  23. package/servers/instance-sizer/lib/instance-ranker.js +76 -76
  24. package/servers/instance-sizer/lib/model-resolver.js +61 -61
  25. package/servers/instance-sizer/lib/quota-resolver.js +113 -113
  26. package/servers/instance-sizer/lib/vram-estimator.js +31 -31
  27. package/servers/lib/bedrock-client.js +38 -38
  28. package/servers/lib/catalogs/jumpstart-public.json +101 -16
  29. package/servers/lib/catalogs/model-servers.json +201 -3
  30. package/servers/lib/catalogs/models.json +182 -26
  31. package/servers/lib/custom-validators.js +13 -13
  32. package/servers/lib/dynamic-resolver.js +4 -4
  33. package/servers/marketplace-picker/index.js +342 -0
  34. package/servers/marketplace-picker/manifest.json +14 -0
  35. package/servers/marketplace-picker/package.json +18 -0
  36. package/servers/model-picker/index.js +382 -382
  37. package/servers/region-picker/index.js +56 -56
  38. package/servers/workload-picker/LICENSE +202 -0
  39. package/servers/workload-picker/catalogs/workload-profiles.json +67 -0
  40. package/servers/workload-picker/index.js +171 -0
  41. package/servers/workload-picker/manifest.json +16 -0
  42. package/servers/workload-picker/package.json +16 -0
  43. package/src/app.js +4 -390
  44. package/src/lib/bootstrap-command-handler.js +710 -1148
  45. package/src/lib/bootstrap-config.js +36 -0
  46. package/src/lib/bootstrap-profile-manager.js +641 -0
  47. package/src/lib/bootstrap-provisioners.js +421 -0
  48. package/src/lib/ci-register-helpers.js +74 -0
  49. package/src/lib/config-loader.js +408 -0
  50. package/src/lib/config-manager.js +66 -1685
  51. package/src/lib/config-mcp-client.js +118 -0
  52. package/src/lib/config-validator.js +634 -0
  53. package/src/lib/cuda-resolver.js +149 -0
  54. package/src/lib/e2e-catalog-validator.js +251 -3
  55. package/src/lib/e2e-ci-recorder.js +103 -0
  56. package/src/lib/generated/cli-options.js +315 -311
  57. package/src/lib/generated/parameter-matrix.js +671 -0
  58. package/src/lib/generated/validation-rules.js +71 -71
  59. package/src/lib/marketplace-flow.js +276 -0
  60. package/src/lib/mcp-query-runner.js +768 -0
  61. package/src/lib/parameter-schema-validator.js +62 -18
  62. package/src/lib/path-prover-brain.js +607 -0
  63. package/src/lib/prompt-runner.js +41 -1504
  64. package/src/lib/prompts/feature-prompts.js +172 -0
  65. package/src/lib/prompts/index.js +48 -0
  66. package/src/lib/prompts/infrastructure-prompts.js +690 -0
  67. package/src/lib/prompts/model-prompts.js +552 -0
  68. package/src/lib/prompts/project-prompts.js +82 -0
  69. package/src/lib/prompts.js +2 -1446
  70. package/src/lib/registry-command-handler.js +135 -3
  71. package/src/lib/secrets-prompt-runner.js +251 -0
  72. package/src/lib/template-variable-resolver.js +422 -0
  73. package/src/lib/tune-catalog-validator.js +37 -4
  74. package/templates/Dockerfile +9 -0
  75. package/templates/code/adapter_sidecar.py +444 -0
  76. package/templates/code/serve +6 -0
  77. package/templates/code/serve.d/vllm.ejs +1 -1
  78. package/templates/do/.benchmark_writer.py +1476 -0
  79. package/templates/do/.tune_helper.py +982 -57
  80. package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
  81. package/templates/do/adapter +149 -0
  82. package/templates/do/benchmark +639 -85
  83. package/templates/do/config +108 -5
  84. package/templates/do/deploy.d/managed-inference.ejs +192 -11
  85. package/templates/do/optimize +106 -37
  86. package/templates/do/register +89 -0
  87. package/templates/do/test +13 -0
  88. package/templates/do/tune +378 -59
  89. package/templates/do/validate +44 -4
  90. package/config/parameter-schema.json +0 -88
@@ -11,20 +11,20 @@
11
11
  * All methods degrade gracefully — API failures return null and log to stderr.
12
12
  */
13
13
 
14
- import { ServiceQuotasClient, ListServiceQuotasCommand } from '@aws-sdk/client-service-quotas'
15
- import { SageMakerClient, ListEndpointsCommand, ListTrainingPlansCommand } from '@aws-sdk/client-sagemaker'
14
+ import { ServiceQuotasClient, ListServiceQuotasCommand } from '@aws-sdk/client-service-quotas';
15
+ import { SageMakerClient, ListEndpointsCommand, ListTrainingPlansCommand } from '@aws-sdk/client-sagemaker';
16
16
 
17
17
  // ── Constants ────────────────────────────────────────────────────────────────
18
18
 
19
- const SAGEMAKER_SERVICE_CODE = 'sagemaker'
20
- const DEFAULT_TIMEOUT_MS = 5000
21
- const DEFAULT_CACHE_TTL_MS = 300000 // 5 minutes
22
- const QUOTA_NAME_PATTERN = /^(ml\.[a-z0-9]+\.[a-z0-9]+) for endpoint usage$/
19
+ const SAGEMAKER_SERVICE_CODE = 'sagemaker';
20
+ const DEFAULT_TIMEOUT_MS = 5000;
21
+ const DEFAULT_CACHE_TTL_MS = 300000; // 5 minutes
22
+ const QUOTA_NAME_PATTERN = /^(ml\.[a-z0-9]+\.[a-z0-9]+) for endpoint usage$/;
23
23
 
24
24
  // ── Logging ──────────────────────────────────────────────────────────────────
25
25
 
26
26
  function log(message) {
27
- process.stderr.write(`[quota-resolver] ${message}\n`)
27
+ process.stderr.write(`[quota-resolver] ${message}\n`);
28
28
  }
29
29
 
30
30
  // ── QuotaResolver Class ──────────────────────────────────────────────────────
@@ -37,20 +37,20 @@ class QuotaResolver {
37
37
  * @param {number} [options.cacheTtl=300000] - Cache TTL in ms (default 5 min)
38
38
  */
39
39
  constructor(region, options = {}) {
40
- this.region = region
41
- this.timeout = options.timeout || DEFAULT_TIMEOUT_MS
42
- this.cacheTtl = options.cacheTtl || DEFAULT_CACHE_TTL_MS
43
- this.cache = new Map()
40
+ this.region = region;
41
+ this.timeout = options.timeout || DEFAULT_TIMEOUT_MS;
42
+ this.cacheTtl = options.cacheTtl || DEFAULT_CACHE_TTL_MS;
43
+ this.cache = new Map();
44
44
 
45
45
  const clientConfig = {
46
46
  region: this.region,
47
47
  requestHandler: {
48
48
  requestTimeout: this.timeout
49
49
  }
50
- }
50
+ };
51
51
 
52
- this.quotasClient = new ServiceQuotasClient(clientConfig)
53
- this.sagemakerClient = new SageMakerClient(clientConfig)
52
+ this.quotasClient = new ServiceQuotasClient(clientConfig);
53
+ this.sagemakerClient = new SageMakerClient(clientConfig);
54
54
  }
55
55
 
56
56
  /**
@@ -59,13 +59,13 @@ class QuotaResolver {
59
59
  * @returns {*|null} Cached value or null
60
60
  */
61
61
  _getCached(key) {
62
- const entry = this.cache.get(key)
63
- if (!entry) return null
62
+ const entry = this.cache.get(key);
63
+ if (!entry) return null;
64
64
  if (Date.now() - entry.timestamp > this.cacheTtl) {
65
- this.cache.delete(key)
66
- return null
65
+ this.cache.delete(key);
66
+ return null;
67
67
  }
68
- return entry.value
68
+ return entry.value;
69
69
  }
70
70
 
71
71
  /**
@@ -74,7 +74,7 @@ class QuotaResolver {
74
74
  * @param {*} value - Value to cache
75
75
  */
76
76
  _setCache(key, value) {
77
- this.cache.set(key, { value, timestamp: Date.now() })
77
+ this.cache.set(key, { value, timestamp: Date.now() });
78
78
  }
79
79
 
80
80
  /**
@@ -85,8 +85,8 @@ class QuotaResolver {
85
85
  * @returns {string|null} Instance type or null if pattern doesn't match
86
86
  */
87
87
  _parseQuotaName(quotaName) {
88
- const match = quotaName.match(QUOTA_NAME_PATTERN)
89
- return match ? match[1] : null
88
+ const match = quotaName.match(QUOTA_NAME_PATTERN);
89
+ return match ? match[1] : null;
90
90
  }
91
91
 
92
92
  /**
@@ -100,50 +100,50 @@ class QuotaResolver {
100
100
  * @returns {Promise<Map|null>} Map: instanceType → { quota, deployed, headroom }, or null on failure
101
101
  */
102
102
  async getQuotaHeadroom(instanceTypes) {
103
- const cacheKey = 'quotaHeadroom'
104
- const cached = this._getCached(cacheKey)
105
- if (cached) return cached
103
+ const cacheKey = 'quotaHeadroom';
104
+ const cached = this._getCached(cacheKey);
105
+ if (cached) return cached;
106
106
 
107
107
  try {
108
108
  const [quotaMap, deployedMap] = await Promise.allSettled([
109
109
  this._fetchServiceQuotas(),
110
110
  this._fetchDeployedCounts()
111
- ])
111
+ ]);
112
112
 
113
- const quotas = quotaMap.status === 'fulfilled' ? quotaMap.value : null
114
- const deployed = deployedMap.status === 'fulfilled' ? deployedMap.value : null
113
+ const quotas = quotaMap.status === 'fulfilled' ? quotaMap.value : null;
114
+ const deployed = deployedMap.status === 'fulfilled' ? deployedMap.value : null;
115
115
 
116
116
  if (!quotas) {
117
- return null
117
+ return null;
118
118
  }
119
119
 
120
- const result = new Map()
121
- const deployedCounts = deployed || new Map()
120
+ const result = new Map();
121
+ const deployedCounts = deployed || new Map();
122
122
 
123
123
  for (const instanceType of instanceTypes) {
124
- const quota = quotas.get(instanceType)
125
- if (quota != null) {
126
- const deployedCount = deployedCounts.get(instanceType) || 0
127
- const headroom = quota - deployedCount
124
+ const quota = quotas.get(instanceType);
125
+ if (quota !== null && quota !== undefined) {
126
+ const deployedCount = deployedCounts.get(instanceType) || 0;
127
+ const headroom = quota - deployedCount;
128
128
  result.set(instanceType, {
129
129
  quota,
130
130
  deployed: deployedCount,
131
131
  headroom
132
- })
132
+ });
133
133
  }
134
134
  }
135
135
 
136
- this._setCache(cacheKey, result)
137
- return result
136
+ this._setCache(cacheKey, result);
137
+ return result;
138
138
  } catch (err) {
139
139
  if (err.name === 'AccessDeniedException' || err.Code === 'AccessDeniedException') {
140
- log(`AccessDenied: insufficient permissions for quota queries — skipping`)
140
+ log('AccessDenied: insufficient permissions for quota queries — skipping');
141
141
  } else if (err.name === 'ThrottlingException' || err.Code === 'ThrottlingException') {
142
- log(`Throttled: Service Quotas API rate limit hit — skipping`)
142
+ log('Throttled: Service Quotas API rate limit hit — skipping');
143
143
  } else {
144
- log(`Failed to get quota headroom: ${err.message}`)
144
+ log(`Failed to get quota headroom: ${err.message}`);
145
145
  }
146
- return null
146
+ return null;
147
147
  }
148
148
  }
149
149
 
@@ -154,28 +154,28 @@ class QuotaResolver {
154
154
  * @returns {Promise<Map>} Map: instanceType → quota limit (number)
155
155
  */
156
156
  async _fetchServiceQuotas() {
157
- const quotaMap = new Map()
158
- let nextToken = undefined
157
+ const quotaMap = new Map();
158
+ let nextToken = undefined;
159
159
 
160
160
  do {
161
161
  const command = new ListServiceQuotasCommand({
162
162
  ServiceCode: SAGEMAKER_SERVICE_CODE,
163
163
  ...(nextToken && { NextToken: nextToken })
164
- })
164
+ });
165
165
 
166
- const response = await this.quotasClient.send(command)
166
+ const response = await this.quotasClient.send(command);
167
167
 
168
168
  for (const quota of (response.Quotas || [])) {
169
- const instanceType = this._parseQuotaName(quota.QuotaName || '')
170
- if (instanceType && quota.Value != null) {
171
- quotaMap.set(instanceType, quota.Value)
169
+ const instanceType = this._parseQuotaName(quota.QuotaName || '');
170
+ if (instanceType && quota.Value !== null && quota.Value !== undefined) {
171
+ quotaMap.set(instanceType, quota.Value);
172
172
  }
173
173
  }
174
174
 
175
- nextToken = response.NextToken
176
- } while (nextToken)
175
+ nextToken = response.NextToken;
176
+ } while (nextToken);
177
177
 
178
- return quotaMap
178
+ return quotaMap;
179
179
  }
180
180
 
181
181
  /**
@@ -185,16 +185,16 @@ class QuotaResolver {
185
185
  * @returns {Promise<Map>} Map: instanceType → deployed count
186
186
  */
187
187
  async _fetchDeployedCounts() {
188
- const deployedMap = new Map()
189
- let nextToken = undefined
188
+ const deployedMap = new Map();
189
+ let nextToken = undefined;
190
190
 
191
191
  do {
192
192
  const command = new ListEndpointsCommand({
193
193
  StatusEquals: 'InService',
194
194
  ...(nextToken && { NextToken: nextToken })
195
- })
195
+ });
196
196
 
197
- const response = await this.sagemakerClient.send(command)
197
+ const response = await this.sagemakerClient.send(command);
198
198
 
199
199
  for (const endpoint of (response.Endpoints || [])) {
200
200
  // ListEndpoints returns endpoint summaries; instance type info
@@ -208,18 +208,18 @@ class QuotaResolver {
208
208
  if (endpoint.ProductionVariants) {
209
209
  for (const variant of endpoint.ProductionVariants) {
210
210
  if (variant.InstanceType) {
211
- const current = deployedMap.get(variant.InstanceType) || 0
212
- const count = variant.CurrentInstanceCount || 1
213
- deployedMap.set(variant.InstanceType, current + count)
211
+ const current = deployedMap.get(variant.InstanceType) || 0;
212
+ const count = variant.CurrentInstanceCount || 1;
213
+ deployedMap.set(variant.InstanceType, current + count);
214
214
  }
215
215
  }
216
216
  }
217
217
  }
218
218
 
219
- nextToken = response.NextToken
220
- } while (nextToken)
219
+ nextToken = response.NextToken;
220
+ } while (nextToken);
221
221
 
222
- return deployedMap
222
+ return deployedMap;
223
223
  }
224
224
 
225
225
  /**
@@ -234,46 +234,46 @@ class QuotaResolver {
234
234
  * @returns {Promise<Map|null>} Map: instanceType → { planName, planArn, remainingCapacity, startDate, endDate }, or null on failure
235
235
  */
236
236
  async getCapacityReservations() {
237
- const cacheKey = 'capacityReservations'
238
- const cached = this._getCached(cacheKey)
239
- if (cached) return cached
237
+ const cacheKey = 'capacityReservations';
238
+ const cached = this._getCached(cacheKey);
239
+ if (cached) return cached;
240
240
 
241
241
  try {
242
- const result = new Map()
243
- let nextToken = undefined
242
+ const result = new Map();
243
+ let nextToken = undefined;
244
244
 
245
245
  do {
246
246
  const command = new ListTrainingPlansCommand({
247
247
  StatusEquals: 'Active',
248
248
  ...(nextToken && { NextToken: nextToken })
249
- })
249
+ });
250
250
 
251
- const response = await this.sagemakerClient.send(command)
252
- const now = new Date()
251
+ const response = await this.sagemakerClient.send(command);
252
+ const now = new Date();
253
253
 
254
254
  for (const plan of (response.TrainingPlanSummaries || [])) {
255
255
  // Only include plans targeting inference endpoints
256
- const targetResources = plan.TargetResources || []
257
- if (!targetResources.includes('endpoint')) continue
256
+ const targetResources = plan.TargetResources || [];
257
+ if (!targetResources.includes('endpoint')) continue;
258
258
 
259
- const instanceType = plan.InstanceType || plan.ReservedCapacityInstanceType
260
- if (!instanceType) continue
259
+ const instanceType = plan.InstanceType || plan.ReservedCapacityInstanceType;
260
+ if (!instanceType) continue;
261
261
 
262
- const planArn = plan.TrainingPlanArn
263
- const planName = plan.TrainingPlanName || 'unknown'
262
+ const planArn = plan.TrainingPlanArn;
263
+ const planName = plan.TrainingPlanName || 'unknown';
264
264
  const remainingCapacity = plan.AvailableInstanceCount
265
265
  ?? plan.RemainingCapacity
266
266
  ?? plan.TotalInstanceCount
267
- ?? 0
268
- const startDate = plan.StartTime || null
269
- const endDate = plan.EndTime || plan.ExpirationTime || null
267
+ ?? 0;
268
+ const startDate = plan.StartTime || null;
269
+ const endDate = plan.EndTime || plan.ExpirationTime || null;
270
270
 
271
271
  // Skip plans outside their time window
272
- if (startDate && new Date(startDate) > now) continue
273
- if (endDate && new Date(endDate) < now) continue
272
+ if (startDate && new Date(startDate) > now) continue;
273
+ if (endDate && new Date(endDate) < now) continue;
274
274
 
275
275
  // Only include if there's remaining capacity
276
- if (remainingCapacity <= 0) continue
276
+ if (remainingCapacity <= 0) continue;
277
277
 
278
278
  result.set(instanceType, {
279
279
  planName,
@@ -282,25 +282,25 @@ class QuotaResolver {
282
282
  count: remainingCapacity,
283
283
  startDate: startDate ? (startDate instanceof Date ? startDate.toISOString() : startDate) : null,
284
284
  endDate: endDate ? (endDate instanceof Date ? endDate.toISOString() : endDate) : null
285
- })
285
+ });
286
286
  }
287
287
 
288
- nextToken = response.NextToken
289
- } while (nextToken)
288
+ nextToken = response.NextToken;
289
+ } while (nextToken);
290
290
 
291
- this._setCache(cacheKey, result)
292
- return result
291
+ this._setCache(cacheKey, result);
292
+ return result;
293
293
  } catch (err) {
294
294
  if (err.name === 'AccessDeniedException' || err.Code === 'AccessDeniedException') {
295
- log(`AccessDenied: insufficient permissions for training plan queries — skipping`)
295
+ log('AccessDenied: insufficient permissions for training plan queries — skipping');
296
296
  } else if (err.name === 'ValidationException') {
297
- log(`ListTrainingPlans not available in region ${this.region} — skipping`)
297
+ log(`ListTrainingPlans not available in region ${this.region} — skipping`);
298
298
  } else if (err.name === 'ThrottlingException' || err.Code === 'ThrottlingException') {
299
- log(`Throttled: ListTrainingPlans rate limit hit — skipping`)
299
+ log('Throttled: ListTrainingPlans rate limit hit — skipping');
300
300
  } else {
301
- log(`Failed to get capacity reservations: ${err.message}`)
301
+ log(`Failed to get capacity reservations: ${err.message}`);
302
302
  }
303
- return null
303
+ return null;
304
304
  }
305
305
  }
306
306
 
@@ -313,56 +313,56 @@ class QuotaResolver {
313
313
  * @returns {Promise<Map|null>} Map: instanceType → { planName, remainingCapacity, expiresAt }, or null on failure
314
314
  */
315
315
  async getTrainingPlans() {
316
- const cacheKey = 'trainingPlans'
317
- const cached = this._getCached(cacheKey)
318
- if (cached) return cached
316
+ const cacheKey = 'trainingPlans';
317
+ const cached = this._getCached(cacheKey);
318
+ if (cached) return cached;
319
319
 
320
320
  try {
321
- const result = new Map()
322
- let nextToken = undefined
321
+ const result = new Map();
322
+ let nextToken = undefined;
323
323
 
324
324
  do {
325
325
  const command = new ListTrainingPlansCommand({
326
326
  StatusEquals: 'Active',
327
327
  ...(nextToken && { NextToken: nextToken })
328
- })
328
+ });
329
329
 
330
- const response = await this.sagemakerClient.send(command)
330
+ const response = await this.sagemakerClient.send(command);
331
331
 
332
332
  for (const plan of (response.TrainingPlanSummaries || [])) {
333
- const instanceType = plan.InstanceType || plan.ReservedCapacityInstanceType
334
- const planName = plan.TrainingPlanName || plan.TrainingPlanArn || 'unknown'
333
+ const instanceType = plan.InstanceType || plan.ReservedCapacityInstanceType;
334
+ const planName = plan.TrainingPlanName || plan.TrainingPlanArn || 'unknown';
335
335
  const remainingCapacity = plan.AvailableInstanceCount
336
336
  ?? plan.RemainingCapacity
337
337
  ?? plan.TotalInstanceCount
338
- ?? 0
339
- const expiresAt = plan.EndTime || plan.ExpirationTime || null
338
+ ?? 0;
339
+ const expiresAt = plan.EndTime || plan.ExpirationTime || null;
340
340
 
341
341
  if (instanceType && remainingCapacity > 0) {
342
342
  result.set(instanceType, {
343
343
  planName,
344
344
  remainingCapacity,
345
345
  expiresAt
346
- })
346
+ });
347
347
  }
348
348
  }
349
349
 
350
- nextToken = response.NextToken
351
- } while (nextToken)
350
+ nextToken = response.NextToken;
351
+ } while (nextToken);
352
352
 
353
- this._setCache(cacheKey, result)
354
- return result
353
+ this._setCache(cacheKey, result);
354
+ return result;
355
355
  } catch (err) {
356
356
  if (err.name === 'AccessDeniedException' || err.Code === 'AccessDeniedException') {
357
- log(`AccessDenied: insufficient permissions for training plan queries — skipping`)
357
+ log('AccessDenied: insufficient permissions for training plan queries — skipping');
358
358
  } else if (err.name === 'ValidationException') {
359
- log(`ListTrainingPlans not available in region ${this.region} — skipping`)
359
+ log(`ListTrainingPlans not available in region ${this.region} — skipping`);
360
360
  } else {
361
- log(`Failed to get training plans: ${err.message}`)
361
+ log(`Failed to get training plans: ${err.message}`);
362
362
  }
363
- return null
363
+ return null;
364
364
  }
365
365
  }
366
366
  }
367
367
 
368
- export { QuotaResolver, QUOTA_NAME_PATTERN, SAGEMAKER_SERVICE_CODE, DEFAULT_TIMEOUT_MS, DEFAULT_CACHE_TTL_MS }
368
+ export { QuotaResolver, QUOTA_NAME_PATTERN, SAGEMAKER_SERVICE_CODE, DEFAULT_TIMEOUT_MS, DEFAULT_CACHE_TTL_MS };
@@ -17,20 +17,20 @@ const BYTES_PER_PARAM = {
17
17
  bfloat16: 2.0,
18
18
  int8: 1.0,
19
19
  int4: 0.5
20
- }
20
+ };
21
21
 
22
22
  const QUANTIZATION_BYTES = {
23
23
  'awq': 0.5,
24
24
  'gptq': 0.5,
25
25
  'bnb-4bit': 0.5,
26
26
  'bnb-8bit': 1.0
27
- }
27
+ };
28
28
 
29
- const BYTES_IN_GB = 1024 ** 3
29
+ const BYTES_IN_GB = 1024 ** 3;
30
30
 
31
- const DEFAULT_MAX_SEQUENCE_LENGTH = 4096
32
- const DEFAULT_BATCH_SIZE = 1
33
- const OVERHEAD_FACTOR = 0.1
31
+ const DEFAULT_MAX_SEQUENCE_LENGTH = 4096;
32
+ const DEFAULT_BATCH_SIZE = 1;
33
+ const OVERHEAD_FACTOR = 0.1;
34
34
 
35
35
  // ── Helper Functions ─────────────────────────────────────────────────────────
36
36
 
@@ -44,10 +44,10 @@ const OVERHEAD_FACTOR = 0.1
44
44
  */
45
45
  const bytesPerParam = (dtype, quantization) => {
46
46
  if (quantization && QUANTIZATION_BYTES[quantization] !== undefined) {
47
- return QUANTIZATION_BYTES[quantization]
47
+ return QUANTIZATION_BYTES[quantization];
48
48
  }
49
- return BYTES_PER_PARAM[dtype] ?? BYTES_PER_PARAM.float16
50
- }
49
+ return BYTES_PER_PARAM[dtype] ?? BYTES_PER_PARAM.float16;
50
+ };
51
51
 
52
52
  /**
53
53
  * Estimate KV cache memory usage.
@@ -66,16 +66,16 @@ const bytesPerParam = (dtype, quantization) => {
66
66
  * @returns {number} Estimated KV cache size in bytes
67
67
  */
68
68
  const estimateKvCache = (parameterCount, maxSequenceLength, batchSize) => {
69
- const seqLength = maxSequenceLength ?? DEFAULT_MAX_SEQUENCE_LENGTH
70
- const batch = batchSize ?? DEFAULT_BATCH_SIZE
69
+ const seqLength = maxSequenceLength ?? DEFAULT_MAX_SEQUENCE_LENGTH;
70
+ const batch = batchSize ?? DEFAULT_BATCH_SIZE;
71
71
 
72
72
  // Heuristic: KV cache ≈ parameterCount × (seqLength / 4096) × batch × 0.05 bytes
73
73
  // This gives ~5% of raw param count in bytes at default seq length and batch=1
74
74
  // For 7B params: 7e9 × 0.05 = 350MB at seq=4096, batch=1
75
75
  // Scales linearly with sequence length and batch size
76
- const kvBytes = parameterCount * (seqLength / DEFAULT_MAX_SEQUENCE_LENGTH) * batch * 0.05
77
- return kvBytes
78
- }
76
+ const kvBytes = parameterCount * (seqLength / DEFAULT_MAX_SEQUENCE_LENGTH) * batch * 0.05;
77
+ return kvBytes;
78
+ };
79
79
 
80
80
  // ── Main Estimation Function ─────────────────────────────────────────────────
81
81
 
@@ -97,28 +97,28 @@ const estimateVram = (modelInfo) => {
97
97
  quantization,
98
98
  maxSequenceLength,
99
99
  batchSize
100
- } = modelInfo
100
+ } = modelInfo;
101
101
 
102
102
  // Determine confidence based on what was explicitly provided
103
- const confidence = determineConfidence(modelInfo)
103
+ const confidence = determineConfidence(modelInfo);
104
104
 
105
105
  // Calculate base weight bytes
106
- const bpp = bytesPerParam(dtype, quantization)
107
- const baseWeightBytes = parameterCount * bpp
106
+ const bpp = bytesPerParam(dtype, quantization);
107
+ const baseWeightBytes = parameterCount * bpp;
108
108
 
109
109
  // Calculate KV cache
110
110
  const kvCacheBytes = estimateKvCache(
111
111
  parameterCount,
112
112
  maxSequenceLength ?? DEFAULT_MAX_SEQUENCE_LENGTH,
113
113
  batchSize ?? DEFAULT_BATCH_SIZE
114
- )
114
+ );
115
115
 
116
116
  // Calculate overhead (framework/CUDA)
117
- const overheadBytes = baseWeightBytes * OVERHEAD_FACTOR
117
+ const overheadBytes = baseWeightBytes * OVERHEAD_FACTOR;
118
118
 
119
119
  // Total VRAM
120
- const totalVramBytes = baseWeightBytes + kvCacheBytes + overheadBytes
121
- const vramGb = totalVramBytes / BYTES_IN_GB
120
+ const totalVramBytes = baseWeightBytes + kvCacheBytes + overheadBytes;
121
+ const vramGb = totalVramBytes / BYTES_IN_GB;
122
122
 
123
123
  return {
124
124
  vramGb,
@@ -129,8 +129,8 @@ const estimateVram = (modelInfo) => {
129
129
  },
130
130
  confidence,
131
131
  source: 'estimate'
132
- }
133
- }
132
+ };
133
+ };
134
134
 
135
135
  /**
136
136
  * Determine confidence level based on which parameters were explicitly provided.
@@ -143,25 +143,25 @@ const estimateVram = (modelInfo) => {
143
143
  * @returns {'high' | 'medium' | 'low'}
144
144
  */
145
145
  const determineConfidence = (modelInfo) => {
146
- const { parameterCount, dtype, maxSequenceLength, batchSize } = modelInfo
146
+ const { parameterCount, dtype, maxSequenceLength, batchSize } = modelInfo;
147
147
 
148
148
  if (!parameterCount || !dtype) {
149
- return 'low'
149
+ return 'low';
150
150
  }
151
151
 
152
152
  // If dtype is not in our known list, confidence drops
153
153
  if (!BYTES_PER_PARAM[dtype]) {
154
- return 'low'
154
+ return 'low';
155
155
  }
156
156
 
157
157
  // All key params explicitly provided
158
158
  if (maxSequenceLength !== undefined && batchSize !== undefined) {
159
- return 'high'
159
+ return 'high';
160
160
  }
161
161
 
162
162
  // Core params present but some optional ones use defaults
163
- return 'medium'
164
- }
163
+ return 'medium';
164
+ };
165
165
 
166
166
  export {
167
167
  estimateVram,
@@ -174,4 +174,4 @@ export {
174
174
  DEFAULT_BATCH_SIZE,
175
175
  OVERHEAD_FACTOR,
176
176
  BYTES_IN_GB
177
- }
177
+ };