@aws/ml-container-creator 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/LICENSE-THIRD-PARTY +68620 -0
- package/NOTICE +2 -0
- package/README.md +106 -0
- package/bin/cli.js +365 -0
- package/config/defaults.json +32 -0
- package/config/presets/transformers-djl.json +26 -0
- package/config/presets/transformers-gpu.json +24 -0
- package/config/presets/transformers-lmi.json +27 -0
- package/package.json +129 -0
- package/servers/README.md +419 -0
- package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
- package/servers/base-image-picker/catalogs/python-slim.json +38 -0
- package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
- package/servers/base-image-picker/catalogs/triton.json +38 -0
- package/servers/base-image-picker/index.js +495 -0
- package/servers/base-image-picker/manifest.json +17 -0
- package/servers/base-image-picker/package.json +15 -0
- package/servers/hyperpod-cluster-picker/LICENSE +202 -0
- package/servers/hyperpod-cluster-picker/index.js +424 -0
- package/servers/hyperpod-cluster-picker/manifest.json +14 -0
- package/servers/hyperpod-cluster-picker/package.json +17 -0
- package/servers/instance-recommender/LICENSE +202 -0
- package/servers/instance-recommender/catalogs/instances.json +852 -0
- package/servers/instance-recommender/index.js +284 -0
- package/servers/instance-recommender/manifest.json +16 -0
- package/servers/instance-recommender/package.json +15 -0
- package/servers/lib/LICENSE +202 -0
- package/servers/lib/bedrock-client.js +160 -0
- package/servers/lib/custom-validators.js +46 -0
- package/servers/lib/dynamic-resolver.js +36 -0
- package/servers/lib/package.json +11 -0
- package/servers/lib/schemas/image-catalog.schema.json +185 -0
- package/servers/lib/schemas/instances.schema.json +124 -0
- package/servers/lib/schemas/manifest.schema.json +64 -0
- package/servers/lib/schemas/model-catalog.schema.json +91 -0
- package/servers/lib/schemas/regions.schema.json +26 -0
- package/servers/lib/schemas/triton-backends.schema.json +51 -0
- package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
- package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
- package/servers/model-picker/catalogs/popular-transformers.json +226 -0
- package/servers/model-picker/index.js +1693 -0
- package/servers/model-picker/manifest.json +18 -0
- package/servers/model-picker/package.json +20 -0
- package/servers/region-picker/LICENSE +202 -0
- package/servers/region-picker/catalogs/regions.json +263 -0
- package/servers/region-picker/index.js +230 -0
- package/servers/region-picker/manifest.json +16 -0
- package/servers/region-picker/package.json +15 -0
- package/src/app.js +1007 -0
- package/src/copy-tpl.js +77 -0
- package/src/lib/accelerator-validator.js +39 -0
- package/src/lib/asset-manager.js +385 -0
- package/src/lib/aws-profile-parser.js +181 -0
- package/src/lib/bootstrap-command-handler.js +1647 -0
- package/src/lib/bootstrap-config.js +238 -0
- package/src/lib/ci-register-helpers.js +124 -0
- package/src/lib/ci-report-helpers.js +158 -0
- package/src/lib/ci-stage-helpers.js +268 -0
- package/src/lib/cli-handler.js +529 -0
- package/src/lib/comment-generator.js +544 -0
- package/src/lib/community-reports-validator.js +91 -0
- package/src/lib/config-manager.js +2106 -0
- package/src/lib/configuration-exporter.js +204 -0
- package/src/lib/configuration-manager.js +695 -0
- package/src/lib/configuration-matcher.js +221 -0
- package/src/lib/cpu-validator.js +36 -0
- package/src/lib/cuda-validator.js +57 -0
- package/src/lib/deployment-config-resolver.js +103 -0
- package/src/lib/deployment-entry-schema.js +125 -0
- package/src/lib/deployment-registry.js +598 -0
- package/src/lib/docker-introspection-validator.js +51 -0
- package/src/lib/engine-prefix-resolver.js +60 -0
- package/src/lib/huggingface-client.js +172 -0
- package/src/lib/key-value-parser.js +37 -0
- package/src/lib/known-flags-validator.js +200 -0
- package/src/lib/manifest-cli.js +280 -0
- package/src/lib/mcp-client.js +303 -0
- package/src/lib/mcp-command-handler.js +532 -0
- package/src/lib/neuron-validator.js +80 -0
- package/src/lib/parameter-schema-validator.js +284 -0
- package/src/lib/prompt-runner.js +1349 -0
- package/src/lib/prompts.js +1138 -0
- package/src/lib/registry-command-handler.js +519 -0
- package/src/lib/registry-loader.js +198 -0
- package/src/lib/rocm-validator.js +80 -0
- package/src/lib/schema-validator.js +157 -0
- package/src/lib/sensitive-redactor.js +59 -0
- package/src/lib/template-engine.js +156 -0
- package/src/lib/template-manager.js +341 -0
- package/src/lib/validation-engine.js +314 -0
- package/src/prompt-adapter.js +63 -0
- package/templates/Dockerfile +300 -0
- package/templates/IAM_PERMISSIONS.md +84 -0
- package/templates/MIGRATION.md +488 -0
- package/templates/PROJECT_README.md +439 -0
- package/templates/TEMPLATE_SYSTEM.md +243 -0
- package/templates/buildspec.yml +64 -0
- package/templates/code/chat_template.jinja +1 -0
- package/templates/code/flask/gunicorn_config.py +35 -0
- package/templates/code/flask/wsgi.py +10 -0
- package/templates/code/model_handler.py +387 -0
- package/templates/code/serve +300 -0
- package/templates/code/serve.py +175 -0
- package/templates/code/serving.properties +105 -0
- package/templates/code/start_server.py +39 -0
- package/templates/code/start_server.sh +39 -0
- package/templates/diffusors/Dockerfile +72 -0
- package/templates/diffusors/patch_image_api.py +35 -0
- package/templates/diffusors/serve +115 -0
- package/templates/diffusors/start_server.sh +114 -0
- package/templates/do/.gitkeep +1 -0
- package/templates/do/README.md +541 -0
- package/templates/do/build +83 -0
- package/templates/do/ci +681 -0
- package/templates/do/clean +811 -0
- package/templates/do/config +260 -0
- package/templates/do/deploy +1560 -0
- package/templates/do/export +306 -0
- package/templates/do/logs +319 -0
- package/templates/do/manifest +12 -0
- package/templates/do/push +119 -0
- package/templates/do/register +580 -0
- package/templates/do/run +113 -0
- package/templates/do/submit +417 -0
- package/templates/do/test +1147 -0
- package/templates/hyperpod/configmap.yaml +24 -0
- package/templates/hyperpod/deployment.yaml +71 -0
- package/templates/hyperpod/pvc.yaml +42 -0
- package/templates/hyperpod/service.yaml +17 -0
- package/templates/nginx-diffusors.conf +74 -0
- package/templates/nginx-predictors.conf +47 -0
- package/templates/nginx-tensorrt.conf +74 -0
- package/templates/requirements.txt +61 -0
- package/templates/sample_model/test_inference.py +123 -0
- package/templates/sample_model/train_abalone.py +252 -0
- package/templates/test/test_endpoint.sh +79 -0
- package/templates/test/test_local_image.sh +80 -0
- package/templates/test/test_model_handler.py +180 -0
- package/templates/triton/Dockerfile +128 -0
- package/templates/triton/config.pbtxt +163 -0
- package/templates/triton/model.py +130 -0
- package/templates/triton/requirements.txt +11 -0
|
@@ -0,0 +1,1693 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
/**
|
|
6
|
+
* Model Picker MCP Server
|
|
7
|
+
*
|
|
8
|
+
* A bundled MCP server that returns model metadata for ML Container Creator.
|
|
9
|
+
* Supports two operating modes:
|
|
10
|
+
* - Static: Returns metadata from local catalog files (popular-transformers.json, popular-diffusors.json)
|
|
11
|
+
* - Discover: Queries HuggingFace Hub API for live metadata, merging with static catalog
|
|
12
|
+
*
|
|
13
|
+
* Uses a pluggable ModelResolver architecture. V1 ships with HuggingFaceResolver
|
|
14
|
+
* and StaticCatalogResolver.
|
|
15
|
+
*
|
|
16
|
+
* Tool: get_models
|
|
17
|
+
* Accepts: { model_id: string, fields?: string[], mode?: string, context?: object }
|
|
18
|
+
* Returns: { values, choices, message }
|
|
19
|
+
*/
|
|
20
|
+
|
|
21
|
+
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'
|
|
22
|
+
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
|
|
23
|
+
import { z } from 'zod'
|
|
24
|
+
import { readFileSync } from 'node:fs'
|
|
25
|
+
import { fileURLToPath } from 'node:url'
|
|
26
|
+
import { resolve, dirname } from 'node:path'
|
|
27
|
+
import { DynamicResolver } from '../lib/dynamic-resolver.js'
|
|
28
|
+
|
|
29
|
+
// ── Catalog loader ───────────────────────────────────────────────────────────
|
|
30
|
+
|
|
31
|
+
const __filename = fileURLToPath(import.meta.url)
|
|
32
|
+
const __dirname = dirname(__filename)
|
|
33
|
+
|
|
34
|
+
/**
|
|
35
|
+
* Load and parse a JSON catalog file relative to the server directory.
|
|
36
|
+
* Throws on missing file or invalid JSON with the file path in the message.
|
|
37
|
+
*
|
|
38
|
+
* @param {string} relativePath - Path relative to server dir (e.g. './catalogs/popular-transformers.json')
|
|
39
|
+
* @returns {any} Parsed JSON content
|
|
40
|
+
*/
|
|
41
|
+
function loadCatalog(relativePath) {
|
|
42
|
+
const fullPath = resolve(__dirname, relativePath)
|
|
43
|
+
let raw
|
|
44
|
+
try {
|
|
45
|
+
raw = readFileSync(fullPath, 'utf8')
|
|
46
|
+
} catch (err) {
|
|
47
|
+
throw new Error(`Catalog file not found: ${fullPath}`)
|
|
48
|
+
}
|
|
49
|
+
try {
|
|
50
|
+
return JSON.parse(raw)
|
|
51
|
+
} catch (err) {
|
|
52
|
+
throw new Error(`Failed to parse catalog ${fullPath}: ${err.message}`)
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
// ── ModelResolver interface ──────────────────────────────────────────────────
|
|
57
|
+
|
|
58
|
+
/**
|
|
59
|
+
* ModelResolver — model-specific dynamic resolver.
|
|
60
|
+
*
|
|
61
|
+
* Extends DynamicResolver with model-specific method names (fetchModelMetadata,
|
|
62
|
+
* supportedPatterns) that delegate to the generic fetch/supportedKeys interface.
|
|
63
|
+
*
|
|
64
|
+
* Each resolver knows how to fetch metadata from a specific model source.
|
|
65
|
+
* The MCP server delegates to the appropriate resolver based on model ID pattern.
|
|
66
|
+
*/
|
|
67
|
+
class ModelResolver extends DynamicResolver {
|
|
68
|
+
/**
|
|
69
|
+
* Fetch metadata for a model ID.
|
|
70
|
+
* @param {string} modelId - e.g. 'meta-llama/Llama-2-7b-chat-hf'
|
|
71
|
+
* @param {object} options - { fields, limit, context }
|
|
72
|
+
* @returns {Promise<object|null>} Model metadata or null
|
|
73
|
+
*/
|
|
74
|
+
async fetchModelMetadata(modelId, options = {}) {
|
|
75
|
+
throw new Error('fetchModelMetadata() must be implemented by subclass')
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
/**
|
|
79
|
+
* Returns glob patterns this resolver handles.
|
|
80
|
+
* @returns {string[]} e.g. ['hf:org/model'] for HuggingFace org/model pattern
|
|
81
|
+
*/
|
|
82
|
+
supportedPatterns() {
|
|
83
|
+
throw new Error('supportedPatterns() must be implemented by subclass')
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// ── DynamicResolver interface bridge ─────────────────────────────────
|
|
87
|
+
|
|
88
|
+
async fetch(key, options = {}) {
|
|
89
|
+
return this.fetchModelMetadata(key, options)
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
supportedKeys() {
|
|
93
|
+
return this.supportedPatterns()
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
// ── StaticCatalogResolver ────────────────────────────────────────────────────
|
|
99
|
+
|
|
100
|
+
/**
|
|
101
|
+
* StaticCatalogResolver — fallback resolver.
|
|
102
|
+
*
|
|
103
|
+
* Returns model metadata from the static catalog.
|
|
104
|
+
* No network calls, no auth, no external dependencies.
|
|
105
|
+
* Supports exact match and glob-style pattern matching.
|
|
106
|
+
*/
|
|
107
|
+
class StaticCatalogResolver extends ModelResolver {
|
|
108
|
+
constructor(catalog) {
|
|
109
|
+
super()
|
|
110
|
+
this._catalog = catalog
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
supportedPatterns() {
|
|
114
|
+
return ['*']
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
async fetchModelMetadata(modelId, options = {}) {
|
|
118
|
+
// Exact match first
|
|
119
|
+
if (this._catalog[modelId]) {
|
|
120
|
+
return { ...this._catalog[modelId] }
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
// Glob pattern match (e.g., 'meta-llama/Llama-2-*')
|
|
124
|
+
for (const [pattern, metadata] of Object.entries(this._catalog)) {
|
|
125
|
+
if (pattern.includes('*') || pattern.includes('?')) {
|
|
126
|
+
if (this._globMatch(modelId, pattern)) {
|
|
127
|
+
return { ...metadata }
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
return null
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
/**
|
|
136
|
+
* Match a string against a glob pattern.
|
|
137
|
+
* Converts * to .* and ? to . for regex matching.
|
|
138
|
+
*
|
|
139
|
+
* @param {string} str - The string to test
|
|
140
|
+
* @param {string} pattern - Glob pattern with * and ? wildcards
|
|
141
|
+
* @returns {boolean}
|
|
142
|
+
*/
|
|
143
|
+
_globMatch(str, pattern) {
|
|
144
|
+
const regex = new RegExp(
|
|
145
|
+
'^' + pattern.replace(/\*/g, '.*').replace(/\?/g, '.') + '$'
|
|
146
|
+
)
|
|
147
|
+
return regex.test(str)
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
// ── HuggingFaceResolver ──────────────────────────────────────────────────────
|
|
153
|
+
|
|
154
|
+
/**
|
|
155
|
+
* HuggingFaceResolver — fetches live model metadata from HuggingFace Hub API.
|
|
156
|
+
*
|
|
157
|
+
* Handles model IDs matching the org/model-name pattern. Queries three endpoints:
|
|
158
|
+
* - /api/models/{modelId} — model info (always)
|
|
159
|
+
* - /{modelId}/resolve/main/tokenizer_config.json — chat template (conditional)
|
|
160
|
+
* - /{modelId}/resolve/main/config.json — architecture (conditional)
|
|
161
|
+
*
|
|
162
|
+
* All HTTP errors are non-fatal: returns null for affected fields and logs to stderr.
|
|
163
|
+
*/
|
|
164
|
+
class HuggingFaceResolver extends ModelResolver {
|
|
165
|
+
constructor(options = {}) {
|
|
166
|
+
super()
|
|
167
|
+
this.baseUrl = options.baseUrl || 'https://huggingface.co'
|
|
168
|
+
this.timeout = options.timeout || 5000
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
supportedPatterns() {
|
|
172
|
+
return ['hf:*/*']
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
async fetchModelMetadata(modelId, options = {}) {
|
|
176
|
+
const { fields } = options
|
|
177
|
+
const metadata = {}
|
|
178
|
+
|
|
179
|
+
// Fetch model info (always)
|
|
180
|
+
const modelInfo = await this._fetchJson(
|
|
181
|
+
`${this.baseUrl}/api/models/${modelId}`
|
|
182
|
+
)
|
|
183
|
+
if (modelInfo) {
|
|
184
|
+
metadata.tags = modelInfo.tags || []
|
|
185
|
+
metadata.gated = modelInfo.gated || false
|
|
186
|
+
metadata.pipeline_tag = modelInfo.pipeline_tag || null
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// Fetch tokenizer config (conditional)
|
|
190
|
+
if (!fields || fields.includes('chat_template')) {
|
|
191
|
+
const tokenizerConfig = await this._fetchJson(
|
|
192
|
+
`${this.baseUrl}/${modelId}/resolve/main/tokenizer_config.json`
|
|
193
|
+
)
|
|
194
|
+
metadata.chat_template = tokenizerConfig?.chat_template || null
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
// Fetch model config (conditional)
|
|
198
|
+
if (!fields || fields.includes('architecture')) {
|
|
199
|
+
const modelConfig = await this._fetchJson(
|
|
200
|
+
`${this.baseUrl}/${modelId}/resolve/main/config.json`
|
|
201
|
+
)
|
|
202
|
+
metadata.architecture = modelConfig?.architectures?.[0] || null
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
return Object.keys(metadata).length > 0 ? metadata : null
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
/**
|
|
209
|
+
* Fetch JSON from a URL with timeout and error handling.
|
|
210
|
+
* Returns null on any error (429, 404, network, timeout).
|
|
211
|
+
*
|
|
212
|
+
* @param {string} url - URL to fetch
|
|
213
|
+
* @returns {Promise<object|null>}
|
|
214
|
+
*/
|
|
215
|
+
async _fetchJson(url) {
|
|
216
|
+
const controller = new AbortController()
|
|
217
|
+
const timer = setTimeout(() => controller.abort(), this.timeout)
|
|
218
|
+
try {
|
|
219
|
+
const response = await fetch(url, { signal: controller.signal })
|
|
220
|
+
clearTimeout(timer)
|
|
221
|
+
if (response.status === 429) {
|
|
222
|
+
process.stderr.write(
|
|
223
|
+
`[model-picker] Rate limited: ${url}\n`
|
|
224
|
+
)
|
|
225
|
+
return null
|
|
226
|
+
}
|
|
227
|
+
if (response.status === 404) return null
|
|
228
|
+
if (!response.ok) return null
|
|
229
|
+
return await response.json()
|
|
230
|
+
} catch (err) {
|
|
231
|
+
clearTimeout(timer)
|
|
232
|
+
process.stderr.write(
|
|
233
|
+
`[model-picker] Fetch failed: ${url} — ${err.message}\n`
|
|
234
|
+
)
|
|
235
|
+
return null
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
// ── JumpStartPublicResolver ───────────────────────────────────────────────────
|
|
242
|
+
|
|
243
|
+
/**
|
|
244
|
+
* Credential-related error names/codes that indicate missing or expired AWS credentials.
|
|
245
|
+
* When these occur, the resolver falls back to the static catalog.
|
|
246
|
+
*/
|
|
247
|
+
const CREDENTIAL_ERROR_NAMES = new Set([
|
|
248
|
+
'CredentialsError',
|
|
249
|
+
'CredentialsProviderError',
|
|
250
|
+
'ExpiredTokenException',
|
|
251
|
+
'ExpiredToken',
|
|
252
|
+
'InvalidIdentityToken',
|
|
253
|
+
'NoSuchTokenException',
|
|
254
|
+
'UnrecognizedClientException'
|
|
255
|
+
])
|
|
256
|
+
|
|
257
|
+
/**
|
|
258
|
+
* JumpStartPublicResolver — fetches model metadata from the JumpStart public
|
|
259
|
+
* S3 cache bucket (`jumpstart-cache-prod-{region}`).
|
|
260
|
+
*
|
|
261
|
+
* Handles model IDs matching the `jumpstart://` URI prefix. Retrieves the
|
|
262
|
+
* `models_manifest.json` to find the model's `spec_key`, then fetches the
|
|
263
|
+
* full model spec JSON at that key (e.g.
|
|
264
|
+
* `community_models/{model-id}/specs_v{version}.json`).
|
|
265
|
+
*
|
|
266
|
+
* Uses anonymous (unsigned) S3 requests since the bucket is publicly readable.
|
|
267
|
+
* On S3 error or timeout, falls back to the static catalog.
|
|
268
|
+
* AWS SDK is lazy-imported to keep the server fast in static mode.
|
|
269
|
+
*/
|
|
270
|
+
class JumpStartPublicResolver extends ModelResolver {
|
|
271
|
+
constructor(options = {}) {
|
|
272
|
+
super()
|
|
273
|
+
this.timeout = options.timeout ?? 10000
|
|
274
|
+
this.region = options.region || process.env.AWS_REGION || 'us-east-1'
|
|
275
|
+
this._client = null
|
|
276
|
+
this._sdkModule = null
|
|
277
|
+
this._staticCatalog = options.staticCatalog || null
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
supportedPatterns() {
|
|
281
|
+
return ['jumpstart://*']
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
/**
|
|
285
|
+
* Fetch metadata for a JumpStart public model.
|
|
286
|
+
*
|
|
287
|
+
* For a specific model ID, fetches `models_manifest.json` from the
|
|
288
|
+
* JumpStart S3 cache bucket, finds the latest version entry for the
|
|
289
|
+
* requested model, then fetches the full spec JSON using the entry's
|
|
290
|
+
* `spec_key`.
|
|
291
|
+
*
|
|
292
|
+
* For list mode (bareId === '*'), returns metadata from the first
|
|
293
|
+
* manifest entry.
|
|
294
|
+
*
|
|
295
|
+
* @param {string} modelId - e.g. 'jumpstart://huggingface-llm-falcon-7b'
|
|
296
|
+
* @param {object} options - { fields, context }
|
|
297
|
+
* @returns {Promise<object|null>} ModelMetadata or null
|
|
298
|
+
*/
|
|
299
|
+
async fetchModelMetadata(modelId, options = {}) {
|
|
300
|
+
const bareId = modelId.replace(/^jumpstart:\/\//, '')
|
|
301
|
+
|
|
302
|
+
try {
|
|
303
|
+
const sdk = await this._loadSdk()
|
|
304
|
+
const client = this._createClient(sdk)
|
|
305
|
+
|
|
306
|
+
// Fetch the manifest
|
|
307
|
+
const manifestCmd = new sdk.GetObjectCommand({
|
|
308
|
+
Bucket: this._bucketName(),
|
|
309
|
+
Key: 'models_manifest.json'
|
|
310
|
+
})
|
|
311
|
+
const manifestResp = await client.send(manifestCmd)
|
|
312
|
+
const manifestBody = await manifestResp.Body.transformToString()
|
|
313
|
+
const manifest = JSON.parse(manifestBody)
|
|
314
|
+
|
|
315
|
+
if (!Array.isArray(manifest) || manifest.length === 0) {
|
|
316
|
+
return null
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
// List mode — return metadata from the first manifest entry
|
|
320
|
+
if (!bareId || bareId === '*') {
|
|
321
|
+
return this._mapToMetadata(manifest[0], manifest[0].model_id || '*')
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
// Find the latest version entry for the requested model
|
|
325
|
+
const entry = this._findLatestEntry(manifest, bareId)
|
|
326
|
+
if (!entry || !entry.spec_key) {
|
|
327
|
+
process.stderr.write(
|
|
328
|
+
`[jumpstart] Model not found in manifest: ${bareId}\n`
|
|
329
|
+
)
|
|
330
|
+
return this._fallbackToStaticCatalog(modelId)
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
// Fetch the full spec using the spec_key from the manifest
|
|
334
|
+
const specCmd = new sdk.GetObjectCommand({
|
|
335
|
+
Bucket: this._bucketName(),
|
|
336
|
+
Key: entry.spec_key
|
|
337
|
+
})
|
|
338
|
+
const specResp = await client.send(specCmd)
|
|
339
|
+
const specBody = await specResp.Body.transformToString()
|
|
340
|
+
const spec = JSON.parse(specBody)
|
|
341
|
+
return this._mapToMetadata(spec, bareId)
|
|
342
|
+
} catch (err) {
|
|
343
|
+
if (this._isCredentialError(err)) {
|
|
344
|
+
process.stderr.write(
|
|
345
|
+
`[jumpstart] AWS credentials not available. Falling back to static catalog.\n`
|
|
346
|
+
)
|
|
347
|
+
return this._fallbackToStaticCatalog(modelId)
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
process.stderr.write(
|
|
351
|
+
`[jumpstart] JumpStart S3 bucket unreachable: ${err.name || err.code || 'Unknown'}. Falling back to static catalog.\n`
|
|
352
|
+
)
|
|
353
|
+
return this._fallbackToStaticCatalog(modelId)
|
|
354
|
+
}
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
/**
|
|
358
|
+
* Find the latest non-deprecated manifest entry for a given model ID.
|
|
359
|
+
*
|
|
360
|
+
* The manifest contains multiple version entries per model. This method
|
|
361
|
+
* finds the first non-deprecated entry (manifest is sorted newest-first
|
|
362
|
+
* per model).
|
|
363
|
+
*
|
|
364
|
+
* @param {Array} manifest - Parsed models_manifest.json array
|
|
365
|
+
* @param {string} bareId - Model ID without the jumpstart:// prefix
|
|
366
|
+
* @returns {object|null} Manifest entry or null
|
|
367
|
+
*/
|
|
368
|
+
_findLatestEntry(manifest, bareId) {
|
|
369
|
+
return manifest.find(e => e.model_id === bareId && !e.deprecated) ||
|
|
370
|
+
manifest.find(e => e.model_id === bareId) ||
|
|
371
|
+
null
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
/**
|
|
375
|
+
* Lazy-load the @aws-sdk/client-s3 module.
|
|
376
|
+
* @returns {Promise<object>} The SDK module
|
|
377
|
+
*/
|
|
378
|
+
async _loadSdk() {
|
|
379
|
+
if (!this._sdkModule) {
|
|
380
|
+
this._sdkModule = await import('@aws-sdk/client-s3')
|
|
381
|
+
}
|
|
382
|
+
return this._sdkModule
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
/**
|
|
386
|
+
* Create an S3Client configured for anonymous (unsigned) access to the
|
|
387
|
+
* JumpStart public cache bucket. Reuses the client across calls.
|
|
388
|
+
*
|
|
389
|
+
* The JumpStart cache bucket is publicly readable, so requests are sent
|
|
390
|
+
* without AWS credentials — equivalent to `--no-sign-request` in the CLI.
|
|
391
|
+
*
|
|
392
|
+
* @param {object} sdk - The loaded @aws-sdk/client-s3 module
|
|
393
|
+
* @returns {object} S3Client instance
|
|
394
|
+
*/
|
|
395
|
+
_createClient(sdk) {
|
|
396
|
+
if (!this._client) {
|
|
397
|
+
this._client = new sdk.S3Client({
|
|
398
|
+
region: this.region,
|
|
399
|
+
requestHandler: {
|
|
400
|
+
requestTimeout: this.timeout
|
|
401
|
+
},
|
|
402
|
+
signer: { sign: async (request) => request }
|
|
403
|
+
})
|
|
404
|
+
}
|
|
405
|
+
return this._client
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
/**
|
|
409
|
+
* Return the JumpStart public cache bucket name for the configured region.
|
|
410
|
+
* @returns {string} Bucket name
|
|
411
|
+
*/
|
|
412
|
+
_bucketName() {
|
|
413
|
+
return `jumpstart-cache-prod-${this.region}`
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
/**
|
|
417
|
+
* Map a JumpStart model spec JSON object (or manifest entry) to the
|
|
418
|
+
* common ModelMetadata shape.
|
|
419
|
+
*
|
|
420
|
+
* Handles both full spec objects (from spec_key fetch) and manifest
|
|
421
|
+
* entries. Full specs have fields like `hosting_ecr_specs`, `provider`,
|
|
422
|
+
* `url`, `supported_inference_instance_types`. Manifest entries have
|
|
423
|
+
* `model_id`, `version`, `spec_key`, `provider`, `search_keywords`.
|
|
424
|
+
*
|
|
425
|
+
* @param {object} spec - JumpStart model spec JSON or manifest entry
|
|
426
|
+
* @param {string} bareId - The model ID without the jumpstart:// prefix
|
|
427
|
+
* @returns {object} ModelMetadata
|
|
428
|
+
*/
|
|
429
|
+
_mapToMetadata(spec, bareId) {
|
|
430
|
+
if (!spec) return null
|
|
431
|
+
|
|
432
|
+
const modelId = spec.model_id || bareId
|
|
433
|
+
const metadata = {
|
|
434
|
+
provider: 'jumpstart',
|
|
435
|
+
modelId: `jumpstart://${modelId}`,
|
|
436
|
+
description: this._humanReadableId(modelId)
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
// Extract framework from hosting_ecr_specs (full spec) or spec.framework
|
|
440
|
+
const framework = spec.hosting_ecr_specs?.framework ||
|
|
441
|
+
spec.hosting_ecr_specs?.Framework ||
|
|
442
|
+
spec.framework
|
|
443
|
+
if (framework) {
|
|
444
|
+
metadata.framework = framework
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
// Extract tags from search_keywords or task-related fields
|
|
448
|
+
const tags = []
|
|
449
|
+
if (Array.isArray(spec.search_keywords)) {
|
|
450
|
+
tags.push(...spec.search_keywords)
|
|
451
|
+
}
|
|
452
|
+
if (spec.model_type) tags.push(spec.model_type)
|
|
453
|
+
if (spec.inference_task) tags.push(spec.inference_task)
|
|
454
|
+
if (tags.length > 0) {
|
|
455
|
+
metadata.tags = [...new Set(tags)]
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
// Extract default instance type if available
|
|
459
|
+
if (spec.default_inference_instance_type) {
|
|
460
|
+
metadata.defaultInstanceType = spec.default_inference_instance_type
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
// Extract supported instance types if available
|
|
464
|
+
if (Array.isArray(spec.supported_inference_instance_types) &&
|
|
465
|
+
spec.supported_inference_instance_types.length > 0) {
|
|
466
|
+
metadata.supportedInstanceTypes = spec.supported_inference_instance_types
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
// Extract artifact URI from hosting artifact keys
|
|
470
|
+
// Prefer hosting_prepacked_artifact_key (pre-packaged model ready for serving)
|
|
471
|
+
// Fall back to hosting_artifact_key (raw model artifacts)
|
|
472
|
+
const artifactKey = spec.hosting_prepacked_artifact_key || spec.hosting_artifact_key
|
|
473
|
+
if (artifactKey) {
|
|
474
|
+
metadata.artifactUri = `s3://${this._bucketName()}/${artifactKey}`
|
|
475
|
+
} else {
|
|
476
|
+
process.stderr.write(
|
|
477
|
+
`[jumpstart] No artifact key found for model ${modelId}. artifactUri will be undefined.\n`
|
|
478
|
+
)
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
return metadata
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
/**
|
|
485
|
+
* Convert a model ID like "huggingface-reasoning-qwen3-8b" into a
|
|
486
|
+
* human-readable description: "Huggingface Reasoning Qwen3 8b".
|
|
487
|
+
*
|
|
488
|
+
* @param {string} id - Raw model ID
|
|
489
|
+
* @returns {string} Title-cased, space-separated description
|
|
490
|
+
*/
|
|
491
|
+
_humanReadableId(id) {
|
|
492
|
+
if (!id) return ''
|
|
493
|
+
return id
|
|
494
|
+
.split('-')
|
|
495
|
+
.map(w => w.charAt(0).toUpperCase() + w.slice(1))
|
|
496
|
+
.join(' ')
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
/**
|
|
500
|
+
* Check if an error is a credential-related error.
|
|
501
|
+
* @param {Error} err
|
|
502
|
+
* @returns {boolean}
|
|
503
|
+
*/
|
|
504
|
+
_isCredentialError(err) {
|
|
505
|
+
return CREDENTIAL_ERROR_NAMES.has(err.name) ||
|
|
506
|
+
CREDENTIAL_ERROR_NAMES.has(err.Code) ||
|
|
507
|
+
(err.message && err.message.includes('credentials'))
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
/**
|
|
511
|
+
* Fall back to the static catalog for a given model ID.
|
|
512
|
+
* @param {string} modelId - Full model ID with jumpstart:// prefix
|
|
513
|
+
* @returns {object|null} Static catalog entry or null
|
|
514
|
+
*/
|
|
515
|
+
_fallbackToStaticCatalog(modelId) {
|
|
516
|
+
if (this._staticCatalog && this._staticCatalog[modelId]) {
|
|
517
|
+
return { ...this._staticCatalog[modelId] }
|
|
518
|
+
}
|
|
519
|
+
return null
|
|
520
|
+
}
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
// ── JumpStartPrivateResolver ──────────────────────────────────────────────────
|
|
525
|
+
|
|
526
|
+
/**
|
|
527
|
+
* JumpStartPrivateResolver — fetches model metadata from a private SageMaker
|
|
528
|
+
* JumpStart model hub via the SageMaker API.
|
|
529
|
+
*
|
|
530
|
+
* Handles model IDs matching the `jumpstart-hub://` URI prefix. Parses the URI
|
|
531
|
+
* into hub name and model name, then queries:
|
|
532
|
+
* - ListHubContents — browse models in a private hub
|
|
533
|
+
* - DescribeHubContent — get detailed metadata for a specific model in a hub
|
|
534
|
+
*
|
|
535
|
+
* Distinct error handling:
|
|
536
|
+
* - ResourceNotFoundException for hub → "Hub not found: {hubName}"
|
|
537
|
+
* - ResourceNotFoundException for model → "Model not found in hub: {hubName}/{modelName}"
|
|
538
|
+
* - AccessDeniedException → "Access denied to hub: {hubName}" (no credential details)
|
|
539
|
+
* - Credential failure → return null + log to stderr
|
|
540
|
+
*
|
|
541
|
+
* AWS SDK is lazy-imported to keep the server fast in static mode.
|
|
542
|
+
*/
|
|
543
|
+
class JumpStartPrivateResolver extends ModelResolver {
|
|
544
|
+
constructor(options = {}) {
|
|
545
|
+
super()
|
|
546
|
+
this.timeout = options.timeout ?? 10000
|
|
547
|
+
this.region = options.region || process.env.AWS_REGION || 'us-east-1'
|
|
548
|
+
this._client = null
|
|
549
|
+
this._sdkModule = null
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
supportedPatterns() {
|
|
553
|
+
return ['jumpstart-hub://*']
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
/**
|
|
557
|
+
* Parse a jumpstart-hub:// URI into hub name and model name.
|
|
558
|
+
*
|
|
559
|
+
* @param {string} modelId - e.g. 'jumpstart-hub://my-hub/my-model'
|
|
560
|
+
* @returns {{ hubName: string, modelName: string } | null}
|
|
561
|
+
*/
|
|
562
|
+
_parseHubUri(modelId) {
|
|
563
|
+
const withoutPrefix = modelId.replace(/^jumpstart-hub:\/\//, '')
|
|
564
|
+
if (!withoutPrefix) return null
|
|
565
|
+
|
|
566
|
+
const slashIndex = withoutPrefix.indexOf('/')
|
|
567
|
+
if (slashIndex === -1) {
|
|
568
|
+
// Only hub name, no model name — list mode
|
|
569
|
+
return { hubName: withoutPrefix, modelName: null }
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
const hubName = withoutPrefix.slice(0, slashIndex)
|
|
573
|
+
const modelName = withoutPrefix.slice(slashIndex + 1) || null
|
|
574
|
+
|
|
575
|
+
if (!hubName) return null
|
|
576
|
+
return { hubName, modelName }
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
/**
|
|
580
|
+
* Fetch metadata for a model in a private JumpStart hub.
|
|
581
|
+
*
|
|
582
|
+
* @param {string} modelId - e.g. 'jumpstart-hub://my-hub/my-model'
|
|
583
|
+
* @param {object} options - { fields, context }
|
|
584
|
+
* @returns {Promise<object|null>} ModelMetadata or null
|
|
585
|
+
*/
|
|
586
|
+
async fetchModelMetadata(modelId, options = {}) {
|
|
587
|
+
const parsed = this._parseHubUri(modelId)
|
|
588
|
+
if (!parsed) {
|
|
589
|
+
process.stderr.write(
|
|
590
|
+
`[jumpstart-hub] Invalid hub URI: ${modelId}\n`
|
|
591
|
+
)
|
|
592
|
+
return null
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
const { hubName, modelName } = parsed
|
|
596
|
+
|
|
597
|
+
try {
|
|
598
|
+
const sdk = await this._loadSdk()
|
|
599
|
+
const client = this._createClient(sdk)
|
|
600
|
+
|
|
601
|
+
// If a specific model is requested, describe it
|
|
602
|
+
if (modelName) {
|
|
603
|
+
const command = new sdk.DescribeHubContentCommand({
|
|
604
|
+
HubName: hubName,
|
|
605
|
+
HubContentName: modelName,
|
|
606
|
+
HubContentType: 'Model'
|
|
607
|
+
})
|
|
608
|
+
const response = await client.send(command)
|
|
609
|
+
return this._mapToMetadata(response, hubName)
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
// Otherwise list hub contents
|
|
613
|
+
const command = new sdk.ListHubContentsCommand({
|
|
614
|
+
HubName: hubName,
|
|
615
|
+
HubContentType: 'Model'
|
|
616
|
+
})
|
|
617
|
+
const response = await client.send(command)
|
|
618
|
+
if (response.HubContentSummaries && response.HubContentSummaries.length > 0) {
|
|
619
|
+
return this._mapToMetadata(response.HubContentSummaries[0], hubName)
|
|
620
|
+
}
|
|
621
|
+
|
|
622
|
+
return null
|
|
623
|
+
} catch (err) {
|
|
624
|
+
return this._handleError(err, hubName, modelName)
|
|
625
|
+
}
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
/**
|
|
629
|
+
* Lazy-load the @aws-sdk/client-sagemaker module.
|
|
630
|
+
* @returns {Promise<object>} The SDK module
|
|
631
|
+
*/
|
|
632
|
+
async _loadSdk() {
|
|
633
|
+
if (!this._sdkModule) {
|
|
634
|
+
this._sdkModule = await import('@aws-sdk/client-sagemaker')
|
|
635
|
+
}
|
|
636
|
+
return this._sdkModule
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
/**
|
|
640
|
+
* Create a SageMakerClient with region and timeout configuration.
|
|
641
|
+
* Reuses the client across calls.
|
|
642
|
+
*
|
|
643
|
+
* @param {object} sdk - The loaded @aws-sdk/client-sagemaker module
|
|
644
|
+
* @returns {object} SageMakerClient instance
|
|
645
|
+
*/
|
|
646
|
+
_createClient(sdk) {
|
|
647
|
+
if (!this._client) {
|
|
648
|
+
this._client = new sdk.SageMakerClient({
|
|
649
|
+
region: this.region,
|
|
650
|
+
requestHandler: {
|
|
651
|
+
requestTimeout: this.timeout
|
|
652
|
+
}
|
|
653
|
+
})
|
|
654
|
+
}
|
|
655
|
+
return this._client
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
/**
|
|
659
|
+
* Map a JumpStart hub API response to the common ModelMetadata shape.
|
|
660
|
+
*
|
|
661
|
+
* @param {object} apiResponse - DescribeHubContent or HubContentSummary from the API
|
|
662
|
+
* @param {string} hubName - The hub name from the URI
|
|
663
|
+
* @returns {object} ModelMetadata
|
|
664
|
+
*/
|
|
665
|
+
_mapToMetadata(apiResponse, hubName) {
|
|
666
|
+
if (!apiResponse) return null
|
|
667
|
+
|
|
668
|
+
const contentName = apiResponse.HubContentName || apiResponse.HubContentDisplayName || ''
|
|
669
|
+
const metadata = {
|
|
670
|
+
provider: 'jumpstart-hub',
|
|
671
|
+
modelId: `jumpstart-hub://${hubName}/${contentName}`,
|
|
672
|
+
description: apiResponse.HubContentDescription || apiResponse.HubContentDisplayName || contentName,
|
|
673
|
+
hubName
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
// Extract framework from hub content document schema or search keywords
|
|
677
|
+
if (apiResponse.HubContentDocument) {
|
|
678
|
+
try {
|
|
679
|
+
const doc = typeof apiResponse.HubContentDocument === 'string'
|
|
680
|
+
? JSON.parse(apiResponse.HubContentDocument)
|
|
681
|
+
: apiResponse.HubContentDocument
|
|
682
|
+
if (doc.Framework) {
|
|
683
|
+
metadata.framework = doc.Framework
|
|
684
|
+
}
|
|
685
|
+
if (doc.ModelFormat) {
|
|
686
|
+
metadata.modelFormat = doc.ModelFormat
|
|
687
|
+
}
|
|
688
|
+
// artifactUri extraction (Requirement 1.2): extract from
|
|
689
|
+
// HubContentDocument — check both ArtifactUri and HostingArtifactUri
|
|
690
|
+
// as the field name varies by hub content document schema
|
|
691
|
+
if (doc.ArtifactUri) {
|
|
692
|
+
metadata.artifactUri = doc.ArtifactUri
|
|
693
|
+
} else if (doc.HostingArtifactUri) {
|
|
694
|
+
metadata.artifactUri = doc.HostingArtifactUri
|
|
695
|
+
}
|
|
696
|
+
} catch {
|
|
697
|
+
// Ignore JSON parse errors in hub content document
|
|
698
|
+
}
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
// Extract tags from search keywords
|
|
702
|
+
if (Array.isArray(apiResponse.HubContentSearchKeywords)) {
|
|
703
|
+
metadata.tags = apiResponse.HubContentSearchKeywords
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
return metadata
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
/**
|
|
710
|
+
* Check if an error is a credential-related error.
|
|
711
|
+
* @param {Error} err
|
|
712
|
+
* @returns {boolean}
|
|
713
|
+
*/
|
|
714
|
+
_isCredentialError(err) {
|
|
715
|
+
return CREDENTIAL_ERROR_NAMES.has(err.name) ||
|
|
716
|
+
CREDENTIAL_ERROR_NAMES.has(err.Code) ||
|
|
717
|
+
(err.message && err.message.includes('credentials'))
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
/**
|
|
721
|
+
* Handle errors from SageMaker API calls with distinct error messages.
|
|
722
|
+
*
|
|
723
|
+
* @param {Error} err - The caught error
|
|
724
|
+
* @param {string} hubName - The hub name from the URI
|
|
725
|
+
* @param {string|null} modelName - The model name, if provided
|
|
726
|
+
* @returns {null}
|
|
727
|
+
*/
|
|
728
|
+
_handleError(err, hubName, modelName) {
|
|
729
|
+
if (this._isCredentialError(err)) {
|
|
730
|
+
process.stderr.write(
|
|
731
|
+
`[jumpstart-hub] AWS credentials required for private hub access.\n`
|
|
732
|
+
)
|
|
733
|
+
return null
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
if (err.name === 'ResourceNotFoundException' || err.Code === 'ResourceNotFoundException') {
|
|
737
|
+
if (modelName) {
|
|
738
|
+
process.stderr.write(
|
|
739
|
+
`[jumpstart-hub] Model not found in hub: ${hubName}/${modelName}\n`
|
|
740
|
+
)
|
|
741
|
+
} else {
|
|
742
|
+
process.stderr.write(
|
|
743
|
+
`[jumpstart-hub] Hub not found: ${hubName}\n`
|
|
744
|
+
)
|
|
745
|
+
}
|
|
746
|
+
return null
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
if (err.name === 'AccessDeniedException' || err.Code === 'AccessDeniedException' ||
|
|
750
|
+
err.$metadata?.httpStatusCode === 403) {
|
|
751
|
+
process.stderr.write(
|
|
752
|
+
`[jumpstart-hub] Access denied to hub: ${hubName}\n`
|
|
753
|
+
)
|
|
754
|
+
return null
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
process.stderr.write(
|
|
758
|
+
`[jumpstart-hub] SageMaker API error: ${err.name || err.code || 'Unknown'}.\n`
|
|
759
|
+
)
|
|
760
|
+
return null
|
|
761
|
+
}
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
// ── ModelRegistryResolver ──────────────────────────────────────────────────────
|
|
766
|
+
|
|
767
|
+
/**
|
|
768
|
+
* ModelRegistryResolver — fetches model metadata from SageMaker Model Registry
|
|
769
|
+
* via the SageMaker API.
|
|
770
|
+
*
|
|
771
|
+
* Handles model IDs matching the `registry://` URI prefix. Parses the URI
|
|
772
|
+
* into group name and optional version, then queries:
|
|
773
|
+
* - ListModelPackages — list versions in a model package group (no version)
|
|
774
|
+
* - DescribeModelPackage — get detailed metadata for a specific version
|
|
775
|
+
*
|
|
776
|
+
* On credential failure or group not found, returns null and logs to stderr.
|
|
777
|
+
* AWS SDK is lazy-imported to keep the server fast in static mode.
|
|
778
|
+
*/
|
|
779
|
+
class ModelRegistryResolver extends ModelResolver {
|
|
780
|
+
constructor(options = {}) {
|
|
781
|
+
super()
|
|
782
|
+
this.timeout = options.timeout ?? 10000
|
|
783
|
+
this.region = options.region || process.env.AWS_REGION || 'us-east-1'
|
|
784
|
+
this._client = null
|
|
785
|
+
this._sdkModule = null
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
supportedPatterns() {
|
|
789
|
+
return ['registry://*']
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
/**
|
|
793
|
+
* Parse a registry:// URI into group name and optional version.
|
|
794
|
+
*
|
|
795
|
+
* @param {string} modelId - e.g. 'registry://my-model-group/3'
|
|
796
|
+
* @returns {{ groupName: string, version: string|null } | null}
|
|
797
|
+
*/
|
|
798
|
+
_parseRegistryUri(modelId) {
|
|
799
|
+
const withoutPrefix = modelId.replace(/^registry:\/\//, '')
|
|
800
|
+
if (!withoutPrefix) return null
|
|
801
|
+
|
|
802
|
+
const slashIndex = withoutPrefix.indexOf('/')
|
|
803
|
+
if (slashIndex === -1) {
|
|
804
|
+
// Only group name, no version — list mode
|
|
805
|
+
return { groupName: withoutPrefix, version: null }
|
|
806
|
+
}
|
|
807
|
+
|
|
808
|
+
const groupName = withoutPrefix.slice(0, slashIndex)
|
|
809
|
+
const version = withoutPrefix.slice(slashIndex + 1) || null
|
|
810
|
+
|
|
811
|
+
if (!groupName) return null
|
|
812
|
+
return { groupName, version }
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
/**
|
|
816
|
+
* Fetch metadata for a model in SageMaker Model Registry.
|
|
817
|
+
*
|
|
818
|
+
* @param {string} modelId - e.g. 'registry://my-model-group/3'
|
|
819
|
+
* @param {object} options - { fields, context }
|
|
820
|
+
* @returns {Promise<object|null>} ModelMetadata or null
|
|
821
|
+
*/
|
|
822
|
+
async fetchModelMetadata(modelId, options = {}) {
|
|
823
|
+
const parsed = this._parseRegistryUri(modelId)
|
|
824
|
+
if (!parsed) {
|
|
825
|
+
process.stderr.write(
|
|
826
|
+
`[registry] Invalid registry URI: ${modelId}\n`
|
|
827
|
+
)
|
|
828
|
+
return null
|
|
829
|
+
}
|
|
830
|
+
|
|
831
|
+
const { groupName, version } = parsed
|
|
832
|
+
|
|
833
|
+
try {
|
|
834
|
+
const sdk = await this._loadSdk()
|
|
835
|
+
const client = this._createClient(sdk)
|
|
836
|
+
|
|
837
|
+
// If a specific version is requested, describe that model package
|
|
838
|
+
if (version) {
|
|
839
|
+
const command = new sdk.DescribeModelPackageCommand({
|
|
840
|
+
ModelPackageName: `${groupName}/${version}`
|
|
841
|
+
})
|
|
842
|
+
const response = await client.send(command)
|
|
843
|
+
return this._mapToMetadata(response, groupName)
|
|
844
|
+
}
|
|
845
|
+
|
|
846
|
+
// Otherwise list model packages in the group
|
|
847
|
+
const command = new sdk.ListModelPackagesCommand({
|
|
848
|
+
ModelPackageGroupName: groupName
|
|
849
|
+
})
|
|
850
|
+
const response = await client.send(command)
|
|
851
|
+
if (response.ModelPackageSummaryList && response.ModelPackageSummaryList.length > 0) {
|
|
852
|
+
return this._mapToMetadata(response.ModelPackageSummaryList[0], groupName)
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
return null
|
|
856
|
+
} catch (err) {
|
|
857
|
+
return this._handleError(err, groupName)
|
|
858
|
+
}
|
|
859
|
+
}
|
|
860
|
+
|
|
861
|
+
/**
|
|
862
|
+
* Lazy-load the @aws-sdk/client-sagemaker module.
|
|
863
|
+
* @returns {Promise<object>} The SDK module
|
|
864
|
+
*/
|
|
865
|
+
async _loadSdk() {
|
|
866
|
+
if (!this._sdkModule) {
|
|
867
|
+
this._sdkModule = await import('@aws-sdk/client-sagemaker')
|
|
868
|
+
}
|
|
869
|
+
return this._sdkModule
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
/**
|
|
873
|
+
* Create a SageMakerClient with region and timeout configuration.
|
|
874
|
+
* Reuses the client across calls.
|
|
875
|
+
*
|
|
876
|
+
* @param {object} sdk - The loaded @aws-sdk/client-sagemaker module
|
|
877
|
+
* @returns {object} SageMakerClient instance
|
|
878
|
+
*/
|
|
879
|
+
_createClient(sdk) {
|
|
880
|
+
if (!this._client) {
|
|
881
|
+
this._client = new sdk.SageMakerClient({
|
|
882
|
+
region: this.region,
|
|
883
|
+
requestHandler: {
|
|
884
|
+
requestTimeout: this.timeout
|
|
885
|
+
}
|
|
886
|
+
})
|
|
887
|
+
}
|
|
888
|
+
return this._client
|
|
889
|
+
}
|
|
890
|
+
|
|
891
|
+
/**
|
|
892
|
+
* Map a Model Registry API response to the common ModelMetadata shape.
|
|
893
|
+
*
|
|
894
|
+
* @param {object} apiResponse - DescribeModelPackage or ModelPackageSummary from the API
|
|
895
|
+
* @param {string} groupName - The model package group name from the URI
|
|
896
|
+
* @returns {object} ModelMetadata
|
|
897
|
+
*/
|
|
898
|
+
_mapToMetadata(apiResponse, groupName) {
|
|
899
|
+
if (!apiResponse) return null
|
|
900
|
+
|
|
901
|
+
const metadata = {
|
|
902
|
+
provider: 'registry',
|
|
903
|
+
modelId: `registry://${groupName}`,
|
|
904
|
+
description: apiResponse.ModelPackageDescription || `Model package group: ${groupName}`
|
|
905
|
+
}
|
|
906
|
+
|
|
907
|
+
// Model package ARN
|
|
908
|
+
if (apiResponse.ModelPackageArn) {
|
|
909
|
+
metadata.modelPackageArn = apiResponse.ModelPackageArn
|
|
910
|
+
}
|
|
911
|
+
|
|
912
|
+
// Group name
|
|
913
|
+
metadata.modelPackageGroupName = apiResponse.ModelPackageGroupName || groupName
|
|
914
|
+
|
|
915
|
+
// Version
|
|
916
|
+
if (apiResponse.ModelPackageVersion !== undefined && apiResponse.ModelPackageVersion !== null) {
|
|
917
|
+
metadata.modelPackageVersion = apiResponse.ModelPackageVersion
|
|
918
|
+
metadata.modelId = `registry://${groupName}/${apiResponse.ModelPackageVersion}`
|
|
919
|
+
}
|
|
920
|
+
|
|
921
|
+
// Approval status
|
|
922
|
+
if (apiResponse.ModelApprovalStatus) {
|
|
923
|
+
metadata.approvalStatus = apiResponse.ModelApprovalStatus
|
|
924
|
+
}
|
|
925
|
+
|
|
926
|
+
// artifactUri extraction (Requirement 1.3): extract from
|
|
927
|
+
// InferenceSpecification.Containers[0].ModelDataUrl — the S3 URI
|
|
928
|
+
// where the registered model package stores its inference artifacts
|
|
929
|
+
const container = apiResponse.InferenceSpecification?.Containers?.[0]
|
|
930
|
+
if (container) {
|
|
931
|
+
if (container.Framework) {
|
|
932
|
+
metadata.framework = container.Framework
|
|
933
|
+
}
|
|
934
|
+
if (container.ModelDataUrl) {
|
|
935
|
+
metadata.artifactUri = container.ModelDataUrl
|
|
936
|
+
}
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
// Fallback: top-level ModelDataUrl when InferenceSpecification is absent
|
|
940
|
+
if (!metadata.artifactUri && apiResponse.ModelDataUrl) {
|
|
941
|
+
metadata.artifactUri = apiResponse.ModelDataUrl
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
return metadata
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
/**
|
|
948
|
+
* Check if an error is a credential-related error.
|
|
949
|
+
* @param {Error} err
|
|
950
|
+
* @returns {boolean}
|
|
951
|
+
*/
|
|
952
|
+
_isCredentialError(err) {
|
|
953
|
+
return CREDENTIAL_ERROR_NAMES.has(err.name) ||
|
|
954
|
+
CREDENTIAL_ERROR_NAMES.has(err.Code) ||
|
|
955
|
+
(err.message && err.message.includes('credentials'))
|
|
956
|
+
}
|
|
957
|
+
|
|
958
|
+
/**
|
|
959
|
+
* Handle errors from SageMaker API calls with distinct error messages.
|
|
960
|
+
*
|
|
961
|
+
* @param {Error} err - The caught error
|
|
962
|
+
* @param {string} groupName - The model package group name from the URI
|
|
963
|
+
* @returns {null}
|
|
964
|
+
*/
|
|
965
|
+
_handleError(err, groupName) {
|
|
966
|
+
if (this._isCredentialError(err)) {
|
|
967
|
+
process.stderr.write(
|
|
968
|
+
`[registry] AWS credentials required for Model Registry access.\n`
|
|
969
|
+
)
|
|
970
|
+
return null
|
|
971
|
+
}
|
|
972
|
+
|
|
973
|
+
if (err.name === 'ResourceNotFoundException' || err.Code === 'ResourceNotFoundException' ||
|
|
974
|
+
err.name === 'ValidationException') {
|
|
975
|
+
process.stderr.write(
|
|
976
|
+
`[registry] Model package group not found: ${groupName}\n`
|
|
977
|
+
)
|
|
978
|
+
return null
|
|
979
|
+
}
|
|
980
|
+
|
|
981
|
+
if (err.name === 'AccessDeniedException' || err.Code === 'AccessDeniedException' ||
|
|
982
|
+
err.$metadata?.httpStatusCode === 403) {
|
|
983
|
+
process.stderr.write(
|
|
984
|
+
`[registry] Access denied to model package group: ${groupName}\n`
|
|
985
|
+
)
|
|
986
|
+
return null
|
|
987
|
+
}
|
|
988
|
+
|
|
989
|
+
process.stderr.write(
|
|
990
|
+
`[registry] SageMaker API error: ${err.name || err.code || 'Unknown'}.\n`
|
|
991
|
+
)
|
|
992
|
+
return null
|
|
993
|
+
}
|
|
994
|
+
}
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
// ── S3Resolver ────────────────────────────────────────────────────────────────
|
|
998
|
+
|
|
999
|
+
/**
|
|
1000
|
+
* S3Resolver — validates S3 URIs and inspects model artifacts stored in Amazon S3.
|
|
1001
|
+
*
|
|
1002
|
+
* Handles model IDs matching the `s3://` URI prefix. Uses:
|
|
1003
|
+
* - HeadObject — check single-file artifacts (e.g. model.tar.gz)
|
|
1004
|
+
* - ListObjectsV2 — inspect directory-style artifacts
|
|
1005
|
+
*
|
|
1006
|
+
* Infers framework from config files (config.json, tokenizer_config.json,
|
|
1007
|
+
* serving.properties) when the artifact is a directory.
|
|
1008
|
+
*
|
|
1009
|
+
* On credential failure, bucket/key not found, or access denied, returns null
|
|
1010
|
+
* with a descriptive message logged to stderr. AWS SDK is lazy-imported.
|
|
1011
|
+
*/
|
|
1012
|
+
class S3Resolver extends ModelResolver {
|
|
1013
|
+
constructor(options = {}) {
|
|
1014
|
+
super()
|
|
1015
|
+
this.timeout = options.timeout ?? 10000
|
|
1016
|
+
this.region = options.region || process.env.AWS_REGION || 'us-east-1'
|
|
1017
|
+
this._client = null
|
|
1018
|
+
this._sdkModule = null
|
|
1019
|
+
}
|
|
1020
|
+
|
|
1021
|
+
supportedPatterns() {
|
|
1022
|
+
return ['s3://*']
|
|
1023
|
+
}
|
|
1024
|
+
|
|
1025
|
+
/**
|
|
1026
|
+
* Fetch metadata for a model artifact in S3.
|
|
1027
|
+
*
|
|
1028
|
+
* @param {string} modelId - e.g. 's3://my-bucket/path/to/model.tar.gz'
|
|
1029
|
+
* @param {object} options - { fields, context }
|
|
1030
|
+
* @returns {Promise<object|null>} ModelMetadata or null
|
|
1031
|
+
*/
|
|
1032
|
+
async fetchModelMetadata(modelId, options = {}) {
|
|
1033
|
+
const parsed = parseS3Uri(modelId)
|
|
1034
|
+
if (parsed.error) {
|
|
1035
|
+
process.stderr.write(
|
|
1036
|
+
`[s3] Invalid S3 URI: ${parsed.error}\n`
|
|
1037
|
+
)
|
|
1038
|
+
return null
|
|
1039
|
+
}
|
|
1040
|
+
|
|
1041
|
+
const { bucket, key } = parsed
|
|
1042
|
+
|
|
1043
|
+
try {
|
|
1044
|
+
const sdk = await this._loadSdk()
|
|
1045
|
+
const client = this._createClient(sdk)
|
|
1046
|
+
|
|
1047
|
+
// Try HeadObject first to check if it's a single file
|
|
1048
|
+
if (key && !key.endsWith('/')) {
|
|
1049
|
+
try {
|
|
1050
|
+
const headCommand = new sdk.HeadObjectCommand({
|
|
1051
|
+
Bucket: bucket,
|
|
1052
|
+
Key: key
|
|
1053
|
+
})
|
|
1054
|
+
const headResponse = await client.send(headCommand)
|
|
1055
|
+
|
|
1056
|
+
const artifactType = key.endsWith('.tar.gz') || key.endsWith('.tgz')
|
|
1057
|
+
? 'tarball' : 'single-file'
|
|
1058
|
+
|
|
1059
|
+
const metadata = {
|
|
1060
|
+
provider: 's3',
|
|
1061
|
+
modelId,
|
|
1062
|
+
description: `S3 model artifact: ${modelId}`,
|
|
1063
|
+
// artifactUri extraction (Requirement 1.4): for S3 models,
|
|
1064
|
+
// artifactUri is the original s3:// URI itself — the model
|
|
1065
|
+
// is already in S3, so no additional resolution is needed
|
|
1066
|
+
artifactUri: modelId,
|
|
1067
|
+
artifactType,
|
|
1068
|
+
artifactSizeBytes: headResponse.ContentLength ?? null,
|
|
1069
|
+
lastModified: headResponse.LastModified
|
|
1070
|
+
? headResponse.LastModified.toISOString() : null
|
|
1071
|
+
}
|
|
1072
|
+
|
|
1073
|
+
return metadata
|
|
1074
|
+
} catch (headErr) {
|
|
1075
|
+
// If it's a 404, the key might be a directory prefix — fall through to ListObjectsV2
|
|
1076
|
+
if (headErr.name !== 'NotFound' && headErr.$metadata?.httpStatusCode !== 404) {
|
|
1077
|
+
throw headErr
|
|
1078
|
+
}
|
|
1079
|
+
}
|
|
1080
|
+
}
|
|
1081
|
+
|
|
1082
|
+
// List objects under the key prefix (directory-style artifact)
|
|
1083
|
+
const prefix = key ? (key.endsWith('/') ? key : key + '/') : ''
|
|
1084
|
+
const listCommand = new sdk.ListObjectsV2Command({
|
|
1085
|
+
Bucket: bucket,
|
|
1086
|
+
Prefix: prefix,
|
|
1087
|
+
MaxKeys: 1000
|
|
1088
|
+
})
|
|
1089
|
+
const listResponse = await client.send(listCommand)
|
|
1090
|
+
|
|
1091
|
+
if (!listResponse.Contents || listResponse.Contents.length === 0) {
|
|
1092
|
+
process.stderr.write(
|
|
1093
|
+
`[s3] Key not found: ${bucket}/${key}\n`
|
|
1094
|
+
)
|
|
1095
|
+
return null
|
|
1096
|
+
}
|
|
1097
|
+
|
|
1098
|
+
// Calculate total size and find last modified
|
|
1099
|
+
let totalSize = 0
|
|
1100
|
+
let latestModified = null
|
|
1101
|
+
const fileNames = []
|
|
1102
|
+
|
|
1103
|
+
for (const obj of listResponse.Contents) {
|
|
1104
|
+
totalSize += obj.Size ?? 0
|
|
1105
|
+
if (obj.LastModified && (!latestModified || obj.LastModified > latestModified)) {
|
|
1106
|
+
latestModified = obj.LastModified
|
|
1107
|
+
}
|
|
1108
|
+
// Extract relative file name from the key
|
|
1109
|
+
const relativeName = prefix ? obj.Key.slice(prefix.length) : obj.Key
|
|
1110
|
+
if (relativeName) {
|
|
1111
|
+
fileNames.push(relativeName)
|
|
1112
|
+
}
|
|
1113
|
+
}
|
|
1114
|
+
|
|
1115
|
+
// Try to infer framework from config files
|
|
1116
|
+
const configFiles = {}
|
|
1117
|
+
const configFileNames = ['config.json', 'tokenizer_config.json', 'serving.properties']
|
|
1118
|
+
|
|
1119
|
+
for (const cfgName of configFileNames) {
|
|
1120
|
+
if (fileNames.includes(cfgName)) {
|
|
1121
|
+
try {
|
|
1122
|
+
const getCommand = new sdk.GetObjectCommand({
|
|
1123
|
+
Bucket: bucket,
|
|
1124
|
+
Key: prefix + cfgName
|
|
1125
|
+
})
|
|
1126
|
+
const getResponse = await client.send(getCommand)
|
|
1127
|
+
const body = await getResponse.Body.transformToString()
|
|
1128
|
+
configFiles[cfgName] = body
|
|
1129
|
+
} catch {
|
|
1130
|
+
// Ignore errors reading individual config files
|
|
1131
|
+
}
|
|
1132
|
+
}
|
|
1133
|
+
}
|
|
1134
|
+
|
|
1135
|
+
const framework = this._inferFramework(configFiles)
|
|
1136
|
+
|
|
1137
|
+
const metadata = {
|
|
1138
|
+
provider: 's3',
|
|
1139
|
+
modelId,
|
|
1140
|
+
description: `S3 model directory: ${modelId}`,
|
|
1141
|
+
// artifactUri extraction (Requirement 1.4): for S3 models,
|
|
1142
|
+
// artifactUri is the original s3:// URI itself — the model
|
|
1143
|
+
// is already in S3, so no additional resolution is needed
|
|
1144
|
+
artifactUri: modelId,
|
|
1145
|
+
artifactType: 'directory',
|
|
1146
|
+
artifactSizeBytes: totalSize,
|
|
1147
|
+
lastModified: latestModified ? latestModified.toISOString() : null
|
|
1148
|
+
}
|
|
1149
|
+
|
|
1150
|
+
if (framework) {
|
|
1151
|
+
metadata.framework = framework
|
|
1152
|
+
}
|
|
1153
|
+
|
|
1154
|
+
return metadata
|
|
1155
|
+
} catch (err) {
|
|
1156
|
+
return this._handleError(err, bucket, key, modelId)
|
|
1157
|
+
}
|
|
1158
|
+
}
|
|
1159
|
+
|
|
1160
|
+
/**
|
|
1161
|
+
* Infer the ML framework from config file contents.
|
|
1162
|
+
*
|
|
1163
|
+
* @param {object} configFiles - Map of filename → file content string
|
|
1164
|
+
* @returns {string|null} Inferred framework name or null
|
|
1165
|
+
*/
|
|
1166
|
+
_inferFramework(configFiles) {
|
|
1167
|
+
// Check config.json for HuggingFace transformer architectures
|
|
1168
|
+
if (configFiles['config.json']) {
|
|
1169
|
+
try {
|
|
1170
|
+
const config = JSON.parse(configFiles['config.json'])
|
|
1171
|
+
if (config.architectures && Array.isArray(config.architectures) && config.architectures.length > 0) {
|
|
1172
|
+
return 'huggingface'
|
|
1173
|
+
}
|
|
1174
|
+
if (config.model_type) {
|
|
1175
|
+
return 'huggingface'
|
|
1176
|
+
}
|
|
1177
|
+
} catch {
|
|
1178
|
+
// Invalid JSON — skip
|
|
1179
|
+
}
|
|
1180
|
+
}
|
|
1181
|
+
|
|
1182
|
+
// Check tokenizer_config.json — presence implies HuggingFace
|
|
1183
|
+
if (configFiles['tokenizer_config.json']) {
|
|
1184
|
+
try {
|
|
1185
|
+
JSON.parse(configFiles['tokenizer_config.json'])
|
|
1186
|
+
return 'huggingface'
|
|
1187
|
+
} catch {
|
|
1188
|
+
// Invalid JSON — skip
|
|
1189
|
+
}
|
|
1190
|
+
}
|
|
1191
|
+
|
|
1192
|
+
// Check serving.properties for DJL serving configuration
|
|
1193
|
+
if (configFiles['serving.properties']) {
|
|
1194
|
+
const content = configFiles['serving.properties']
|
|
1195
|
+
if (content.includes('model_id') || content.includes('option.model_id')) {
|
|
1196
|
+
return 'djl'
|
|
1197
|
+
}
|
|
1198
|
+
}
|
|
1199
|
+
|
|
1200
|
+
return null
|
|
1201
|
+
}
|
|
1202
|
+
|
|
1203
|
+
/**
|
|
1204
|
+
* Lazy-load the @aws-sdk/client-s3 module.
|
|
1205
|
+
* @returns {Promise<object>} The SDK module
|
|
1206
|
+
*/
|
|
1207
|
+
async _loadSdk() {
|
|
1208
|
+
if (!this._sdkModule) {
|
|
1209
|
+
this._sdkModule = await import('@aws-sdk/client-s3')
|
|
1210
|
+
}
|
|
1211
|
+
return this._sdkModule
|
|
1212
|
+
}
|
|
1213
|
+
|
|
1214
|
+
/**
|
|
1215
|
+
* Create an S3Client with region and timeout configuration.
|
|
1216
|
+
* Reuses the client across calls.
|
|
1217
|
+
*
|
|
1218
|
+
* @param {object} sdk - The loaded @aws-sdk/client-s3 module
|
|
1219
|
+
* @returns {object} S3Client instance
|
|
1220
|
+
*/
|
|
1221
|
+
_createClient(sdk) {
|
|
1222
|
+
if (!this._client) {
|
|
1223
|
+
this._client = new sdk.S3Client({
|
|
1224
|
+
region: this.region,
|
|
1225
|
+
requestHandler: {
|
|
1226
|
+
requestTimeout: this.timeout
|
|
1227
|
+
}
|
|
1228
|
+
})
|
|
1229
|
+
}
|
|
1230
|
+
return this._client
|
|
1231
|
+
}
|
|
1232
|
+
|
|
1233
|
+
/**
|
|
1234
|
+
* Check if an error is a credential-related error.
|
|
1235
|
+
* @param {Error} err
|
|
1236
|
+
* @returns {boolean}
|
|
1237
|
+
*/
|
|
1238
|
+
_isCredentialError(err) {
|
|
1239
|
+
return CREDENTIAL_ERROR_NAMES.has(err.name) ||
|
|
1240
|
+
CREDENTIAL_ERROR_NAMES.has(err.Code) ||
|
|
1241
|
+
(err.message && err.message.includes('credentials'))
|
|
1242
|
+
}
|
|
1243
|
+
|
|
1244
|
+
/**
|
|
1245
|
+
* Handle errors from S3 API calls with distinct error messages.
|
|
1246
|
+
*
|
|
1247
|
+
* @param {Error} err - The caught error
|
|
1248
|
+
* @param {string} bucket - The bucket name
|
|
1249
|
+
* @param {string} key - The object key
|
|
1250
|
+
* @param {string} uri - The original S3 URI
|
|
1251
|
+
* @returns {null}
|
|
1252
|
+
*/
|
|
1253
|
+
_handleError(err, bucket, key, uri) {
|
|
1254
|
+
if (this._isCredentialError(err)) {
|
|
1255
|
+
process.stderr.write(
|
|
1256
|
+
`[s3] AWS credentials required for S3 access.\n`
|
|
1257
|
+
)
|
|
1258
|
+
return null
|
|
1259
|
+
}
|
|
1260
|
+
|
|
1261
|
+
if (err.name === 'NoSuchBucket' || err.Code === 'NoSuchBucket') {
|
|
1262
|
+
process.stderr.write(
|
|
1263
|
+
`[s3] Bucket not found: ${bucket}\n`
|
|
1264
|
+
)
|
|
1265
|
+
return null
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
if (err.name === 'NoSuchKey' || err.Code === 'NoSuchKey' ||
|
|
1269
|
+
err.name === 'NotFound' || err.$metadata?.httpStatusCode === 404) {
|
|
1270
|
+
process.stderr.write(
|
|
1271
|
+
`[s3] Key not found: ${bucket}/${key}\n`
|
|
1272
|
+
)
|
|
1273
|
+
return null
|
|
1274
|
+
}
|
|
1275
|
+
|
|
1276
|
+
if (err.name === 'AccessDenied' || err.Code === 'AccessDenied' ||
|
|
1277
|
+
err.$metadata?.httpStatusCode === 403) {
|
|
1278
|
+
process.stderr.write(
|
|
1279
|
+
`[s3] Access denied: ${uri}\n`
|
|
1280
|
+
)
|
|
1281
|
+
return null
|
|
1282
|
+
}
|
|
1283
|
+
|
|
1284
|
+
process.stderr.write(
|
|
1285
|
+
`[s3] S3 API error: ${err.name || err.code || 'Unknown'}.\n`
|
|
1286
|
+
)
|
|
1287
|
+
return null
|
|
1288
|
+
}
|
|
1289
|
+
}
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
// ── ResolverRegistry ─────────────────────────────────────────────────────────
|
|
1293
|
+
|
|
1294
|
+
/**
|
|
1295
|
+
* ResolverRegistry — maps model ID patterns to their responsible ModelResolver.
|
|
1296
|
+
*
|
|
1297
|
+
* Each resolver is registered with a match function that determines whether
|
|
1298
|
+
* it can handle a given model ID. The first matching resolver wins.
|
|
1299
|
+
* A default resolver is used as fallback when no match function returns true.
|
|
1300
|
+
*/
|
|
1301
|
+
class ResolverRegistry {
|
|
1302
|
+
constructor() {
|
|
1303
|
+
this._resolvers = []
|
|
1304
|
+
this._defaultResolver = null
|
|
1305
|
+
}
|
|
1306
|
+
|
|
1307
|
+
/**
|
|
1308
|
+
* Register a resolver with its match function.
|
|
1309
|
+
* @param {ModelResolver} resolver
|
|
1310
|
+
* @param {function(string): boolean} matchFn
|
|
1311
|
+
*/
|
|
1312
|
+
register(resolver, matchFn) {
|
|
1313
|
+
this._resolvers.push({ resolver, matchFn })
|
|
1314
|
+
}
|
|
1315
|
+
|
|
1316
|
+
/**
|
|
1317
|
+
* Set the fallback resolver used when no match function returns true.
|
|
1318
|
+
* @param {ModelResolver} resolver
|
|
1319
|
+
*/
|
|
1320
|
+
setDefault(resolver) {
|
|
1321
|
+
this._defaultResolver = resolver
|
|
1322
|
+
}
|
|
1323
|
+
|
|
1324
|
+
/**
|
|
1325
|
+
* Get the resolver for a given model ID.
|
|
1326
|
+
* @param {string} modelId
|
|
1327
|
+
* @returns {ModelResolver|null}
|
|
1328
|
+
*/
|
|
1329
|
+
getResolver(modelId) {
|
|
1330
|
+
for (const { resolver, matchFn } of this._resolvers) {
|
|
1331
|
+
if (matchFn(modelId)) return resolver
|
|
1332
|
+
}
|
|
1333
|
+
return this._defaultResolver
|
|
1334
|
+
}
|
|
1335
|
+
}
|
|
1336
|
+
|
|
1337
|
+
// ── Merge logic ──────────────────────────────────────────────────────────────
|
|
1338
|
+
|
|
1339
|
+
/**
|
|
1340
|
+
* Merge live API metadata with static catalog metadata.
|
|
1341
|
+
* Live data takes precedence for non-null fields.
|
|
1342
|
+
*
|
|
1343
|
+
* @param {object|null} liveData - Metadata from live API (e.g. HuggingFace)
|
|
1344
|
+
* @param {object|null} staticData - Metadata from static catalog
|
|
1345
|
+
* @returns {object|null} Merged metadata, or null if both inputs are null
|
|
1346
|
+
*/
|
|
1347
|
+
function mergeMetadata(liveData, staticData) {
|
|
1348
|
+
if (!liveData && !staticData) return null
|
|
1349
|
+
if (!liveData) return { ...staticData }
|
|
1350
|
+
if (!staticData) return { ...liveData }
|
|
1351
|
+
|
|
1352
|
+
// Shallow merge: live takes precedence for non-null fields
|
|
1353
|
+
const merged = { ...staticData }
|
|
1354
|
+
for (const [key, value] of Object.entries(liveData)) {
|
|
1355
|
+
if (value !== null && value !== undefined) {
|
|
1356
|
+
merged[key] = value
|
|
1357
|
+
}
|
|
1358
|
+
}
|
|
1359
|
+
return merged
|
|
1360
|
+
}
|
|
1361
|
+
|
|
1362
|
+
// ── S3 URI parsing ───────────────────────────────────────────────────────────
|
|
1363
|
+
|
|
1364
|
+
/**
|
|
1365
|
+
* Regex for valid S3 bucket names: 3–63 chars, lowercase letters/numbers/hyphens/periods,
|
|
1366
|
+
* no consecutive periods, not an IP address format.
|
|
1367
|
+
*/
|
|
1368
|
+
const S3_BUCKET_REGEX = /^(?!(\d{1,3}\.){3}\d{1,3}$)[a-z0-9]([a-z0-9.\-]*[a-z0-9])?$/
|
|
1369
|
+
|
|
1370
|
+
/**
|
|
1371
|
+
* Parse and validate an S3 URI into bucket and key components.
|
|
1372
|
+
* Never throws — returns { bucket, key } on success or { error } on failure.
|
|
1373
|
+
*
|
|
1374
|
+
* Validation rules:
|
|
1375
|
+
* - Must start with 's3://'
|
|
1376
|
+
* - Bucket: 3–63 chars, lowercase letters/numbers/hyphens/periods,
|
|
1377
|
+
* no consecutive periods, no IP address format
|
|
1378
|
+
* - Key: ≤ 1024 characters
|
|
1379
|
+
*
|
|
1380
|
+
* @param {string} uri - e.g. 's3://my-bucket/path/to/model.tar.gz'
|
|
1381
|
+
* @returns {{ bucket: string, key: string } | { error: string }}
|
|
1382
|
+
*/
|
|
1383
|
+
function parseS3Uri(uri) {
|
|
1384
|
+
if (typeof uri !== 'string') {
|
|
1385
|
+
return { error: 'S3 URI must be a string' }
|
|
1386
|
+
}
|
|
1387
|
+
|
|
1388
|
+
if (!uri.startsWith('s3://')) {
|
|
1389
|
+
return { error: 'S3 URI must start with s3://' }
|
|
1390
|
+
}
|
|
1391
|
+
|
|
1392
|
+
const withoutPrefix = uri.slice(5) // strip 's3://'
|
|
1393
|
+
const slashIndex = withoutPrefix.indexOf('/')
|
|
1394
|
+
const bucket = slashIndex === -1 ? withoutPrefix : withoutPrefix.slice(0, slashIndex)
|
|
1395
|
+
const key = slashIndex === -1 ? '' : withoutPrefix.slice(slashIndex + 1)
|
|
1396
|
+
|
|
1397
|
+
// Validate bucket name
|
|
1398
|
+
if (bucket.length === 0) {
|
|
1399
|
+
return { error: 'Bucket name must not be empty' }
|
|
1400
|
+
}
|
|
1401
|
+
if (bucket.length < 3 || bucket.length > 63) {
|
|
1402
|
+
return { error: `Bucket name must be 3–63 characters, got ${bucket.length}` }
|
|
1403
|
+
}
|
|
1404
|
+
if (bucket.includes('..')) {
|
|
1405
|
+
return { error: 'Bucket name must not contain consecutive periods' }
|
|
1406
|
+
}
|
|
1407
|
+
if (!S3_BUCKET_REGEX.test(bucket)) {
|
|
1408
|
+
return { error: `Invalid bucket name: ${bucket}` }
|
|
1409
|
+
}
|
|
1410
|
+
|
|
1411
|
+
// Validate key length
|
|
1412
|
+
if (key.length > 1024) {
|
|
1413
|
+
return { error: `Key must be ≤ 1024 characters, got ${key.length}` }
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1416
|
+
return { bucket, key }
|
|
1417
|
+
}
|
|
1418
|
+
|
|
1419
|
+
/**
|
|
1420
|
+
* Reconstruct an S3 URI from bucket and key components.
|
|
1421
|
+
*
|
|
1422
|
+
* @param {string} bucket
|
|
1423
|
+
* @param {string} key
|
|
1424
|
+
* @returns {string} 's3://<bucket>/<key>'
|
|
1425
|
+
*/
|
|
1426
|
+
function buildS3Uri(bucket, key) {
|
|
1427
|
+
return `s3://${bucket}/${key}`
|
|
1428
|
+
}
|
|
1429
|
+
|
|
1430
|
+
// ── Load catalogs ────────────────────────────────────────────────────────────
|
|
1431
|
+
|
|
1432
|
+
let POPULAR_MODELS_CATALOG
|
|
1433
|
+
|
|
1434
|
+
try {
|
|
1435
|
+
POPULAR_MODELS_CATALOG = {
|
|
1436
|
+
...loadCatalog('./catalogs/popular-transformers.json'),
|
|
1437
|
+
...loadCatalog('./catalogs/popular-diffusors.json'),
|
|
1438
|
+
...loadCatalog('./catalogs/jumpstart-public.json')
|
|
1439
|
+
}
|
|
1440
|
+
} catch (err) {
|
|
1441
|
+
process.stderr.write(`[model-picker] Fatal: ${err.message}\n`)
|
|
1442
|
+
process.exit(1)
|
|
1443
|
+
}
|
|
1444
|
+
|
|
1445
|
+
// ── Wiring ───────────────────────────────────────────────────────────────────
|
|
1446
|
+
|
|
1447
|
+
const staticResolver = new StaticCatalogResolver(POPULAR_MODELS_CATALOG)
|
|
1448
|
+
const hfResolver = new HuggingFaceResolver()
|
|
1449
|
+
const jumpStartPublicResolver = new JumpStartPublicResolver()
|
|
1450
|
+
const jumpStartPrivateResolver = new JumpStartPrivateResolver()
|
|
1451
|
+
const modelRegistryResolver = new ModelRegistryResolver()
|
|
1452
|
+
const s3Resolver = new S3Resolver()
|
|
1453
|
+
const registry = new ResolverRegistry()
|
|
1454
|
+
|
|
1455
|
+
registry.register(
|
|
1456
|
+
jumpStartPublicResolver,
|
|
1457
|
+
id => id.startsWith('jumpstart://')
|
|
1458
|
+
)
|
|
1459
|
+
registry.register(
|
|
1460
|
+
jumpStartPrivateResolver,
|
|
1461
|
+
id => id.startsWith('jumpstart-hub://')
|
|
1462
|
+
)
|
|
1463
|
+
registry.register(
|
|
1464
|
+
modelRegistryResolver,
|
|
1465
|
+
id => id.startsWith('registry://')
|
|
1466
|
+
)
|
|
1467
|
+
registry.register(
|
|
1468
|
+
s3Resolver,
|
|
1469
|
+
id => id.startsWith('s3://')
|
|
1470
|
+
)
|
|
1471
|
+
registry.register(
|
|
1472
|
+
hfResolver,
|
|
1473
|
+
id => /^[^/]+\/[^/]+$/.test(id) && !id.includes('://')
|
|
1474
|
+
)
|
|
1475
|
+
registry.setDefault(staticResolver)
|
|
1476
|
+
|
|
1477
|
+
// ── Choice formatting helpers ─────────────────────────────────────────────────
|
|
1478
|
+
|
|
1479
|
+
/**
|
|
1480
|
+
* Provider prefix label mapping for model choice formatting.
|
|
1481
|
+
*/
|
|
1482
|
+
const PROVIDER_LABELS = {
|
|
1483
|
+
'jumpstart': '[JumpStart]',
|
|
1484
|
+
'jumpstart-hub': '[JumpStart Hub]',
|
|
1485
|
+
'registry': '[Registry]',
|
|
1486
|
+
's3': '[S3]',
|
|
1487
|
+
'huggingface': '[HuggingFace]'
|
|
1488
|
+
}
|
|
1489
|
+
|
|
1490
|
+
/**
|
|
1491
|
+
* Format a model choice with a provider prefix label.
|
|
1492
|
+
*
|
|
1493
|
+
* @param {object} metadata - Model metadata object with `provider` and `modelId` fields
|
|
1494
|
+
* @returns {string} Formatted choice string, e.g. '[JumpStart] huggingface-llm-falcon-7b'
|
|
1495
|
+
*/
|
|
1496
|
+
function formatModelChoice(metadata) {
|
|
1497
|
+
if (!metadata || !metadata.modelId) return ''
|
|
1498
|
+
const label = PROVIDER_LABELS[metadata.provider]
|
|
1499
|
+
if (label) {
|
|
1500
|
+
return `${label} ${metadata.modelId}`
|
|
1501
|
+
}
|
|
1502
|
+
return metadata.modelId
|
|
1503
|
+
}
|
|
1504
|
+
|
|
1505
|
+
/**
|
|
1506
|
+
* Filter an array of model metadata objects by provider.
|
|
1507
|
+
*
|
|
1508
|
+
* @param {object[]} models - Array of model metadata objects
|
|
1509
|
+
* @param {string} provider - Provider string to filter by
|
|
1510
|
+
* @returns {object[]} Filtered array containing only models whose `provider` matches
|
|
1511
|
+
*/
|
|
1512
|
+
function filterByProvider(models, provider) {
|
|
1513
|
+
if (!Array.isArray(models) || !provider) return models || []
|
|
1514
|
+
return models.filter(m => m && m.provider === provider)
|
|
1515
|
+
}
|
|
1516
|
+
|
|
1517
|
+
// ── Tool handler ─────────────────────────────────────────────────────────────
|
|
1518
|
+
|
|
1519
|
+
/**
|
|
1520
|
+
* Handle a get_models tool call.
|
|
1521
|
+
* Extracted as a standalone function so tests can call it directly.
|
|
1522
|
+
*
|
|
1523
|
+
* @param {object} params
|
|
1524
|
+
* @param {string} params.model_id - Model identifier
|
|
1525
|
+
* @param {string[]} [params.fields] - Metadata fields to return
|
|
1526
|
+
* @param {string} [params.mode] - 'static' or 'discover'
|
|
1527
|
+
* @param {object} [params.context] - Configuration context
|
|
1528
|
+
* @returns {Promise<{content: Array}>} MCP response
|
|
1529
|
+
*/
|
|
1530
|
+
async function resolveModel({ model_id, fields, mode = 'discover', context }) {
|
|
1531
|
+
let values = {}
|
|
1532
|
+
let message = null
|
|
1533
|
+
|
|
1534
|
+
if (mode === 'static') {
|
|
1535
|
+
// Static mode: use StaticCatalogResolver only
|
|
1536
|
+
// For jumpstart:// prefixed IDs, resolve from JumpStart static catalog
|
|
1537
|
+
const metadata = await staticResolver.fetchModelMetadata(model_id, { fields })
|
|
1538
|
+
if (metadata) {
|
|
1539
|
+
values = { ...metadata }
|
|
1540
|
+
} else {
|
|
1541
|
+
if (model_id.startsWith('jumpstart://')) {
|
|
1542
|
+
message = `Model not found in JumpStart static catalog: ${model_id}`
|
|
1543
|
+
} else {
|
|
1544
|
+
message = `Model not found in static catalog: ${model_id}`
|
|
1545
|
+
}
|
|
1546
|
+
}
|
|
1547
|
+
} else {
|
|
1548
|
+
// Discover mode: use ResolverRegistry for live data, merge with static
|
|
1549
|
+
const resolver = registry.getResolver(model_id)
|
|
1550
|
+
let liveData = null
|
|
1551
|
+
let resolverFailed = false
|
|
1552
|
+
|
|
1553
|
+
if (resolver) {
|
|
1554
|
+
liveData = await resolver.fetchModelMetadata(model_id, { fields })
|
|
1555
|
+
if (liveData === null) {
|
|
1556
|
+
resolverFailed = true
|
|
1557
|
+
}
|
|
1558
|
+
}
|
|
1559
|
+
|
|
1560
|
+
const staticData = await staticResolver.fetchModelMetadata(model_id, { fields })
|
|
1561
|
+
const merged = mergeMetadata(liveData, staticData)
|
|
1562
|
+
|
|
1563
|
+
if (merged) {
|
|
1564
|
+
values = { ...merged }
|
|
1565
|
+
// If the resolver failed but we got data from static catalog, note the fallback
|
|
1566
|
+
if (resolverFailed && !liveData && staticData) {
|
|
1567
|
+
if (model_id.startsWith('jumpstart://')) {
|
|
1568
|
+
message = '[jumpstart] SageMaker API unreachable. Using static catalog fallback.'
|
|
1569
|
+
} else if (model_id.startsWith('jumpstart-hub://')) {
|
|
1570
|
+
message = '[jumpstart-hub] SageMaker API unreachable. Using static catalog fallback.'
|
|
1571
|
+
} else if (model_id.startsWith('registry://')) {
|
|
1572
|
+
message = '[registry] SageMaker API unreachable. Using static catalog fallback.'
|
|
1573
|
+
} else if (model_id.startsWith('s3://')) {
|
|
1574
|
+
message = '[s3] S3 API unreachable. Using static catalog fallback.'
|
|
1575
|
+
}
|
|
1576
|
+
}
|
|
1577
|
+
} else {
|
|
1578
|
+
// No data from either source
|
|
1579
|
+
if (resolverFailed) {
|
|
1580
|
+
if (model_id.startsWith('jumpstart://')) {
|
|
1581
|
+
message = `[jumpstart] Resolver could not fetch data for: ${model_id}`
|
|
1582
|
+
} else if (model_id.startsWith('jumpstart-hub://')) {
|
|
1583
|
+
message = `[jumpstart-hub] Resolver could not fetch data for: ${model_id}`
|
|
1584
|
+
} else if (model_id.startsWith('registry://')) {
|
|
1585
|
+
message = `[registry] Resolver could not fetch data for: ${model_id}`
|
|
1586
|
+
} else if (model_id.startsWith('s3://')) {
|
|
1587
|
+
message = `[s3] Resolver could not fetch data for: ${model_id}`
|
|
1588
|
+
} else {
|
|
1589
|
+
message = `Model not found: ${model_id}`
|
|
1590
|
+
}
|
|
1591
|
+
} else {
|
|
1592
|
+
message = `Model not found: ${model_id}`
|
|
1593
|
+
}
|
|
1594
|
+
}
|
|
1595
|
+
}
|
|
1596
|
+
|
|
1597
|
+
// Apply provider filter from context
|
|
1598
|
+
if (context && context.provider && Object.keys(values).length > 0) {
|
|
1599
|
+
if (values.provider && values.provider !== context.provider) {
|
|
1600
|
+
message = `Model ${model_id} is from provider '${values.provider}', not '${context.provider}'`
|
|
1601
|
+
values = {}
|
|
1602
|
+
}
|
|
1603
|
+
}
|
|
1604
|
+
|
|
1605
|
+
// Filter fields if specified
|
|
1606
|
+
if (fields && fields.length > 0 && Object.keys(values).length > 0) {
|
|
1607
|
+
const filtered = {}
|
|
1608
|
+
for (const field of fields) {
|
|
1609
|
+
if (field in values) {
|
|
1610
|
+
filtered[field] = values[field]
|
|
1611
|
+
}
|
|
1612
|
+
}
|
|
1613
|
+
values = filtered
|
|
1614
|
+
}
|
|
1615
|
+
|
|
1616
|
+
// Build choices with provider prefix labels
|
|
1617
|
+
const choices = {}
|
|
1618
|
+
if (Object.keys(values).length > 0) {
|
|
1619
|
+
const choiceLabel = formatModelChoice(values)
|
|
1620
|
+
if (choiceLabel) {
|
|
1621
|
+
choices[choiceLabel] = values.modelId || model_id
|
|
1622
|
+
}
|
|
1623
|
+
}
|
|
1624
|
+
|
|
1625
|
+
return {
|
|
1626
|
+
content: [{
|
|
1627
|
+
type: 'text',
|
|
1628
|
+
text: JSON.stringify({ values, choices, message })
|
|
1629
|
+
}]
|
|
1630
|
+
}
|
|
1631
|
+
}
|
|
1632
|
+
|
|
1633
|
+
// ── MCP Server ───────────────────────────────────────────────────────────────
|
|
1634
|
+
|
|
1635
|
+
const server = new McpServer({
|
|
1636
|
+
name: 'model-picker',
|
|
1637
|
+
version: '1.0.0'
|
|
1638
|
+
})
|
|
1639
|
+
|
|
1640
|
+
server.tool(
|
|
1641
|
+
'get_models',
|
|
1642
|
+
'Returns model metadata for ML Container Creator',
|
|
1643
|
+
{
|
|
1644
|
+
model_id: z.string().min(1).describe('Model identifier'),
|
|
1645
|
+
fields: z.array(z.string()).optional().describe(
|
|
1646
|
+
'Metadata fields to return (omit for all)'
|
|
1647
|
+
),
|
|
1648
|
+
mode: z.enum(['static', 'discover']).optional().default('discover')
|
|
1649
|
+
.describe('Operating mode'),
|
|
1650
|
+
context: z.record(z.string(), z.any()).optional().describe(
|
|
1651
|
+
'Current configuration context'
|
|
1652
|
+
)
|
|
1653
|
+
},
|
|
1654
|
+
async (params) => resolveModel(params)
|
|
1655
|
+
)
|
|
1656
|
+
|
|
1657
|
+
// ── Exports for testing ──────────────────────────────────────────────────────
|
|
1658
|
+
|
|
1659
|
+
export {
|
|
1660
|
+
loadCatalog,
|
|
1661
|
+
ModelResolver,
|
|
1662
|
+
StaticCatalogResolver,
|
|
1663
|
+
HuggingFaceResolver,
|
|
1664
|
+
JumpStartPublicResolver,
|
|
1665
|
+
JumpStartPrivateResolver,
|
|
1666
|
+
ModelRegistryResolver,
|
|
1667
|
+
S3Resolver,
|
|
1668
|
+
ResolverRegistry,
|
|
1669
|
+
mergeMetadata,
|
|
1670
|
+
parseS3Uri,
|
|
1671
|
+
buildS3Uri,
|
|
1672
|
+
formatModelChoice,
|
|
1673
|
+
filterByProvider,
|
|
1674
|
+
resolveModel,
|
|
1675
|
+
staticResolver,
|
|
1676
|
+
hfResolver,
|
|
1677
|
+
jumpStartPublicResolver,
|
|
1678
|
+
jumpStartPrivateResolver,
|
|
1679
|
+
modelRegistryResolver,
|
|
1680
|
+
s3Resolver,
|
|
1681
|
+
registry,
|
|
1682
|
+
POPULAR_MODELS_CATALOG
|
|
1683
|
+
}
|
|
1684
|
+
|
|
1685
|
+
// ── Main guard ───────────────────────────────────────────────────────────────
|
|
1686
|
+
|
|
1687
|
+
const isMain = process.argv[1] && resolve(process.argv[1]) === __filename
|
|
1688
|
+
|
|
1689
|
+
if (isMain) {
|
|
1690
|
+
process.stderr.write('[model-picker] Starting model-picker MCP server\n')
|
|
1691
|
+
const transport = new StdioServerTransport()
|
|
1692
|
+
await server.connect(transport)
|
|
1693
|
+
}
|