@aws/ml-container-creator 0.6.0 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,279 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Tune Dataset Validator
6
+ *
7
+ * Parses dataset arguments (S3 URIs and Hugging Face references) and
8
+ * validates JSONL dataset lines against catalog-driven schemas.
9
+ *
10
+ * Requirements: 3.1, 3.5, 3.6, 3.7, 3.8, 3.10, 3.11, 3.12
11
+ */
12
+
13
+ /**
14
+ * Parse a dataset argument string into a structured object.
15
+ * Accepts S3 URIs (`s3://bucket/key`) or Hugging Face references
16
+ * (`hf://org/name` or `hf://org/name/split`).
17
+ *
18
+ * @param {string} datasetStr - The dataset argument string
19
+ * @returns {{ valid: boolean, type?: string, bucket?: string, key?: string, org?: string, name?: string, split?: string, error?: string }}
20
+ */
21
+ export function parseDatasetArg(datasetStr) {
22
+ if (!datasetStr || typeof datasetStr !== 'string') {
23
+ return {
24
+ valid: false,
25
+ error: 'Dataset argument is required and must be a non-empty string.'
26
+ };
27
+ }
28
+
29
+ const trimmed = datasetStr.trim();
30
+
31
+ if (trimmed.startsWith('s3://')) {
32
+ return _parseS3Uri(trimmed);
33
+ }
34
+
35
+ if (trimmed.startsWith('hf://')) {
36
+ return _parseHfReference(trimmed);
37
+ }
38
+
39
+ return {
40
+ valid: false,
41
+ error: `Invalid dataset format: "${trimmed}". Expected s3://bucket/key or hf://org/name[/split].`
42
+ };
43
+ }
44
+
45
+ /**
46
+ * Validate JSONL lines against a dataset schema from the catalog.
47
+ * Inspects only the first 10 lines per requirement.
48
+ *
49
+ * @param {string[]} lines - Array of JSONL line strings
50
+ * @param {Object} schema - The datasetSchema object from the catalog
51
+ * @param {string[]} schema.required - Array of required top-level keys
52
+ * @param {Object} schema.types - Object mapping key to expected type ("string", "array", "object", "number")
53
+ * @returns {{ valid: boolean, error: string|null, lineNumber: number|null, malformedLine: string|null, expectedFormat: string|null }}
54
+ */
55
+ export function validateDatasetFormat(lines, schema) {
56
+ if (!lines || !Array.isArray(lines)) {
57
+ return {
58
+ valid: false,
59
+ error: 'Lines must be provided as an array.',
60
+ lineNumber: null,
61
+ malformedLine: null,
62
+ expectedFormat: _buildExpectedFormat(schema)
63
+ };
64
+ }
65
+
66
+ if (!schema || !schema.required || !Array.isArray(schema.required)) {
67
+ return {
68
+ valid: false,
69
+ error: 'Schema must include a "required" array of keys.',
70
+ lineNumber: null,
71
+ malformedLine: null,
72
+ expectedFormat: null
73
+ };
74
+ }
75
+
76
+ const linesToInspect = lines.slice(0, 10);
77
+
78
+ for (let i = 0; i < linesToInspect.length; i++) {
79
+ const line = linesToInspect[i];
80
+ const lineNumber = i + 1;
81
+
82
+ // Skip empty lines
83
+ if (!line || line.trim() === '') {
84
+ continue;
85
+ }
86
+
87
+ // Try to parse as JSON
88
+ let parsed;
89
+ try {
90
+ parsed = JSON.parse(line);
91
+ } catch (e) {
92
+ return {
93
+ valid: false,
94
+ error: `Line ${lineNumber} is not valid JSON: ${e.message}`,
95
+ lineNumber,
96
+ malformedLine: line,
97
+ expectedFormat: _buildExpectedFormat(schema)
98
+ };
99
+ }
100
+
101
+ // Check that parsed value is an object
102
+ if (typeof parsed !== 'object' || parsed === null || Array.isArray(parsed)) {
103
+ return {
104
+ valid: false,
105
+ error: `Line ${lineNumber} must be a JSON object.`,
106
+ lineNumber,
107
+ malformedLine: line,
108
+ expectedFormat: _buildExpectedFormat(schema)
109
+ };
110
+ }
111
+
112
+ // Check required keys
113
+ for (const key of schema.required) {
114
+ if (!Object.hasOwn(parsed, key)) {
115
+ return {
116
+ valid: false,
117
+ error: `Line ${lineNumber} is missing required key "${key}".`,
118
+ lineNumber,
119
+ malformedLine: line,
120
+ expectedFormat: _buildExpectedFormat(schema)
121
+ };
122
+ }
123
+ }
124
+
125
+ // Check types if specified
126
+ if (schema.types) {
127
+ for (const [key, expectedType] of Object.entries(schema.types)) {
128
+ if (!Object.hasOwn(parsed, key)) {
129
+ continue;
130
+ }
131
+
132
+ const value = parsed[key];
133
+ if (!_checkType(value, expectedType)) {
134
+ return {
135
+ valid: false,
136
+ error: `Line ${lineNumber} has key "${key}" with wrong type. Expected "${expectedType}", got "${_getType(value)}".`,
137
+ lineNumber,
138
+ malformedLine: line,
139
+ expectedFormat: _buildExpectedFormat(schema)
140
+ };
141
+ }
142
+ }
143
+ }
144
+ }
145
+
146
+ return {
147
+ valid: true,
148
+ error: null,
149
+ lineNumber: null,
150
+ malformedLine: null,
151
+ expectedFormat: null
152
+ };
153
+ }
154
+
155
+ /**
156
+ * Parse an S3 URI into bucket and key components.
157
+ * @param {string} uri - The S3 URI (e.g., "s3://bucket/path/to/file.jsonl")
158
+ * @returns {Object} Parsed result
159
+ * @private
160
+ */
161
+ function _parseS3Uri(uri) {
162
+ const withoutScheme = uri.slice(5); // Remove "s3://"
163
+ const slashIndex = withoutScheme.indexOf('/');
164
+
165
+ if (slashIndex === -1 || slashIndex === 0) {
166
+ return {
167
+ valid: false,
168
+ error: `Invalid S3 URI: "${uri}". Expected format: s3://bucket/key.`
169
+ };
170
+ }
171
+
172
+ const bucket = withoutScheme.slice(0, slashIndex);
173
+ const key = withoutScheme.slice(slashIndex + 1);
174
+
175
+ if (!bucket) {
176
+ return {
177
+ valid: false,
178
+ error: `Invalid S3 URI: "${uri}". Bucket name is empty.`
179
+ };
180
+ }
181
+
182
+ if (!key) {
183
+ return {
184
+ valid: false,
185
+ error: `Invalid S3 URI: "${uri}". Key path is empty.`
186
+ };
187
+ }
188
+
189
+ return {
190
+ valid: true,
191
+ type: 's3',
192
+ bucket,
193
+ key
194
+ };
195
+ }
196
+
197
+ /**
198
+ * Parse a Hugging Face dataset reference into org, name, and split.
199
+ * Defaults to 'train' split if not specified.
200
+ * @param {string} ref - The HF reference (e.g., "hf://org/name" or "hf://org/name/split")
201
+ * @returns {Object} Parsed result
202
+ * @private
203
+ */
204
+ function _parseHfReference(ref) {
205
+ const withoutScheme = ref.slice(5); // Remove "hf://"
206
+ const parts = withoutScheme.split('/');
207
+
208
+ if (parts.length < 2 || !parts[0] || !parts[1]) {
209
+ return {
210
+ valid: false,
211
+ error: `Invalid Hugging Face reference: "${ref}". Expected format: hf://org/name[/split].`
212
+ };
213
+ }
214
+
215
+ const org = parts[0];
216
+ const name = parts[1];
217
+ const split = parts.length >= 3 && parts[2] ? parts[2] : 'train';
218
+
219
+ return {
220
+ valid: true,
221
+ type: 'hf',
222
+ org,
223
+ name,
224
+ split
225
+ };
226
+ }
227
+
228
+ /**
229
+ * Check if a value matches the expected schema type.
230
+ * @param {*} value - The value to check
231
+ * @param {string} expectedType - One of "string", "array", "object", "number"
232
+ * @returns {boolean} True if the value matches the expected type
233
+ * @private
234
+ */
235
+ function _checkType(value, expectedType) {
236
+ switch (expectedType) {
237
+ case 'string':
238
+ return typeof value === 'string';
239
+ case 'number':
240
+ return typeof value === 'number';
241
+ case 'array':
242
+ return Array.isArray(value);
243
+ case 'object':
244
+ return typeof value === 'object' && value !== null && !Array.isArray(value);
245
+ default:
246
+ return true;
247
+ }
248
+ }
249
+
250
+ /**
251
+ * Get a human-readable type name for a value.
252
+ * @param {*} value - The value to describe
253
+ * @returns {string} The type name
254
+ * @private
255
+ */
256
+ function _getType(value) {
257
+ if (value === null) return 'null';
258
+ if (Array.isArray(value)) return 'array';
259
+ return typeof value;
260
+ }
261
+
262
+ /**
263
+ * Build a human-readable expected format description from a schema.
264
+ * @param {Object} schema - The dataset schema
265
+ * @returns {string|null} Description of expected format
266
+ * @private
267
+ */
268
+ function _buildExpectedFormat(schema) {
269
+ if (!schema || !schema.required) {
270
+ return null;
271
+ }
272
+
273
+ const fields = schema.required.map(key => {
274
+ const type = schema.types && schema.types[key] ? schema.types[key] : 'any';
275
+ return `"${key}": <${type}>`;
276
+ });
277
+
278
+ return `Each line must be a JSON object with: {${fields.join(', ')}}`;
279
+ }
@@ -0,0 +1,66 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Tune Output Resolver
6
+ *
7
+ * Detects output type from training type and generates context-aware
8
+ * next-step commands for deploying tune job artifacts.
9
+ *
10
+ * Requirements: 8.3, 8.11
11
+ */
12
+
13
+ /**
14
+ * Detect the output type based on the training type used for the job.
15
+ * LoRA training produces adapter weights; full-rank produces a full model.
16
+ *
17
+ * @param {string} trainingType - The training type ('lora' or 'full-rank')
18
+ * @returns {string} The output type: 'adapter' for lora, 'full-model' for full-rank
19
+ */
20
+ export function detectOutputType(trainingType) {
21
+ if (trainingType === 'lora') {
22
+ return 'adapter';
23
+ }
24
+ if (trainingType === 'full-rank') {
25
+ return 'full-model';
26
+ }
27
+ return 'adapter';
28
+ }
29
+
30
+ /**
31
+ * Generate context-aware next-step commands based on the output type.
32
+ *
33
+ * For adapter output:
34
+ * - Quick path: ./do/adapter add tuned-${technique} --from-tune
35
+ * - Technique-specific: ./do/adapter add tuned-${technique} --from-tune ${technique}
36
+ * - Explicit path: ./do/adapter add tuned-${technique} --weights ${artifactPath}
37
+ *
38
+ * For full-model output:
39
+ * - Deploy as new IC: ./do/add-ic tuned-v1 --from-tune
40
+ * - Explicit path: ./do/add-ic tuned-v1 --model-data ${artifactPath}
41
+ * - Replace current base: ./do/deploy --force-ic --model-data ${artifactPath}
42
+ *
43
+ * @param {string} outputType - The output type ('adapter' or 'full-model')
44
+ * @param {string} technique - The technique used (e.g., 'sft', 'dpo')
45
+ * @param {string} artifactPath - The S3 path to the output artifact
46
+ * @returns {string[]} Array of suggested next-step commands
47
+ */
48
+ export function generateNextStepCommands(outputType, technique, artifactPath) {
49
+ if (outputType === 'adapter') {
50
+ return [
51
+ `./do/adapter add tuned-${technique} --from-tune`,
52
+ `./do/adapter add tuned-${technique} --from-tune ${technique}`,
53
+ `./do/adapter add tuned-${technique} --weights ${artifactPath}`
54
+ ];
55
+ }
56
+
57
+ if (outputType === 'full-model') {
58
+ return [
59
+ './do/add-ic tuned-v1 --from-tune',
60
+ `./do/add-ic tuned-v1 --model-data ${artifactPath}`,
61
+ `./do/deploy --force-ic --model-data ${artifactPath}`
62
+ ];
63
+ }
64
+
65
+ return [];
66
+ }
@@ -290,6 +290,7 @@ RUN chmod +x /usr/bin/serve_trtllm
290
290
 
291
291
  # Copy startup script
292
292
  COPY code/cuda_compat.sh /usr/bin/cuda_compat.sh
293
+ COPY code/cw_log_forwarder.py /usr/bin/cw_log_forwarder.py
293
294
  COPY code/start_server.sh /usr/bin/start_server.sh
294
295
  RUN chmod +x /usr/bin/start_server.sh /usr/bin/cuda_compat.sh
295
296
 
@@ -307,6 +308,7 @@ COPY code/serving.properties /opt/ml/model/serving.properties
307
308
  # The container will automatically start DJL Serving with the configuration
308
309
  <% } else { %>
309
310
  COPY code/cuda_compat.sh /usr/bin/cuda_compat.sh
311
+ COPY code/cw_log_forwarder.py /usr/bin/cw_log_forwarder.py
310
312
  COPY code/serve /usr/bin/serve
311
313
  RUN chmod 777 /usr/bin/serve /usr/bin/cuda_compat.sh
312
314
 
@@ -0,0 +1,64 @@
1
+ #!/usr/bin/env python3
2
+ """CloudWatch log forwarder — workaround for IC platform log routing gap.
3
+ Pipes stdin to a CW log stream while passing through to stderr.
4
+ Usage: exec > >(python3 /usr/bin/cw_log_forwarder.py) 2>&1
5
+ """
6
+ import sys, os, time, threading
7
+ import boto3
8
+ from botocore.config import Config
9
+
10
+ LOG_GROUP = os.environ.get("CW_LOG_GROUP",
11
+ f"/aws/sagemaker/InferenceComponents/{os.environ.get('INFERENCE_COMPONENT_NAME', os.environ.get('HOSTNAME', 'unknown'))}")
12
+ LOG_STREAM = f"AllTraffic/{os.environ.get('HOSTNAME', 'container')}"
13
+ REGION = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION", "us-west-2"))
14
+
15
+ def main():
16
+ client = boto3.client("logs", region_name=REGION, config=Config(retries={"max_attempts": 2}))
17
+ try:
18
+ client.create_log_group(logGroupName=LOG_GROUP)
19
+ except Exception:
20
+ pass
21
+ try:
22
+ client.create_log_stream(logGroupName=LOG_GROUP, logStreamName=LOG_STREAM)
23
+ except Exception as e:
24
+ # Can't create stream — just passthrough
25
+ for line in sys.stdin:
26
+ sys.stderr.write(line)
27
+ return
28
+
29
+ buf, lock, seq = [], threading.Lock(), [None]
30
+
31
+ def flush():
32
+ with lock:
33
+ if not buf:
34
+ return
35
+ batch = buf[:50]
36
+ del buf[:50]
37
+ events = [{"timestamp": int(t * 1000), "message": m} for t, m in batch]
38
+ kw = {"logGroupName": LOG_GROUP, "logStreamName": LOG_STREAM, "logEvents": events}
39
+ if seq[0]:
40
+ kw["sequenceToken"] = seq[0]
41
+ try:
42
+ r = client.put_log_events(**kw)
43
+ seq[0] = r.get("nextSequenceToken")
44
+ except Exception:
45
+ pass
46
+
47
+ def loop():
48
+ while True:
49
+ time.sleep(2)
50
+ flush()
51
+
52
+ threading.Thread(target=loop, daemon=True).start()
53
+ try:
54
+ for line in sys.stdin:
55
+ sys.stderr.write(line)
56
+ with lock:
57
+ buf.append((time.time(), line.rstrip("\n")))
58
+ except (KeyboardInterrupt, BrokenPipeError):
59
+ pass
60
+ finally:
61
+ flush()
62
+
63
+ if __name__ == "__main__":
64
+ main()
@@ -2,6 +2,11 @@
2
2
  # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
+ # CloudWatch log forwarder — workaround for IC platform log routing gap
6
+ exec > >(python3 /usr/bin/cw_log_forwarder.py) 2>&1
7
+
8
+ echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ') [serve] Container started — PID $$"
9
+
5
10
  # CUDA compatibility setup (required for newer SageMaker inference AMIs)
6
11
  source /usr/bin/cuda_compat.sh 2>/dev/null || true
7
12
 
@@ -270,8 +275,14 @@ for var in "${env_vars[@]}"; do
270
275
 
271
276
  # Remove prefix, convert to lowercase, and replace underscores with dashes
272
277
  arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
278
+
279
+ # Boolean handling: true = flag only, false = skip entirely
280
+ if [ "$value" = "false" ]; then
281
+ continue
282
+ fi
283
+
273
284
  SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
274
- if [ -n "$value" ]; then
285
+ if [ -n "$value" ] && [ "$value" != "true" ]; then
275
286
  SERVER_ARGS+=("$value")
276
287
  fi
277
288
  done