@aws/ml-container-creator 0.2.5 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. package/bin/cli.js +7 -2
  2. package/package.json +7 -8
  3. package/servers/base-image-picker/index.js +3 -3
  4. package/servers/base-image-picker/manifest.json +4 -2
  5. package/servers/instance-sizer/index.js +561 -0
  6. package/servers/instance-sizer/lib/instance-ranker.js +245 -0
  7. package/servers/instance-sizer/lib/model-resolver.js +265 -0
  8. package/servers/instance-sizer/lib/vram-estimator.js +177 -0
  9. package/servers/instance-sizer/manifest.json +17 -0
  10. package/servers/instance-sizer/package.json +15 -0
  11. package/servers/{instance-recommender → lib}/catalogs/instances.json +136 -34
  12. package/servers/{base-image-picker → lib}/catalogs/model-servers.json +19 -249
  13. package/servers/lib/catalogs/model-sizes.json +131 -0
  14. package/servers/lib/catalogs/models.json +602 -0
  15. package/servers/{model-picker → lib}/catalogs/popular-diffusors.json +32 -10
  16. package/servers/{model-picker → lib}/catalogs/popular-transformers.json +59 -26
  17. package/servers/{base-image-picker → lib}/catalogs/python-slim.json +12 -12
  18. package/servers/lib/schemas/image-catalog.schema.json +0 -12
  19. package/servers/lib/schemas/instances.schema.json +29 -0
  20. package/servers/lib/schemas/model-catalog.schema.json +12 -10
  21. package/servers/lib/schemas/unified-model-catalog.schema.json +129 -0
  22. package/servers/model-picker/index.js +2 -3
  23. package/servers/model-picker/manifest.json +2 -3
  24. package/servers/region-picker/index.js +1 -1
  25. package/servers/region-picker/manifest.json +1 -1
  26. package/src/app.js +17 -0
  27. package/src/lib/bootstrap-command-handler.js +38 -0
  28. package/src/lib/cli-handler.js +3 -3
  29. package/src/lib/config-manager.js +4 -1
  30. package/src/lib/configuration-manager.js +2 -2
  31. package/src/lib/cross-cutting-checker.js +341 -0
  32. package/src/lib/dry-run-validator.js +78 -0
  33. package/src/lib/generation-validator.js +102 -0
  34. package/src/lib/mcp-validator-config.js +89 -0
  35. package/src/lib/payload-builder.js +153 -0
  36. package/src/lib/prompt-runner.js +445 -135
  37. package/src/lib/prompts.js +1 -1
  38. package/src/lib/registry-loader.js +5 -5
  39. package/src/lib/schema-sync.js +203 -0
  40. package/src/lib/schema-validation-engine.js +195 -0
  41. package/src/lib/service-model-parser.js +102 -0
  42. package/src/lib/validate-runner.js +167 -0
  43. package/src/lib/validation-report.js +133 -0
  44. package/src/lib/validators/base-validator.js +36 -0
  45. package/src/lib/validators/catalog-validator.js +177 -0
  46. package/src/lib/validators/enum-validator.js +120 -0
  47. package/src/lib/validators/required-field-validator.js +150 -0
  48. package/src/lib/validators/type-validator.js +313 -0
  49. package/templates/Dockerfile +1 -1
  50. package/templates/do/build +15 -5
  51. package/templates/do/run +5 -1
  52. package/templates/do/validate +61 -0
  53. package/servers/instance-recommender/LICENSE +0 -202
  54. package/servers/instance-recommender/index.js +0 -284
  55. package/servers/instance-recommender/manifest.json +0 -16
  56. package/servers/instance-recommender/package.json +0 -15
  57. /package/servers/{model-picker → lib}/catalogs/jumpstart-public.json +0 -0
  58. /package/servers/{region-picker → lib}/catalogs/regions.json +0 -0
  59. /package/servers/{base-image-picker → lib}/catalogs/triton-backends.json +0 -0
  60. /package/servers/{base-image-picker → lib}/catalogs/triton.json +0 -0
@@ -0,0 +1,245 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Instance Filter & Ranker
6
+ *
7
+ * Filters and ranks SageMaker instances by compatibility with a model's
8
+ * VRAM requirement. Considers tensor parallelism for multi-GPU instances
9
+ * and applies cost-efficiency ranking within each TP tier.
10
+ */
11
+
12
+ // ── Constants ────────────────────────────────────────────────────────────────
13
+
14
+ /**
15
+ * GPU memory per chip (in GB) by hardware type.
16
+ * Used when the catalog doesn't have a direct gpuMemoryGb field.
17
+ */
18
+ const GPU_MEMORY_MAP = {
19
+ 'NVIDIA T4': 16,
20
+ 'NVIDIA A10G': 24,
21
+ 'NVIDIA V100': 16,
22
+ 'NVIDIA L4': 24,
23
+ 'NVIDIA A100': 40,
24
+ 'NVIDIA H100': 80,
25
+ 'AWS Inferentia2': 32,
26
+ 'AWS Trainium': 32
27
+ }
28
+
29
+ /**
30
+ * Cost tier classification by instance family.
31
+ */
32
+ const COST_TIER_MAP = {
33
+ 'g4dn': 'low',
34
+ 'inf2': 'low',
35
+ 'g5': 'medium',
36
+ 'g6': 'medium',
37
+ 'trn1': 'medium',
38
+ 'p3': 'high',
39
+ 'p4d': 'high',
40
+ 'p4de': 'high',
41
+ 'p5': 'high'
42
+ }
43
+
44
+ /**
45
+ * Relative cost weight by tier for sorting within TP groups.
46
+ * Lower is better (more cost-efficient).
47
+ */
48
+ const COST_TIER_WEIGHT = {
49
+ 'low': 1,
50
+ 'medium': 2,
51
+ 'high': 3
52
+ }
53
+
54
+ /**
55
+ * TP overhead penalty: 10% per additional GPU beyond the first.
56
+ * Effective VRAM = totalVram × (1 - 0.10 × (gpuCount - 1))
57
+ */
58
+ const TP_OVERHEAD_PER_GPU = 0.10
59
+
60
+ // ── Helper Functions ─────────────────────────────────────────────────────────
61
+
62
+ /**
63
+ * Extract per-GPU memory in GB from an instance catalog entry.
64
+ *
65
+ * Tries these approaches in order:
66
+ * 1. Direct gpuMemoryGb field (if catalog has been extended)
67
+ * 2. Parse from accelerator string (e.g., "4x A10G 96GB" → 24 per GPU)
68
+ * 3. Lookup by hardware type from GPU_MEMORY_MAP
69
+ *
70
+ * @param {object} instance - Instance catalog entry
71
+ * @returns {number|null} Per-GPU memory in GB, or null if not determinable
72
+ */
73
+ const getPerGpuMemoryGb = (instance) => {
74
+ // 1. Direct field
75
+ if (instance.gpuMemoryGb) {
76
+ return instance.gpuMemoryGb
77
+ }
78
+
79
+ // 2. Parse from accelerator string
80
+ if (instance.accelerator) {
81
+ // Match patterns like "A10G 24GB", "4x A10G 96GB", "T4 16GB"
82
+ const totalMatch = instance.accelerator.match(/(\d+)GB/)
83
+ if (totalMatch) {
84
+ const totalGb = parseInt(totalMatch[1], 10)
85
+ const gpuCount = instance.gpus || 1
86
+ // If the string has a multiplier prefix like "4x", the GB is total
87
+ const hasMultiplier = instance.accelerator.match(/^(\d+)x\s/)
88
+ if (hasMultiplier) {
89
+ return totalGb / gpuCount
90
+ }
91
+ // Single GPU entry — the GB value is per-GPU
92
+ return totalGb
93
+ }
94
+ }
95
+
96
+ // 3. Lookup by hardware type
97
+ if (instance.hardware && GPU_MEMORY_MAP[instance.hardware]) {
98
+ return GPU_MEMORY_MAP[instance.hardware]
99
+ }
100
+
101
+ return null
102
+ }
103
+
104
+ /**
105
+ * Determine cost tier for an instance based on its family.
106
+ *
107
+ * @param {object} instance - Instance catalog entry
108
+ * @returns {string} 'low', 'medium', or 'high'
109
+ */
110
+ const getCostTier = (instance) => {
111
+ if (instance.costTier) {
112
+ return instance.costTier
113
+ }
114
+ const family = instance.family || ''
115
+ return COST_TIER_MAP[family] || 'medium'
116
+ }
117
+
118
+ /**
119
+ * Calculate effective VRAM available after TP overhead penalty.
120
+ *
121
+ * Each additional GPU beyond the first loses 10% of its per-GPU capacity
122
+ * to communication overhead. The first GPU contributes its full capacity.
123
+ *
124
+ * Formula: perGpuMemory + (gpuCount - 1) × perGpuMemory × (1 - TP_OVERHEAD_PER_GPU)
125
+ * Simplified: perGpuMemory × (1 + (gpuCount - 1) × 0.9)
126
+ * Or equivalently: totalVram - perGpuMemory × 0.10 × (gpuCount - 1)
127
+ *
128
+ * @param {number} totalVramGb - Total GPU VRAM in GB
129
+ * @param {number} gpuCount - Number of GPUs (TP degree)
130
+ * @returns {number} Effective usable VRAM in GB
131
+ */
132
+ const effectiveVram = (totalVramGb, gpuCount) => {
133
+ if (gpuCount <= 1) return totalVramGb
134
+ const perGpuMemory = totalVramGb / gpuCount
135
+ const overhead = perGpuMemory * TP_OVERHEAD_PER_GPU * (gpuCount - 1)
136
+ return totalVramGb - overhead
137
+ }
138
+
139
+ // ── Main Function ────────────────────────────────────────────────────────────
140
+
141
+ /**
142
+ * Filter and rank instances by compatibility with VRAM requirement.
143
+ *
144
+ * @param {number} vramRequired - Required VRAM in GB
145
+ * @param {object} instanceCatalog - Object keyed by instance type, values are metadata
146
+ * @param {object} [options={}]
147
+ * @param {number} [options.limit=8] - Max results to return
148
+ * @param {boolean} [options.allowTensorParallelism=true] - Consider multi-GPU splits
149
+ * @returns {object[]} Ranked list of compatible instances
150
+ */
151
+ const filterAndRankInstances = (vramRequired, instanceCatalog, options = {}) => {
152
+ const { limit = 8, allowTensorParallelism = true } = options
153
+
154
+ if (!vramRequired || vramRequired <= 0) {
155
+ return []
156
+ }
157
+
158
+ if (!instanceCatalog || typeof instanceCatalog !== 'object') {
159
+ return []
160
+ }
161
+
162
+ const candidates = []
163
+
164
+ for (const [instanceType, meta] of Object.entries(instanceCatalog)) {
165
+ // Skip non-GPU instances
166
+ if (!meta.gpus || meta.gpus <= 0) continue
167
+ if (meta.category !== 'gpu') continue
168
+
169
+ const perGpuMemory = getPerGpuMemoryGb(meta)
170
+ if (!perGpuMemory) continue
171
+
172
+ const gpuCount = meta.gpus
173
+ const totalVramGb = perGpuMemory * gpuCount
174
+
175
+ // Determine if model fits on a single GPU
176
+ if (gpuCount === 1) {
177
+ if (perGpuMemory >= vramRequired) {
178
+ const utilizationPercent = Math.round((vramRequired / perGpuMemory) * 100)
179
+ candidates.push({
180
+ instanceType,
181
+ gpuCount,
182
+ totalVramGb,
183
+ utilizationPercent,
184
+ tensorParallelism: 1,
185
+ costTier: getCostTier(meta)
186
+ })
187
+ }
188
+ } else if (allowTensorParallelism) {
189
+ // Multi-GPU: check if model fits with TP across all GPUs
190
+ const effectiveTotal = effectiveVram(totalVramGb, gpuCount)
191
+ if (effectiveTotal >= vramRequired) {
192
+ const utilizationPercent = Math.round((vramRequired / effectiveTotal) * 100)
193
+ candidates.push({
194
+ instanceType,
195
+ gpuCount,
196
+ totalVramGb,
197
+ utilizationPercent,
198
+ tensorParallelism: gpuCount,
199
+ costTier: getCostTier(meta)
200
+ })
201
+ }
202
+ }
203
+ }
204
+
205
+ // Sort candidates by ranking criteria:
206
+ // 1. Single-GPU first (TP=1), then multi-GPU by lowest TP degree
207
+ // 2. Within each TP tier, sort by cost-efficiency (lowest cost tier first,
208
+ // then by lowest utilization — more headroom is better for the same cost)
209
+ candidates.sort((a, b) => {
210
+ // Primary: TP degree (lower is better)
211
+ if (a.tensorParallelism !== b.tensorParallelism) {
212
+ return a.tensorParallelism - b.tensorParallelism
213
+ }
214
+
215
+ // Secondary: cost tier (lower is better)
216
+ const costA = COST_TIER_WEIGHT[a.costTier] || 2
217
+ const costB = COST_TIER_WEIGHT[b.costTier] || 2
218
+ if (costA !== costB) {
219
+ return costA - costB
220
+ }
221
+
222
+ // Tertiary: cost-efficiency — lower $/GB approximated by
223
+ // lower cost tier with higher total VRAM (more GB per dollar)
224
+ // Since cost tier is equal here, prefer higher total VRAM (better value)
225
+ if (a.totalVramGb !== b.totalVramGb) {
226
+ return a.totalVramGb - b.totalVramGb
227
+ }
228
+
229
+ // Final tiebreaker: instance type name for deterministic ordering
230
+ return a.instanceType.localeCompare(b.instanceType)
231
+ })
232
+
233
+ return candidates.slice(0, limit)
234
+ }
235
+
236
+ export {
237
+ filterAndRankInstances,
238
+ getPerGpuMemoryGb,
239
+ getCostTier,
240
+ effectiveVram,
241
+ GPU_MEMORY_MAP,
242
+ COST_TIER_MAP,
243
+ COST_TIER_WEIGHT,
244
+ TP_OVERHEAD_PER_GPU
245
+ }
@@ -0,0 +1,265 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Model Metadata Resolver
6
+ *
7
+ * Three-tier resolution strategy for model metadata:
8
+ * 1. Check model-sizes catalog (exact match or glob pattern match)
9
+ * 2. If discover mode enabled, fetch HuggingFace config.json
10
+ * 3. If neither available, return null (caller handles fallback)
11
+ */
12
+
13
+ import { readFile } from 'node:fs/promises'
14
+ import { fileURLToPath } from 'node:url'
15
+ import { dirname, join } from 'node:path'
16
+
17
+ // ── Constants ────────────────────────────────────────────────────────────────
18
+
19
+ const __filename = fileURLToPath(import.meta.url)
20
+ const __dirname = dirname(__filename)
21
+
22
+ const DEFAULT_CATALOG_PATH = join(__dirname, '..', '..', 'lib', 'catalogs', 'models.json')
23
+ const HUGGINGFACE_BASE_URL = 'https://huggingface.co'
24
+ const HUGGINGFACE_TIMEOUT_MS = 5000
25
+
26
+ // ── Glob Pattern Matching ────────────────────────────────────────────────────
27
+
28
+ /**
29
+ * Simple glob pattern matcher supporting * wildcards.
30
+ * Case-insensitive matching.
31
+ *
32
+ * @param {string} pattern - Glob pattern (e.g., 'meta-llama/Llama-2-7b*')
33
+ * @param {string} text - Text to match against
34
+ * @returns {boolean} Whether the text matches the pattern
35
+ */
36
+ const globMatch = (pattern, text) => {
37
+ const regexStr = pattern
38
+ .replace(/[.+^${}()|[\]\\]/g, '\\$&')
39
+ .replace(/\*/g, '.*')
40
+ const regex = new RegExp(`^${regexStr}$`, 'i')
41
+ return regex.test(text)
42
+ }
43
+
44
+ // ── Catalog Lookup ───────────────────────────────────────────────────────────
45
+
46
+ /**
47
+ * Load the model-sizes catalog from disk.
48
+ *
49
+ * @param {string} [catalogPath] - Path to catalog JSON file
50
+ * @returns {Promise<object|null>} Parsed catalog or null on failure
51
+ */
52
+ const loadCatalog = async (catalogPath) => {
53
+ try {
54
+ const raw = await readFile(catalogPath || DEFAULT_CATALOG_PATH, 'utf-8')
55
+ return JSON.parse(raw)
56
+ } catch {
57
+ return null
58
+ }
59
+ }
60
+
61
+ /**
62
+ * Look up a model in the catalog by exact match or glob pattern.
63
+ *
64
+ * @param {string} modelName - HuggingFace model ID or catalog key
65
+ * @param {object} catalog - Parsed catalog object (flat or with .models wrapper)
66
+ * @returns {object|null} Catalog entry or null if not found
67
+ */
68
+ const catalogLookup = (modelName, catalog) => {
69
+ if (!catalog) {
70
+ return null
71
+ }
72
+
73
+ // Support both flat catalog (models.json) and wrapped format ({ models: {...} })
74
+ const models = catalog.models || catalog
75
+
76
+ // Try exact match first
77
+ if (models[modelName]) {
78
+ return models[modelName]
79
+ }
80
+
81
+ // Try glob pattern matching
82
+ for (const pattern of Object.keys(models)) {
83
+ if (globMatch(pattern, modelName)) {
84
+ return models[pattern]
85
+ }
86
+ }
87
+
88
+ return null
89
+ }
90
+
91
+ // ── HuggingFace API ──────────────────────────────────────────────────────────
92
+
93
+ /**
94
+ * Fetch model config.json from HuggingFace Hub.
95
+ *
96
+ * @param {string} modelName - HuggingFace model ID (e.g., 'meta-llama/Llama-2-7b-chat-hf')
97
+ * @returns {Promise<object|null>} Parsed config or null on failure
98
+ */
99
+ const fetchHuggingFaceConfig = async (modelName) => {
100
+ const url = `${HUGGINGFACE_BASE_URL}/${modelName}/resolve/main/config.json`
101
+
102
+ try {
103
+ const controller = new AbortController()
104
+ const timeout = setTimeout(() => controller.abort(), HUGGINGFACE_TIMEOUT_MS)
105
+
106
+ const response = await fetch(url, {
107
+ signal: controller.signal,
108
+ headers: { 'Accept': 'application/json' }
109
+ })
110
+
111
+ clearTimeout(timeout)
112
+
113
+ if (!response.ok) {
114
+ return null
115
+ }
116
+
117
+ return await response.json()
118
+ } catch {
119
+ return null
120
+ }
121
+ }
122
+
123
+ /**
124
+ * Estimate parameter count from architecture dimensions.
125
+ * Uses the approximation: hidden_size × num_hidden_layers × 12
126
+ *
127
+ * This accounts for:
128
+ * - Attention weights (Q, K, V, O projections = 4 × hidden_size²)
129
+ * - FFN weights (typically 8 × hidden_size²)
130
+ * - Embeddings and other components
131
+ *
132
+ * @param {object} config - HuggingFace config.json contents
133
+ * @returns {number|null} Estimated parameter count or null if dimensions unavailable
134
+ */
135
+ const estimateParamsFromConfig = (config) => {
136
+ const hiddenSize = config.hidden_size
137
+ const numLayers = config.num_hidden_layers
138
+
139
+ if (!hiddenSize || !numLayers) {
140
+ return null
141
+ }
142
+
143
+ return hiddenSize * numLayers * 12
144
+ }
145
+
146
+ /**
147
+ * Extract model metadata from a HuggingFace config.json.
148
+ *
149
+ * @param {object} config - Parsed HuggingFace config.json
150
+ * @returns {object} Extracted metadata
151
+ */
152
+ const extractFromHuggingFaceConfig = (config) => {
153
+ const parameterCount = config.num_parameters
154
+ ?? estimateParamsFromConfig(config)
155
+
156
+ const dtype = config.torch_dtype || 'float16'
157
+ const architecture = config.architectures?.[0] || 'unknown'
158
+ const maxPositionEmbeddings = config.max_position_embeddings || 4096
159
+
160
+ return {
161
+ parameterCount,
162
+ dtype,
163
+ architecture,
164
+ maxPositionEmbeddings,
165
+ source: 'huggingface_api'
166
+ }
167
+ }
168
+
169
+ // ── In-memory cache for discover mode ────────────────────────────────────────
170
+
171
+ const discoverCache = new Map()
172
+
173
+ // ── Protocol prefix detection ────────────────────────────────────────────────
174
+
175
+ const PROTOCOL_PREFIXES = ['jumpstart://', 'jumpstart-hub://', 's3://', 'registry://']
176
+
177
+ /**
178
+ * Check if a model name matches the HuggingFace org/model-name pattern.
179
+ * Must contain exactly one `/` and no protocol prefix.
180
+ *
181
+ * @param {string} modelName - Model identifier to check
182
+ * @returns {boolean} True if it matches the HuggingFace pattern
183
+ */
184
+ const isHuggingFacePattern = (modelName) => {
185
+ if (!modelName || typeof modelName !== 'string') return false
186
+ // Must not have a protocol prefix
187
+ if (PROTOCOL_PREFIXES.some(prefix => modelName.startsWith(prefix))) return false
188
+ // Must contain exactly one `/` (org/model-name)
189
+ const slashCount = (modelName.match(/\//g) || []).length
190
+ return slashCount === 1
191
+ }
192
+
193
+ // ── Main Resolver ────────────────────────────────────────────────────────────
194
+
195
+ /**
196
+ * Resolve model metadata from available sources.
197
+ *
198
+ * Three-tier resolution:
199
+ * 1. Check model-sizes catalog (exact match or pattern match)
200
+ * 2. If discover mode enabled AND model matches HuggingFace pattern, fetch config.json
201
+ * 3. If neither available, return null
202
+ *
203
+ * @param {string} modelName - HuggingFace model ID or catalog key
204
+ * @param {object} [options={}]
205
+ * @param {boolean} [options.discover=false] - Enable HuggingFace API lookups
206
+ * @param {string} [options.catalogPath] - Path to model-sizes catalog (for testing)
207
+ * @returns {Promise<{ parameterCount: number, dtype: string, architecture: string, maxPositionEmbeddings: number, source: string } | null>}
208
+ */
209
+ const resolveModelMetadata = async (modelName, options = {}) => {
210
+ const { discover = false, catalogPath } = options
211
+
212
+ // Tier 1: Catalog lookup
213
+ const catalog = await loadCatalog(catalogPath)
214
+ const catalogEntry = catalogLookup(modelName, catalog)
215
+
216
+ if (catalogEntry) {
217
+ return {
218
+ parameterCount: catalogEntry.parameterCount,
219
+ dtype: catalogEntry.defaultDtype,
220
+ architecture: catalogEntry.architecture,
221
+ maxPositionEmbeddings: catalogEntry.maxPositionEmbeddings,
222
+ source: 'catalog'
223
+ }
224
+ }
225
+
226
+ // Tier 2: HuggingFace API (only in discover mode, only for org/model-name pattern)
227
+ if (discover && isHuggingFacePattern(modelName)) {
228
+ // Check in-memory cache first
229
+ if (discoverCache.has(modelName)) {
230
+ return discoverCache.get(modelName)
231
+ }
232
+
233
+ const config = await fetchHuggingFaceConfig(modelName)
234
+
235
+ if (config) {
236
+ const metadata = extractFromHuggingFaceConfig(config)
237
+
238
+ // Only return if we got a usable parameter count
239
+ if (metadata.parameterCount) {
240
+ // Cache for session duration
241
+ discoverCache.set(modelName, metadata)
242
+ return metadata
243
+ }
244
+ }
245
+ }
246
+
247
+ // Tier 3: No metadata available
248
+ return null
249
+ }
250
+
251
+ export {
252
+ resolveModelMetadata,
253
+ globMatch,
254
+ loadCatalog,
255
+ catalogLookup,
256
+ fetchHuggingFaceConfig,
257
+ estimateParamsFromConfig,
258
+ extractFromHuggingFaceConfig,
259
+ isHuggingFacePattern,
260
+ discoverCache,
261
+ PROTOCOL_PREFIXES,
262
+ DEFAULT_CATALOG_PATH,
263
+ HUGGINGFACE_BASE_URL,
264
+ HUGGINGFACE_TIMEOUT_MS
265
+ }
@@ -0,0 +1,177 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * VRAM Estimation Engine
6
+ *
7
+ * Converts model metadata (parameter count, dtype, quantization) into a
8
+ * memory requirement estimate. Used by the instance-sizer MCP server to
9
+ * filter and rank compatible SageMaker instances.
10
+ */
11
+
12
+ // ── Constants ────────────────────────────────────────────────────────────────
13
+
14
+ const BYTES_PER_PARAM = {
15
+ float32: 4.0,
16
+ float16: 2.0,
17
+ bfloat16: 2.0,
18
+ int8: 1.0,
19
+ int4: 0.5
20
+ }
21
+
22
+ const QUANTIZATION_BYTES = {
23
+ 'awq': 0.5,
24
+ 'gptq': 0.5,
25
+ 'bnb-4bit': 0.5,
26
+ 'bnb-8bit': 1.0
27
+ }
28
+
29
+ const BYTES_IN_GB = 1024 ** 3
30
+
31
+ const DEFAULT_MAX_SEQUENCE_LENGTH = 4096
32
+ const DEFAULT_BATCH_SIZE = 1
33
+ const OVERHEAD_FACTOR = 0.1
34
+
35
+ // ── Helper Functions ─────────────────────────────────────────────────────────
36
+
37
+ /**
38
+ * Look up bytes per parameter based on dtype and optional quantization.
39
+ * Quantization takes precedence over dtype when provided.
40
+ *
41
+ * @param {string} dtype - Data type: 'float32', 'float16', 'bfloat16', 'int8', 'int4'
42
+ * @param {string} [quantization] - Quantization method: 'awq', 'gptq', 'bnb-4bit', 'bnb-8bit'
43
+ * @returns {number} Bytes per parameter
44
+ */
45
+ const bytesPerParam = (dtype, quantization) => {
46
+ if (quantization && QUANTIZATION_BYTES[quantization] !== undefined) {
47
+ return QUANTIZATION_BYTES[quantization]
48
+ }
49
+ return BYTES_PER_PARAM[dtype] ?? BYTES_PER_PARAM.float16
50
+ }
51
+
52
+ /**
53
+ * Estimate KV cache memory usage.
54
+ *
55
+ * The KV cache scales with sequence length and batch size. This uses a
56
+ * simplified heuristic based on the observation that KV cache for a typical
57
+ * transformer is roughly proportional to (numLayers × hiddenSize × seqLen × batch × 2 keys+values × 2 bytes).
58
+ * We approximate numLayers × hiddenSize as sqrt(parameterCount) × scaling factor.
59
+ *
60
+ * For a 7B model at seq=4096, batch=1, this yields ~0.5GB which matches
61
+ * real-world observations for Llama-2-7B.
62
+ *
63
+ * @param {number} parameterCount - Total model parameters
64
+ * @param {number} maxSequenceLength - Maximum context/sequence length
65
+ * @param {number} batchSize - Expected concurrent batch size
66
+ * @returns {number} Estimated KV cache size in bytes
67
+ */
68
+ const estimateKvCache = (parameterCount, maxSequenceLength, batchSize) => {
69
+ const seqLength = maxSequenceLength ?? DEFAULT_MAX_SEQUENCE_LENGTH
70
+ const batch = batchSize ?? DEFAULT_BATCH_SIZE
71
+
72
+ // Heuristic: KV cache ≈ parameterCount × (seqLength / 4096) × batch × 0.05 bytes
73
+ // This gives ~5% of raw param count in bytes at default seq length and batch=1
74
+ // For 7B params: 7e9 × 0.05 = 350MB at seq=4096, batch=1
75
+ // Scales linearly with sequence length and batch size
76
+ const kvBytes = parameterCount * (seqLength / DEFAULT_MAX_SEQUENCE_LENGTH) * batch * 0.05
77
+ return kvBytes
78
+ }
79
+
80
+ // ── Main Estimation Function ─────────────────────────────────────────────────
81
+
82
+ /**
83
+ * Estimate VRAM required to serve a model.
84
+ *
85
+ * @param {object} modelInfo
86
+ * @param {number} modelInfo.parameterCount - Total parameters (e.g., 7_000_000_000)
87
+ * @param {string} modelInfo.dtype - Data type: 'float32', 'float16', 'bfloat16', 'int8', 'int4'
88
+ * @param {string} [modelInfo.quantization] - Quantization method: 'awq', 'gptq', 'bnb-4bit', 'bnb-8bit'
89
+ * @param {number} [modelInfo.maxSequenceLength] - Max context length (affects KV cache)
90
+ * @param {number} [modelInfo.batchSize] - Expected concurrent batch size
91
+ * @returns {{ vramGb: number, breakdown: { weightsGb: number, kvCacheGb: number, overheadGb: number }, confidence: string, source: string }}
92
+ */
93
+ const estimateVram = (modelInfo) => {
94
+ const {
95
+ parameterCount,
96
+ dtype,
97
+ quantization,
98
+ maxSequenceLength,
99
+ batchSize
100
+ } = modelInfo
101
+
102
+ // Determine confidence based on what was explicitly provided
103
+ const confidence = determineConfidence(modelInfo)
104
+
105
+ // Calculate base weight bytes
106
+ const bpp = bytesPerParam(dtype, quantization)
107
+ const baseWeightBytes = parameterCount * bpp
108
+
109
+ // Calculate KV cache
110
+ const kvCacheBytes = estimateKvCache(
111
+ parameterCount,
112
+ maxSequenceLength ?? DEFAULT_MAX_SEQUENCE_LENGTH,
113
+ batchSize ?? DEFAULT_BATCH_SIZE
114
+ )
115
+
116
+ // Calculate overhead (framework/CUDA)
117
+ const overheadBytes = baseWeightBytes * OVERHEAD_FACTOR
118
+
119
+ // Total VRAM
120
+ const totalVramBytes = baseWeightBytes + kvCacheBytes + overheadBytes
121
+ const vramGb = totalVramBytes / BYTES_IN_GB
122
+
123
+ return {
124
+ vramGb,
125
+ breakdown: {
126
+ weightsGb: baseWeightBytes / BYTES_IN_GB,
127
+ kvCacheGb: kvCacheBytes / BYTES_IN_GB,
128
+ overheadGb: overheadBytes / BYTES_IN_GB
129
+ },
130
+ confidence,
131
+ source: 'estimate'
132
+ }
133
+ }
134
+
135
+ /**
136
+ * Determine confidence level based on which parameters were explicitly provided.
137
+ *
138
+ * - 'high': All key parameters (parameterCount, dtype) are explicitly provided
139
+ * - 'medium': Some parameters are provided but others use defaults
140
+ * - 'low': Using fallback values for critical parameters
141
+ *
142
+ * @param {object} modelInfo
143
+ * @returns {'high' | 'medium' | 'low'}
144
+ */
145
+ const determineConfidence = (modelInfo) => {
146
+ const { parameterCount, dtype, maxSequenceLength, batchSize } = modelInfo
147
+
148
+ if (!parameterCount || !dtype) {
149
+ return 'low'
150
+ }
151
+
152
+ // If dtype is not in our known list, confidence drops
153
+ if (!BYTES_PER_PARAM[dtype]) {
154
+ return 'low'
155
+ }
156
+
157
+ // All key params explicitly provided
158
+ if (maxSequenceLength !== undefined && batchSize !== undefined) {
159
+ return 'high'
160
+ }
161
+
162
+ // Core params present but some optional ones use defaults
163
+ return 'medium'
164
+ }
165
+
166
+ export {
167
+ estimateVram,
168
+ bytesPerParam,
169
+ estimateKvCache,
170
+ determineConfidence,
171
+ BYTES_PER_PARAM,
172
+ QUANTIZATION_BYTES,
173
+ DEFAULT_MAX_SEQUENCE_LENGTH,
174
+ DEFAULT_BATCH_SIZE,
175
+ OVERHEAD_FACTOR,
176
+ BYTES_IN_GB
177
+ }
@@ -0,0 +1,17 @@
1
+ {
2
+ "name": "@amzn/ml-container-creator-instance-sizer",
3
+ "version": "1.0.0",
4
+ "description": "MCP server that estimates VRAM requirements and recommends compatible SageMaker instances.",
5
+ "modes": {
6
+ "static": true,
7
+ "smart": true,
8
+ "discover": true
9
+ },
10
+ "catalogs": {
11
+ "models": "../lib/catalogs/models.json",
12
+ "instances": "../lib/catalogs/instances.json"
13
+ },
14
+ "tool": {
15
+ "name": "get_instance_recommendation"
16
+ }
17
+ }