@aws/ml-container-creator 0.15.0 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/config/parameter-schema-v2.json +1 -1
- package/package.json +1 -1
- package/servers/endpoint-picker/index.js +23 -14
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +2 -2
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-client.js +15 -1
- package/src/lib/mcp-query-runner.js +5 -1
- package/src/lib/prompt-runner.js +35 -20
- package/templates/do/.benchmark_writer.py +34 -0
- package/templates/do/.register_helper.py +63 -41
- package/templates/do/benchmark +14 -9
- package/templates/do/register +8 -3
- package/templates/do/tune +19 -1
package/package.json
CHANGED
|
@@ -78,6 +78,7 @@ function getGpusForInstance(instanceType) {
|
|
|
78
78
|
let _SageMakerClient = null;
|
|
79
79
|
let _ListEndpointsCommand = null;
|
|
80
80
|
let _DescribeEndpointCommand = null;
|
|
81
|
+
let _DescribeEndpointConfigCommand = null;
|
|
81
82
|
let _ListInferenceComponentsCommand = null;
|
|
82
83
|
let _fromIni = null;
|
|
83
84
|
|
|
@@ -90,6 +91,7 @@ async function _ensureSdkLoaded() {
|
|
|
90
91
|
_SageMakerClient = sdk.SageMakerClient;
|
|
91
92
|
_ListEndpointsCommand = sdk.ListEndpointsCommand;
|
|
92
93
|
_DescribeEndpointCommand = sdk.DescribeEndpointCommand;
|
|
94
|
+
_DescribeEndpointConfigCommand = sdk.DescribeEndpointConfigCommand;
|
|
93
95
|
_ListInferenceComponentsCommand = sdk.ListInferenceComponentsCommand;
|
|
94
96
|
try {
|
|
95
97
|
const credentialProviders = await import('@aws-sdk/credential-providers');
|
|
@@ -197,9 +199,24 @@ async function fetchEndpoints(client, { limit = 10, showFull = false } = {}) {
|
|
|
197
199
|
const primaryVariant = variants[0] || {};
|
|
198
200
|
|
|
199
201
|
const variantName = primaryVariant.VariantName || 'AllTraffic';
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
202
|
+
let instanceType = primaryVariant.InstanceType || null;
|
|
203
|
+
|
|
204
|
+
// For IC-based endpoints, InstanceType may not be in the variant runtime response.
|
|
205
|
+
// Fall back to DescribeEndpointConfig which always has it.
|
|
206
|
+
if (!instanceType && detail.EndpointConfigName) {
|
|
207
|
+
try {
|
|
208
|
+
const ecCmd = new _DescribeEndpointConfigCommand({ EndpointConfigName: detail.EndpointConfigName });
|
|
209
|
+
const ecDetail = await client.send(ecCmd);
|
|
210
|
+
const ecVariant = (ecDetail.ProductionVariants || [])[0];
|
|
211
|
+
if (ecVariant?.InstanceType) {
|
|
212
|
+
instanceType = ecVariant.InstanceType;
|
|
213
|
+
}
|
|
214
|
+
} catch (ecErr) {
|
|
215
|
+
log(`Warning: could not describe endpoint config for "${endpointName}": ${ecErr.message}`);
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
instanceType = instanceType || 'unknown';
|
|
219
|
+
|
|
203
220
|
const instanceCount = primaryVariant.CurrentInstanceCount ?? primaryVariant.DesiredInstanceCount ?? 1;
|
|
204
221
|
const hasInstancePools = !!(primaryVariant.InstancePools && primaryVariant.InstancePools.length > 0);
|
|
205
222
|
|
|
@@ -387,17 +404,9 @@ server.tool(
|
|
|
387
404
|
limit: z.number().int().positive().default(10).describe('Maximum number of endpoints to return'),
|
|
388
405
|
context: z.record(z.string(), z.any()).optional().describe('Current configuration context (awsRegion, awsProfile, deploymentTarget)')
|
|
389
406
|
},
|
|
390
|
-
async ({ parameters, limit, context }) => {
|
|
391
|
-
// Only respond if
|
|
392
|
-
|
|
393
|
-
return {
|
|
394
|
-
content: [{
|
|
395
|
-
type: 'text',
|
|
396
|
-
text: JSON.stringify({ values: {}, choices: {} })
|
|
397
|
-
}]
|
|
398
|
-
};
|
|
399
|
-
}
|
|
400
|
-
|
|
407
|
+
async ({ parameters: _parameters, limit, context }) => {
|
|
408
|
+
// Only respond if context.deploymentTarget is realtime-inference
|
|
409
|
+
// Note: parameters may be empty when called on-demand via queryMcpServer()
|
|
401
410
|
if (context?.deploymentTarget && context.deploymentTarget !== 'realtime-inference') {
|
|
402
411
|
return {
|
|
403
412
|
content: [{
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
// AUTO-GENERATED by scripts/codegen-parameter-matrix.js — DO NOT EDIT
|
|
2
2
|
// Source: config/parameter-schema-v2.json
|
|
3
|
-
// Generated: 2026-06-
|
|
3
|
+
// Generated: 2026-06-23T20:55:23.482Z
|
|
4
4
|
|
|
5
5
|
/**
|
|
6
6
|
* Parameter matrix defining how each parameter is loaded from various sources.
|
|
@@ -225,7 +225,7 @@ export const parameterMatrix = {
|
|
|
225
225
|
'configFile': true,
|
|
226
226
|
'packageJson': false,
|
|
227
227
|
'mcp': true,
|
|
228
|
-
'promptable':
|
|
228
|
+
'promptable': true,
|
|
229
229
|
'required': false,
|
|
230
230
|
'default': null,
|
|
231
231
|
'valueSpace': 'unbounded'
|
package/src/lib/mcp-client.js
CHANGED
|
@@ -143,9 +143,23 @@ class McpClient {
|
|
|
143
143
|
// Build context from bounded parameters that have defaults
|
|
144
144
|
const context = this._buildContext();
|
|
145
145
|
|
|
146
|
+
// Auto-discover tool name if using the default (get_ml_config)
|
|
147
|
+
// Each server registers its own tool name (e.g. get_base_images, get_inference_endpoints)
|
|
148
|
+
let toolName = this.toolName;
|
|
149
|
+
if (toolName === DEFAULT_TOOL_NAME) {
|
|
150
|
+
try {
|
|
151
|
+
const toolList = await this._client.listTools();
|
|
152
|
+
if (toolList && toolList.tools && toolList.tools.length > 0) {
|
|
153
|
+
toolName = toolList.tools[0].name;
|
|
154
|
+
}
|
|
155
|
+
} catch (_listErr) {
|
|
156
|
+
// Fall through to use default tool name
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
|
|
146
160
|
// Call the configured tool
|
|
147
161
|
const result = await this._client.callTool({
|
|
148
|
-
name:
|
|
162
|
+
name: toolName,
|
|
149
163
|
arguments: {
|
|
150
164
|
parameters: unboundedParams,
|
|
151
165
|
limit: this.limit,
|
|
@@ -371,9 +371,13 @@ export default class McpQueryRunner {
|
|
|
371
371
|
console.log(' 🔍 Querying endpoint-picker...');
|
|
372
372
|
|
|
373
373
|
try {
|
|
374
|
+
// Pass awsProfile from bootstrap config for credential resolution
|
|
375
|
+
const awsProfile = this.runner.configManager?.config?.awsProfile
|
|
376
|
+
|| this.runner.options?.profile || process.env.AWS_PROFILE || null;
|
|
374
377
|
const result = await cm.queryMcpServer('endpoint-picker', {
|
|
375
378
|
awsRegion: infraAnswers.awsRegion,
|
|
376
|
-
deploymentTarget: 'realtime-inference'
|
|
379
|
+
deploymentTarget: 'realtime-inference',
|
|
380
|
+
...(awsProfile ? { awsProfile } : {})
|
|
377
381
|
});
|
|
378
382
|
|
|
379
383
|
if (result && result.choices?.endpointName?.length > 0) {
|
package/src/lib/prompt-runner.js
CHANGED
|
@@ -224,25 +224,39 @@ export default class PromptRunner {
|
|
|
224
224
|
// Requirements: 3.3, 4.3, 4.4 — endpoint-picker MCP query
|
|
225
225
|
let existingEndpointAnswers = {};
|
|
226
226
|
if (regionAndTargetAnswers.deploymentTarget === 'realtime-inference') {
|
|
227
|
-
//
|
|
228
|
-
const
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
const endpointPreviousAnswers = {
|
|
232
|
-
...regionAndTargetAnswers,
|
|
233
|
-
...(this._mcpEndpointChoices ? { _mcpEndpointChoices: this._mcpEndpointChoices } : {})
|
|
234
|
-
};
|
|
235
|
-
existingEndpointAnswers = await this._runPhase(
|
|
236
|
-
infraExistingEndpointPrompts,
|
|
237
|
-
endpointPreviousAnswers,
|
|
227
|
+
// First ask if user wants to attach to existing endpoint (no MCP call yet)
|
|
228
|
+
const attachAnswer = await this._runPhase(
|
|
229
|
+
[infraExistingEndpointPrompts[0]],
|
|
230
|
+
{ ...regionAndTargetAnswers },
|
|
238
231
|
explicitConfig,
|
|
239
232
|
existingConfig
|
|
240
233
|
);
|
|
241
234
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
235
|
+
if (attachAnswer.useExistingEndpoint === 'yes') {
|
|
236
|
+
// Only now query endpoint-picker MCP server
|
|
237
|
+
const resolvedRegion = regionAndTargetAnswers.customAwsRegion || regionAndTargetAnswers.awsRegion;
|
|
238
|
+
await this.mcpQueryRunner._queryMcpForEndpoints({ ...regionAndTargetAnswers, awsRegion: resolvedRegion }, explicitConfig);
|
|
239
|
+
|
|
240
|
+
const endpointPreviousAnswers = {
|
|
241
|
+
...regionAndTargetAnswers,
|
|
242
|
+
...attachAnswer,
|
|
243
|
+
...(this._mcpEndpointChoices ? { _mcpEndpointChoices: this._mcpEndpointChoices } : {})
|
|
244
|
+
};
|
|
245
|
+
existingEndpointAnswers = await this._runPhase(
|
|
246
|
+
infraExistingEndpointPrompts.slice(1),
|
|
247
|
+
endpointPreviousAnswers,
|
|
248
|
+
explicitConfig,
|
|
249
|
+
existingConfig
|
|
250
|
+
);
|
|
251
|
+
existingEndpointAnswers.useExistingEndpoint = 'yes';
|
|
252
|
+
|
|
253
|
+
// Resolve custom endpoint name
|
|
254
|
+
if (existingEndpointAnswers.customExistingEndpointName) {
|
|
255
|
+
existingEndpointAnswers.existingEndpointName = existingEndpointAnswers.customExistingEndpointName;
|
|
256
|
+
delete existingEndpointAnswers.customExistingEndpointName;
|
|
257
|
+
}
|
|
258
|
+
} else {
|
|
259
|
+
existingEndpointAnswers = attachAnswer;
|
|
246
260
|
}
|
|
247
261
|
}
|
|
248
262
|
|
|
@@ -376,11 +390,12 @@ export default class PromptRunner {
|
|
|
376
390
|
const sizerRecs = this._instanceSizerMetadata.recommendations || [];
|
|
377
391
|
const finalInstanceType = instanceAnswers.customInstanceType || instanceAnswers.instanceType;
|
|
378
392
|
const matchingRec = sizerRecs.find(r => r.instanceType === finalInstanceType);
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
this.
|
|
383
|
-
|
|
393
|
+
// Only use sizer TP recommendation if user selected a recommended instance
|
|
394
|
+
// Custom instances resolve TP from the instance catalog in template-variable-resolver
|
|
395
|
+
if (matchingRec && matchingRec.tensorParallelism > 1) {
|
|
396
|
+
this._autoTensorParallelism = matchingRec.tensorParallelism;
|
|
397
|
+
this._autoGpuCount = matchingRec.gpuCount;
|
|
398
|
+
console.log(` ✓ Auto-set tensor parallelism: TP=${matchingRec.tensorParallelism} (${matchingRec.gpuCount} GPUs)`);
|
|
384
399
|
}
|
|
385
400
|
|
|
386
401
|
// Display capacity type confirmation for selected instance
|
|
@@ -1401,8 +1401,18 @@ def _load_config_file(config_path):
|
|
|
1401
1401
|
'BASE_IMAGE_VERSION': 'base_image_version',
|
|
1402
1402
|
'BENCHMARK_CONCURRENCY': 'benchmark_concurrency',
|
|
1403
1403
|
}
|
|
1404
|
+
# Also capture IC_ENV_* serving config vars
|
|
1405
|
+
ic_env_map = {
|
|
1406
|
+
'IC_ENV_VLLM_MAX_MODEL_LEN': 'max_model_len',
|
|
1407
|
+
'IC_ENV_VLLM_QUANTIZATION': 'quantization',
|
|
1408
|
+
'IC_ENV_VLLM_GPU_MEMORY_UTILIZATION': 'gpu_memory_utilization',
|
|
1409
|
+
'IC_ENV_VLLM_KV_CACHE_DTYPE': 'kv_cache_dtype',
|
|
1410
|
+
'IC_ENV_VLLM_TENSOR_PARALLEL_SIZE': 'tensor_parallel_degree',
|
|
1411
|
+
}
|
|
1404
1412
|
if key in shell_map:
|
|
1405
1413
|
context[shell_map[key]] = value
|
|
1414
|
+
elif key in ic_env_map:
|
|
1415
|
+
context[ic_env_map[key]] = value
|
|
1406
1416
|
|
|
1407
1417
|
except Exception:
|
|
1408
1418
|
pass
|
|
@@ -1419,6 +1429,30 @@ def _load_config_file(config_path):
|
|
|
1419
1429
|
parts = context['model_name'].rstrip('/').split('/')
|
|
1420
1430
|
context['model_name'] = parts[-1] if parts else context['model_name']
|
|
1421
1431
|
|
|
1432
|
+
# Also scan IC config files (do/ic/*.conf) for IC_ENV_* serving params
|
|
1433
|
+
# These override do/config values for serving-specific settings
|
|
1434
|
+
try:
|
|
1435
|
+
import glob
|
|
1436
|
+
config_dir = os.path.dirname(os.path.abspath(config_path))
|
|
1437
|
+
ic_dir = os.path.join(config_dir, 'ic')
|
|
1438
|
+
ic_env_map = {
|
|
1439
|
+
'IC_ENV_VLLM_MAX_MODEL_LEN': 'max_model_len',
|
|
1440
|
+
'IC_ENV_VLLM_QUANTIZATION': 'quantization',
|
|
1441
|
+
'IC_ENV_VLLM_GPU_MEMORY_UTILIZATION': 'gpu_memory_utilization',
|
|
1442
|
+
'IC_ENV_VLLM_KV_CACHE_DTYPE': 'kv_cache_dtype',
|
|
1443
|
+
'IC_ENV_VLLM_TENSOR_PARALLEL_SIZE': 'tensor_parallel_degree',
|
|
1444
|
+
}
|
|
1445
|
+
for conf_file in sorted(glob.glob(os.path.join(ic_dir, '*.conf'))):
|
|
1446
|
+
with open(conf_file, 'r') as f:
|
|
1447
|
+
for line in f:
|
|
1448
|
+
match = re.match(r'^export\s+([A-Z_][A-Z0-9_]*)=["\']?([^"\']*)["\']?\s*$', line.strip())
|
|
1449
|
+
if match:
|
|
1450
|
+
key, value = match.group(1), match.group(2)
|
|
1451
|
+
if key in ic_env_map and value:
|
|
1452
|
+
context[ic_env_map[key]] = value
|
|
1453
|
+
except Exception:
|
|
1454
|
+
pass # IC config scanning is best-effort
|
|
1455
|
+
|
|
1422
1456
|
return context
|
|
1423
1457
|
|
|
1424
1458
|
|
|
@@ -89,6 +89,8 @@ def _truncate_metadata(props):
|
|
|
89
89
|
result = {}
|
|
90
90
|
for key, value in props.items():
|
|
91
91
|
str_val = str(value) if value is not None else ""
|
|
92
|
+
if not str_val:
|
|
93
|
+
continue # SageMaker requires min length 1 for metadata values — skip empty
|
|
92
94
|
if len(str_val) > MAX_METADATA_VALUE_LEN:
|
|
93
95
|
_warn(f"Metadata '{key}' truncated ({len(str_val)} → {MAX_METADATA_VALUE_LEN} chars)")
|
|
94
96
|
str_val = str_val[: MAX_METADATA_VALUE_LEN - 1] + "…"
|
|
@@ -264,33 +266,41 @@ def cmd_register_model(args):
|
|
|
264
266
|
container_image = args.container_image or ""
|
|
265
267
|
model_data_url = args.model_data_url or ""
|
|
266
268
|
|
|
267
|
-
inference_spec = {
|
|
268
|
-
"Containers": [
|
|
269
|
-
{
|
|
270
|
-
"Image": container_image,
|
|
271
|
-
}
|
|
272
|
-
],
|
|
273
|
-
"SupportedContentTypes": ["application/json"],
|
|
274
|
-
"SupportedResponseMIMETypes": ["application/json"],
|
|
275
|
-
}
|
|
276
|
-
# Only include ModelDataUrl if provided
|
|
277
|
-
if model_data_url:
|
|
278
|
-
inference_spec["Containers"][0]["ModelDataUrl"] = model_data_url
|
|
279
|
-
|
|
280
269
|
# Step 4: Create Model Package version (AC-1.2, AC-1.7)
|
|
281
270
|
description = f"{args.deployment_config or 'model'} on {args.instance_type or 'unknown'}"
|
|
282
271
|
|
|
283
272
|
print(f"Registering model version in {project_name}...", file=sys.stderr)
|
|
284
273
|
try:
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
274
|
+
# Use boto3 directly — sagemaker-core v2.14 has a KeyError bug in ModelPackage.create()
|
|
275
|
+
# where it tries to read response["ModelPackageName"] but the API returns "ModelPackageArn".
|
|
276
|
+
import boto3
|
|
277
|
+
sm_client = boto3.client("sagemaker", region_name=region)
|
|
278
|
+
|
|
279
|
+
create_params = {
|
|
280
|
+
"ModelPackageGroupName": project_name,
|
|
281
|
+
"ModelPackageDescription": description,
|
|
282
|
+
"ModelApprovalStatus": "Approved",
|
|
283
|
+
}
|
|
284
|
+
if container_image:
|
|
285
|
+
create_params["InferenceSpecification"] = {
|
|
286
|
+
"Containers": [{"Image": container_image}],
|
|
287
|
+
"SupportedContentTypes": ["application/json"],
|
|
288
|
+
"SupportedResponseMIMETypes": ["application/json"],
|
|
289
|
+
}
|
|
290
|
+
if model_data_url:
|
|
291
|
+
create_params["InferenceSpecification"]["Containers"][0]["ModelDataUrl"] = model_data_url
|
|
292
|
+
if model_data_url:
|
|
293
|
+
if "InferenceSpecification" not in create_params:
|
|
294
|
+
# Store model data URL in metadata if no container image
|
|
295
|
+
if not metadata:
|
|
296
|
+
metadata = {}
|
|
297
|
+
metadata["modelDataUrl"] = model_data_url[:1024]
|
|
298
|
+
if metadata:
|
|
299
|
+
create_params["CustomerMetadataProperties"] = metadata
|
|
300
|
+
|
|
301
|
+
response = sm_client.create_model_package(**create_params)
|
|
302
|
+
model_package_arn = response["ModelPackageArn"]
|
|
292
303
|
|
|
293
|
-
model_package_arn = pkg.model_package_arn
|
|
294
304
|
# Extract version number from ARN (format: .../project-name/version)
|
|
295
305
|
version = _extract_version_from_arn(model_package_arn)
|
|
296
306
|
|
|
@@ -407,33 +417,39 @@ def cmd_register_adapter(args):
|
|
|
407
417
|
container_image = args.container_image or ""
|
|
408
418
|
model_data_url = args.model_data_url or ""
|
|
409
419
|
|
|
410
|
-
inference_spec = {
|
|
411
|
-
"Containers": [
|
|
412
|
-
{
|
|
413
|
-
"Image": container_image,
|
|
414
|
-
}
|
|
415
|
-
],
|
|
416
|
-
"SupportedContentTypes": ["application/json"],
|
|
417
|
-
"SupportedResponseMIMETypes": ["application/json"],
|
|
418
|
-
}
|
|
419
|
-
if model_data_url:
|
|
420
|
-
inference_spec["Containers"][0]["ModelDataUrl"] = model_data_url
|
|
421
|
-
|
|
422
420
|
# Step 4: Create adapter Model Package version (AC-2.1)
|
|
423
421
|
technique = args.tune_technique or "unknown"
|
|
424
422
|
description = f"adapter ({technique}) on {args.instance_type or 'unknown'}, parent: {parent_version_arn}"
|
|
425
423
|
|
|
426
424
|
print(f"Registering adapter version in {project_name}...", file=sys.stderr)
|
|
427
425
|
try:
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
426
|
+
# Use boto3 directly — sagemaker-core v2.14 has a KeyError bug in ModelPackage.create()
|
|
427
|
+
import boto3
|
|
428
|
+
sm_client = boto3.client("sagemaker", region_name=region)
|
|
429
|
+
|
|
430
|
+
create_params = {
|
|
431
|
+
"ModelPackageGroupName": project_name,
|
|
432
|
+
"ModelPackageDescription": description,
|
|
433
|
+
"ModelApprovalStatus": "Approved",
|
|
434
|
+
}
|
|
435
|
+
if container_image:
|
|
436
|
+
create_params["InferenceSpecification"] = {
|
|
437
|
+
"Containers": [{"Image": container_image}],
|
|
438
|
+
"SupportedContentTypes": ["application/json"],
|
|
439
|
+
"SupportedResponseMIMETypes": ["application/json"],
|
|
440
|
+
}
|
|
441
|
+
if model_data_url:
|
|
442
|
+
create_params["InferenceSpecification"]["Containers"][0]["ModelDataUrl"] = model_data_url
|
|
443
|
+
elif model_data_url:
|
|
444
|
+
if not metadata:
|
|
445
|
+
metadata = {}
|
|
446
|
+
metadata["modelDataUrl"] = model_data_url[:1024]
|
|
447
|
+
if metadata:
|
|
448
|
+
create_params["CustomerMetadataProperties"] = metadata
|
|
449
|
+
|
|
450
|
+
response = sm_client.create_model_package(**create_params)
|
|
451
|
+
model_package_arn = response["ModelPackageArn"]
|
|
435
452
|
|
|
436
|
-
model_package_arn = pkg.model_package_arn
|
|
437
453
|
version = _extract_version_from_arn(model_package_arn)
|
|
438
454
|
|
|
439
455
|
print(f"Registered adapter version {version}: {model_package_arn}", file=sys.stderr)
|
|
@@ -1133,6 +1149,12 @@ def main():
|
|
|
1133
1149
|
parser.print_help()
|
|
1134
1150
|
sys.exit(1)
|
|
1135
1151
|
|
|
1152
|
+
# Set region before any sagemaker-core import (creates boto3 clients at import time)
|
|
1153
|
+
region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
|
|
1154
|
+
if region:
|
|
1155
|
+
os.environ['AWS_DEFAULT_REGION'] = region
|
|
1156
|
+
os.environ.setdefault('AWS_REGION', region)
|
|
1157
|
+
|
|
1136
1158
|
if args.command == "create-mpg":
|
|
1137
1159
|
cmd_create_mpg(args)
|
|
1138
1160
|
elif args.command == "register-model":
|
package/templates/do/benchmark
CHANGED
|
@@ -117,10 +117,12 @@ if [ "${ARG_STATUS}" = true ]; then
|
|
|
117
117
|
tar_file=""
|
|
118
118
|
tar_file=$(find "${LOCAL_RESULTS_DIR}" -name "output.tar.gz" -type f 2>/dev/null | head -1)
|
|
119
119
|
if [ -n "${tar_file}" ]; then
|
|
120
|
-
# Detect whether
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
120
|
+
# Detect whether ALL entries share a common leading directory prefix
|
|
121
|
+
_tar_prefix_count=""
|
|
122
|
+
_tar_prefix_count=$(tar -tzf "${tar_file}" 2>/dev/null | sed 's|/.*||' | sort -u | wc -l | tr -d ' ')
|
|
123
|
+
_tar_first_dir=""
|
|
124
|
+
_tar_first_dir=$(tar -tzf "${tar_file}" 2>/dev/null | head -1)
|
|
125
|
+
if [ "${_tar_prefix_count}" = "1" ] && echo "${_tar_first_dir}" | grep -qE '^[^/]+/$'; then
|
|
124
126
|
tar -xzf "${tar_file}" --strip-components=1 -C "${LOCAL_RESULTS_DIR}/output/" 2>/dev/null || true
|
|
125
127
|
else
|
|
126
128
|
tar -xzf "${tar_file}" -C "${LOCAL_RESULTS_DIR}/output/" 2>/dev/null || true
|
|
@@ -1097,11 +1099,14 @@ if [ "${JOB_STATUS}" = "Completed" ]; then
|
|
|
1097
1099
|
# Extract any tar.gz archives (benchmark service packages results as output.tar.gz)
|
|
1098
1100
|
for ARCHIVE in $(find "${LOCAL_RESULTS_DIR}" -name "*.tar.gz" -type f 2>/dev/null); do
|
|
1099
1101
|
ARCHIVE_DIR=$(dirname "${ARCHIVE}")
|
|
1100
|
-
# Detect whether
|
|
1101
|
-
#
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1102
|
+
# Detect whether ALL entries share a common leading directory prefix.
|
|
1103
|
+
# Only strip if every entry starts with the same dir (e.g., "output/file1", "output/file2").
|
|
1104
|
+
# A flat archive with mixed top-level files/dirs (e.g., "plots/", "profile_export.jsonl")
|
|
1105
|
+
# must NOT be stripped.
|
|
1106
|
+
_TAR_PREFIX=$(tar -tzf "${ARCHIVE}" 2>/dev/null | sed 's|/.*||' | sort -u | wc -l | tr -d ' ')
|
|
1107
|
+
_TAR_FIRST_DIR=$(tar -tzf "${ARCHIVE}" 2>/dev/null | head -1)
|
|
1108
|
+
if [ "${_TAR_PREFIX}" = "1" ] && echo "${_TAR_FIRST_DIR}" | grep -qE '^[^/]+/$'; then
|
|
1109
|
+
# Single common leading directory (e.g., all under "output/") — strip it
|
|
1105
1110
|
tar -xzf "${ARCHIVE}" --strip-components=1 -C "${ARCHIVE_DIR}" 2>/dev/null || true
|
|
1106
1111
|
else
|
|
1107
1112
|
# Flat archive — extract as-is
|
package/templates/do/register
CHANGED
|
@@ -1137,9 +1137,14 @@ ml-container-creator "${CMD_ARGS[@]}"
|
|
|
1137
1137
|
# ============================================================
|
|
1138
1138
|
|
|
1139
1139
|
# Container image URI for the deployed model
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1140
|
+
# Build full ECR URI from profile (accountId + region + repoName + tag)
|
|
1141
|
+
_ACCOUNT_ID="${_PROFILE_accountId:-}"
|
|
1142
|
+
_REGION="${AWS_DEFAULT_REGION:-${_PROFILE_awsRegion:-us-east-1}}"
|
|
1143
|
+
if [ -n "${_ACCOUNT_ID}" ] && [ -n "${ECR_REPOSITORY_NAME}" ]; then
|
|
1144
|
+
CONTAINER_IMAGE_URI="${_ACCOUNT_ID}.dkr.ecr.${_REGION}.amazonaws.com/${ECR_REPOSITORY_NAME}:${PROJECT_NAME}-latest"
|
|
1145
|
+
else
|
|
1146
|
+
# No ECR info available — MPG will be registered without InferenceSpecification
|
|
1147
|
+
CONTAINER_IMAGE_URI=""
|
|
1143
1148
|
fi
|
|
1144
1149
|
|
|
1145
1150
|
# Model data S3 URI (from do/config if set)
|
package/templates/do/tune
CHANGED
|
@@ -829,7 +829,25 @@ _validate_dataset() {
|
|
|
829
829
|
if [ -z "${dataset}" ]; then
|
|
830
830
|
if [ -n "${ARG_DATASET_NAME}" ]; then
|
|
831
831
|
# Name-based resolution happens below via resolve-dataset
|
|
832
|
-
|
|
832
|
+
echo "🔍 Resolving dataset '${ARG_DATASET_NAME}' from registry..."
|
|
833
|
+
local resolve_result
|
|
834
|
+
resolve_result=$(python3 "${SCRIPT_DIR}/.register_helper.py" resolve-dataset \
|
|
835
|
+
--name "${ARG_DATASET_NAME}" 2>/dev/null) || resolve_result=""
|
|
836
|
+
|
|
837
|
+
if [ -n "${resolve_result}" ]; then
|
|
838
|
+
local resolved_uri
|
|
839
|
+
resolved_uri=$(echo "${resolve_result}" | grep -E '^\{' | tail -1 | python3 -c "import sys,json; print(json.load(sys.stdin).get('s3_uri',''))" 2>/dev/null) || resolved_uri=""
|
|
840
|
+
if [ -n "${resolved_uri}" ]; then
|
|
841
|
+
echo " Resolved to: ${resolved_uri}"
|
|
842
|
+
dataset="${resolved_uri}"
|
|
843
|
+
RESOLVED_DATASET_S3_URI="${resolved_uri}"
|
|
844
|
+
return 0
|
|
845
|
+
fi
|
|
846
|
+
fi
|
|
847
|
+
echo "❌ Dataset '${ARG_DATASET_NAME}' not found in registry"
|
|
848
|
+
echo " Run ./do/tune --list-datasets to see available datasets."
|
|
849
|
+
echo " Register: ./do/register dataset <name> --s3-uri <uri> --technique <sft|dpo>"
|
|
850
|
+
exit 1
|
|
833
851
|
else
|
|
834
852
|
echo "❌ --dataset is required"
|
|
835
853
|
echo " Provide an S3 URI (s3://bucket/path.jsonl), HF reference (hf://org/name), or registered name"
|