@aws/ml-container-creator 0.15.1 → 1.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +1 -1
- package/servers/endpoint-picker/index.js +24 -4
- package/src/lib/bootstrap-command-handler.js +8 -0
- package/src/lib/bootstrap-profile-manager.js +17 -0
- package/src/lib/bootstrap-provisioners.js +48 -0
- package/src/lib/path-prover-brain.js +57 -0
- package/src/lib/prove-pipeline-executor.js +35 -0
- package/templates/do/.benchmark_writer.py +114 -4
- package/templates/do/.register_helper.py +643 -67
- package/templates/do/.stage_helper.py +1 -0
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +267 -171
- package/templates/do/benchmark +74 -14
- package/templates/do/config +1 -1
- package/templates/do/lib/inference-component.sh +6 -25
- package/templates/do/register +29 -2
- package/templates/do/tune +94 -12
package/templates/do/benchmark
CHANGED
|
@@ -69,6 +69,9 @@ done
|
|
|
69
69
|
# Query the tracked benchmark job, display status, and if completed:
|
|
70
70
|
# download results, display metrics, and write to Athena (if not already done).
|
|
71
71
|
if [ "${ARG_STATUS}" = true ]; then
|
|
72
|
+
# Resolve instance type: BENCHMARK_INSTANCE_TYPE (persisted by main flow) > INSTANCE_TYPE from config
|
|
73
|
+
_STATUS_INSTANCE_TYPE="${BENCHMARK_INSTANCE_TYPE:-${INSTANCE_TYPE:-}}"
|
|
74
|
+
|
|
72
75
|
JOB_NAME="${BENCHMARK_JOB_NAME:-}"
|
|
73
76
|
if [ -z "${JOB_NAME}" ]; then
|
|
74
77
|
echo "❌ No benchmark job tracked"
|
|
@@ -98,7 +101,7 @@ if [ "${ARG_STATUS}" = true ]; then
|
|
|
98
101
|
# Check if results already exist locally
|
|
99
102
|
PROJECT_ROOT="${SCRIPT_DIR}/.."
|
|
100
103
|
LOCAL_RESULTS_DIR="${PROJECT_ROOT}/benchmarks/${JOB_NAME}"
|
|
101
|
-
RESULTS_JSONL=$(find "${LOCAL_RESULTS_DIR}" -name "profile_export.jsonl" -type f 2>/dev/null | head -1)
|
|
104
|
+
RESULTS_JSONL=$(find "${LOCAL_RESULTS_DIR}" -name "profile_export.jsonl" -type f 2>/dev/null | head -1 || true)
|
|
102
105
|
|
|
103
106
|
if [ -z "${RESULTS_JSONL}" ]; then
|
|
104
107
|
echo ""
|
|
@@ -115,19 +118,21 @@ if [ "${ARG_STATUS}" = true ]; then
|
|
|
115
118
|
--region "${AWS_REGION}" --quiet
|
|
116
119
|
# Untar if output.tar.gz exists
|
|
117
120
|
tar_file=""
|
|
118
|
-
tar_file=$(find "${LOCAL_RESULTS_DIR}" -name "output.tar.gz" -type f 2>/dev/null | head -1)
|
|
121
|
+
tar_file=$(find "${LOCAL_RESULTS_DIR}" -name "output.tar.gz" -type f 2>/dev/null | head -1 || true)
|
|
119
122
|
if [ -n "${tar_file}" ]; then
|
|
120
|
-
# Detect whether
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
123
|
+
# Detect whether ALL entries share a common leading directory prefix
|
|
124
|
+
_tar_prefix_count=""
|
|
125
|
+
_tar_prefix_count=$(tar -tzf "${tar_file}" 2>/dev/null | sed 's|/.*||' | sort -u | wc -l | tr -d ' ')
|
|
126
|
+
_tar_first_dir=""
|
|
127
|
+
_tar_first_dir=$(tar -tzf "${tar_file}" 2>/dev/null | head -1)
|
|
128
|
+
if [ "${_tar_prefix_count}" = "1" ] && echo "${_tar_first_dir}" | grep -qE '^[^/]+/$'; then
|
|
124
129
|
tar -xzf "${tar_file}" --strip-components=1 -C "${LOCAL_RESULTS_DIR}/output/" 2>/dev/null || true
|
|
125
130
|
else
|
|
126
131
|
tar -xzf "${tar_file}" -C "${LOCAL_RESULTS_DIR}/output/" 2>/dev/null || true
|
|
127
132
|
fi
|
|
128
133
|
fi
|
|
129
134
|
# Re-search after extraction
|
|
130
|
-
RESULTS_JSONL=$(find "${LOCAL_RESULTS_DIR}" -name "profile_export.jsonl" -type f 2>/dev/null | head -1)
|
|
135
|
+
RESULTS_JSONL=$(find "${LOCAL_RESULTS_DIR}" -name "profile_export.jsonl" -type f 2>/dev/null | head -1 || true)
|
|
131
136
|
echo " ✅ Results downloaded to: benchmarks/${JOB_NAME}/"
|
|
132
137
|
fi
|
|
133
138
|
else
|
|
@@ -140,7 +145,7 @@ if [ "${ARG_STATUS}" = true ]; then
|
|
|
140
145
|
if [ -n "${RESULTS_JSONL}" ] && [ -f "${RESULTS_JSONL}" ]; then
|
|
141
146
|
_WRITER_INPUT="${RESULTS_JSONL}"
|
|
142
147
|
else
|
|
143
|
-
_WRITER_INPUT=$(find "${LOCAL_RESULTS_DIR}" -name "profile_export_aiperf.json" -type f 2>/dev/null | head -1)
|
|
148
|
+
_WRITER_INPUT=$(find "${LOCAL_RESULTS_DIR}" -name "profile_export_aiperf.json" -type f 2>/dev/null | head -1 || true)
|
|
144
149
|
fi
|
|
145
150
|
|
|
146
151
|
if [ -n "${_WRITER_INPUT}" ]; then
|
|
@@ -153,7 +158,8 @@ if [ "${ARG_STATUS}" = true ]; then
|
|
|
153
158
|
--workload "${BENCHMARK_WORKLOAD:-manual}" \
|
|
154
159
|
--concurrency "${BENCHMARK_CONCURRENCY:-2}" \
|
|
155
160
|
--bucket "${CI_BENCHMARK_RESULTS_BUCKET}" \
|
|
156
|
-
--region "${AWS_REGION:-${REGION}}" \
|
|
161
|
+
--region "${AWS_REGION:-${REGION:-us-east-1}}" \
|
|
162
|
+
${_STATUS_INSTANCE_TYPE:+--instance-type "${_STATUS_INSTANCE_TYPE}"} \
|
|
157
163
|
${ADAPTER_ARG:+--adapter-name "${ADAPTER_ARG}"}; then
|
|
158
164
|
echo " ✅ Results persisted to Athena"
|
|
159
165
|
else
|
|
@@ -559,6 +565,7 @@ print(f'Combined {n_metrics} concurrency level results')
|
|
|
559
565
|
--workload "${BENCHMARK_WORKLOAD:-manual}" \
|
|
560
566
|
--bucket "${CI_BENCHMARK_RESULTS_BUCKET}" \
|
|
561
567
|
--region "${AWS_REGION:-${REGION}}" \
|
|
568
|
+
${RESOLVED_INSTANCE_TYPE:+--instance-type "${RESOLVED_INSTANCE_TYPE}"} \
|
|
562
569
|
${ADAPTER_ARG:+--adapter-name "${ADAPTER_ARG}"}; then
|
|
563
570
|
echo "✅ Multi-level benchmark results persisted to S3"
|
|
564
571
|
else
|
|
@@ -799,6 +806,55 @@ fi
|
|
|
799
806
|
|
|
800
807
|
echo "✅ Endpoint is InService: ${ENDPOINT_NAME}"
|
|
801
808
|
|
|
809
|
+
# ── Resolve actual instance type from endpoint ────────────────────────────────
|
|
810
|
+
# For heterogeneous instance pools, INSTANCE_TYPE in do/config may not reflect
|
|
811
|
+
# the actual provisioned instance. Query the endpoint to determine the real type.
|
|
812
|
+
RESOLVED_INSTANCE_TYPE=""
|
|
813
|
+
_EP_JSON=$(aws sagemaker describe-endpoint \
|
|
814
|
+
--endpoint-name "${ENDPOINT_NAME}" \
|
|
815
|
+
--region "${AWS_REGION}" \
|
|
816
|
+
--output json 2>/dev/null) || _EP_JSON=""
|
|
817
|
+
|
|
818
|
+
if [ -n "${_EP_JSON}" ]; then
|
|
819
|
+
# Try InstanceType from the primary variant runtime response
|
|
820
|
+
RESOLVED_INSTANCE_TYPE=$(echo "${_EP_JSON}" | python3 -c "
|
|
821
|
+
import sys, json
|
|
822
|
+
try:
|
|
823
|
+
ep = json.load(sys.stdin)
|
|
824
|
+
variant = ep.get('ProductionVariants', [{}])[0]
|
|
825
|
+
it = variant.get('CurrentInstanceType') or variant.get('InstanceType') or ''
|
|
826
|
+
if it:
|
|
827
|
+
print(it)
|
|
828
|
+
else:
|
|
829
|
+
# Fall back to endpoint config for pool-based endpoints
|
|
830
|
+
print('')
|
|
831
|
+
except:
|
|
832
|
+
print('')
|
|
833
|
+
" 2>/dev/null) || RESOLVED_INSTANCE_TYPE=""
|
|
834
|
+
|
|
835
|
+
# If still empty, query endpoint config for InstancePools
|
|
836
|
+
if [ -z "${RESOLVED_INSTANCE_TYPE}" ]; then
|
|
837
|
+
_EC_NAME=$(echo "${_EP_JSON}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('EndpointConfigName',''))" 2>/dev/null) || _EC_NAME=""
|
|
838
|
+
if [ -n "${_EC_NAME}" ]; then
|
|
839
|
+
RESOLVED_INSTANCE_TYPE=$(aws sagemaker describe-endpoint-config \
|
|
840
|
+
--endpoint-config-name "${_EC_NAME}" \
|
|
841
|
+
--region "${AWS_REGION}" \
|
|
842
|
+
--query 'ProductionVariants[0].InstanceType' \
|
|
843
|
+
--output text 2>/dev/null) || RESOLVED_INSTANCE_TYPE=""
|
|
844
|
+
# 'None' is returned as literal text when the field is null
|
|
845
|
+
[ "${RESOLVED_INSTANCE_TYPE}" = "None" ] && RESOLVED_INSTANCE_TYPE=""
|
|
846
|
+
fi
|
|
847
|
+
fi
|
|
848
|
+
fi
|
|
849
|
+
|
|
850
|
+
# Final fallback: use INSTANCE_TYPE from do/config
|
|
851
|
+
RESOLVED_INSTANCE_TYPE="${RESOLVED_INSTANCE_TYPE:-${INSTANCE_TYPE:-}}"
|
|
852
|
+
|
|
853
|
+
# Persist to do/config for --status (endpoint may be gone by then)
|
|
854
|
+
if [ -n "${RESOLVED_INSTANCE_TYPE}" ]; then
|
|
855
|
+
_update_benchmark_var "BENCHMARK_INSTANCE_TYPE" "${RESOLVED_INSTANCE_TYPE}"
|
|
856
|
+
fi
|
|
857
|
+
|
|
802
858
|
# ── Pre-flight check: Ensure S3 output bucket exists ──────────────────────────
|
|
803
859
|
echo "🔍 Pre-flight: Checking S3 output bucket..."
|
|
804
860
|
|
|
@@ -1097,11 +1153,14 @@ if [ "${JOB_STATUS}" = "Completed" ]; then
|
|
|
1097
1153
|
# Extract any tar.gz archives (benchmark service packages results as output.tar.gz)
|
|
1098
1154
|
for ARCHIVE in $(find "${LOCAL_RESULTS_DIR}" -name "*.tar.gz" -type f 2>/dev/null); do
|
|
1099
1155
|
ARCHIVE_DIR=$(dirname "${ARCHIVE}")
|
|
1100
|
-
# Detect whether
|
|
1101
|
-
#
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1156
|
+
# Detect whether ALL entries share a common leading directory prefix.
|
|
1157
|
+
# Only strip if every entry starts with the same dir (e.g., "output/file1", "output/file2").
|
|
1158
|
+
# A flat archive with mixed top-level files/dirs (e.g., "plots/", "profile_export.jsonl")
|
|
1159
|
+
# must NOT be stripped.
|
|
1160
|
+
_TAR_PREFIX=$(tar -tzf "${ARCHIVE}" 2>/dev/null | sed 's|/.*||' | sort -u | wc -l | tr -d ' ')
|
|
1161
|
+
_TAR_FIRST_DIR=$(tar -tzf "${ARCHIVE}" 2>/dev/null | head -1)
|
|
1162
|
+
if [ "${_TAR_PREFIX}" = "1" ] && echo "${_TAR_FIRST_DIR}" | grep -qE '^[^/]+/$'; then
|
|
1163
|
+
# Single common leading directory (e.g., all under "output/") — strip it
|
|
1105
1164
|
tar -xzf "${ARCHIVE}" --strip-components=1 -C "${ARCHIVE_DIR}" 2>/dev/null || true
|
|
1106
1165
|
else
|
|
1107
1166
|
# Flat archive — extract as-is
|
|
@@ -1366,6 +1425,7 @@ except Exception as e:
|
|
|
1366
1425
|
--concurrency "${BENCHMARK_CONCURRENCY}" \
|
|
1367
1426
|
--bucket "${CI_BENCHMARK_RESULTS_BUCKET}" \
|
|
1368
1427
|
--region "${AWS_REGION:-${REGION}}" \
|
|
1428
|
+
${RESOLVED_INSTANCE_TYPE:+--instance-type "${RESOLVED_INSTANCE_TYPE}"} \
|
|
1369
1429
|
${ADAPTER_ARG:+--adapter-name "${ADAPTER_ARG}"}; then
|
|
1370
1430
|
echo "✅ Benchmark results persisted to S3"
|
|
1371
1431
|
else
|
package/templates/do/config
CHANGED
|
@@ -46,7 +46,7 @@ export INSTANCE_TYPE="<%= instanceType %>"
|
|
|
46
46
|
<% if (typeof instancePools !== 'undefined' && instancePools && instancePools.length > 1) { %>
|
|
47
47
|
# Instance pools: heterogeneous instance types with priority-based fallback
|
|
48
48
|
# Priority = selection order (1 = preferred, higher = fallback)
|
|
49
|
-
export INSTANCE_POOLS='
|
|
49
|
+
export INSTANCE_POOLS='<%- JSON.stringify(instancePools) %>'
|
|
50
50
|
<% } else { %>
|
|
51
51
|
# Instance pools: heterogeneous instance types with priority-based fallback (uncomment to enable)
|
|
52
52
|
# Format: [{"InstanceType":"ml.g6e.48xlarge","Priority":1},{"InstanceType":"ml.g5.48xlarge","Priority":2}]
|
|
@@ -112,31 +112,12 @@ create_inference_component() {
|
|
|
112
112
|
|
|
113
113
|
# Build specification JSON — multi-spec (Specifications array) or single (Specification object)
|
|
114
114
|
local spec_json
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
local spec_gpu_count_var="IC_SPEC_${i}_GPU_COUNT"
|
|
122
|
-
local spec_min_memory_var="IC_SPEC_${i}_MIN_MEMORY_MB"
|
|
123
|
-
|
|
124
|
-
local spec_instance_type="${!spec_instance_type_var}"
|
|
125
|
-
local spec_gpu_count="${!spec_gpu_count_var:-1}"
|
|
126
|
-
local spec_min_memory="${!spec_min_memory_var:-1024}"
|
|
127
|
-
|
|
128
|
-
if [ "${i}" -gt 1 ]; then
|
|
129
|
-
spec_json="${spec_json},"
|
|
130
|
-
fi
|
|
131
|
-
spec_json="${spec_json}{\"Container\":${container_spec},\"StartupParameters\":{\"ContainerStartupHealthCheckTimeoutInSeconds\":${IC_STARTUP_TIMEOUT:-900}},\"ComputeResourceRequirements\":{\"NumberOfAcceleratorDevicesRequired\":${spec_gpu_count},\"MinMemoryRequiredInMb\":${spec_min_memory}}}"
|
|
132
|
-
|
|
133
|
-
i=$((i + 1))
|
|
134
|
-
done
|
|
135
|
-
spec_json="${spec_json}]}"
|
|
136
|
-
else
|
|
137
|
-
# Single spec: standard Specification object (existing behavior)
|
|
138
|
-
spec_json="{\"Container\":${container_spec},\"StartupParameters\":{\"ContainerStartupHealthCheckTimeoutInSeconds\":${IC_STARTUP_TIMEOUT:-900}},\"ComputeResourceRequirements\":{\"NumberOfAcceleratorDevicesRequired\":${IC_GPU_COUNT:-1},\"MinMemoryRequiredInMb\":${IC_MIN_MEMORY_MB:-1024}}}"
|
|
139
|
-
fi
|
|
115
|
+
# Always use singular Specification. For heterogeneous instance pools, the IC
|
|
116
|
+
# declares its minimum resource requirements and SageMaker places it on whatever
|
|
117
|
+
# instance was provisioned from the pool. Multi-spec (Specifications plural) is
|
|
118
|
+
# only needed when you want different configurations per instance type (e.g.,
|
|
119
|
+
# different TP, different model artifact) — a future optimization.
|
|
120
|
+
spec_json="{\"Container\":${container_spec},\"StartupParameters\":{\"ContainerStartupHealthCheckTimeoutInSeconds\":${IC_STARTUP_TIMEOUT:-900}},\"ComputeResourceRequirements\":{\"NumberOfAcceleratorDevicesRequired\":${IC_GPU_COUNT:-1},\"MinMemoryRequiredInMb\":${IC_MIN_MEMORY_MB:-1024}}}"
|
|
140
121
|
|
|
141
122
|
echo "📦 Creating inference component: ${ic_name}"
|
|
142
123
|
if ! aws sagemaker create-inference-component \
|
package/templates/do/register
CHANGED
|
@@ -41,6 +41,7 @@ _show_usage() {
|
|
|
41
41
|
echo " --technique <tech> Technique: sft, dpo, rlaif, rlvr (default: sft)"
|
|
42
42
|
echo " --row-count <n> Number of records"
|
|
43
43
|
echo " --column-schema <j> Column schema as JSON string"
|
|
44
|
+
echo " --force Force new version even if content is unchanged"
|
|
44
45
|
echo ""
|
|
45
46
|
echo "Evaluator options:"
|
|
46
47
|
echo " <name> Evaluator name (required, positional)"
|
|
@@ -89,6 +90,7 @@ DATASET_FORMAT="jsonl"
|
|
|
89
90
|
DATASET_TECHNIQUE="sft"
|
|
90
91
|
DATASET_ROW_COUNT=""
|
|
91
92
|
DATASET_COLUMN_SCHEMA=""
|
|
93
|
+
DATASET_FORCE=false
|
|
92
94
|
EVALUATOR_NAME=""
|
|
93
95
|
EVALUATOR_TYPE=""
|
|
94
96
|
EVALUATOR_ARN_OR_URI=""
|
|
@@ -135,6 +137,7 @@ if [ "${SUBCOMMAND}" = "dataset" ]; then
|
|
|
135
137
|
--dataset-technique) DATASET_TECHNIQUE="$2"; shift 2 ;;
|
|
136
138
|
--dataset-row-count) DATASET_ROW_COUNT="$2"; shift 2 ;;
|
|
137
139
|
--dataset-column-schema) DATASET_COLUMN_SCHEMA="$2"; shift 2 ;;
|
|
140
|
+
--force) DATASET_FORCE=true; shift ;;
|
|
138
141
|
--help|-h) _show_usage; exit 0 ;;
|
|
139
142
|
*) echo "⚠️ Unknown dataset option: $1"; _show_usage; exit 1 ;;
|
|
140
143
|
esac
|
|
@@ -196,7 +199,10 @@ if [ "${SUBCOMMAND}" = "dataset" ]; then
|
|
|
196
199
|
_slug=$(basename "${_source}" .jsonl)
|
|
197
200
|
fi
|
|
198
201
|
_slug=$(echo "${_slug}" | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z0-9]/-/g' | sed 's/--*/-/g' | sed 's/^-//' | sed 's/-$//')
|
|
199
|
-
|
|
202
|
+
# Salt with 4-char hash of S3 URI to prevent slug conflicts for
|
|
203
|
+
# same repo with different preprocessing/split/technique
|
|
204
|
+
_salt=$(echo "${DATASET_S3_URI}" | shasum | cut -c1-4)
|
|
205
|
+
DATASET_NAME="${_slug:-dataset}-${_salt}"
|
|
200
206
|
fi
|
|
201
207
|
fi
|
|
202
208
|
|
|
@@ -427,6 +433,11 @@ if [ "${SUBCOMMAND}" = "dataset" ]; then
|
|
|
427
433
|
fi
|
|
428
434
|
|
|
429
435
|
DS_ARGS+=("--project-name" "${PROJECT_NAME}")
|
|
436
|
+
DS_ARGS+=("--region" "${AWS_REGION}")
|
|
437
|
+
|
|
438
|
+
if [ "${DATASET_FORCE}" = true ]; then
|
|
439
|
+
DS_ARGS+=("--force")
|
|
440
|
+
fi
|
|
430
441
|
|
|
431
442
|
# Call .register_helper.py register-dataset
|
|
432
443
|
if ds_output=$(python3 "${SCRIPT_DIR}/.register_helper.py" "${DS_ARGS[@]}" 2>/dev/null); then
|
|
@@ -1288,9 +1299,10 @@ elif [ -n "${MODEL_PKG_ARN:-}" ] && [ -d "${SCRIPT_DIR}/adapters" ]; then
|
|
|
1288
1299
|
ADAPTER_TECHNIQUE=""
|
|
1289
1300
|
eval "$(grep '^export ADAPTER_WEIGHTS_URI=' "${conf}" 2>/dev/null)" 2>/dev/null || true
|
|
1290
1301
|
eval "$(grep '^export ADAPTER_TECHNIQUE=' "${conf}" 2>/dev/null)" 2>/dev/null || true
|
|
1302
|
+
eval "$(grep '^export ADAPTER_TUNE_TECHNIQUE=' "${conf}" 2>/dev/null)" 2>/dev/null || true
|
|
1291
1303
|
|
|
1292
1304
|
_ADAPTER_DATA_URL="${ADAPTER_WEIGHTS_URI:-}"
|
|
1293
|
-
_ADAPTER_TECHNIQUE="${ADAPTER_TECHNIQUE:-${TUNE_TECHNIQUE:-}}"
|
|
1305
|
+
_ADAPTER_TECHNIQUE="${ADAPTER_TECHNIQUE:-${ADAPTER_TUNE_TECHNIQUE:-${TUNE_TECHNIQUE:-}}}"
|
|
1294
1306
|
|
|
1295
1307
|
echo ""
|
|
1296
1308
|
echo "📦 Registering adapter: ${_ADAPTER_NAME}"
|
|
@@ -1316,6 +1328,21 @@ elif [ -n "${MODEL_PKG_ARN:-}" ] && [ -d "${SCRIPT_DIR}/adapters" ]; then
|
|
|
1316
1328
|
[ -n "${AWS_REGION:-}" ] && ADAPTER_REG_ARGS+=("--region" "${AWS_REGION}")
|
|
1317
1329
|
[ -n "${ROLE_ARN:-}" ] && ADAPTER_REG_ARGS+=("--role-arn" "${ROLE_ARN}")
|
|
1318
1330
|
|
|
1331
|
+
# Include dataset lineage from tune state (AC-2.7: version reference for reproducibility)
|
|
1332
|
+
if [ -n "${TUNE_DATASET_S3_URI:-}" ]; then
|
|
1333
|
+
ADAPTER_REG_ARGS+=("--dataset-s3-uri" "${TUNE_DATASET_S3_URI}")
|
|
1334
|
+
fi
|
|
1335
|
+
# Resolve dataset version from technique-specific or generic config var
|
|
1336
|
+
_ds_version=""
|
|
1337
|
+
if [ -n "${_ADAPTER_TECHNIQUE}" ]; then
|
|
1338
|
+
_ds_ver_var="TUNE_DATASET_VERSION_$(echo "${_ADAPTER_TECHNIQUE}" | tr '[:lower:]' '[:upper:]')"
|
|
1339
|
+
_ds_version="${!_ds_ver_var:-}"
|
|
1340
|
+
fi
|
|
1341
|
+
[ -z "${_ds_version}" ] && _ds_version="${TUNE_DATASET_VERSION:-}"
|
|
1342
|
+
if [ -n "${_ds_version}" ]; then
|
|
1343
|
+
ADAPTER_REG_ARGS+=("--dataset-version" "${_ds_version}")
|
|
1344
|
+
fi
|
|
1345
|
+
|
|
1319
1346
|
# Call .register_helper.py register-adapter — non-fatal on failure
|
|
1320
1347
|
if adapter_output=$(python3 "${SCRIPT_DIR}/.register_helper.py" "${ADAPTER_REG_ARGS[@]}" 2>/dev/null); then
|
|
1321
1348
|
adapter_json=$(echo "${adapter_output}" | grep -E '^\{' | tail -1)
|
package/templates/do/tune
CHANGED
|
@@ -52,7 +52,9 @@ ARG_COLUMN_MAP=""
|
|
|
52
52
|
ARG_TAKE=""
|
|
53
53
|
ARG_ACCEPT_EULA=false
|
|
54
54
|
ARG_DATASET_NAME=""
|
|
55
|
+
ARG_DATASET_VERSION=""
|
|
55
56
|
ARG_EVALUATOR_NAME=""
|
|
57
|
+
ARG_NO_REGISTER=false
|
|
56
58
|
|
|
57
59
|
|
|
58
60
|
# ── _parse_args() ─────────────────────────────────────────────────────────────
|
|
@@ -147,6 +149,7 @@ _parse_args() {
|
|
|
147
149
|
--force) ARG_FORCE=true; shift ;;
|
|
148
150
|
--accept-eula) ARG_ACCEPT_EULA=true; shift ;;
|
|
149
151
|
--no-wait) ARG_NO_WAIT=true; shift ;;
|
|
152
|
+
--no-register) ARG_NO_REGISTER=true; shift ;;
|
|
150
153
|
--status) ARG_STATUS=true; shift ;;
|
|
151
154
|
--help|-h) ARG_HELP=true; shift ;;
|
|
152
155
|
--dry-run) ARG_DRY_RUN=true; shift ;;
|
|
@@ -273,6 +276,8 @@ _show_help() {
|
|
|
273
276
|
echo " --force Force new job even if one exists for this technique"
|
|
274
277
|
echo " --accept-eula Accept model EULA (required for gated models like Llama)"
|
|
275
278
|
echo " --no-wait Submit and exit without polling for completion"
|
|
279
|
+
echo " --no-register Skip auto-stage and auto-register after completion"
|
|
280
|
+
echo " (prints next-step commands instead)"
|
|
276
281
|
echo " --status Show status of all tracked tune jobs"
|
|
277
282
|
echo ""
|
|
278
283
|
echo "Dataset options:"
|
|
@@ -791,13 +796,26 @@ else:
|
|
|
791
796
|
_validate_dataset() {
|
|
792
797
|
local dataset="${ARG_DATASET}"
|
|
793
798
|
|
|
799
|
+
# ── Parse @v<N> version suffix (AC-2.1, AC-2.3) ──────────────────────────
|
|
800
|
+
# Syntax: dataset-name@v2 → name="dataset-name", version ordinal=2
|
|
801
|
+
if [[ "${dataset}" =~ ^(.+)@v([0-9]+)$ ]]; then
|
|
802
|
+
ARG_DATASET_NAME="${BASH_REMATCH[1]}"
|
|
803
|
+
ARG_DATASET_VERSION="${BASH_REMATCH[2]}"
|
|
804
|
+
dataset="" # Clear so name-based resolution takes over
|
|
805
|
+
fi
|
|
806
|
+
|
|
794
807
|
# If --dataset-name is set, resolve from registry (AC-2b.4)
|
|
795
808
|
# --dataset-name takes precedence over --dataset for named registry lookup
|
|
796
809
|
if [ -n "${ARG_DATASET_NAME}" ]; then
|
|
797
810
|
echo "🔍 Resolving dataset '${ARG_DATASET_NAME}' from registry..."
|
|
811
|
+
local resolve_args=("--name" "${ARG_DATASET_NAME}")
|
|
812
|
+
if [ -n "${ARG_DATASET_VERSION}" ]; then
|
|
813
|
+
resolve_args+=("--version" "${ARG_DATASET_VERSION}")
|
|
814
|
+
echo " Version: v${ARG_DATASET_VERSION}"
|
|
815
|
+
fi
|
|
798
816
|
local resolve_result
|
|
799
817
|
resolve_result=$(python3 "${SCRIPT_DIR}/.register_helper.py" resolve-dataset \
|
|
800
|
-
|
|
818
|
+
"${resolve_args[@]}" 2>/dev/null) || resolve_result=""
|
|
801
819
|
|
|
802
820
|
if [ -n "${resolve_result}" ]; then
|
|
803
821
|
local resolved_uri
|
|
@@ -830,9 +848,14 @@ _validate_dataset() {
|
|
|
830
848
|
if [ -n "${ARG_DATASET_NAME}" ]; then
|
|
831
849
|
# Name-based resolution happens below via resolve-dataset
|
|
832
850
|
echo "🔍 Resolving dataset '${ARG_DATASET_NAME}' from registry..."
|
|
851
|
+
local resolve_args=("--name" "${ARG_DATASET_NAME}")
|
|
852
|
+
if [ -n "${ARG_DATASET_VERSION}" ]; then
|
|
853
|
+
resolve_args+=("--version" "${ARG_DATASET_VERSION}")
|
|
854
|
+
echo " Version: v${ARG_DATASET_VERSION}"
|
|
855
|
+
fi
|
|
833
856
|
local resolve_result
|
|
834
857
|
resolve_result=$(python3 "${SCRIPT_DIR}/.register_helper.py" resolve-dataset \
|
|
835
|
-
|
|
858
|
+
"${resolve_args[@]}" 2>/dev/null) || resolve_result=""
|
|
836
859
|
|
|
837
860
|
if [ -n "${resolve_result}" ]; then
|
|
838
861
|
local resolved_uri
|
|
@@ -1330,6 +1353,10 @@ print(entry.get('provider', ''))
|
|
|
1330
1353
|
_update_config_var "TUNE_DATASET_S3_URI_${technique_upper}" "${RESOLVED_DATASET_S3_URI:-}"
|
|
1331
1354
|
_update_config_var "TUNE_DATASET_ROW_COUNT_${technique_upper}" "${RESOLVED_DATASET_ROW_COUNT:-0}"
|
|
1332
1355
|
_update_config_var "TUNE_DATASET_SOURCE_${technique_upper}" "${ARG_DATASET}"
|
|
1356
|
+
# Store dataset version ordinal if pinned (AC-2.6)
|
|
1357
|
+
if [ -n "${ARG_DATASET_VERSION}" ]; then
|
|
1358
|
+
_update_config_var "TUNE_DATASET_VERSION_${technique_upper}" "${ARG_DATASET_VERSION}"
|
|
1359
|
+
fi
|
|
1333
1360
|
}
|
|
1334
1361
|
|
|
1335
1362
|
|
|
@@ -1528,12 +1555,59 @@ _handle_completion() {
|
|
|
1528
1555
|
_update_config_var "TUNE_OUTPUT_PATH_LATEST" "${artifact_path}"
|
|
1529
1556
|
_update_config_var "TUNE_OUTPUT_TYPE_LATEST" "${output_type}"
|
|
1530
1557
|
|
|
1531
|
-
#
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1558
|
+
# Auto-register or print next-step commands
|
|
1559
|
+
if [ "${output_type}" = "adapter" ] && [ "${ARG_NO_REGISTER}" != true ]; then
|
|
1560
|
+
# Auto-register: stage adapter and register in deployment MPG
|
|
1561
|
+
local dataset_slug
|
|
1562
|
+
dataset_slug=$(_derive_dataset_slug "${ARG_DATASET:-}")
|
|
1563
|
+
local adapter_name="tuned-${ARG_TECHNIQUE}-${dataset_slug}"
|
|
1564
|
+
if [ -z "${dataset_slug}" ]; then
|
|
1565
|
+
adapter_name="tuned-${ARG_TECHNIQUE}"
|
|
1566
|
+
fi
|
|
1567
|
+
|
|
1568
|
+
echo "🔄 Auto-registering adapter: ${adapter_name}"
|
|
1569
|
+
echo ""
|
|
1570
|
+
|
|
1571
|
+
# Step 1: Stage the adapter via do/adapter add
|
|
1572
|
+
local adapter_add_output
|
|
1573
|
+
if adapter_add_output=$("${SCRIPT_DIR}/adapter" add "${adapter_name}" --from-tune "${ARG_TECHNIQUE}" 2>&1); then
|
|
1574
|
+
echo " ✅ Adapter staged: ${adapter_name}"
|
|
1575
|
+
|
|
1576
|
+
# Step 2: Register in deployment MPG via do/register
|
|
1577
|
+
local register_output
|
|
1578
|
+
if register_output=$("${SCRIPT_DIR}/register" 2>&1); then
|
|
1579
|
+
echo " ✅ Registration complete"
|
|
1580
|
+
|
|
1581
|
+
# Step 3: Extract adapter deployment ARN from register output
|
|
1582
|
+
local adapter_deploy_arn
|
|
1583
|
+
adapter_deploy_arn=$(echo "${register_output}" | grep "${adapter_name}" | grep -E '^\{' | tail -1 | jq -r '.model_package_arn' 2>/dev/null) || adapter_deploy_arn=""
|
|
1584
|
+
|
|
1585
|
+
if [ -n "${adapter_deploy_arn}" ] && [ "${adapter_deploy_arn}" != "null" ]; then
|
|
1586
|
+
_update_config_var "TUNE_ADAPTER_DEPLOY_ARN_${technique_upper}" "${adapter_deploy_arn}"
|
|
1587
|
+
echo " ✅ Deployment ARN stored: ${adapter_deploy_arn}"
|
|
1588
|
+
else
|
|
1589
|
+
echo " ⚠️ Could not extract adapter deployment ARN from register output"
|
|
1590
|
+
echo " (adapter was staged and registered — ARN can be found via do/register --status)"
|
|
1591
|
+
fi
|
|
1592
|
+
echo ""
|
|
1593
|
+
else
|
|
1594
|
+
echo " ⚠️ Registration failed (adapter was staged successfully)"
|
|
1595
|
+
echo " Run manually: ./do/register"
|
|
1596
|
+
echo ""
|
|
1597
|
+
fi
|
|
1598
|
+
else
|
|
1599
|
+
echo " ⚠️ Adapter staging failed"
|
|
1600
|
+
echo " Run manually:"
|
|
1601
|
+
echo " ./do/adapter add ${adapter_name} --from-tune ${ARG_TECHNIQUE}"
|
|
1602
|
+
echo " ./do/register"
|
|
1603
|
+
echo ""
|
|
1604
|
+
fi
|
|
1605
|
+
elif [ "${output_type}" = "adapter" ]; then
|
|
1606
|
+
# --no-register: print next steps as before
|
|
1535
1607
|
local dataset_slug
|
|
1536
1608
|
dataset_slug=$(_derive_dataset_slug "${ARG_DATASET:-}")
|
|
1609
|
+
echo "📋 Next steps:"
|
|
1610
|
+
echo ""
|
|
1537
1611
|
echo " Deploy as LoRA adapter:"
|
|
1538
1612
|
echo " ./do/adapter add tuned-${ARG_TECHNIQUE} --from-tune"
|
|
1539
1613
|
echo " ./do/adapter add tuned-${ARG_TECHNIQUE} --from-tune ${ARG_TECHNIQUE}"
|
|
@@ -1541,13 +1615,16 @@ _handle_completion() {
|
|
|
1541
1615
|
echo " ./do/adapter add tuned-${ARG_TECHNIQUE}-${dataset_slug} --from-tune ${ARG_TECHNIQUE}-${dataset_slug}"
|
|
1542
1616
|
fi
|
|
1543
1617
|
echo " ./do/adapter add tuned-${ARG_TECHNIQUE} --weights ${artifact_path}"
|
|
1618
|
+
echo ""
|
|
1544
1619
|
else
|
|
1620
|
+
echo "📋 Next steps:"
|
|
1621
|
+
echo ""
|
|
1545
1622
|
echo " Deploy as new inference component:"
|
|
1546
1623
|
echo " ./do/add-ic tuned-v1 --from-tune"
|
|
1547
1624
|
echo " ./do/add-ic tuned-v1 --model-data ${artifact_path}"
|
|
1548
1625
|
echo " ./do/deploy --force-ic --model-data ${artifact_path}"
|
|
1626
|
+
echo ""
|
|
1549
1627
|
fi
|
|
1550
|
-
echo ""
|
|
1551
1628
|
}
|
|
1552
1629
|
|
|
1553
1630
|
|
|
@@ -1648,17 +1725,21 @@ if [ "${ARG_LIST_DATASETS}" = true ]; then
|
|
|
1648
1725
|
if [ -n "${_ds_json}" ]; then
|
|
1649
1726
|
_ds_count=$(echo "${_ds_json}" | python3 -c "import sys,json; print(len(json.load(sys.stdin).get('datasets',[])))" 2>/dev/null) || _ds_count=0
|
|
1650
1727
|
if [ "${_ds_count}" -gt 0 ]; then
|
|
1651
|
-
printf " %-
|
|
1652
|
-
printf " %-
|
|
1728
|
+
printf " %-20s %-10s %-10s %-8s %s\n" "NAME" "TECHNIQUE" "LATEST" "ROWS" "S3 URI"
|
|
1729
|
+
printf " %-20s %-10s %-10s %-8s %s\n" "----" "---------" "------" "----" "------"
|
|
1653
1730
|
echo "${_ds_json}" | python3 -c "
|
|
1654
1731
|
import sys, json
|
|
1655
1732
|
data = json.load(sys.stdin)
|
|
1656
1733
|
for ds in data.get('datasets', []):
|
|
1657
|
-
name = ds.get('name','')[:
|
|
1734
|
+
name = ds.get('name','')[:20]
|
|
1658
1735
|
tech = ds.get('technique','')[:10]
|
|
1659
|
-
|
|
1736
|
+
latest = ds.get('latest_version','')[:10]
|
|
1737
|
+
ver_count = ds.get('version_count', 1)
|
|
1738
|
+
if ver_count > 1:
|
|
1739
|
+
latest = f'{latest} ({ver_count}v)'
|
|
1740
|
+
rows = str(ds.get('row_count','') or '')[:8]
|
|
1660
1741
|
uri = ds.get('s3_uri','')
|
|
1661
|
-
print(f' {name:<
|
|
1742
|
+
print(f' {name:<20} {tech:<10} {latest:<10} {rows:<8} {uri}')
|
|
1662
1743
|
" 2>/dev/null
|
|
1663
1744
|
else
|
|
1664
1745
|
echo " (none registered)"
|
|
@@ -1669,6 +1750,7 @@ for ds in data.get('datasets', []):
|
|
|
1669
1750
|
echo ""
|
|
1670
1751
|
echo " Register: ./do/register dataset <name> --s3-uri <uri> --technique <sft|dpo>"
|
|
1671
1752
|
echo " Use: ./do/tune --technique sft --dataset <name>"
|
|
1753
|
+
echo " Versions: python3 .register_helper.py list-dataset-versions --name <name>"
|
|
1672
1754
|
echo ""
|
|
1673
1755
|
exit 0
|
|
1674
1756
|
fi
|