@aws/ml-container-creator 1.0.3 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/README.md +10 -1
  2. package/bin/cli.js +57 -0
  3. package/config/agent.json +16 -0
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +5 -2
  6. package/pyproject.toml +3 -0
  7. package/servers/agent-knowledge/index.js +592 -0
  8. package/servers/agent-knowledge/package.json +15 -0
  9. package/servers/base-image-picker/index.js +65 -18
  10. package/servers/instance-sizer/index.js +32 -0
  11. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  12. package/servers/lib/catalogs/model-arch-support.json +51 -0
  13. package/servers/lib/catalogs/model-servers.json +2842 -1730
  14. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  15. package/src/agent/__init__.py +2 -0
  16. package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
  17. package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
  18. package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
  19. package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
  20. package/src/agent/agent.py +513 -0
  21. package/src/agent/config_loader.py +215 -0
  22. package/src/agent/context.py +380 -0
  23. package/src/agent/data/capability-matrix.json +106 -0
  24. package/src/agent/health_check.py +341 -0
  25. package/src/agent/prompts/system.md +173 -0
  26. package/src/agent/requirements-agent.txt +3 -0
  27. package/src/app.js +6 -4
  28. package/src/lib/generated/cli-options.js +1 -1
  29. package/src/lib/generated/parameter-matrix.js +1 -1
  30. package/src/lib/generated/validation-rules.js +1 -1
  31. package/src/lib/mcp-query-runner.js +110 -3
  32. package/src/lib/prompt-runner.js +66 -22
  33. package/src/lib/template-variable-resolver.js +8 -0
  34. package/src/lib/train-config-builder.js +339 -0
  35. package/src/lib/tune-config-state.js +89 -68
  36. package/templates/do/.benchmark_writer.py +3 -0
  37. package/templates/do/.eval_helper.py +409 -0
  38. package/templates/do/.register_helper.py +185 -11
  39. package/templates/do/.train_build_request.py +102 -113
  40. package/templates/do/.train_helper.py +433 -0
  41. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  42. package/templates/do/adapter +157 -0
  43. package/templates/do/benchmark +60 -3
  44. package/templates/do/config +6 -1
  45. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  46. package/templates/do/evaluate +272 -0
  47. package/templates/do/lib/resolve-instance.sh +155 -0
  48. package/templates/do/register +5 -0
  49. package/templates/do/test +1 -0
  50. package/templates/do/train +879 -126
  51. package/templates/do/training/config.yaml +83 -11
  52. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  53. package/templates/do/training/dpo/defaults.yaml +26 -0
  54. package/templates/do/training/dpo/prompts.json +8 -0
  55. package/templates/do/training/dpo/train.py +363 -0
  56. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  57. package/templates/do/training/sft/defaults.yaml +18 -0
  58. package/templates/do/training/sft/prompts.json +7 -0
  59. package/templates/do/training/sft/train.py +310 -0
  60. package/templates/do/tune +11 -2
  61. package/src/lib/auto-prompt-builder.js +0 -172
  62. package/src/lib/cli-handler.js +0 -529
  63. package/src/lib/community-reports-validator.js +0 -91
  64. package/src/lib/configuration-exporter.js +0 -204
  65. package/src/lib/dataset-slug.js +0 -152
  66. package/src/lib/docker-introspection-validator.js +0 -51
  67. package/src/lib/known-flags-validator.js +0 -200
  68. package/src/lib/schema-validator.js +0 -157
  69. package/src/lib/train-config-parser.js +0 -136
  70. package/src/lib/train-config-persistence.js +0 -143
  71. package/src/lib/train-config-validator.js +0 -112
  72. package/src/lib/train-feedback.js +0 -46
  73. package/src/lib/train-idempotency.js +0 -97
  74. package/src/lib/train-request-builder.js +0 -120
  75. package/src/lib/tune-dataset-validator.js +0 -279
  76. package/src/lib/tune-output-resolver.js +0 -66
  77. package/templates/do/.train_poll_parser.py +0 -135
  78. package/templates/do/.train_status_parser.py +0 -187
  79. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -0,0 +1,433 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """SageMaker Training Job helper (SDK v3).
6
+
7
+ Subcommands:
8
+ submit - Create a training job via TrainingJob.create()
9
+ status - Get job status via TrainingJob.get()
10
+ resolve - Extract artifact path from completed job
11
+ stop - Stop a running training job
12
+
13
+ All output is JSON on stdout for bash consumption.
14
+ Pattern: grep -E '^\\{' | tail -1 to extract JSON from mixed output.
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import sys
21
+ import warnings
22
+
23
+ # Suppress noisy dependency warnings
24
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
25
+ warnings.filterwarnings("ignore", message=".*urllib3.*")
26
+
27
+ # Suppress ALL logging to prevent sagemaker-core/rich from writing to stdout
28
+ import logging as _logging
29
+ _logging.disable(_logging.CRITICAL)
30
+ os.environ.setdefault("SAGEMAKER_LOG_LEVEL", "CRITICAL")
31
+
32
+
33
+ # ── Utility functions ─────────────────────────────────────────────────────────
34
+
35
+ def _error_exit(message):
36
+ """Print JSON error to stdout and exit with code 1."""
37
+ print(json.dumps({"error": True, "message": message}))
38
+ sys.exit(1)
39
+
40
+
41
+ def _output(data):
42
+ """Print JSON result to stdout."""
43
+ print(json.dumps(data))
44
+ sys.exit(0)
45
+
46
+
47
+ def _sanitize_for_json(value):
48
+ """Convert sagemaker-core Unassigned sentinel values to None."""
49
+ if value is None:
50
+ return None
51
+ type_name = type(value).__name__
52
+ if type_name in ("Unassigned", "UnassignedValue"):
53
+ return None
54
+ return value
55
+
56
+
57
+ # ── cmd_submit ────────────────────────────────────────────────────────────────
58
+
59
+ def cmd_submit(args):
60
+ """Create a SageMaker Training Job via SDK v3.
61
+
62
+ Reads job configuration from a JSON file (same format as the old
63
+ CreateTrainingJob CLI input), then submits via TrainingJob.create().
64
+
65
+ Returns: {"job_name": str, "job_arn": str, "status": "InProgress"}
66
+ """
67
+ # Set region BEFORE any sagemaker import (Bug 26 pattern)
68
+ region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
69
+ if region:
70
+ os.environ['AWS_DEFAULT_REGION'] = region
71
+ os.environ.setdefault('AWS_REGION', region)
72
+
73
+ # Read config file
74
+ try:
75
+ with open(args.config, 'r') as f:
76
+ config = json.load(f)
77
+ except (IOError, json.JSONDecodeError) as e:
78
+ _error_exit(f"Failed to read config file: {e}")
79
+
80
+ # Import SDK v3 TrainingJob (same pattern as .tune_helper.py cmd_status)
81
+ try:
82
+ from sagemaker.core.resources import TrainingJob
83
+ except ImportError:
84
+ _error_exit(
85
+ "sagemaker SDK v3 not installed. "
86
+ "Install: pip install 'sagemaker>=3.0'"
87
+ )
88
+
89
+ # Extract fields from the CreateTrainingJob-format config
90
+ job_name = config.get('TrainingJobName', '')
91
+ role_arn = config.get('RoleArn', '')
92
+ algo_spec = config.get('AlgorithmSpecification', {})
93
+ resource_config = config.get('ResourceConfig', {})
94
+ input_data_config = config.get('InputDataConfig', [])
95
+ output_data_config = config.get('OutputDataConfig', {})
96
+ stopping_condition = config.get('StoppingCondition', {})
97
+ hyper_parameters = config.get('HyperParameters', {})
98
+ checkpoint_config = config.get('CheckpointConfig')
99
+ environment = config.get('Environment', {})
100
+ enable_spot = config.get('EnableManagedSpotTraining', False)
101
+ tags = config.get('Tags', [])
102
+
103
+ # Build SDK v3 create kwargs (snake_case per Pydantic v2)
104
+ create_kwargs = {
105
+ 'training_job_name': job_name,
106
+ 'role_arn': role_arn,
107
+ 'algorithm_specification': {
108
+ 'training_image': algo_spec.get('TrainingImage', ''),
109
+ 'training_input_mode': algo_spec.get('TrainingInputMode', 'File'),
110
+ },
111
+ 'resource_config': {
112
+ 'instance_type': resource_config.get('InstanceType', 'ml.g5.xlarge'),
113
+ 'instance_count': resource_config.get('InstanceCount', 1),
114
+ 'volume_size_in_gb': resource_config.get('VolumeSizeInGB', 50),
115
+ },
116
+ 'output_data_config': {
117
+ 's3_output_path': output_data_config.get('S3OutputPath', ''),
118
+ },
119
+ 'stopping_condition': {
120
+ 'max_runtime_in_seconds': stopping_condition.get('MaxRuntimeInSeconds', 86400),
121
+ },
122
+ }
123
+
124
+ # Input data channels
125
+ if input_data_config:
126
+ channels = []
127
+ for channel in input_data_config:
128
+ ch = {
129
+ 'channel_name': channel.get('ChannelName', 'training'),
130
+ 'data_source': {
131
+ 's3_data_source': {
132
+ 's3_data_type': channel.get('DataSource', {}).get('S3DataSource', {}).get('S3DataType', 'S3Prefix'),
133
+ 's3_uri': channel.get('DataSource', {}).get('S3DataSource', {}).get('S3Uri', ''),
134
+ 's3_data_distribution_type': channel.get('DataSource', {}).get('S3DataSource', {}).get('S3DataDistributionType', 'FullyReplicated'),
135
+ }
136
+ }
137
+ }
138
+ channels.append(ch)
139
+ create_kwargs['input_data_config'] = channels
140
+
141
+ # Hyperparameters (all values must be strings)
142
+ if hyper_parameters:
143
+ create_kwargs['hyper_parameters'] = {
144
+ str(k): str(v) for k, v in hyper_parameters.items()
145
+ }
146
+
147
+ # Metric definitions
148
+ metric_defs = algo_spec.get('MetricDefinitions', [])
149
+ if metric_defs:
150
+ create_kwargs['algorithm_specification']['metric_definitions'] = [
151
+ {'name': m.get('Name', ''), 'regex': m.get('Regex', '')}
152
+ for m in metric_defs
153
+ ]
154
+
155
+ # Managed spot training
156
+ if enable_spot:
157
+ create_kwargs['enable_managed_spot_training'] = True
158
+ max_wait = stopping_condition.get('MaxWaitTimeInSeconds')
159
+ if max_wait:
160
+ create_kwargs['stopping_condition']['max_wait_time_in_seconds'] = max_wait
161
+
162
+ # Checkpoint config
163
+ if checkpoint_config:
164
+ create_kwargs['checkpoint_config'] = {
165
+ 's3_uri': checkpoint_config.get('S3Uri', ''),
166
+ }
167
+
168
+ # Environment
169
+ if environment:
170
+ create_kwargs['environment'] = environment
171
+
172
+ # Tags
173
+ if tags:
174
+ create_kwargs['tags'] = [
175
+ {'key': t.get('Key', ''), 'value': t.get('Value', '')}
176
+ for t in tags
177
+ ]
178
+
179
+ # Submit the job
180
+ try:
181
+ job = TrainingJob.create(**create_kwargs)
182
+ job_arn = getattr(job, 'training_job_arn', '') or ''
183
+ _output({
184
+ "job_name": job_name,
185
+ "job_arn": _sanitize_for_json(job_arn) or job_name,
186
+ "status": "InProgress"
187
+ })
188
+ except Exception as e:
189
+ error_msg = str(e)
190
+ if "AccessDenied" in error_msg or "AccessDeniedException" in error_msg:
191
+ _error_exit(
192
+ f"Access denied when submitting training job. "
193
+ f"Ensure the role has sagemaker:CreateTrainingJob permission. "
194
+ f"Details: {error_msg}"
195
+ )
196
+ else:
197
+ _error_exit(f"Failed to create training job: {error_msg}")
198
+
199
+
200
+ # ── cmd_status ────────────────────────────────────────────────────────────────
201
+
202
+ def cmd_status(args):
203
+ """Query job status via TrainingJob.get().
204
+
205
+ Returns: {"status": str, "secondary_status": str, "failure_reason": str|null,
206
+ "elapsed_seconds": int|null, "metrics": dict|null,
207
+ "display": str, "model_artifacts": str|null}
208
+ """
209
+ region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
210
+ if region:
211
+ os.environ['AWS_DEFAULT_REGION'] = region
212
+ os.environ.setdefault('AWS_REGION', region)
213
+
214
+ try:
215
+ from sagemaker.core.resources import TrainingJob
216
+ except ImportError:
217
+ _error_exit("sagemaker SDK v3 not installed.")
218
+
219
+ # Get job
220
+ try:
221
+ job = TrainingJob.get(training_job_name=args.job_name)
222
+ except Exception as e:
223
+ _error_exit(f"Failed to describe training job '{args.job_name}': {e}")
224
+
225
+ status = _sanitize_for_json(getattr(job, 'training_job_status', 'Unknown')) or 'Unknown'
226
+ secondary = _sanitize_for_json(getattr(job, 'secondary_status', '')) or ''
227
+ failure_reason = _sanitize_for_json(getattr(job, 'failure_reason', None))
228
+
229
+ # Elapsed time
230
+ elapsed_seconds = None
231
+ start_time = _sanitize_for_json(getattr(job, 'training_start_time', None))
232
+ end_time = _sanitize_for_json(getattr(job, 'training_end_time', None))
233
+ if start_time:
234
+ from datetime import datetime, timezone
235
+ try:
236
+ if end_time:
237
+ elapsed_seconds = int((end_time - start_time).total_seconds())
238
+ else:
239
+ now = datetime.now(timezone.utc)
240
+ elapsed_seconds = int((now - start_time).total_seconds())
241
+ except (TypeError, AttributeError):
242
+ pass
243
+
244
+ # Metrics
245
+ metrics = None
246
+ final_metrics = _sanitize_for_json(getattr(job, 'final_metric_data_list', None))
247
+ if final_metrics:
248
+ try:
249
+ metrics = {
250
+ m.metric_name: m.value
251
+ for m in final_metrics
252
+ if hasattr(m, 'metric_name') and hasattr(m, 'value')
253
+ }
254
+ except (TypeError, AttributeError):
255
+ pass
256
+
257
+ # Model artifacts
258
+ model_artifacts = None
259
+ artifacts_obj = _sanitize_for_json(getattr(job, 'model_artifacts', None))
260
+ if artifacts_obj:
261
+ model_artifacts = _sanitize_for_json(getattr(artifacts_obj, 's3_model_artifacts', None))
262
+
263
+ # Build display line
264
+ emoji_map = {'InProgress': '🔄', 'Completed': '✅', 'Failed': '❌', 'Stopped': '⏹️'}
265
+ emoji = emoji_map.get(status, '❓')
266
+ display_parts = [f" {emoji} {status}"]
267
+ if secondary:
268
+ display_parts.append(f"| {secondary}")
269
+ if elapsed_seconds is not None:
270
+ hours = elapsed_seconds // 3600
271
+ mins = (elapsed_seconds % 3600) // 60
272
+ secs = elapsed_seconds % 60
273
+ if hours > 0:
274
+ display_parts.append(f"| elapsed: {hours}h {mins}m {secs}s")
275
+ elif mins > 0:
276
+ display_parts.append(f"| elapsed: {mins}m {secs}s")
277
+ else:
278
+ display_parts.append(f"| elapsed: {secs}s")
279
+
280
+ _output({
281
+ "status": status,
282
+ "secondary_status": secondary,
283
+ "failure_reason": failure_reason,
284
+ "elapsed_seconds": elapsed_seconds,
285
+ "metrics": metrics,
286
+ "model_artifacts": model_artifacts,
287
+ "display": " ".join(display_parts),
288
+ })
289
+
290
+
291
+ # ── cmd_resolve ───────────────────────────────────────────────────────────────
292
+
293
+ def cmd_resolve(args):
294
+ """Extract model artifact or checkpoint S3 path from a training job.
295
+
296
+ With --checkpoints: returns checkpoint_config.s3_uri (for --resume).
297
+ Without: returns model artifacts path (for adapter staging).
298
+
299
+ Returns: {"artifact_path": str, "output_type": str, "checkpoint_path": str|null}
300
+ """
301
+ region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
302
+ if region:
303
+ os.environ['AWS_DEFAULT_REGION'] = region
304
+ os.environ.setdefault('AWS_REGION', region)
305
+
306
+ try:
307
+ from sagemaker.core.resources import TrainingJob
308
+ except ImportError:
309
+ _error_exit("sagemaker SDK v3 not installed.")
310
+
311
+ try:
312
+ job = TrainingJob.get(training_job_name=args.job_name)
313
+ except Exception as e:
314
+ _error_exit(f"Failed to describe training job '{args.job_name}': {e}")
315
+
316
+ # If --checkpoints flag, return checkpoint path (job can be any status)
317
+ if getattr(args, 'checkpoints', False):
318
+ checkpoint_config = _sanitize_for_json(getattr(job, 'checkpoint_config', None))
319
+ checkpoint_path = None
320
+ if checkpoint_config:
321
+ checkpoint_path = _sanitize_for_json(getattr(checkpoint_config, 's3_uri', None))
322
+
323
+ # Fallback: derive from output path
324
+ if not checkpoint_path:
325
+ output_config = _sanitize_for_json(getattr(job, 'output_data_config', None))
326
+ if output_config:
327
+ s3_output = _sanitize_for_json(getattr(output_config, 's3_output_path', None))
328
+ if s3_output:
329
+ checkpoint_path = f"{s3_output.rstrip('/')}/checkpoints/"
330
+
331
+ _output({
332
+ "checkpoint_path": checkpoint_path or "",
333
+ "job_name": args.job_name,
334
+ })
335
+ return
336
+
337
+ # Normal resolve: require completed status
338
+ status = _sanitize_for_json(getattr(job, 'training_job_status', 'Unknown')) or 'Unknown'
339
+ if status != 'Completed':
340
+ _error_exit(f"Job '{args.job_name}' is not completed (status: {status})")
341
+
342
+ artifacts_obj = _sanitize_for_json(getattr(job, 'model_artifacts', None))
343
+ if not artifacts_obj:
344
+ _error_exit(f"No model artifacts found for job '{args.job_name}'")
345
+
346
+ artifact_path = _sanitize_for_json(getattr(artifacts_obj, 's3_model_artifacts', None))
347
+ if not artifact_path:
348
+ _error_exit(f"No S3 model artifacts path for job '{args.job_name}'")
349
+
350
+ # Detect output type based on technique hint
351
+ output_type = "full-model"
352
+ technique = getattr(args, 'technique', None)
353
+ if technique and technique in ('sft', 'dpo'):
354
+ output_type = "adapter"
355
+
356
+ _output({
357
+ "artifact_path": artifact_path,
358
+ "output_type": output_type,
359
+ })
360
+
361
+
362
+ # ── cmd_stop ──────────────────────────────────────────────────────────────────
363
+
364
+ def cmd_stop(args):
365
+ """Stop a running training job.
366
+
367
+ Returns: {"stopped": true, "job_name": str}
368
+ """
369
+ region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
370
+ if region:
371
+ os.environ['AWS_DEFAULT_REGION'] = region
372
+ os.environ.setdefault('AWS_REGION', region)
373
+
374
+ try:
375
+ from sagemaker.core.resources import TrainingJob
376
+ except ImportError:
377
+ _error_exit("sagemaker SDK v3 not installed.")
378
+
379
+ try:
380
+ job = TrainingJob.get(training_job_name=args.job_name)
381
+ job.stop()
382
+ _output({"stopped": True, "job_name": args.job_name})
383
+ except Exception as e:
384
+ _error_exit(f"Failed to stop training job '{args.job_name}': {e}")
385
+
386
+
387
+ # ── Main ──────────────────────────────────────────────────────────────────────
388
+
389
+ def main():
390
+ """Parse arguments and dispatch to subcommand handler."""
391
+ parser = argparse.ArgumentParser(description='SageMaker Training Job helper (SDK v3)')
392
+ subparsers = parser.add_subparsers(dest='command', required=True)
393
+
394
+ # submit
395
+ submit_parser = subparsers.add_parser('submit', help='Create a training job')
396
+ submit_parser.add_argument('--config', required=True, help='Path to job config JSON')
397
+ submit_parser.add_argument('--region', help='AWS region')
398
+
399
+ # status
400
+ status_parser = subparsers.add_parser('status', help='Get job status')
401
+ status_parser.add_argument('--job-name', required=True, help='Training job name')
402
+ status_parser.add_argument('--region', help='AWS region')
403
+
404
+ # resolve
405
+ resolve_parser = subparsers.add_parser('resolve', help='Resolve artifacts from completed job')
406
+ resolve_parser.add_argument('--job-name', required=True, help='Training job name')
407
+ resolve_parser.add_argument('--technique', help='Training technique (for output type hint)')
408
+ resolve_parser.add_argument('--checkpoints', action='store_true', help='Return checkpoint S3 path instead of model artifacts')
409
+ resolve_parser.add_argument('--region', help='AWS region')
410
+
411
+ # stop
412
+ stop_parser = subparsers.add_parser('stop', help='Stop a running job')
413
+ stop_parser.add_argument('--job-name', required=True, help='Training job name')
414
+ stop_parser.add_argument('--region', help='AWS region')
415
+
416
+ args = parser.parse_args()
417
+
418
+ commands = {
419
+ 'submit': cmd_submit,
420
+ 'status': cmd_status,
421
+ 'resolve': cmd_resolve,
422
+ 'stop': cmd_stop,
423
+ }
424
+
425
+ handler = commands.get(args.command)
426
+ if handler:
427
+ handler(args)
428
+ else:
429
+ _error_exit(f"Unknown command: {args.command}")
430
+
431
+
432
+ if __name__ == '__main__':
433
+ main()
@@ -35,6 +35,7 @@ _usage() {
35
35
  echo " add <name> --weights <s3-uri> Add a new LoRA adapter from S3"
36
36
  echo " add <name> --from-hub <hf-repo-id> Add a new LoRA adapter from HuggingFace Hub"
37
37
  echo " add <name> --from-tune [technique] Add adapter from do/tune output"
38
+ echo " add <name> --from-train [technique] Add adapter from do/train output"
38
39
  echo " add <name> --from-registry [arn] Add adapter from model registry"
39
40
  echo " list List all adapters on the endpoint"
40
41
  echo " remove <name> Remove an adapter"
@@ -375,6 +376,8 @@ _adapter_add() {
375
376
  local from_hub=""
376
377
  local from_tune=""
377
378
  local from_tune_technique=""
379
+ local from_train=""
380
+ local from_train_technique=""
378
381
  local from_registry=""
379
382
  local registry_arn=""
380
383
  local use_local=""
@@ -413,6 +416,16 @@ _adapter_add() {
413
416
  shift
414
417
  fi
415
418
  ;;
419
+ --from-train)
420
+ from_train="true"
421
+ # Check if next argument is a technique (not another flag and not empty)
422
+ if [ -n "${2:-}" ] && [[ "${2}" != -* ]]; then
423
+ from_train_technique="$2"
424
+ shift 2
425
+ else
426
+ shift
427
+ fi
428
+ ;;
416
429
  --from-registry)
417
430
  from_registry="true"
418
431
  # Check if next argument is an ARN (starts with arn:)
@@ -508,6 +521,7 @@ _adapter_add() {
508
521
  [ -n "${weights_uri}" ] && source_count=$((source_count + 1))
509
522
  [ -n "${from_hub}" ] && source_count=$((source_count + 1))
510
523
  [ -n "${from_tune}" ] && source_count=$((source_count + 1))
524
+ [ -n "${from_train}" ] && source_count=$((source_count + 1))
511
525
  [ -n "${from_registry}" ] && source_count=$((source_count + 1))
512
526
 
513
527
  if [ "${source_count}" -gt 1 ]; then
@@ -868,6 +882,96 @@ _adapter_add() {
868
882
  fi # end --local else branch
869
883
  fi
870
884
 
885
+ # ── Resolve --from-train to weights_uri ───────────────────────────────
886
+ if [ -n "${from_train}" ]; then
887
+ if [ -n "${from_train_technique}" ]; then
888
+ local technique_upper
889
+ technique_upper=$(echo "${from_train_technique}" | tr '[:lower:]' '[:upper:]')
890
+ local train_var="TRAIN_ADAPTER_PATH_${technique_upper}"
891
+ local train_path="${!train_var:-}"
892
+
893
+ if [ -z "${train_path}" ]; then
894
+ echo "❌ No training adapter output found for technique: ${from_train_technique}"
895
+ echo ""
896
+ echo " ${train_var} is not set in do/config."
897
+ echo ""
898
+ echo " Run a training job first:"
899
+ echo " ./do/train --technique ${from_train_technique} --dataset <source>"
900
+ exit 1
901
+ fi
902
+
903
+ weights_uri="${train_path}"
904
+ echo "📦 Using train adapter output for technique '${from_train_technique}': ${weights_uri}"
905
+ else
906
+ # No technique: read TRAIN_OUTPUT_PATH_LATEST
907
+ if [ -z "${TRAIN_OUTPUT_PATH_LATEST:-}" ]; then
908
+ echo "❌ No training output found."
909
+ echo ""
910
+ echo " TRAIN_OUTPUT_PATH_LATEST is not set in do/config."
911
+ echo ""
912
+ echo " Run a training job first:"
913
+ echo " ./do/train --technique <technique> --dataset <source>"
914
+ exit 1
915
+ fi
916
+
917
+ weights_uri="${TRAIN_OUTPUT_PATH_LATEST}"
918
+ echo "📦 Using latest train adapter output: ${weights_uri}"
919
+ fi
920
+ echo ""
921
+
922
+ # Use same staging path as --from-tune (Processing Job or local)
923
+ if [ -z "${use_local}" ]; then
924
+ echo "🚀 Submitting Processing Job to stage adapter from training output..."
925
+ echo ""
926
+
927
+ local exec_role="${EXECUTION_ROLE_ARN:-}"
928
+ if [ -z "${exec_role}" ]; then
929
+ exec_role="${ROLE_ARN:-}"
930
+ fi
931
+ if [ -z "${exec_role}" ]; then
932
+ exec_role="${SAGEMAKER_ROLE_ARN:-}"
933
+ fi
934
+ if [ -z "${exec_role}" ]; then
935
+ echo "❌ No execution role found."
936
+ echo " Run 'ml-container-creator bootstrap' to set up your profile."
937
+ exit 1
938
+ fi
939
+
940
+ local adapter_bucket="${ADAPTER_S3_BUCKET:-}"
941
+ if [ -z "${adapter_bucket}" ]; then
942
+ local account_id
943
+ account_id=$(aws sts get-caller-identity --query Account --output text 2>/dev/null || echo "")
944
+ adapter_bucket="sagemaker-${AWS_REGION:-us-east-1}-${account_id}"
945
+ fi
946
+
947
+ local adapter_s3_prefix="s3://${adapter_bucket}/${PROJECT_NAME}/adapters/${adapter_name}"
948
+
949
+ local stage_args=(
950
+ --source-uri "${weights_uri}"
951
+ --output-uri "${adapter_s3_prefix}/"
952
+ --role-arn "${exec_role}"
953
+ --region "${AWS_REGION}"
954
+ )
955
+ if [ -n "${no_wait}" ]; then
956
+ stage_args+=(--no-wait)
957
+ fi
958
+
959
+ local stage_result
960
+ stage_result=$(python3 "${SCRIPT_DIR}/.adapter_helper.py" stage "${stage_args[@]}" 2>/dev/null | grep -E '^\{' | tail -1) || {
961
+ echo "❌ Failed to submit adapter staging job"
962
+ exit 1
963
+ }
964
+
965
+ weights_uri=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('adapter_s3_uri',''))" 2>/dev/null) || weights_uri=""
966
+ if [ -z "${weights_uri}" ]; then
967
+ echo "❌ Failed to extract staged adapter URI"
968
+ exit 1
969
+ fi
970
+ echo " ✅ Adapter staged to: ${weights_uri}"
971
+ fi
972
+ echo ""
973
+ fi
974
+
871
975
  # ── Resolve --from-registry to weights_uri ────────────────────────────
872
976
  if [ -n "${from_registry}" ]; then
873
977
  if [ -z "${registry_arn}" ]; then
@@ -986,6 +1090,13 @@ _adapter_add() {
986
1090
  # Extract model data URL (weights path)
987
1091
  weights_uri=$(echo "${version_line}" | python3 -c "import sys,json; data=json.loads(sys.stdin.read()); print(data.get('modelDataUrl',''))" 2>/dev/null || echo "")
988
1092
 
1093
+ # Ensure adapter weights URI ends with / (S3 prefix for directory-style adapters).
1094
+ # Registry metadata may have the slash stripped (Bug 52 rstrip), but SageMaker IC
1095
+ # ModelDataUrl requires it to download all objects under the prefix.
1096
+ if [ -n "${weights_uri}" ] && ! echo "${weights_uri}" | grep -q '\.tar\.gz$'; then
1097
+ weights_uri="${weights_uri%/}/"
1098
+ fi
1099
+
989
1100
  if [ -z "${weights_uri}" ]; then
990
1101
  echo "❌ No model data URL found for registry version: ${registry_arn}"
991
1102
  echo ""
@@ -1294,6 +1405,16 @@ EOF
1294
1405
  fi
1295
1406
  fi
1296
1407
 
1408
+ # Add train-specific metadata if --from-train was used
1409
+ if [ -n "${from_train}" ]; then
1410
+ local train_technique_meta="${from_train_technique:-${TRAIN_TECHNIQUE:-custom}}"
1411
+ cat >> "${SCRIPT_DIR}/adapters/${adapter_name}.conf" <<EOF
1412
+ export ADAPTER_SOURCE="train"
1413
+ export ADAPTER_TECHNIQUE="${train_technique_meta}"
1414
+ export ADAPTER_TRAIN_JOB="${TRAIN_JOB_NAME:-}"
1415
+ EOF
1416
+ fi
1417
+
1297
1418
  # Add registry-specific metadata if --from-registry was used
1298
1419
  if [ -n "${from_registry}" ]; then
1299
1420
  cat >> "${SCRIPT_DIR}/adapters/${adapter_name}.conf" <<EOF
@@ -1425,6 +1546,42 @@ if endpoint_name:
1425
1546
  except Exception:
1426
1547
  print("⚠️ Could not query endpoint — showing local confs only.", file=sys.stderr)
1427
1548
 
1549
+ # ── Data source 3: Registry (MPG) adapters ──
1550
+ # Query the deployment MPG for registered adapter versions (if .register_helper.py exists)
1551
+ helper_path = os.path.join(script_dir, ".register_helper.py")
1552
+ if os.path.exists(helper_path):
1553
+ try:
1554
+ result = subprocess.run(
1555
+ ["python3", helper_path, "list-adapters",
1556
+ "--project-name", project_name, "--region", region],
1557
+ capture_output=True, text=True, timeout=15)
1558
+ if result.returncode == 0:
1559
+ # Extract JSON line
1560
+ for line in result.stdout.strip().split("\n"):
1561
+ if line.startswith("{"):
1562
+ reg_data = json.loads(line)
1563
+ for adapter in reg_data.get("adapters", []):
1564
+ reg_name = adapter.get("name", "")
1565
+ if not reg_name:
1566
+ continue
1567
+ # Only add if not already tracked locally
1568
+ if reg_name not in adapters:
1569
+ adapters[reg_name] = {
1570
+ "source": "registry",
1571
+ "ic_name": "",
1572
+ "technique": adapter.get("technique", ""),
1573
+ "dataset": "",
1574
+ "status": f"v{adapter.get('version', '?')}",
1575
+ }
1576
+ else:
1577
+ # Annotate existing entry with registry version
1578
+ ver = adapter.get("version", "")
1579
+ if ver:
1580
+ adapters[reg_name]["status"] += f" (reg:v{ver})"
1581
+ break
1582
+ except Exception:
1583
+ pass # Registry query is best-effort
1584
+
1428
1585
  # ── Output ──
1429
1586
  if not adapters:
1430
1587
  print("No adapters found.")