@aws/ml-container-creator 0.2.5 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. package/bin/cli.js +7 -2
  2. package/package.json +7 -8
  3. package/servers/base-image-picker/index.js +3 -3
  4. package/servers/base-image-picker/manifest.json +4 -2
  5. package/servers/instance-sizer/index.js +561 -0
  6. package/servers/instance-sizer/lib/instance-ranker.js +245 -0
  7. package/servers/instance-sizer/lib/model-resolver.js +265 -0
  8. package/servers/instance-sizer/lib/vram-estimator.js +177 -0
  9. package/servers/instance-sizer/manifest.json +17 -0
  10. package/servers/instance-sizer/package.json +15 -0
  11. package/servers/{instance-recommender → lib}/catalogs/instances.json +136 -34
  12. package/servers/{base-image-picker → lib}/catalogs/model-servers.json +19 -249
  13. package/servers/lib/catalogs/model-sizes.json +131 -0
  14. package/servers/lib/catalogs/models.json +602 -0
  15. package/servers/{model-picker → lib}/catalogs/popular-diffusors.json +32 -10
  16. package/servers/{model-picker → lib}/catalogs/popular-transformers.json +59 -26
  17. package/servers/{base-image-picker → lib}/catalogs/python-slim.json +12 -12
  18. package/servers/lib/schemas/image-catalog.schema.json +0 -12
  19. package/servers/lib/schemas/instances.schema.json +29 -0
  20. package/servers/lib/schemas/model-catalog.schema.json +12 -10
  21. package/servers/lib/schemas/unified-model-catalog.schema.json +129 -0
  22. package/servers/model-picker/index.js +2 -3
  23. package/servers/model-picker/manifest.json +2 -3
  24. package/servers/region-picker/index.js +1 -1
  25. package/servers/region-picker/manifest.json +1 -1
  26. package/src/app.js +17 -0
  27. package/src/lib/bootstrap-command-handler.js +38 -0
  28. package/src/lib/cli-handler.js +3 -3
  29. package/src/lib/config-manager.js +4 -1
  30. package/src/lib/configuration-manager.js +2 -2
  31. package/src/lib/cross-cutting-checker.js +341 -0
  32. package/src/lib/dry-run-validator.js +78 -0
  33. package/src/lib/generation-validator.js +102 -0
  34. package/src/lib/mcp-validator-config.js +89 -0
  35. package/src/lib/payload-builder.js +153 -0
  36. package/src/lib/prompt-runner.js +445 -135
  37. package/src/lib/prompts.js +1 -1
  38. package/src/lib/registry-loader.js +5 -5
  39. package/src/lib/schema-sync.js +203 -0
  40. package/src/lib/schema-validation-engine.js +195 -0
  41. package/src/lib/service-model-parser.js +102 -0
  42. package/src/lib/validate-runner.js +167 -0
  43. package/src/lib/validation-report.js +133 -0
  44. package/src/lib/validators/base-validator.js +36 -0
  45. package/src/lib/validators/catalog-validator.js +177 -0
  46. package/src/lib/validators/enum-validator.js +120 -0
  47. package/src/lib/validators/required-field-validator.js +150 -0
  48. package/src/lib/validators/type-validator.js +313 -0
  49. package/templates/Dockerfile +1 -1
  50. package/templates/do/build +15 -5
  51. package/templates/do/run +5 -1
  52. package/templates/do/validate +61 -0
  53. package/servers/instance-recommender/LICENSE +0 -202
  54. package/servers/instance-recommender/index.js +0 -284
  55. package/servers/instance-recommender/manifest.json +0 -16
  56. package/servers/instance-recommender/package.json +0 -15
  57. /package/servers/{model-picker → lib}/catalogs/jumpstart-public.json +0 -0
  58. /package/servers/{region-picker → lib}/catalogs/regions.json +0 -0
  59. /package/servers/{base-image-picker → lib}/catalogs/triton-backends.json +0 -0
  60. /package/servers/{base-image-picker → lib}/catalogs/triton.json +0 -0
package/bin/cli.js CHANGED
@@ -101,6 +101,7 @@ program
101
101
  .addOption(new Option('--discover', 'Enable live registry lookups via MCP discovery'))
102
102
 
103
103
  // --- Validation ---
104
+ .addOption(new Option('--no-validate', 'Skip schema-driven validation at generation time'))
104
105
  .addOption(new Option('--validate-env-vars', 'Enable environment variable validation (default: true)'))
105
106
  .addOption(new Option('--validate-with-docker', 'Enable Docker introspection validation (opt-in)'))
106
107
  .addOption(new Option('--offline', 'Disable HuggingFace API lookups'))
@@ -179,7 +180,7 @@ program.configureHelp({
179
180
  groups.features.push(opt);
180
181
  } else if (['--smart', '--discover'].includes(long)) {
181
182
  groups.mcp.push(opt);
182
- } else if (['--validate-env-vars', '--validate-with-docker', '--offline'].includes(long)) {
183
+ } else if (['--validate-env-vars', '--validate-with-docker', '--offline', '--no-validate'].includes(long)) {
183
184
  groups.validation.push(opt);
184
185
  } else {
185
186
  groups.general.push(opt);
@@ -241,7 +242,7 @@ program
241
242
  .command('bootstrap')
242
243
  .description('Set up AWS infrastructure (IAM role, ECR repo, S3 buckets)')
243
244
  .passThroughOptions()
244
- .argument('[action]', 'Bootstrap action (status, use, list, remove, scan, prune, update)')
245
+ .argument('[action]', 'Bootstrap action (status, use, list, remove, scan, prune, update, sync-schemas)')
245
246
  .argument('[args...]', 'Additional arguments')
246
247
  .option('--profile <profile>', 'AWS profile name')
247
248
  .option('--region <region>', 'AWS region')
@@ -250,6 +251,10 @@ program
250
251
  .option('--force', 'Force removal without confirmation')
251
252
  .option('--verify', 'Verify resources exist (for status)')
252
253
  .option('--delete-stack', 'Delete CloudFormation stack on remove')
254
+ .option('--ignore-staleness', 'Suppress schema staleness warnings')
255
+ .option('--ci', 'Provision CI integration infrastructure')
256
+ .option('--skip-ci', 'Skip CI infrastructure provisioning')
257
+ .option('--skip-s3', 'Skip S3 bucket creation')
253
258
  .action(async (action, args, options) => {
254
259
  const { default: BootstrapCommandHandler } = await import('../src/lib/bootstrap-command-handler.js');
255
260
  const handler = new BootstrapCommandHandler();
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@aws/ml-container-creator",
3
- "version": "0.2.5",
3
+ "version": "0.2.6",
4
4
  "description": "Generator for SageMaker AI BYOC paradigm for predictive inference use-cases.",
5
5
  "type": "module",
6
6
  "main": "src/app.js",
@@ -19,25 +19,23 @@
19
19
  "bin/",
20
20
  "src/",
21
21
  "templates/",
22
- "servers/base-image-picker/catalogs/",
22
+ "servers/lib/catalogs/",
23
23
  "servers/base-image-picker/index.js",
24
24
  "servers/base-image-picker/manifest.json",
25
25
  "servers/base-image-picker/package.json",
26
- "servers/instance-recommender/catalogs/",
27
- "servers/instance-recommender/index.js",
28
- "servers/instance-recommender/manifest.json",
29
- "servers/instance-recommender/package.json",
30
- "servers/model-picker/catalogs/",
31
26
  "servers/model-picker/index.js",
32
27
  "servers/model-picker/manifest.json",
33
28
  "servers/model-picker/package.json",
34
- "servers/region-picker/catalogs/",
35
29
  "servers/region-picker/index.js",
36
30
  "servers/region-picker/manifest.json",
37
31
  "servers/region-picker/package.json",
38
32
  "servers/hyperpod-cluster-picker/index.js",
39
33
  "servers/hyperpod-cluster-picker/manifest.json",
40
34
  "servers/hyperpod-cluster-picker/package.json",
35
+ "servers/instance-sizer/lib/",
36
+ "servers/instance-sizer/index.js",
37
+ "servers/instance-sizer/manifest.json",
38
+ "servers/instance-sizer/package.json",
41
39
  "servers/lib/bedrock-client.js",
42
40
  "servers/lib/custom-validators.js",
43
41
  "servers/lib/dynamic-resolver.js",
@@ -91,6 +89,7 @@
91
89
  "dev": "npm link && ml-container-creator",
92
90
  "clean": "rm -rf test-output-*",
93
91
  "validate": "npm run lint && npm run test:all",
92
+ "validate:catalogs": "node scripts/validate-catalog-enums.js",
94
93
  "validate:namespaces": "node scripts/validate-namespaces.js",
95
94
  "docs:serve": "mkdocs serve",
96
95
  "docs:build": "mkdocs build",
@@ -104,9 +104,9 @@ let PYTHON_SLIM_CATALOG
104
104
  let TRITON_IMAGE_CATALOG
105
105
 
106
106
  try {
107
- TRANSFORMER_IMAGE_CATALOG = loadCatalog('./catalogs/model-servers.json')
108
- PYTHON_SLIM_CATALOG = loadCatalog('./catalogs/python-slim.json')
109
- TRITON_IMAGE_CATALOG = loadCatalog('./catalogs/triton.json')
107
+ TRANSFORMER_IMAGE_CATALOG = loadCatalog('../lib/catalogs/model-servers.json')
108
+ PYTHON_SLIM_CATALOG = loadCatalog('../lib/catalogs/python-slim.json')
109
+ TRITON_IMAGE_CATALOG = loadCatalog('../lib/catalogs/triton.json')
110
110
  } catch (err) {
111
111
  process.stderr.write(`[base-image-picker] Fatal: ${err.message}\n`)
112
112
  process.exit(1)
@@ -8,8 +8,10 @@
8
8
  "discover": true
9
9
  },
10
10
  "catalogs": {
11
- "model-servers": "./catalogs/model-servers.json",
12
- "python-slim": "./catalogs/python-slim.json"
11
+ "model-servers": "../lib/catalogs/model-servers.json",
12
+ "python-slim": "../lib/catalogs/python-slim.json",
13
+ "triton": "../lib/catalogs/triton.json",
14
+ "triton-backends": "../lib/catalogs/triton-backends.json"
13
15
  },
14
16
  "tool": {
15
17
  "name": "get_base_images"
@@ -0,0 +1,561 @@
1
+ #!/usr/bin/env node
2
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ /**
6
+ * Instance Sizer MCP Server
7
+ *
8
+ * A bundled MCP server that estimates VRAM requirements from model metadata
9
+ * and returns a filtered, ranked list of compatible SageMaker instances.
10
+ *
11
+ * Supports three modes:
12
+ * - Static (default): Uses pre-built model-sizes catalog for popular models
13
+ * - Smart (BEDROCK_SMART=true): Queries Bedrock for edge-case reasoning
14
+ * - Discover (--discover flag): Fetches model config.json from HuggingFace Hub
15
+ *
16
+ * Tool: get_instance_recommendation
17
+ * Accepts: { modelName, quantization?, maxSequenceLength?, batchSize?, limit?, context? }
18
+ * Returns: { values, choices, metadata }
19
+ */
20
+
21
+ import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'
22
+ import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
23
+ import { z } from 'zod'
24
+ import { readFileSync } from 'node:fs'
25
+ import { fileURLToPath } from 'node:url'
26
+ import { resolve, dirname } from 'node:path'
27
+ import { resolveModelMetadata } from './lib/model-resolver.js'
28
+ import { estimateVram } from './lib/vram-estimator.js'
29
+ import { filterAndRankInstances } from './lib/instance-ranker.js'
30
+ import { queryBedrock } from '../lib/bedrock-client.js'
31
+
32
+ // ── Path setup ───────────────────────────────────────────────────────────────
33
+
34
+ const __filename = fileURLToPath(import.meta.url)
35
+ const __dirname = dirname(__filename)
36
+
37
+ // ── Load instance catalog from shared lib ────────────────────────────────────
38
+
39
+ let INSTANCE_CATALOG
40
+
41
+ try {
42
+ const catalogPath = resolve(__dirname, '../lib/catalogs/instances.json')
43
+ const raw = readFileSync(catalogPath, 'utf8')
44
+ const data = JSON.parse(raw)
45
+ INSTANCE_CATALOG = data.catalog
46
+ } catch (err) {
47
+ process.stderr.write(`[instance-sizer] Fatal: Failed to load instance catalog: ${err.message}\n`)
48
+ process.exit(1)
49
+ }
50
+
51
+ // ── Mode configuration ───────────────────────────────────────────────────────
52
+
53
+ const DISCOVER_MODE = process.argv.includes('--discover') || process.env.DISCOVER_MODE === 'true'
54
+ const SMART_MODE = process.env.BEDROCK_SMART === 'true'
55
+ const BEDROCK_MODEL = process.env.BEDROCK_MODEL || 'global.anthropic.claude-sonnet-4-20250514-v1:0'
56
+ const BEDROCK_REGION = process.env.BEDROCK_REGION || process.env.AWS_REGION || 'us-east-1'
57
+
58
+ // ── Bedrock server config ─────────────────────────────────────────────────────
59
+
60
+ /**
61
+ * Per-server configuration passed to the shared Bedrock client.
62
+ * The system prompt provides model context and asks Bedrock to validate
63
+ * or adjust the static recommendation for edge cases.
64
+ */
65
+ const SERVER_CONFIG = {
66
+ serverName: 'instance-sizer',
67
+ systemPromptTemplate: `You are an AWS SageMaker instance sizing advisor specializing in GPU memory estimation for ML model deployment.
68
+
69
+ Given the following model metadata and VRAM estimate, validate or adjust the instance recommendation for edge cases (unusual architectures, custom quantization, multi-modal models, etc.).
70
+
71
+ Model context:
72
+ {context}
73
+
74
+ Requested parameters: {parameters}
75
+ Maximum recommendations: {limit}
76
+
77
+ Respond with ONLY a JSON object in this exact format, no other text:
78
+ {
79
+ "values": {
80
+ "instanceType": "the single best instance type as a string"
81
+ },
82
+ "reasoning": "brief explanation of why this instance is recommended or why the static recommendation was adjusted"
83
+ }
84
+
85
+ Rules:
86
+ - Only recommend real SageMaker instance types (ml.* prefix)
87
+ - Consider the VRAM estimate and breakdown provided
88
+ - If the static recommendation looks correct, return the same instance type
89
+ - If you detect an edge case (e.g., model needs more headroom for KV cache, unusual architecture overhead), adjust accordingly
90
+ - Prefer single-GPU instances when the model fits
91
+ - Consider tensor parallelism for models that exceed single-GPU capacity
92
+ - Return valid JSON only`,
93
+ temperature: 0.3,
94
+ maxTokens: 1024,
95
+ modelId: BEDROCK_MODEL,
96
+ region: BEDROCK_REGION
97
+ }
98
+
99
+ // ── Logging ──────────────────────────────────────────────────────────────────
100
+
101
+ /**
102
+ * Log to stderr so it doesn't interfere with MCP stdio protocol on stdout.
103
+ */
104
+ function log(message) {
105
+ process.stderr.write(`[instance-sizer] ${message}\n`)
106
+ }
107
+
108
+ // ── Tag-based search filtering ───────────────────────────────────────────────
109
+
110
+ /**
111
+ * Search instances by tag/keyword matching.
112
+ * Ported from instance-recommender's getStaticInstances logic.
113
+ *
114
+ * @param {string} search - Search query string
115
+ * @param {object} instanceCatalog - Instance catalog object
116
+ * @param {object} [options={}]
117
+ * @param {number} [options.limit=8] - Max results
118
+ * @returns {string[]} Matching instance type names, sorted by relevance
119
+ */
120
+ function searchInstancesByTag(search, instanceCatalog, options = {}) {
121
+ const { limit = 8 } = options
122
+ const candidates = Object.entries(instanceCatalog)
123
+
124
+ // Tokenize search into lowercase keywords
125
+ const tokens = search.toLowerCase().split(/[\s,\-_]+/).filter(Boolean)
126
+
127
+ // Detect compound terms
128
+ const rawLower = search.toLowerCase()
129
+ const wantsMultiGpu = rawLower.includes('multi gpu') || rawLower.includes('multi-gpu') || rawLower.includes('multigpu')
130
+
131
+ // Detect CUDA version requests: "cuda 12", "cuda 11.8", "cuda-12.1"
132
+ const cudaMatch = rawLower.match(/cuda[\s\-_]*(\d+(?:\.\d+)?)/)
133
+ const wantsCudaVersion = cudaMatch ? cudaMatch[1] : null
134
+
135
+ // Score each instance
136
+ const scored = candidates.map(([name, meta]) => {
137
+ let score = 0
138
+ const cudaStr = meta.cudaVersions ? meta.cudaVersions.join(' ') : ''
139
+ const haystack = [...(meta.tags || []), (meta.accelerator || '').toLowerCase(), name, meta.category || '', cudaStr].join(' ')
140
+
141
+ // Compound term: multi-gpu
142
+ if (wantsMultiGpu) {
143
+ if (meta.gpus > 1) {
144
+ score += 5
145
+ } else {
146
+ return { name, meta, score: 0 }
147
+ }
148
+ }
149
+
150
+ // Compound term: cuda version
151
+ if (wantsCudaVersion) {
152
+ if (!meta.cudaVersions) return { name, meta, score: 0 }
153
+ const hasExact = meta.cudaVersions.includes(wantsCudaVersion)
154
+ const hasMajor = meta.cudaVersions.some(v => v.startsWith(wantsCudaVersion))
155
+ if (hasExact) {
156
+ score += 4
157
+ } else if (hasMajor) {
158
+ score += 3
159
+ } else {
160
+ return { name, meta, score: 0 }
161
+ }
162
+ }
163
+
164
+ for (const token of tokens) {
165
+ if (wantsMultiGpu && (token === 'multi' || token === 'gpu')) continue
166
+ if (wantsCudaVersion && (token === 'cuda' || token === wantsCudaVersion)) continue
167
+
168
+ if (haystack.includes(token)) score += 1
169
+ if (meta.gpus > 1 && token === 'parallel') score += 2
170
+ if (token === 'gpu' && meta.gpus > 0) score += 1
171
+ if (token === 'cpu' && meta.gpus === 0) score += 1
172
+ if (token === 'cheap' || token === 'budget' || token === 'cost') {
173
+ if ((meta.tags || []).includes('budget') || (meta.tags || []).includes('cost-effective')) score += 1
174
+ }
175
+ if (token === 'memory' || token === 'high-memory') {
176
+ if (meta.memGb >= 32) score += 1
177
+ }
178
+ if (token === 'large' && meta.vcpus >= 16) score += 1
179
+ if (meta.cudaVersions && meta.cudaVersions.includes(token)) score += 2
180
+ }
181
+ return { name, meta, score }
182
+ })
183
+
184
+ const matched = scored.filter(s => s.score > 0).sort((a, b) => b.score - a.score)
185
+
186
+ if (matched.length === 0) {
187
+ return []
188
+ }
189
+
190
+ return matched.slice(0, limit).map(s => s.name)
191
+ }
192
+
193
+ // ── CUDA version filtering ───────────────────────────────────────────────────
194
+
195
+ /**
196
+ * Filter instances to only those supporting a required CUDA version.
197
+ *
198
+ * @param {object} instanceCatalog - Instance catalog object
199
+ * @param {string} requiredCuda - Required CUDA version (e.g., "12.1")
200
+ * @returns {object} Filtered instance catalog
201
+ */
202
+ function filterByCudaVersion(instanceCatalog, requiredCuda) {
203
+ const majorRequired = requiredCuda.split('.')[0]
204
+ const filtered = {}
205
+
206
+ for (const [name, meta] of Object.entries(instanceCatalog)) {
207
+ if (!meta.cudaVersions || meta.cudaVersions.length === 0) continue
208
+ const hasCompatible = meta.cudaVersions.some(v => {
209
+ if (v === requiredCuda) return true
210
+ if (v.startsWith(majorRequired + '.')) return true
211
+ return false
212
+ })
213
+ if (hasCompatible) {
214
+ filtered[name] = meta
215
+ }
216
+ }
217
+
218
+ return filtered
219
+ }
220
+
221
+ // ── Tool handler ─────────────────────────────────────────────────────────────
222
+
223
+ /**
224
+ * Handle the get_instance_recommendation tool invocation.
225
+ *
226
+ * Pipeline: resolveModelMetadata → estimateVram → filterAndRankInstances
227
+ *
228
+ * @param {object} params - Tool input parameters
229
+ * @returns {object} MCP tool response
230
+ */
231
+ async function handleGetInstanceRecommendation(params) {
232
+ const {
233
+ modelName,
234
+ instanceSearch,
235
+ quantization,
236
+ maxSequenceLength,
237
+ batchSize,
238
+ cudaVersion,
239
+ limit = 8,
240
+ context
241
+ } = params
242
+
243
+ // Apply profile ENV overrides to sequence length and batch size
244
+ let effectiveMaxSeqLen = maxSequenceLength
245
+ let effectiveBatchSize = batchSize
246
+ if (context?.profileEnvVars) {
247
+ if (context.profileEnvVars.VLLM_MAX_MODEL_LEN) {
248
+ effectiveMaxSeqLen = parseInt(context.profileEnvVars.VLLM_MAX_MODEL_LEN, 10) || effectiveMaxSeqLen
249
+ }
250
+ if (context.profileEnvVars.VLLM_MAX_NUM_SEQS) {
251
+ effectiveBatchSize = parseInt(context.profileEnvVars.VLLM_MAX_NUM_SEQS, 10) || effectiveBatchSize
252
+ }
253
+ }
254
+
255
+ // Apply CUDA version filtering to instance catalog
256
+ let effectiveCatalog = INSTANCE_CATALOG
257
+ if (cudaVersion) {
258
+ effectiveCatalog = filterByCudaVersion(INSTANCE_CATALOG, cudaVersion)
259
+ if (Object.keys(effectiveCatalog).length === 0) {
260
+ log(`CUDA version ${cudaVersion} filter eliminated all instances`)
261
+ return {
262
+ content: [{
263
+ type: 'text',
264
+ text: JSON.stringify({
265
+ values: { instanceType: null },
266
+ choices: { instanceType: [] },
267
+ metadata: {
268
+ modelName: modelName || null,
269
+ warning: `No instances support CUDA version ${cudaVersion}. Check base image compatibility.`,
270
+ cudaVersionFilter: cudaVersion
271
+ }
272
+ })
273
+ }]
274
+ }
275
+ }
276
+ }
277
+
278
+ // Mode: tag-based search only (no model name)
279
+ if (!modelName && instanceSearch) {
280
+ const searchResults = searchInstancesByTag(instanceSearch, effectiveCatalog, { limit })
281
+ return {
282
+ content: [{
283
+ type: 'text',
284
+ text: JSON.stringify({
285
+ values: { instanceType: searchResults[0] || null },
286
+ choices: { instanceType: searchResults },
287
+ metadata: {
288
+ instanceSearch,
289
+ source: 'tag-search',
290
+ cudaVersionFilter: cudaVersion || null,
291
+ resultCount: searchResults.length
292
+ }
293
+ })
294
+ }]
295
+ }
296
+ }
297
+
298
+ // Mode: no model name and no search — return all GPU instances
299
+ if (!modelName) {
300
+ const allGpuInstances = Object.keys(effectiveCatalog)
301
+ .filter(key => effectiveCatalog[key].category === 'gpu')
302
+ .slice(0, limit)
303
+
304
+ return {
305
+ content: [{
306
+ type: 'text',
307
+ text: JSON.stringify({
308
+ values: { instanceType: allGpuInstances[0] || null },
309
+ choices: { instanceType: allGpuInstances },
310
+ metadata: {
311
+ modelName: null,
312
+ source: 'unfiltered',
313
+ cudaVersionFilter: cudaVersion || null,
314
+ warning: 'No model name provided. Returning GPU instances without VRAM filtering.'
315
+ }
316
+ })
317
+ }]
318
+ }
319
+ }
320
+
321
+ // Step 1: Resolve model metadata
322
+ const modelMetadata = await resolveModelMetadata(modelName, {
323
+ discover: DISCOVER_MODE
324
+ })
325
+
326
+ // If model metadata cannot be resolved, return all GPU instances unfiltered
327
+ if (!modelMetadata) {
328
+ log(`Model metadata not found for "${modelName}", returning unfiltered GPU instances`)
329
+ const allGpuInstances = Object.keys(effectiveCatalog)
330
+ .filter(key => effectiveCatalog[key].category === 'gpu')
331
+ .slice(0, limit)
332
+
333
+ return {
334
+ content: [{
335
+ type: 'text',
336
+ text: JSON.stringify({
337
+ values: { instanceType: allGpuInstances[0] || null },
338
+ choices: { instanceType: allGpuInstances },
339
+ metadata: {
340
+ modelName,
341
+ parameterCount: null,
342
+ dtype: null,
343
+ quantization: quantization || null,
344
+ estimatedVramGb: null,
345
+ vramBreakdown: null,
346
+ recommendations: allGpuInstances.map(instanceType => ({
347
+ instanceType,
348
+ gpuCount: effectiveCatalog[instanceType]?.gpus || 0,
349
+ totalVramGb: null,
350
+ utilizationPercent: null,
351
+ tensorParallelism: null,
352
+ costTier: null
353
+ })),
354
+ source: 'unfiltered',
355
+ cudaVersionFilter: cudaVersion || null,
356
+ warning: `Could not resolve model metadata for "${modelName}". Returning all GPU instances without filtering.`
357
+ }
358
+ })
359
+ }]
360
+ }
361
+ }
362
+
363
+ // Step 2: Estimate VRAM
364
+ const vramEstimate = estimateVram({
365
+ parameterCount: modelMetadata.parameterCount,
366
+ dtype: modelMetadata.dtype,
367
+ quantization: quantization || undefined,
368
+ maxSequenceLength: effectiveMaxSeqLen || undefined,
369
+ batchSize: effectiveBatchSize || undefined
370
+ })
371
+
372
+ // Step 3: Filter and rank instances
373
+ let recommendations = filterAndRankInstances(
374
+ vramEstimate.vramGb,
375
+ effectiveCatalog,
376
+ { limit }
377
+ )
378
+
379
+ // Step 3b: If instanceSearch is also provided, further filter by tags
380
+ if (instanceSearch && recommendations.length > 0) {
381
+ const searchMatches = new Set(searchInstancesByTag(instanceSearch, effectiveCatalog, { limit: 100 }))
382
+ recommendations = recommendations.filter(r => searchMatches.has(r.instanceType))
383
+ }
384
+
385
+ // Step 4: Smart mode — query Bedrock for edge-case reasoning
386
+ let finalRecommendations = recommendations
387
+ let smartModeUsed = false
388
+
389
+ if (SMART_MODE && recommendations.length > 0) {
390
+ log('[smart] Smart mode enabled, querying Amazon Bedrock...')
391
+
392
+ const bedrockContext = {
393
+ modelName,
394
+ parameterCount: modelMetadata.parameterCount,
395
+ dtype: modelMetadata.dtype,
396
+ quantization: quantization || null,
397
+ estimatedVramGb: vramEstimate.vramGb,
398
+ vramBreakdown: vramEstimate.breakdown,
399
+ staticRecommendations: recommendations.slice(0, 3).map(r => ({
400
+ instanceType: r.instanceType,
401
+ gpuCount: r.gpuCount,
402
+ totalVramGb: r.totalVramGb,
403
+ utilizationPercent: r.utilizationPercent,
404
+ tensorParallelism: r.tensorParallelism
405
+ })),
406
+ ...(context || {})
407
+ }
408
+
409
+ const bedrockResult = await queryBedrock(
410
+ SERVER_CONFIG,
411
+ ['instanceType'],
412
+ limit,
413
+ bedrockContext
414
+ )
415
+
416
+ if (bedrockResult?.values?.instanceType) {
417
+ const bedrockInstance = bedrockResult.values.instanceType
418
+ log(`[smart] Bedrock recommendation: ${bedrockInstance}`)
419
+
420
+ // Check if Bedrock's suggestion is already in our list
421
+ const existingIndex = finalRecommendations.findIndex(
422
+ r => r.instanceType === bedrockInstance
423
+ )
424
+
425
+ if (existingIndex > 0) {
426
+ // Move Bedrock's pick to the top
427
+ const [picked] = finalRecommendations.splice(existingIndex, 1)
428
+ finalRecommendations = [picked, ...finalRecommendations]
429
+ smartModeUsed = true
430
+ } else if (existingIndex === 0) {
431
+ // Already at the top — Bedrock agrees with static
432
+ smartModeUsed = true
433
+ log('[smart] Bedrock agrees with static top recommendation')
434
+ } else {
435
+ // Bedrock suggested an instance not in our filtered list;
436
+ // verify it exists in the catalog before prepending
437
+ if (INSTANCE_CATALOG[bedrockInstance]) {
438
+ const catalogEntry = INSTANCE_CATALOG[bedrockInstance]
439
+ const bedrockRec = {
440
+ instanceType: bedrockInstance,
441
+ gpuCount: catalogEntry.gpus || 0,
442
+ totalVramGb: (catalogEntry.gpuMemoryGb || 0) * (catalogEntry.gpus || 1),
443
+ utilizationPercent: null,
444
+ tensorParallelism: catalogEntry.gpus || 1,
445
+ costTier: catalogEntry.costTier || null
446
+ }
447
+ finalRecommendations = [bedrockRec, ...finalRecommendations].slice(0, limit)
448
+ smartModeUsed = true
449
+ } else {
450
+ log(`[smart] Bedrock suggested unknown instance "${bedrockInstance}", ignoring`)
451
+ }
452
+ }
453
+ } else {
454
+ log('[smart] Bedrock did not return usable results, falling back to static recommendations')
455
+ }
456
+ }
457
+
458
+ // Build response
459
+ const topRecommendation = finalRecommendations.length > 0
460
+ ? finalRecommendations[0].instanceType
461
+ : null
462
+
463
+ const rankedList = finalRecommendations.map(r => r.instanceType)
464
+
465
+ return {
466
+ content: [{
467
+ type: 'text',
468
+ text: JSON.stringify({
469
+ values: { instanceType: topRecommendation },
470
+ choices: { instanceType: rankedList },
471
+ metadata: {
472
+ modelName,
473
+ parameterCount: modelMetadata.parameterCount,
474
+ dtype: modelMetadata.dtype,
475
+ quantization: quantization || null,
476
+ estimatedVramGb: vramEstimate.vramGb,
477
+ vramBreakdown: vramEstimate.breakdown,
478
+ recommendations: finalRecommendations,
479
+ source: modelMetadata.source,
480
+ smartModeUsed
481
+ }
482
+ })
483
+ }]
484
+ }
485
+ }
486
+
487
+ // ── MCP Server setup ─────────────────────────────────────────────────────────
488
+
489
+ const server = new McpServer({
490
+ name: 'instance-sizer',
491
+ version: '1.0.0'
492
+ })
493
+
494
+ // Register the get_instance_recommendation tool
495
+ server.tool(
496
+ 'get_instance_recommendation',
497
+ 'Estimates VRAM requirements from model metadata and returns filtered, ranked SageMaker instance recommendations. Supports VRAM-based sizing, tag-based search, or both combined.',
498
+ {
499
+ modelName: z.string().optional().describe('HuggingFace model ID or catalog key'),
500
+ instanceSearch: z.string().optional().describe('Tag/keyword search for instances (e.g., "multi-gpu", "cost-effective cpu")'),
501
+ quantization: z.string().optional().describe('Quantization method: awq, gptq, bnb-4bit, bnb-8bit'),
502
+ maxSequenceLength: z.number().optional().describe('Max context/sequence length (affects KV cache estimate)'),
503
+ batchSize: z.number().optional().describe('Expected concurrent batch size'),
504
+ cudaVersion: z.string().optional().describe('Required CUDA version from base image (filters incompatible instances)'),
505
+ limit: z.number().optional().default(8).describe('Maximum number of instance recommendations to return'),
506
+ context: z.object({
507
+ architecture: z.string().optional(),
508
+ backend: z.string().optional(),
509
+ deploymentTarget: z.string().optional(),
510
+ profileEnvVars: z.record(z.string()).optional().describe('Serving profile ENV overrides (e.g., VLLM_MAX_MODEL_LEN)')
511
+ }).optional().describe('Additional deployment context')
512
+ },
513
+ async (params) => {
514
+ return handleGetInstanceRecommendation(params)
515
+ }
516
+ )
517
+
518
+ // Register alias tool name for backward compatibility
519
+ server.tool(
520
+ 'get_instance_types',
521
+ 'Alias for get_instance_recommendation — recommends SageMaker instances via VRAM sizing and/or tag-based search',
522
+ {
523
+ modelName: z.string().optional().describe('HuggingFace model ID or catalog key'),
524
+ instanceSearch: z.string().optional().describe('Tag/keyword search for instances (e.g., "multi-gpu", "cost-effective cpu")'),
525
+ quantization: z.string().optional().describe('Quantization method: awq, gptq, bnb-4bit, bnb-8bit'),
526
+ maxSequenceLength: z.number().optional().describe('Max context/sequence length (affects KV cache estimate)'),
527
+ batchSize: z.number().optional().describe('Expected concurrent batch size'),
528
+ cudaVersion: z.string().optional().describe('Required CUDA version from base image (filters incompatible instances)'),
529
+ limit: z.number().optional().default(8).describe('Maximum number of instance recommendations to return'),
530
+ context: z.object({
531
+ architecture: z.string().optional(),
532
+ backend: z.string().optional(),
533
+ deploymentTarget: z.string().optional(),
534
+ profileEnvVars: z.record(z.string()).optional().describe('Serving profile ENV overrides (e.g., VLLM_MAX_MODEL_LEN)')
535
+ }).optional().describe('Additional deployment context')
536
+ },
537
+ async (params) => {
538
+ return handleGetInstanceRecommendation(params)
539
+ }
540
+ )
541
+
542
+ // ── Exports for testing ──────────────────────────────────────────────────────
543
+
544
+ export { handleGetInstanceRecommendation, INSTANCE_CATALOG, SERVER_CONFIG, server, searchInstancesByTag, filterByCudaVersion }
545
+
546
+ // ── Transport connection (main module only) ──────────────────────────────────
547
+
548
+ const isMain = process.argv[1] && resolve(process.argv[1]) === __filename
549
+
550
+ if (isMain) {
551
+ if (SMART_MODE) {
552
+ log(`Smart mode enabled (model: ${BEDROCK_MODEL}, region: ${BEDROCK_REGION})`)
553
+ } else if (DISCOVER_MODE) {
554
+ log('Discover mode enabled (HuggingFace API lookups active)')
555
+ } else {
556
+ log('Static mode (catalog-only, no network calls)')
557
+ }
558
+
559
+ const transport = new StdioServerTransport()
560
+ await server.connect(transport)
561
+ }