@aws/ml-container-creator 0.2.5 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. package/bin/cli.js +45 -4
  2. package/config/bootstrap-stack.json +14 -0
  3. package/infra/ci-harness/package-lock.json +22 -9
  4. package/package.json +7 -8
  5. package/servers/base-image-picker/index.js +3 -3
  6. package/servers/base-image-picker/manifest.json +4 -2
  7. package/servers/instance-sizer/index.js +564 -0
  8. package/servers/instance-sizer/lib/instance-ranker.js +270 -0
  9. package/servers/instance-sizer/lib/model-resolver.js +269 -0
  10. package/servers/instance-sizer/lib/vram-estimator.js +177 -0
  11. package/servers/instance-sizer/manifest.json +17 -0
  12. package/servers/instance-sizer/package.json +15 -0
  13. package/servers/{instance-recommender → lib}/catalogs/instances.json +136 -34
  14. package/servers/{base-image-picker → lib}/catalogs/model-servers.json +302 -254
  15. package/servers/lib/catalogs/model-sizes.json +131 -0
  16. package/servers/lib/catalogs/models.json +632 -0
  17. package/servers/{model-picker → lib}/catalogs/popular-diffusors.json +32 -10
  18. package/servers/{model-picker → lib}/catalogs/popular-transformers.json +59 -26
  19. package/servers/{base-image-picker → lib}/catalogs/python-slim.json +12 -12
  20. package/servers/lib/schemas/image-catalog.schema.json +6 -12
  21. package/servers/lib/schemas/instances.schema.json +29 -0
  22. package/servers/lib/schemas/model-catalog.schema.json +12 -10
  23. package/servers/lib/schemas/unified-model-catalog.schema.json +129 -0
  24. package/servers/model-picker/index.js +4 -4
  25. package/servers/model-picker/manifest.json +2 -3
  26. package/servers/region-picker/index.js +1 -1
  27. package/servers/region-picker/manifest.json +1 -1
  28. package/src/app.js +36 -0
  29. package/src/lib/architecture-sync.js +171 -0
  30. package/src/lib/arn-detection.js +22 -0
  31. package/src/lib/bootstrap-command-handler.js +120 -0
  32. package/src/lib/cli-handler.js +3 -3
  33. package/src/lib/config-manager.js +47 -1
  34. package/src/lib/configuration-manager.js +2 -2
  35. package/src/lib/cross-cutting-checker.js +460 -0
  36. package/src/lib/deployment-entry-schema.js +1 -2
  37. package/src/lib/dry-run-validator.js +78 -0
  38. package/src/lib/generation-validator.js +102 -0
  39. package/src/lib/mcp-validator-config.js +89 -0
  40. package/src/lib/payload-builder.js +153 -0
  41. package/src/lib/prompt-runner.js +866 -149
  42. package/src/lib/prompts.js +2 -2
  43. package/src/lib/registry-command-handler.js +236 -0
  44. package/src/lib/registry-loader.js +5 -5
  45. package/src/lib/schema-sync.js +203 -0
  46. package/src/lib/schema-validation-engine.js +195 -0
  47. package/src/lib/secret-classification.js +56 -0
  48. package/src/lib/secrets-command-handler.js +550 -0
  49. package/src/lib/service-model-parser.js +102 -0
  50. package/src/lib/validate-runner.js +216 -0
  51. package/src/lib/validation-report.js +140 -0
  52. package/src/lib/validators/base-validator.js +36 -0
  53. package/src/lib/validators/catalog-validator.js +177 -0
  54. package/src/lib/validators/enum-validator.js +120 -0
  55. package/src/lib/validators/required-field-validator.js +150 -0
  56. package/src/lib/validators/type-validator.js +313 -0
  57. package/src/prompt-adapter.js +3 -2
  58. package/templates/Dockerfile +1 -1
  59. package/templates/do/build +37 -5
  60. package/templates/do/config +15 -3
  61. package/templates/do/deploy +60 -5
  62. package/templates/do/logs +18 -3
  63. package/templates/do/run +15 -1
  64. package/templates/do/validate +61 -0
  65. package/servers/instance-recommender/LICENSE +0 -202
  66. package/servers/instance-recommender/index.js +0 -284
  67. package/servers/instance-recommender/manifest.json +0 -16
  68. package/servers/instance-recommender/package.json +0 -15
  69. /package/servers/{model-picker → lib}/catalogs/jumpstart-public.json +0 -0
  70. /package/servers/{region-picker → lib}/catalogs/regions.json +0 -0
  71. /package/servers/{base-image-picker → lib}/catalogs/triton-backends.json +0 -0
  72. /package/servers/{base-image-picker → lib}/catalogs/triton.json +0 -0
@@ -0,0 +1,564 @@
1
+ #!/usr/bin/env node
2
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ /**
6
+ * Instance Sizer MCP Server
7
+ *
8
+ * A bundled MCP server that estimates VRAM requirements from model metadata
9
+ * and returns a filtered, ranked list of compatible SageMaker instances.
10
+ *
11
+ * Supports three modes:
12
+ * - Static (default): Uses pre-built model-sizes catalog for popular models
13
+ * - Smart (BEDROCK_SMART=true): Queries Bedrock for edge-case reasoning
14
+ * - Discover (--discover flag): Fetches model config.json from HuggingFace Hub
15
+ *
16
+ * Tool: get_instance_recommendation
17
+ * Accepts: { modelName, quantization?, maxSequenceLength?, batchSize?, limit?, context? }
18
+ * Returns: { values, choices, metadata }
19
+ */
20
+
21
+ import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'
22
+ import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
23
+ import { z } from 'zod'
24
+ import { readFileSync } from 'node:fs'
25
+ import { fileURLToPath } from 'node:url'
26
+ import { resolve, dirname } from 'node:path'
27
+ import { resolveModelMetadata } from './lib/model-resolver.js'
28
+ import { estimateVram } from './lib/vram-estimator.js'
29
+ import { filterAndRankInstances } from './lib/instance-ranker.js'
30
+ import { queryBedrock } from '../lib/bedrock-client.js'
31
+
32
+ // ── Path setup ───────────────────────────────────────────────────────────────
33
+
34
+ const __filename = fileURLToPath(import.meta.url)
35
+ const __dirname = dirname(__filename)
36
+
37
+ // ── Load instance catalog from shared lib ────────────────────────────────────
38
+
39
+ let INSTANCE_CATALOG
40
+
41
+ try {
42
+ const catalogPath = resolve(__dirname, '../lib/catalogs/instances.json')
43
+ const raw = readFileSync(catalogPath, 'utf8')
44
+ const data = JSON.parse(raw)
45
+ INSTANCE_CATALOG = data.catalog
46
+ } catch (err) {
47
+ process.stderr.write(`[instance-sizer] Fatal: Failed to load instance catalog: ${err.message}\n`)
48
+ process.exit(1)
49
+ }
50
+
51
+ // ── Mode configuration ───────────────────────────────────────────────────────
52
+
53
+ const DISCOVER_MODE = process.argv.includes('--discover') || process.env.DISCOVER_MODE === 'true'
54
+ const SMART_MODE = process.env.BEDROCK_SMART === 'true'
55
+ const BEDROCK_MODEL = process.env.BEDROCK_MODEL || 'global.anthropic.claude-sonnet-4-20250514-v1:0'
56
+ const BEDROCK_REGION = process.env.BEDROCK_REGION || process.env.AWS_REGION || 'us-east-1'
57
+
58
+ // ── Bedrock server config ─────────────────────────────────────────────────────
59
+
60
+ /**
61
+ * Per-server configuration passed to the shared Bedrock client.
62
+ * The system prompt provides model context and asks Bedrock to validate
63
+ * or adjust the static recommendation for edge cases.
64
+ */
65
+ const SERVER_CONFIG = {
66
+ serverName: 'instance-sizer',
67
+ systemPromptTemplate: `You are an AWS SageMaker instance sizing advisor specializing in GPU memory estimation for ML model deployment.
68
+
69
+ Given the following model metadata and VRAM estimate, validate or adjust the instance recommendation for edge cases (unusual architectures, custom quantization, multi-modal models, etc.).
70
+
71
+ Model context:
72
+ {context}
73
+
74
+ Requested parameters: {parameters}
75
+ Maximum recommendations: {limit}
76
+
77
+ Respond with ONLY a JSON object in this exact format, no other text:
78
+ {
79
+ "values": {
80
+ "instanceType": "the single best instance type as a string"
81
+ },
82
+ "reasoning": "brief explanation of why this instance is recommended or why the static recommendation was adjusted"
83
+ }
84
+
85
+ Rules:
86
+ - Only recommend real SageMaker instance types (ml.* prefix)
87
+ - Consider the VRAM estimate and breakdown provided
88
+ - If the static recommendation looks correct, return the same instance type
89
+ - If you detect an edge case (e.g., model needs more headroom for KV cache, unusual architecture overhead), adjust accordingly
90
+ - Prefer single-GPU instances when the model fits
91
+ - Consider tensor parallelism for models that exceed single-GPU capacity
92
+ - Return valid JSON only`,
93
+ temperature: 0.3,
94
+ maxTokens: 1024,
95
+ modelId: BEDROCK_MODEL,
96
+ region: BEDROCK_REGION
97
+ }
98
+
99
+ // ── Logging ──────────────────────────────────────────────────────────────────
100
+
101
+ /**
102
+ * Log to stderr so it doesn't interfere with MCP stdio protocol on stdout.
103
+ */
104
+ function log(message) {
105
+ process.stderr.write(`[instance-sizer] ${message}\n`)
106
+ }
107
+
108
+ // ── Tag-based search filtering ───────────────────────────────────────────────
109
+
110
+ /**
111
+ * Search instances by tag/keyword matching.
112
+ * Ported from instance-recommender's getStaticInstances logic.
113
+ *
114
+ * @param {string} search - Search query string
115
+ * @param {object} instanceCatalog - Instance catalog object
116
+ * @param {object} [options={}]
117
+ * @param {number} [options.limit=10] - Max results
118
+ * @returns {string[]} Matching instance type names, sorted by relevance
119
+ */
120
+ function searchInstancesByTag(search, instanceCatalog, options = {}) {
121
+ const { limit = 10 } = options
122
+ const candidates = Object.entries(instanceCatalog)
123
+
124
+ // Tokenize search into lowercase keywords
125
+ const tokens = search.toLowerCase().split(/[\s,\-_]+/).filter(Boolean)
126
+
127
+ // Detect compound terms
128
+ const rawLower = search.toLowerCase()
129
+ const wantsMultiGpu = rawLower.includes('multi gpu') || rawLower.includes('multi-gpu') || rawLower.includes('multigpu')
130
+
131
+ // Detect CUDA version requests: "cuda 12", "cuda 11.8", "cuda-12.1"
132
+ const cudaMatch = rawLower.match(/cuda[\s\-_]*(\d+(?:\.\d+)?)/)
133
+ const wantsCudaVersion = cudaMatch ? cudaMatch[1] : null
134
+
135
+ // Score each instance
136
+ const scored = candidates.map(([name, meta]) => {
137
+ let score = 0
138
+ const cudaStr = meta.cudaVersions ? meta.cudaVersions.join(' ') : ''
139
+ const haystack = [...(meta.tags || []), (meta.accelerator || '').toLowerCase(), name, meta.category || '', cudaStr].join(' ')
140
+
141
+ // Compound term: multi-gpu
142
+ if (wantsMultiGpu) {
143
+ if (meta.gpus > 1) {
144
+ score += 5
145
+ } else {
146
+ return { name, meta, score: 0 }
147
+ }
148
+ }
149
+
150
+ // Compound term: cuda version
151
+ if (wantsCudaVersion) {
152
+ if (!meta.cudaVersions) return { name, meta, score: 0 }
153
+ const hasExact = meta.cudaVersions.includes(wantsCudaVersion)
154
+ const hasMajor = meta.cudaVersions.some(v => v.startsWith(wantsCudaVersion))
155
+ if (hasExact) {
156
+ score += 4
157
+ } else if (hasMajor) {
158
+ score += 3
159
+ } else {
160
+ return { name, meta, score: 0 }
161
+ }
162
+ }
163
+
164
+ for (const token of tokens) {
165
+ if (wantsMultiGpu && (token === 'multi' || token === 'gpu')) continue
166
+ if (wantsCudaVersion && (token === 'cuda' || token === wantsCudaVersion)) continue
167
+
168
+ if (haystack.includes(token)) score += 1
169
+ if (meta.gpus > 1 && token === 'parallel') score += 2
170
+ if (token === 'gpu' && meta.gpus > 0) score += 1
171
+ if (token === 'cpu' && meta.gpus === 0) score += 1
172
+ if (token === 'cheap' || token === 'budget' || token === 'cost') {
173
+ if ((meta.tags || []).includes('budget') || (meta.tags || []).includes('cost-effective')) score += 1
174
+ }
175
+ if (token === 'memory' || token === 'high-memory') {
176
+ if (meta.memGb >= 32) score += 1
177
+ }
178
+ if (token === 'large' && meta.vcpus >= 16) score += 1
179
+ if (meta.cudaVersions && meta.cudaVersions.includes(token)) score += 2
180
+ }
181
+ return { name, meta, score }
182
+ })
183
+
184
+ const matched = scored.filter(s => s.score > 0).sort((a, b) => b.score - a.score)
185
+
186
+ if (matched.length === 0) {
187
+ return []
188
+ }
189
+
190
+ return matched.slice(0, limit).map(s => s.name)
191
+ }
192
+
193
+ // ── CUDA version filtering ───────────────────────────────────────────────────
194
+
195
+ /**
196
+ * Filter instances to only those supporting a required CUDA version.
197
+ *
198
+ * @param {object} instanceCatalog - Instance catalog object
199
+ * @param {string} requiredCuda - Required CUDA version (e.g., "12.1")
200
+ * @returns {object} Filtered instance catalog
201
+ */
202
+ function filterByCudaVersion(instanceCatalog, requiredCuda) {
203
+ const majorRequired = requiredCuda.split('.')[0]
204
+ const filtered = {}
205
+
206
+ for (const [name, meta] of Object.entries(instanceCatalog)) {
207
+ if (!meta.cudaVersions || meta.cudaVersions.length === 0) continue
208
+ const hasCompatible = meta.cudaVersions.some(v => {
209
+ if (v === requiredCuda) return true
210
+ if (v.startsWith(majorRequired + '.')) return true
211
+ return false
212
+ })
213
+ if (hasCompatible) {
214
+ filtered[name] = meta
215
+ }
216
+ }
217
+
218
+ return filtered
219
+ }
220
+
221
+ // ── Tool handler ─────────────────────────────────────────────────────────────
222
+
223
+ /**
224
+ * Handle the get_instance_recommendation tool invocation.
225
+ *
226
+ * Pipeline: resolveModelMetadata → estimateVram → filterAndRankInstances
227
+ *
228
+ * @param {object} params - Tool input parameters
229
+ * @returns {object} MCP tool response
230
+ */
231
+ async function handleGetInstanceRecommendation(params) {
232
+ const {
233
+ modelName,
234
+ instanceSearch,
235
+ quantization,
236
+ maxSequenceLength,
237
+ batchSize,
238
+ cudaVersion,
239
+ limit = 10,
240
+ context
241
+ } = params
242
+
243
+ // Apply profile ENV overrides to sequence length and batch size
244
+ let effectiveMaxSeqLen = maxSequenceLength
245
+ let effectiveBatchSize = batchSize
246
+ if (context?.profileEnvVars) {
247
+ if (context.profileEnvVars.VLLM_MAX_MODEL_LEN) {
248
+ effectiveMaxSeqLen = parseInt(context.profileEnvVars.VLLM_MAX_MODEL_LEN, 10) || effectiveMaxSeqLen
249
+ }
250
+ if (context.profileEnvVars.VLLM_MAX_NUM_SEQS) {
251
+ effectiveBatchSize = parseInt(context.profileEnvVars.VLLM_MAX_NUM_SEQS, 10) || effectiveBatchSize
252
+ }
253
+ }
254
+
255
+ // Apply CUDA version filtering to instance catalog
256
+ let effectiveCatalog = INSTANCE_CATALOG
257
+ if (cudaVersion) {
258
+ effectiveCatalog = filterByCudaVersion(INSTANCE_CATALOG, cudaVersion)
259
+ if (Object.keys(effectiveCatalog).length === 0) {
260
+ log(`CUDA version ${cudaVersion} filter eliminated all instances`)
261
+ return {
262
+ content: [{
263
+ type: 'text',
264
+ text: JSON.stringify({
265
+ values: { instanceType: null },
266
+ choices: { instanceType: [] },
267
+ metadata: {
268
+ modelName: modelName || null,
269
+ warning: `No instances support CUDA version ${cudaVersion}. Check base image compatibility.`,
270
+ cudaVersionFilter: cudaVersion
271
+ }
272
+ })
273
+ }]
274
+ }
275
+ }
276
+ }
277
+
278
+ // Mode: tag-based search only (no model name)
279
+ if (!modelName && instanceSearch) {
280
+ const searchResults = searchInstancesByTag(instanceSearch, effectiveCatalog, { limit })
281
+ return {
282
+ content: [{
283
+ type: 'text',
284
+ text: JSON.stringify({
285
+ values: { instanceType: searchResults[0] || null },
286
+ choices: { instanceType: searchResults },
287
+ metadata: {
288
+ instanceSearch,
289
+ source: 'tag-search',
290
+ cudaVersionFilter: cudaVersion || null,
291
+ resultCount: searchResults.length
292
+ }
293
+ })
294
+ }]
295
+ }
296
+ }
297
+
298
+ // Mode: no model name and no search — return all GPU instances
299
+ if (!modelName) {
300
+ const allGpuInstances = Object.keys(effectiveCatalog)
301
+ .filter(key => effectiveCatalog[key].category === 'gpu')
302
+ .slice(0, limit)
303
+
304
+ return {
305
+ content: [{
306
+ type: 'text',
307
+ text: JSON.stringify({
308
+ values: { instanceType: allGpuInstances[0] || null },
309
+ choices: { instanceType: allGpuInstances },
310
+ metadata: {
311
+ modelName: null,
312
+ source: 'unfiltered',
313
+ cudaVersionFilter: cudaVersion || null,
314
+ warning: 'No model name provided. Returning GPU instances without VRAM filtering.'
315
+ }
316
+ })
317
+ }]
318
+ }
319
+ }
320
+
321
+ // Step 1: Resolve model metadata
322
+ const modelMetadata = await resolveModelMetadata(modelName, {
323
+ discover: DISCOVER_MODE
324
+ })
325
+
326
+ // If model metadata cannot be resolved, return all GPU instances unfiltered
327
+ if (!modelMetadata) {
328
+ log(`Model metadata not found for "${modelName}", returning unfiltered GPU instances`)
329
+ const allGpuInstances = Object.keys(effectiveCatalog)
330
+ .filter(key => effectiveCatalog[key].category === 'gpu')
331
+ .slice(0, limit)
332
+
333
+ return {
334
+ content: [{
335
+ type: 'text',
336
+ text: JSON.stringify({
337
+ values: { instanceType: allGpuInstances[0] || null },
338
+ choices: { instanceType: allGpuInstances },
339
+ metadata: {
340
+ modelName,
341
+ parameterCount: null,
342
+ dtype: null,
343
+ quantization: quantization || null,
344
+ estimatedVramGb: null,
345
+ vramBreakdown: null,
346
+ recommendations: allGpuInstances.map(instanceType => ({
347
+ instanceType,
348
+ gpuCount: effectiveCatalog[instanceType]?.gpus || 0,
349
+ totalVramGb: null,
350
+ utilizationPercent: null,
351
+ tensorParallelism: null,
352
+ costTier: null
353
+ })),
354
+ source: 'unfiltered',
355
+ cudaVersionFilter: cudaVersion || null,
356
+ warning: `Could not resolve model metadata for "${modelName}". Returning all GPU instances without filtering.`
357
+ }
358
+ })
359
+ }]
360
+ }
361
+ }
362
+
363
+ // Step 2: Estimate VRAM
364
+ // Use model's max_position_embeddings as the sequence length when no explicit value is provided.
365
+ // This ensures KV cache is sized for the model's actual context window, not the 4096 default.
366
+ const resolvedMaxSeqLen = effectiveMaxSeqLen || modelMetadata.maxPositionEmbeddings || undefined
367
+ const vramEstimate = estimateVram({
368
+ parameterCount: modelMetadata.parameterCount,
369
+ dtype: modelMetadata.dtype,
370
+ quantization: quantization || undefined,
371
+ maxSequenceLength: resolvedMaxSeqLen,
372
+ batchSize: effectiveBatchSize || undefined
373
+ })
374
+
375
+ // Step 3: Filter and rank instances
376
+ let recommendations = filterAndRankInstances(
377
+ vramEstimate.vramGb,
378
+ effectiveCatalog,
379
+ { limit }
380
+ )
381
+
382
+ // Step 3b: If instanceSearch is also provided, further filter by tags
383
+ if (instanceSearch && recommendations.length > 0) {
384
+ const searchMatches = new Set(searchInstancesByTag(instanceSearch, effectiveCatalog, { limit: 100 }))
385
+ recommendations = recommendations.filter(r => searchMatches.has(r.instanceType))
386
+ }
387
+
388
+ // Step 4: Smart mode — query Bedrock for edge-case reasoning
389
+ let finalRecommendations = recommendations
390
+ let smartModeUsed = false
391
+
392
+ if (SMART_MODE && recommendations.length > 0) {
393
+ log('[smart] Smart mode enabled, querying Amazon Bedrock...')
394
+
395
+ const bedrockContext = {
396
+ modelName,
397
+ parameterCount: modelMetadata.parameterCount,
398
+ dtype: modelMetadata.dtype,
399
+ quantization: quantization || null,
400
+ estimatedVramGb: vramEstimate.vramGb,
401
+ vramBreakdown: vramEstimate.breakdown,
402
+ staticRecommendations: recommendations.slice(0, 3).map(r => ({
403
+ instanceType: r.instanceType,
404
+ gpuCount: r.gpuCount,
405
+ totalVramGb: r.totalVramGb,
406
+ utilizationPercent: r.utilizationPercent,
407
+ tensorParallelism: r.tensorParallelism
408
+ })),
409
+ ...(context || {})
410
+ }
411
+
412
+ const bedrockResult = await queryBedrock(
413
+ SERVER_CONFIG,
414
+ ['instanceType'],
415
+ limit,
416
+ bedrockContext
417
+ )
418
+
419
+ if (bedrockResult?.values?.instanceType) {
420
+ const bedrockInstance = bedrockResult.values.instanceType
421
+ log(`[smart] Bedrock recommendation: ${bedrockInstance}`)
422
+
423
+ // Check if Bedrock's suggestion is already in our list
424
+ const existingIndex = finalRecommendations.findIndex(
425
+ r => r.instanceType === bedrockInstance
426
+ )
427
+
428
+ if (existingIndex > 0) {
429
+ // Move Bedrock's pick to the top
430
+ const [picked] = finalRecommendations.splice(existingIndex, 1)
431
+ finalRecommendations = [picked, ...finalRecommendations]
432
+ smartModeUsed = true
433
+ } else if (existingIndex === 0) {
434
+ // Already at the top — Bedrock agrees with static
435
+ smartModeUsed = true
436
+ log('[smart] Bedrock agrees with static top recommendation')
437
+ } else {
438
+ // Bedrock suggested an instance not in our filtered list;
439
+ // verify it exists in the catalog before prepending
440
+ if (INSTANCE_CATALOG[bedrockInstance]) {
441
+ const catalogEntry = INSTANCE_CATALOG[bedrockInstance]
442
+ const bedrockRec = {
443
+ instanceType: bedrockInstance,
444
+ gpuCount: catalogEntry.gpus || 0,
445
+ totalVramGb: (catalogEntry.gpuMemoryGb || 0) * (catalogEntry.gpus || 1),
446
+ utilizationPercent: null,
447
+ tensorParallelism: catalogEntry.gpus || 1,
448
+ costTier: catalogEntry.costTier || null
449
+ }
450
+ finalRecommendations = [bedrockRec, ...finalRecommendations].slice(0, limit)
451
+ smartModeUsed = true
452
+ } else {
453
+ log(`[smart] Bedrock suggested unknown instance "${bedrockInstance}", ignoring`)
454
+ }
455
+ }
456
+ } else {
457
+ log('[smart] Bedrock did not return usable results, falling back to static recommendations')
458
+ }
459
+ }
460
+
461
+ // Build response
462
+ const topRecommendation = finalRecommendations.length > 0
463
+ ? finalRecommendations[0].instanceType
464
+ : null
465
+
466
+ const rankedList = finalRecommendations.map(r => r.instanceType)
467
+
468
+ return {
469
+ content: [{
470
+ type: 'text',
471
+ text: JSON.stringify({
472
+ values: { instanceType: topRecommendation },
473
+ choices: { instanceType: rankedList },
474
+ metadata: {
475
+ modelName,
476
+ parameterCount: modelMetadata.parameterCount,
477
+ dtype: modelMetadata.dtype,
478
+ quantization: quantization || null,
479
+ estimatedVramGb: vramEstimate.vramGb,
480
+ vramBreakdown: vramEstimate.breakdown,
481
+ recommendations: finalRecommendations,
482
+ source: modelMetadata.source,
483
+ smartModeUsed
484
+ }
485
+ })
486
+ }]
487
+ }
488
+ }
489
+
490
+ // ── MCP Server setup ─────────────────────────────────────────────────────────
491
+
492
+ const server = new McpServer({
493
+ name: 'instance-sizer',
494
+ version: '1.0.0'
495
+ })
496
+
497
+ // Register the get_instance_recommendation tool
498
+ server.tool(
499
+ 'get_instance_recommendation',
500
+ 'Estimates VRAM requirements from model metadata and returns filtered, ranked SageMaker instance recommendations. Supports VRAM-based sizing, tag-based search, or both combined.',
501
+ {
502
+ modelName: z.string().optional().describe('HuggingFace model ID or catalog key'),
503
+ instanceSearch: z.string().optional().describe('Tag/keyword search for instances (e.g., "multi-gpu", "cost-effective cpu")'),
504
+ quantization: z.string().optional().describe('Quantization method: awq, gptq, bnb-4bit, bnb-8bit'),
505
+ maxSequenceLength: z.number().optional().describe('Max context/sequence length (affects KV cache estimate)'),
506
+ batchSize: z.number().optional().describe('Expected concurrent batch size'),
507
+ cudaVersion: z.string().optional().describe('Required CUDA version from base image (filters incompatible instances)'),
508
+ limit: z.number().optional().default(10).describe('Maximum number of instance recommendations to return'),
509
+ context: z.object({
510
+ architecture: z.string().optional(),
511
+ backend: z.string().optional(),
512
+ deploymentTarget: z.string().optional(),
513
+ profileEnvVars: z.record(z.string()).optional().describe('Serving profile ENV overrides (e.g., VLLM_MAX_MODEL_LEN)')
514
+ }).optional().describe('Additional deployment context')
515
+ },
516
+ async (params) => {
517
+ return handleGetInstanceRecommendation(params)
518
+ }
519
+ )
520
+
521
+ // Register alias tool name for backward compatibility
522
+ server.tool(
523
+ 'get_instance_types',
524
+ 'Alias for get_instance_recommendation — recommends SageMaker instances via VRAM sizing and/or tag-based search',
525
+ {
526
+ modelName: z.string().optional().describe('HuggingFace model ID or catalog key'),
527
+ instanceSearch: z.string().optional().describe('Tag/keyword search for instances (e.g., "multi-gpu", "cost-effective cpu")'),
528
+ quantization: z.string().optional().describe('Quantization method: awq, gptq, bnb-4bit, bnb-8bit'),
529
+ maxSequenceLength: z.number().optional().describe('Max context/sequence length (affects KV cache estimate)'),
530
+ batchSize: z.number().optional().describe('Expected concurrent batch size'),
531
+ cudaVersion: z.string().optional().describe('Required CUDA version from base image (filters incompatible instances)'),
532
+ limit: z.number().optional().default(10).describe('Maximum number of instance recommendations to return'),
533
+ context: z.object({
534
+ architecture: z.string().optional(),
535
+ backend: z.string().optional(),
536
+ deploymentTarget: z.string().optional(),
537
+ profileEnvVars: z.record(z.string()).optional().describe('Serving profile ENV overrides (e.g., VLLM_MAX_MODEL_LEN)')
538
+ }).optional().describe('Additional deployment context')
539
+ },
540
+ async (params) => {
541
+ return handleGetInstanceRecommendation(params)
542
+ }
543
+ )
544
+
545
+ // ── Exports for testing ──────────────────────────────────────────────────────
546
+
547
+ export { handleGetInstanceRecommendation, INSTANCE_CATALOG, SERVER_CONFIG, server, searchInstancesByTag, filterByCudaVersion }
548
+
549
+ // ── Transport connection (main module only) ──────────────────────────────────
550
+
551
+ const isMain = process.argv[1] && resolve(process.argv[1]) === __filename
552
+
553
+ if (isMain) {
554
+ if (SMART_MODE) {
555
+ log(`Smart mode enabled (model: ${BEDROCK_MODEL}, region: ${BEDROCK_REGION})`)
556
+ } else if (DISCOVER_MODE) {
557
+ log('Discover mode enabled (HuggingFace API lookups active)')
558
+ } else {
559
+ log('Static mode (catalog-only, no network calls)')
560
+ }
561
+
562
+ const transport = new StdioServerTransport()
563
+ await server.connect(transport)
564
+ }