@aws/ml-container-creator 1.0.3 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -1
- package/bin/cli.js +57 -0
- package/config/agent.json +16 -0
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +5 -2
- package/pyproject.toml +3 -0
- package/servers/agent-knowledge/index.js +592 -0
- package/servers/agent-knowledge/package.json +15 -0
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1730
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/agent/__init__.py +2 -0
- package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
- package/src/agent/agent.py +513 -0
- package/src/agent/config_loader.py +215 -0
- package/src/agent/context.py +380 -0
- package/src/agent/data/capability-matrix.json +106 -0
- package/src/agent/health_check.py +341 -0
- package/src/agent/prompts/system.md +173 -0
- package/src/agent/requirements-agent.txt +3 -0
- package/src/app.js +6 -4
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/src/lib/tune-config-state.js +89 -68
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/config +6 -1
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/src/lib/auto-prompt-builder.js +0 -172
- package/src/lib/cli-handler.js +0 -529
- package/src/lib/community-reports-validator.js +0 -91
- package/src/lib/configuration-exporter.js +0 -204
- package/src/lib/dataset-slug.js +0 -152
- package/src/lib/docker-introspection-validator.js +0 -51
- package/src/lib/known-flags-validator.js +0 -200
- package/src/lib/schema-validator.js +0 -157
- package/src/lib/train-config-parser.js +0 -136
- package/src/lib/train-config-persistence.js +0 -143
- package/src/lib/train-config-validator.js +0 -112
- package/src/lib/train-feedback.js +0 -46
- package/src/lib/train-idempotency.js +0 -97
- package/src/lib/train-request-builder.js +0 -120
- package/src/lib/tune-dataset-validator.js +0 -279
- package/src/lib/tune-output-resolver.js +0 -66
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""SageMaker Training Job helper (SDK v3).
|
|
6
|
+
|
|
7
|
+
Subcommands:
|
|
8
|
+
submit - Create a training job via TrainingJob.create()
|
|
9
|
+
status - Get job status via TrainingJob.get()
|
|
10
|
+
resolve - Extract artifact path from completed job
|
|
11
|
+
stop - Stop a running training job
|
|
12
|
+
|
|
13
|
+
All output is JSON on stdout for bash consumption.
|
|
14
|
+
Pattern: grep -E '^\\{' | tail -1 to extract JSON from mixed output.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import os
|
|
20
|
+
import sys
|
|
21
|
+
import warnings
|
|
22
|
+
|
|
23
|
+
# Suppress noisy dependency warnings
|
|
24
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
25
|
+
warnings.filterwarnings("ignore", message=".*urllib3.*")
|
|
26
|
+
|
|
27
|
+
# Suppress ALL logging to prevent sagemaker-core/rich from writing to stdout
|
|
28
|
+
import logging as _logging
|
|
29
|
+
_logging.disable(_logging.CRITICAL)
|
|
30
|
+
os.environ.setdefault("SAGEMAKER_LOG_LEVEL", "CRITICAL")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ── Utility functions ─────────────────────────────────────────────────────────
|
|
34
|
+
|
|
35
|
+
def _error_exit(message):
|
|
36
|
+
"""Print JSON error to stdout and exit with code 1."""
|
|
37
|
+
print(json.dumps({"error": True, "message": message}))
|
|
38
|
+
sys.exit(1)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _output(data):
|
|
42
|
+
"""Print JSON result to stdout."""
|
|
43
|
+
print(json.dumps(data))
|
|
44
|
+
sys.exit(0)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _sanitize_for_json(value):
|
|
48
|
+
"""Convert sagemaker-core Unassigned sentinel values to None."""
|
|
49
|
+
if value is None:
|
|
50
|
+
return None
|
|
51
|
+
type_name = type(value).__name__
|
|
52
|
+
if type_name in ("Unassigned", "UnassignedValue"):
|
|
53
|
+
return None
|
|
54
|
+
return value
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# ── cmd_submit ────────────────────────────────────────────────────────────────
|
|
58
|
+
|
|
59
|
+
def cmd_submit(args):
|
|
60
|
+
"""Create a SageMaker Training Job via SDK v3.
|
|
61
|
+
|
|
62
|
+
Reads job configuration from a JSON file (same format as the old
|
|
63
|
+
CreateTrainingJob CLI input), then submits via TrainingJob.create().
|
|
64
|
+
|
|
65
|
+
Returns: {"job_name": str, "job_arn": str, "status": "InProgress"}
|
|
66
|
+
"""
|
|
67
|
+
# Set region BEFORE any sagemaker import (Bug 26 pattern)
|
|
68
|
+
region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
|
|
69
|
+
if region:
|
|
70
|
+
os.environ['AWS_DEFAULT_REGION'] = region
|
|
71
|
+
os.environ.setdefault('AWS_REGION', region)
|
|
72
|
+
|
|
73
|
+
# Read config file
|
|
74
|
+
try:
|
|
75
|
+
with open(args.config, 'r') as f:
|
|
76
|
+
config = json.load(f)
|
|
77
|
+
except (IOError, json.JSONDecodeError) as e:
|
|
78
|
+
_error_exit(f"Failed to read config file: {e}")
|
|
79
|
+
|
|
80
|
+
# Import SDK v3 TrainingJob (same pattern as .tune_helper.py cmd_status)
|
|
81
|
+
try:
|
|
82
|
+
from sagemaker.core.resources import TrainingJob
|
|
83
|
+
except ImportError:
|
|
84
|
+
_error_exit(
|
|
85
|
+
"sagemaker SDK v3 not installed. "
|
|
86
|
+
"Install: pip install 'sagemaker>=3.0'"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Extract fields from the CreateTrainingJob-format config
|
|
90
|
+
job_name = config.get('TrainingJobName', '')
|
|
91
|
+
role_arn = config.get('RoleArn', '')
|
|
92
|
+
algo_spec = config.get('AlgorithmSpecification', {})
|
|
93
|
+
resource_config = config.get('ResourceConfig', {})
|
|
94
|
+
input_data_config = config.get('InputDataConfig', [])
|
|
95
|
+
output_data_config = config.get('OutputDataConfig', {})
|
|
96
|
+
stopping_condition = config.get('StoppingCondition', {})
|
|
97
|
+
hyper_parameters = config.get('HyperParameters', {})
|
|
98
|
+
checkpoint_config = config.get('CheckpointConfig')
|
|
99
|
+
environment = config.get('Environment', {})
|
|
100
|
+
enable_spot = config.get('EnableManagedSpotTraining', False)
|
|
101
|
+
tags = config.get('Tags', [])
|
|
102
|
+
|
|
103
|
+
# Build SDK v3 create kwargs (snake_case per Pydantic v2)
|
|
104
|
+
create_kwargs = {
|
|
105
|
+
'training_job_name': job_name,
|
|
106
|
+
'role_arn': role_arn,
|
|
107
|
+
'algorithm_specification': {
|
|
108
|
+
'training_image': algo_spec.get('TrainingImage', ''),
|
|
109
|
+
'training_input_mode': algo_spec.get('TrainingInputMode', 'File'),
|
|
110
|
+
},
|
|
111
|
+
'resource_config': {
|
|
112
|
+
'instance_type': resource_config.get('InstanceType', 'ml.g5.xlarge'),
|
|
113
|
+
'instance_count': resource_config.get('InstanceCount', 1),
|
|
114
|
+
'volume_size_in_gb': resource_config.get('VolumeSizeInGB', 50),
|
|
115
|
+
},
|
|
116
|
+
'output_data_config': {
|
|
117
|
+
's3_output_path': output_data_config.get('S3OutputPath', ''),
|
|
118
|
+
},
|
|
119
|
+
'stopping_condition': {
|
|
120
|
+
'max_runtime_in_seconds': stopping_condition.get('MaxRuntimeInSeconds', 86400),
|
|
121
|
+
},
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
# Input data channels
|
|
125
|
+
if input_data_config:
|
|
126
|
+
channels = []
|
|
127
|
+
for channel in input_data_config:
|
|
128
|
+
ch = {
|
|
129
|
+
'channel_name': channel.get('ChannelName', 'training'),
|
|
130
|
+
'data_source': {
|
|
131
|
+
's3_data_source': {
|
|
132
|
+
's3_data_type': channel.get('DataSource', {}).get('S3DataSource', {}).get('S3DataType', 'S3Prefix'),
|
|
133
|
+
's3_uri': channel.get('DataSource', {}).get('S3DataSource', {}).get('S3Uri', ''),
|
|
134
|
+
's3_data_distribution_type': channel.get('DataSource', {}).get('S3DataSource', {}).get('S3DataDistributionType', 'FullyReplicated'),
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
channels.append(ch)
|
|
139
|
+
create_kwargs['input_data_config'] = channels
|
|
140
|
+
|
|
141
|
+
# Hyperparameters (all values must be strings)
|
|
142
|
+
if hyper_parameters:
|
|
143
|
+
create_kwargs['hyper_parameters'] = {
|
|
144
|
+
str(k): str(v) for k, v in hyper_parameters.items()
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
# Metric definitions
|
|
148
|
+
metric_defs = algo_spec.get('MetricDefinitions', [])
|
|
149
|
+
if metric_defs:
|
|
150
|
+
create_kwargs['algorithm_specification']['metric_definitions'] = [
|
|
151
|
+
{'name': m.get('Name', ''), 'regex': m.get('Regex', '')}
|
|
152
|
+
for m in metric_defs
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
# Managed spot training
|
|
156
|
+
if enable_spot:
|
|
157
|
+
create_kwargs['enable_managed_spot_training'] = True
|
|
158
|
+
max_wait = stopping_condition.get('MaxWaitTimeInSeconds')
|
|
159
|
+
if max_wait:
|
|
160
|
+
create_kwargs['stopping_condition']['max_wait_time_in_seconds'] = max_wait
|
|
161
|
+
|
|
162
|
+
# Checkpoint config
|
|
163
|
+
if checkpoint_config:
|
|
164
|
+
create_kwargs['checkpoint_config'] = {
|
|
165
|
+
's3_uri': checkpoint_config.get('S3Uri', ''),
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# Environment
|
|
169
|
+
if environment:
|
|
170
|
+
create_kwargs['environment'] = environment
|
|
171
|
+
|
|
172
|
+
# Tags
|
|
173
|
+
if tags:
|
|
174
|
+
create_kwargs['tags'] = [
|
|
175
|
+
{'key': t.get('Key', ''), 'value': t.get('Value', '')}
|
|
176
|
+
for t in tags
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
# Submit the job
|
|
180
|
+
try:
|
|
181
|
+
job = TrainingJob.create(**create_kwargs)
|
|
182
|
+
job_arn = getattr(job, 'training_job_arn', '') or ''
|
|
183
|
+
_output({
|
|
184
|
+
"job_name": job_name,
|
|
185
|
+
"job_arn": _sanitize_for_json(job_arn) or job_name,
|
|
186
|
+
"status": "InProgress"
|
|
187
|
+
})
|
|
188
|
+
except Exception as e:
|
|
189
|
+
error_msg = str(e)
|
|
190
|
+
if "AccessDenied" in error_msg or "AccessDeniedException" in error_msg:
|
|
191
|
+
_error_exit(
|
|
192
|
+
f"Access denied when submitting training job. "
|
|
193
|
+
f"Ensure the role has sagemaker:CreateTrainingJob permission. "
|
|
194
|
+
f"Details: {error_msg}"
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
_error_exit(f"Failed to create training job: {error_msg}")
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
# ── cmd_status ────────────────────────────────────────────────────────────────
|
|
201
|
+
|
|
202
|
+
def cmd_status(args):
|
|
203
|
+
"""Query job status via TrainingJob.get().
|
|
204
|
+
|
|
205
|
+
Returns: {"status": str, "secondary_status": str, "failure_reason": str|null,
|
|
206
|
+
"elapsed_seconds": int|null, "metrics": dict|null,
|
|
207
|
+
"display": str, "model_artifacts": str|null}
|
|
208
|
+
"""
|
|
209
|
+
region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
|
|
210
|
+
if region:
|
|
211
|
+
os.environ['AWS_DEFAULT_REGION'] = region
|
|
212
|
+
os.environ.setdefault('AWS_REGION', region)
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
from sagemaker.core.resources import TrainingJob
|
|
216
|
+
except ImportError:
|
|
217
|
+
_error_exit("sagemaker SDK v3 not installed.")
|
|
218
|
+
|
|
219
|
+
# Get job
|
|
220
|
+
try:
|
|
221
|
+
job = TrainingJob.get(training_job_name=args.job_name)
|
|
222
|
+
except Exception as e:
|
|
223
|
+
_error_exit(f"Failed to describe training job '{args.job_name}': {e}")
|
|
224
|
+
|
|
225
|
+
status = _sanitize_for_json(getattr(job, 'training_job_status', 'Unknown')) or 'Unknown'
|
|
226
|
+
secondary = _sanitize_for_json(getattr(job, 'secondary_status', '')) or ''
|
|
227
|
+
failure_reason = _sanitize_for_json(getattr(job, 'failure_reason', None))
|
|
228
|
+
|
|
229
|
+
# Elapsed time
|
|
230
|
+
elapsed_seconds = None
|
|
231
|
+
start_time = _sanitize_for_json(getattr(job, 'training_start_time', None))
|
|
232
|
+
end_time = _sanitize_for_json(getattr(job, 'training_end_time', None))
|
|
233
|
+
if start_time:
|
|
234
|
+
from datetime import datetime, timezone
|
|
235
|
+
try:
|
|
236
|
+
if end_time:
|
|
237
|
+
elapsed_seconds = int((end_time - start_time).total_seconds())
|
|
238
|
+
else:
|
|
239
|
+
now = datetime.now(timezone.utc)
|
|
240
|
+
elapsed_seconds = int((now - start_time).total_seconds())
|
|
241
|
+
except (TypeError, AttributeError):
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
# Metrics
|
|
245
|
+
metrics = None
|
|
246
|
+
final_metrics = _sanitize_for_json(getattr(job, 'final_metric_data_list', None))
|
|
247
|
+
if final_metrics:
|
|
248
|
+
try:
|
|
249
|
+
metrics = {
|
|
250
|
+
m.metric_name: m.value
|
|
251
|
+
for m in final_metrics
|
|
252
|
+
if hasattr(m, 'metric_name') and hasattr(m, 'value')
|
|
253
|
+
}
|
|
254
|
+
except (TypeError, AttributeError):
|
|
255
|
+
pass
|
|
256
|
+
|
|
257
|
+
# Model artifacts
|
|
258
|
+
model_artifacts = None
|
|
259
|
+
artifacts_obj = _sanitize_for_json(getattr(job, 'model_artifacts', None))
|
|
260
|
+
if artifacts_obj:
|
|
261
|
+
model_artifacts = _sanitize_for_json(getattr(artifacts_obj, 's3_model_artifacts', None))
|
|
262
|
+
|
|
263
|
+
# Build display line
|
|
264
|
+
emoji_map = {'InProgress': '🔄', 'Completed': '✅', 'Failed': '❌', 'Stopped': '⏹️'}
|
|
265
|
+
emoji = emoji_map.get(status, '❓')
|
|
266
|
+
display_parts = [f" {emoji} {status}"]
|
|
267
|
+
if secondary:
|
|
268
|
+
display_parts.append(f"| {secondary}")
|
|
269
|
+
if elapsed_seconds is not None:
|
|
270
|
+
hours = elapsed_seconds // 3600
|
|
271
|
+
mins = (elapsed_seconds % 3600) // 60
|
|
272
|
+
secs = elapsed_seconds % 60
|
|
273
|
+
if hours > 0:
|
|
274
|
+
display_parts.append(f"| elapsed: {hours}h {mins}m {secs}s")
|
|
275
|
+
elif mins > 0:
|
|
276
|
+
display_parts.append(f"| elapsed: {mins}m {secs}s")
|
|
277
|
+
else:
|
|
278
|
+
display_parts.append(f"| elapsed: {secs}s")
|
|
279
|
+
|
|
280
|
+
_output({
|
|
281
|
+
"status": status,
|
|
282
|
+
"secondary_status": secondary,
|
|
283
|
+
"failure_reason": failure_reason,
|
|
284
|
+
"elapsed_seconds": elapsed_seconds,
|
|
285
|
+
"metrics": metrics,
|
|
286
|
+
"model_artifacts": model_artifacts,
|
|
287
|
+
"display": " ".join(display_parts),
|
|
288
|
+
})
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
# ── cmd_resolve ───────────────────────────────────────────────────────────────
|
|
292
|
+
|
|
293
|
+
def cmd_resolve(args):
|
|
294
|
+
"""Extract model artifact or checkpoint S3 path from a training job.
|
|
295
|
+
|
|
296
|
+
With --checkpoints: returns checkpoint_config.s3_uri (for --resume).
|
|
297
|
+
Without: returns model artifacts path (for adapter staging).
|
|
298
|
+
|
|
299
|
+
Returns: {"artifact_path": str, "output_type": str, "checkpoint_path": str|null}
|
|
300
|
+
"""
|
|
301
|
+
region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
|
|
302
|
+
if region:
|
|
303
|
+
os.environ['AWS_DEFAULT_REGION'] = region
|
|
304
|
+
os.environ.setdefault('AWS_REGION', region)
|
|
305
|
+
|
|
306
|
+
try:
|
|
307
|
+
from sagemaker.core.resources import TrainingJob
|
|
308
|
+
except ImportError:
|
|
309
|
+
_error_exit("sagemaker SDK v3 not installed.")
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
job = TrainingJob.get(training_job_name=args.job_name)
|
|
313
|
+
except Exception as e:
|
|
314
|
+
_error_exit(f"Failed to describe training job '{args.job_name}': {e}")
|
|
315
|
+
|
|
316
|
+
# If --checkpoints flag, return checkpoint path (job can be any status)
|
|
317
|
+
if getattr(args, 'checkpoints', False):
|
|
318
|
+
checkpoint_config = _sanitize_for_json(getattr(job, 'checkpoint_config', None))
|
|
319
|
+
checkpoint_path = None
|
|
320
|
+
if checkpoint_config:
|
|
321
|
+
checkpoint_path = _sanitize_for_json(getattr(checkpoint_config, 's3_uri', None))
|
|
322
|
+
|
|
323
|
+
# Fallback: derive from output path
|
|
324
|
+
if not checkpoint_path:
|
|
325
|
+
output_config = _sanitize_for_json(getattr(job, 'output_data_config', None))
|
|
326
|
+
if output_config:
|
|
327
|
+
s3_output = _sanitize_for_json(getattr(output_config, 's3_output_path', None))
|
|
328
|
+
if s3_output:
|
|
329
|
+
checkpoint_path = f"{s3_output.rstrip('/')}/checkpoints/"
|
|
330
|
+
|
|
331
|
+
_output({
|
|
332
|
+
"checkpoint_path": checkpoint_path or "",
|
|
333
|
+
"job_name": args.job_name,
|
|
334
|
+
})
|
|
335
|
+
return
|
|
336
|
+
|
|
337
|
+
# Normal resolve: require completed status
|
|
338
|
+
status = _sanitize_for_json(getattr(job, 'training_job_status', 'Unknown')) or 'Unknown'
|
|
339
|
+
if status != 'Completed':
|
|
340
|
+
_error_exit(f"Job '{args.job_name}' is not completed (status: {status})")
|
|
341
|
+
|
|
342
|
+
artifacts_obj = _sanitize_for_json(getattr(job, 'model_artifacts', None))
|
|
343
|
+
if not artifacts_obj:
|
|
344
|
+
_error_exit(f"No model artifacts found for job '{args.job_name}'")
|
|
345
|
+
|
|
346
|
+
artifact_path = _sanitize_for_json(getattr(artifacts_obj, 's3_model_artifacts', None))
|
|
347
|
+
if not artifact_path:
|
|
348
|
+
_error_exit(f"No S3 model artifacts path for job '{args.job_name}'")
|
|
349
|
+
|
|
350
|
+
# Detect output type based on technique hint
|
|
351
|
+
output_type = "full-model"
|
|
352
|
+
technique = getattr(args, 'technique', None)
|
|
353
|
+
if technique and technique in ('sft', 'dpo'):
|
|
354
|
+
output_type = "adapter"
|
|
355
|
+
|
|
356
|
+
_output({
|
|
357
|
+
"artifact_path": artifact_path,
|
|
358
|
+
"output_type": output_type,
|
|
359
|
+
})
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
# ── cmd_stop ──────────────────────────────────────────────────────────────────
|
|
363
|
+
|
|
364
|
+
def cmd_stop(args):
|
|
365
|
+
"""Stop a running training job.
|
|
366
|
+
|
|
367
|
+
Returns: {"stopped": true, "job_name": str}
|
|
368
|
+
"""
|
|
369
|
+
region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
|
|
370
|
+
if region:
|
|
371
|
+
os.environ['AWS_DEFAULT_REGION'] = region
|
|
372
|
+
os.environ.setdefault('AWS_REGION', region)
|
|
373
|
+
|
|
374
|
+
try:
|
|
375
|
+
from sagemaker.core.resources import TrainingJob
|
|
376
|
+
except ImportError:
|
|
377
|
+
_error_exit("sagemaker SDK v3 not installed.")
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
job = TrainingJob.get(training_job_name=args.job_name)
|
|
381
|
+
job.stop()
|
|
382
|
+
_output({"stopped": True, "job_name": args.job_name})
|
|
383
|
+
except Exception as e:
|
|
384
|
+
_error_exit(f"Failed to stop training job '{args.job_name}': {e}")
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
# ── Main ──────────────────────────────────────────────────────────────────────
|
|
388
|
+
|
|
389
|
+
def main():
|
|
390
|
+
"""Parse arguments and dispatch to subcommand handler."""
|
|
391
|
+
parser = argparse.ArgumentParser(description='SageMaker Training Job helper (SDK v3)')
|
|
392
|
+
subparsers = parser.add_subparsers(dest='command', required=True)
|
|
393
|
+
|
|
394
|
+
# submit
|
|
395
|
+
submit_parser = subparsers.add_parser('submit', help='Create a training job')
|
|
396
|
+
submit_parser.add_argument('--config', required=True, help='Path to job config JSON')
|
|
397
|
+
submit_parser.add_argument('--region', help='AWS region')
|
|
398
|
+
|
|
399
|
+
# status
|
|
400
|
+
status_parser = subparsers.add_parser('status', help='Get job status')
|
|
401
|
+
status_parser.add_argument('--job-name', required=True, help='Training job name')
|
|
402
|
+
status_parser.add_argument('--region', help='AWS region')
|
|
403
|
+
|
|
404
|
+
# resolve
|
|
405
|
+
resolve_parser = subparsers.add_parser('resolve', help='Resolve artifacts from completed job')
|
|
406
|
+
resolve_parser.add_argument('--job-name', required=True, help='Training job name')
|
|
407
|
+
resolve_parser.add_argument('--technique', help='Training technique (for output type hint)')
|
|
408
|
+
resolve_parser.add_argument('--checkpoints', action='store_true', help='Return checkpoint S3 path instead of model artifacts')
|
|
409
|
+
resolve_parser.add_argument('--region', help='AWS region')
|
|
410
|
+
|
|
411
|
+
# stop
|
|
412
|
+
stop_parser = subparsers.add_parser('stop', help='Stop a running job')
|
|
413
|
+
stop_parser.add_argument('--job-name', required=True, help='Training job name')
|
|
414
|
+
stop_parser.add_argument('--region', help='AWS region')
|
|
415
|
+
|
|
416
|
+
args = parser.parse_args()
|
|
417
|
+
|
|
418
|
+
commands = {
|
|
419
|
+
'submit': cmd_submit,
|
|
420
|
+
'status': cmd_status,
|
|
421
|
+
'resolve': cmd_resolve,
|
|
422
|
+
'stop': cmd_stop,
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
handler = commands.get(args.command)
|
|
426
|
+
if handler:
|
|
427
|
+
handler(args)
|
|
428
|
+
else:
|
|
429
|
+
_error_exit(f"Unknown command: {args.command}")
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
if __name__ == '__main__':
|
|
433
|
+
main()
|
|
Binary file
|
package/templates/do/adapter
CHANGED
|
@@ -35,6 +35,7 @@ _usage() {
|
|
|
35
35
|
echo " add <name> --weights <s3-uri> Add a new LoRA adapter from S3"
|
|
36
36
|
echo " add <name> --from-hub <hf-repo-id> Add a new LoRA adapter from HuggingFace Hub"
|
|
37
37
|
echo " add <name> --from-tune [technique] Add adapter from do/tune output"
|
|
38
|
+
echo " add <name> --from-train [technique] Add adapter from do/train output"
|
|
38
39
|
echo " add <name> --from-registry [arn] Add adapter from model registry"
|
|
39
40
|
echo " list List all adapters on the endpoint"
|
|
40
41
|
echo " remove <name> Remove an adapter"
|
|
@@ -375,6 +376,8 @@ _adapter_add() {
|
|
|
375
376
|
local from_hub=""
|
|
376
377
|
local from_tune=""
|
|
377
378
|
local from_tune_technique=""
|
|
379
|
+
local from_train=""
|
|
380
|
+
local from_train_technique=""
|
|
378
381
|
local from_registry=""
|
|
379
382
|
local registry_arn=""
|
|
380
383
|
local use_local=""
|
|
@@ -413,6 +416,16 @@ _adapter_add() {
|
|
|
413
416
|
shift
|
|
414
417
|
fi
|
|
415
418
|
;;
|
|
419
|
+
--from-train)
|
|
420
|
+
from_train="true"
|
|
421
|
+
# Check if next argument is a technique (not another flag and not empty)
|
|
422
|
+
if [ -n "${2:-}" ] && [[ "${2}" != -* ]]; then
|
|
423
|
+
from_train_technique="$2"
|
|
424
|
+
shift 2
|
|
425
|
+
else
|
|
426
|
+
shift
|
|
427
|
+
fi
|
|
428
|
+
;;
|
|
416
429
|
--from-registry)
|
|
417
430
|
from_registry="true"
|
|
418
431
|
# Check if next argument is an ARN (starts with arn:)
|
|
@@ -508,6 +521,7 @@ _adapter_add() {
|
|
|
508
521
|
[ -n "${weights_uri}" ] && source_count=$((source_count + 1))
|
|
509
522
|
[ -n "${from_hub}" ] && source_count=$((source_count + 1))
|
|
510
523
|
[ -n "${from_tune}" ] && source_count=$((source_count + 1))
|
|
524
|
+
[ -n "${from_train}" ] && source_count=$((source_count + 1))
|
|
511
525
|
[ -n "${from_registry}" ] && source_count=$((source_count + 1))
|
|
512
526
|
|
|
513
527
|
if [ "${source_count}" -gt 1 ]; then
|
|
@@ -868,6 +882,96 @@ _adapter_add() {
|
|
|
868
882
|
fi # end --local else branch
|
|
869
883
|
fi
|
|
870
884
|
|
|
885
|
+
# ── Resolve --from-train to weights_uri ───────────────────────────────
|
|
886
|
+
if [ -n "${from_train}" ]; then
|
|
887
|
+
if [ -n "${from_train_technique}" ]; then
|
|
888
|
+
local technique_upper
|
|
889
|
+
technique_upper=$(echo "${from_train_technique}" | tr '[:lower:]' '[:upper:]')
|
|
890
|
+
local train_var="TRAIN_ADAPTER_PATH_${technique_upper}"
|
|
891
|
+
local train_path="${!train_var:-}"
|
|
892
|
+
|
|
893
|
+
if [ -z "${train_path}" ]; then
|
|
894
|
+
echo "❌ No training adapter output found for technique: ${from_train_technique}"
|
|
895
|
+
echo ""
|
|
896
|
+
echo " ${train_var} is not set in do/config."
|
|
897
|
+
echo ""
|
|
898
|
+
echo " Run a training job first:"
|
|
899
|
+
echo " ./do/train --technique ${from_train_technique} --dataset <source>"
|
|
900
|
+
exit 1
|
|
901
|
+
fi
|
|
902
|
+
|
|
903
|
+
weights_uri="${train_path}"
|
|
904
|
+
echo "📦 Using train adapter output for technique '${from_train_technique}': ${weights_uri}"
|
|
905
|
+
else
|
|
906
|
+
# No technique: read TRAIN_OUTPUT_PATH_LATEST
|
|
907
|
+
if [ -z "${TRAIN_OUTPUT_PATH_LATEST:-}" ]; then
|
|
908
|
+
echo "❌ No training output found."
|
|
909
|
+
echo ""
|
|
910
|
+
echo " TRAIN_OUTPUT_PATH_LATEST is not set in do/config."
|
|
911
|
+
echo ""
|
|
912
|
+
echo " Run a training job first:"
|
|
913
|
+
echo " ./do/train --technique <technique> --dataset <source>"
|
|
914
|
+
exit 1
|
|
915
|
+
fi
|
|
916
|
+
|
|
917
|
+
weights_uri="${TRAIN_OUTPUT_PATH_LATEST}"
|
|
918
|
+
echo "📦 Using latest train adapter output: ${weights_uri}"
|
|
919
|
+
fi
|
|
920
|
+
echo ""
|
|
921
|
+
|
|
922
|
+
# Use same staging path as --from-tune (Processing Job or local)
|
|
923
|
+
if [ -z "${use_local}" ]; then
|
|
924
|
+
echo "🚀 Submitting Processing Job to stage adapter from training output..."
|
|
925
|
+
echo ""
|
|
926
|
+
|
|
927
|
+
local exec_role="${EXECUTION_ROLE_ARN:-}"
|
|
928
|
+
if [ -z "${exec_role}" ]; then
|
|
929
|
+
exec_role="${ROLE_ARN:-}"
|
|
930
|
+
fi
|
|
931
|
+
if [ -z "${exec_role}" ]; then
|
|
932
|
+
exec_role="${SAGEMAKER_ROLE_ARN:-}"
|
|
933
|
+
fi
|
|
934
|
+
if [ -z "${exec_role}" ]; then
|
|
935
|
+
echo "❌ No execution role found."
|
|
936
|
+
echo " Run 'ml-container-creator bootstrap' to set up your profile."
|
|
937
|
+
exit 1
|
|
938
|
+
fi
|
|
939
|
+
|
|
940
|
+
local adapter_bucket="${ADAPTER_S3_BUCKET:-}"
|
|
941
|
+
if [ -z "${adapter_bucket}" ]; then
|
|
942
|
+
local account_id
|
|
943
|
+
account_id=$(aws sts get-caller-identity --query Account --output text 2>/dev/null || echo "")
|
|
944
|
+
adapter_bucket="sagemaker-${AWS_REGION:-us-east-1}-${account_id}"
|
|
945
|
+
fi
|
|
946
|
+
|
|
947
|
+
local adapter_s3_prefix="s3://${adapter_bucket}/${PROJECT_NAME}/adapters/${adapter_name}"
|
|
948
|
+
|
|
949
|
+
local stage_args=(
|
|
950
|
+
--source-uri "${weights_uri}"
|
|
951
|
+
--output-uri "${adapter_s3_prefix}/"
|
|
952
|
+
--role-arn "${exec_role}"
|
|
953
|
+
--region "${AWS_REGION}"
|
|
954
|
+
)
|
|
955
|
+
if [ -n "${no_wait}" ]; then
|
|
956
|
+
stage_args+=(--no-wait)
|
|
957
|
+
fi
|
|
958
|
+
|
|
959
|
+
local stage_result
|
|
960
|
+
stage_result=$(python3 "${SCRIPT_DIR}/.adapter_helper.py" stage "${stage_args[@]}" 2>/dev/null | grep -E '^\{' | tail -1) || {
|
|
961
|
+
echo "❌ Failed to submit adapter staging job"
|
|
962
|
+
exit 1
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
weights_uri=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('adapter_s3_uri',''))" 2>/dev/null) || weights_uri=""
|
|
966
|
+
if [ -z "${weights_uri}" ]; then
|
|
967
|
+
echo "❌ Failed to extract staged adapter URI"
|
|
968
|
+
exit 1
|
|
969
|
+
fi
|
|
970
|
+
echo " ✅ Adapter staged to: ${weights_uri}"
|
|
971
|
+
fi
|
|
972
|
+
echo ""
|
|
973
|
+
fi
|
|
974
|
+
|
|
871
975
|
# ── Resolve --from-registry to weights_uri ────────────────────────────
|
|
872
976
|
if [ -n "${from_registry}" ]; then
|
|
873
977
|
if [ -z "${registry_arn}" ]; then
|
|
@@ -986,6 +1090,13 @@ _adapter_add() {
|
|
|
986
1090
|
# Extract model data URL (weights path)
|
|
987
1091
|
weights_uri=$(echo "${version_line}" | python3 -c "import sys,json; data=json.loads(sys.stdin.read()); print(data.get('modelDataUrl',''))" 2>/dev/null || echo "")
|
|
988
1092
|
|
|
1093
|
+
# Ensure adapter weights URI ends with / (S3 prefix for directory-style adapters).
|
|
1094
|
+
# Registry metadata may have the slash stripped (Bug 52 rstrip), but SageMaker IC
|
|
1095
|
+
# ModelDataUrl requires it to download all objects under the prefix.
|
|
1096
|
+
if [ -n "${weights_uri}" ] && ! echo "${weights_uri}" | grep -q '\.tar\.gz$'; then
|
|
1097
|
+
weights_uri="${weights_uri%/}/"
|
|
1098
|
+
fi
|
|
1099
|
+
|
|
989
1100
|
if [ -z "${weights_uri}" ]; then
|
|
990
1101
|
echo "❌ No model data URL found for registry version: ${registry_arn}"
|
|
991
1102
|
echo ""
|
|
@@ -1294,6 +1405,16 @@ EOF
|
|
|
1294
1405
|
fi
|
|
1295
1406
|
fi
|
|
1296
1407
|
|
|
1408
|
+
# Add train-specific metadata if --from-train was used
|
|
1409
|
+
if [ -n "${from_train}" ]; then
|
|
1410
|
+
local train_technique_meta="${from_train_technique:-${TRAIN_TECHNIQUE:-custom}}"
|
|
1411
|
+
cat >> "${SCRIPT_DIR}/adapters/${adapter_name}.conf" <<EOF
|
|
1412
|
+
export ADAPTER_SOURCE="train"
|
|
1413
|
+
export ADAPTER_TECHNIQUE="${train_technique_meta}"
|
|
1414
|
+
export ADAPTER_TRAIN_JOB="${TRAIN_JOB_NAME:-}"
|
|
1415
|
+
EOF
|
|
1416
|
+
fi
|
|
1417
|
+
|
|
1297
1418
|
# Add registry-specific metadata if --from-registry was used
|
|
1298
1419
|
if [ -n "${from_registry}" ]; then
|
|
1299
1420
|
cat >> "${SCRIPT_DIR}/adapters/${adapter_name}.conf" <<EOF
|
|
@@ -1425,6 +1546,42 @@ if endpoint_name:
|
|
|
1425
1546
|
except Exception:
|
|
1426
1547
|
print("⚠️ Could not query endpoint — showing local confs only.", file=sys.stderr)
|
|
1427
1548
|
|
|
1549
|
+
# ── Data source 3: Registry (MPG) adapters ──
|
|
1550
|
+
# Query the deployment MPG for registered adapter versions (if .register_helper.py exists)
|
|
1551
|
+
helper_path = os.path.join(script_dir, ".register_helper.py")
|
|
1552
|
+
if os.path.exists(helper_path):
|
|
1553
|
+
try:
|
|
1554
|
+
result = subprocess.run(
|
|
1555
|
+
["python3", helper_path, "list-adapters",
|
|
1556
|
+
"--project-name", project_name, "--region", region],
|
|
1557
|
+
capture_output=True, text=True, timeout=15)
|
|
1558
|
+
if result.returncode == 0:
|
|
1559
|
+
# Extract JSON line
|
|
1560
|
+
for line in result.stdout.strip().split("\n"):
|
|
1561
|
+
if line.startswith("{"):
|
|
1562
|
+
reg_data = json.loads(line)
|
|
1563
|
+
for adapter in reg_data.get("adapters", []):
|
|
1564
|
+
reg_name = adapter.get("name", "")
|
|
1565
|
+
if not reg_name:
|
|
1566
|
+
continue
|
|
1567
|
+
# Only add if not already tracked locally
|
|
1568
|
+
if reg_name not in adapters:
|
|
1569
|
+
adapters[reg_name] = {
|
|
1570
|
+
"source": "registry",
|
|
1571
|
+
"ic_name": "",
|
|
1572
|
+
"technique": adapter.get("technique", ""),
|
|
1573
|
+
"dataset": "",
|
|
1574
|
+
"status": f"v{adapter.get('version', '?')}",
|
|
1575
|
+
}
|
|
1576
|
+
else:
|
|
1577
|
+
# Annotate existing entry with registry version
|
|
1578
|
+
ver = adapter.get("version", "")
|
|
1579
|
+
if ver:
|
|
1580
|
+
adapters[reg_name]["status"] += f" (reg:v{ver})"
|
|
1581
|
+
break
|
|
1582
|
+
except Exception:
|
|
1583
|
+
pass # Registry query is best-effort
|
|
1584
|
+
|
|
1428
1585
|
# ── Output ──
|
|
1429
1586
|
if not adapters:
|
|
1430
1587
|
print("No adapters found.")
|