@aws/ml-container-creator 0.6.1 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@aws/ml-container-creator",
3
- "version": "0.6.1",
3
+ "version": "0.7.1",
4
4
  "description": "Generator for SageMaker AI BYOC paradigm for predictive inference use-cases.",
5
5
  "type": "module",
6
6
  "main": "src/app.js",
@@ -290,6 +290,7 @@ RUN chmod +x /usr/bin/serve_trtllm
290
290
 
291
291
  # Copy startup script
292
292
  COPY code/cuda_compat.sh /usr/bin/cuda_compat.sh
293
+ COPY code/cw_log_forwarder.py /usr/bin/cw_log_forwarder.py
293
294
  COPY code/start_server.sh /usr/bin/start_server.sh
294
295
  RUN chmod +x /usr/bin/start_server.sh /usr/bin/cuda_compat.sh
295
296
 
@@ -307,6 +308,7 @@ COPY code/serving.properties /opt/ml/model/serving.properties
307
308
  # The container will automatically start DJL Serving with the configuration
308
309
  <% } else { %>
309
310
  COPY code/cuda_compat.sh /usr/bin/cuda_compat.sh
311
+ COPY code/cw_log_forwarder.py /usr/bin/cw_log_forwarder.py
310
312
  COPY code/serve /usr/bin/serve
311
313
  RUN chmod 777 /usr/bin/serve /usr/bin/cuda_compat.sh
312
314
 
@@ -0,0 +1,64 @@
1
+ #!/usr/bin/env python3
2
+ """CloudWatch log forwarder — workaround for IC platform log routing gap.
3
+ Pipes stdin to a CW log stream while passing through to stderr.
4
+ Usage: exec > >(python3 /usr/bin/cw_log_forwarder.py) 2>&1
5
+ """
6
+ import sys, os, time, threading
7
+ import boto3
8
+ from botocore.config import Config
9
+
10
+ LOG_GROUP = os.environ.get("CW_LOG_GROUP",
11
+ f"/aws/sagemaker/InferenceComponents/{os.environ.get('INFERENCE_COMPONENT_NAME', os.environ.get('HOSTNAME', 'unknown'))}")
12
+ LOG_STREAM = f"AllTraffic/{os.environ.get('HOSTNAME', 'container')}"
13
+ REGION = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION", "us-west-2"))
14
+
15
+ def main():
16
+ client = boto3.client("logs", region_name=REGION, config=Config(retries={"max_attempts": 2}))
17
+ try:
18
+ client.create_log_group(logGroupName=LOG_GROUP)
19
+ except Exception:
20
+ pass
21
+ try:
22
+ client.create_log_stream(logGroupName=LOG_GROUP, logStreamName=LOG_STREAM)
23
+ except Exception as e:
24
+ # Can't create stream — just passthrough
25
+ for line in sys.stdin:
26
+ sys.stderr.write(line)
27
+ return
28
+
29
+ buf, lock, seq = [], threading.Lock(), [None]
30
+
31
+ def flush():
32
+ with lock:
33
+ if not buf:
34
+ return
35
+ batch = buf[:50]
36
+ del buf[:50]
37
+ events = [{"timestamp": int(t * 1000), "message": m} for t, m in batch]
38
+ kw = {"logGroupName": LOG_GROUP, "logStreamName": LOG_STREAM, "logEvents": events}
39
+ if seq[0]:
40
+ kw["sequenceToken"] = seq[0]
41
+ try:
42
+ r = client.put_log_events(**kw)
43
+ seq[0] = r.get("nextSequenceToken")
44
+ except Exception:
45
+ pass
46
+
47
+ def loop():
48
+ while True:
49
+ time.sleep(2)
50
+ flush()
51
+
52
+ threading.Thread(target=loop, daemon=True).start()
53
+ try:
54
+ for line in sys.stdin:
55
+ sys.stderr.write(line)
56
+ with lock:
57
+ buf.append((time.time(), line.rstrip("\n")))
58
+ except (KeyboardInterrupt, BrokenPipeError):
59
+ pass
60
+ finally:
61
+ flush()
62
+
63
+ if __name__ == "__main__":
64
+ main()
@@ -2,6 +2,11 @@
2
2
  # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
+ # CloudWatch log forwarder — workaround for IC platform log routing gap
6
+ exec > >(python3 /usr/bin/cw_log_forwarder.py) 2>&1
7
+
8
+ echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ') [serve] Container started — PID $$"
9
+
5
10
  # CUDA compatibility setup (required for newer SageMaker inference AMIs)
6
11
  source /usr/bin/cuda_compat.sh 2>/dev/null || true
7
12
 
@@ -270,8 +275,14 @@ for var in "${env_vars[@]}"; do
270
275
 
271
276
  # Remove prefix, convert to lowercase, and replace underscores with dashes
272
277
  arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
278
+
279
+ # Boolean handling: true = flag only, false = skip entirely
280
+ if [ "$value" = "false" ]; then
281
+ continue
282
+ fi
283
+
273
284
  SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
274
- if [ -n "$value" ]; then
285
+ if [ -n "$value" ] && [ "$value" != "true" ]; then
275
286
  SERVER_ARGS+=("$value")
276
287
  fi
277
288
  done
@@ -0,0 +1,897 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Generate deploy_notebook.ipynb from environment variables."""
6
+ import json
7
+ import os
8
+ import sys
9
+
10
+
11
+ def env(name, default=""):
12
+ """Read an environment variable with a default."""
13
+ return os.environ.get(name, default)
14
+
15
+
16
+ def make_markdown_cell(source_lines):
17
+ """Create a markdown cell dict."""
18
+ return {
19
+ "cell_type": "markdown",
20
+ "metadata": {},
21
+ "source": source_lines
22
+ }
23
+
24
+
25
+ def make_code_cell(source_lines):
26
+ """Create a code cell dict."""
27
+ return {
28
+ "cell_type": "code",
29
+ "metadata": {},
30
+ "source": source_lines,
31
+ "outputs": [],
32
+ "execution_count": None
33
+ }
34
+
35
+
36
+ cells = []
37
+
38
+ # ── Section 1: Setup ─────────────────────────────────────────────────────────
39
+
40
+ # Title markdown cell
41
+ cells.append(make_markdown_cell([
42
+ f"# Deploy {env('PROJECT_NAME')} on SageMaker\n",
43
+ "\n",
44
+ f"**Model Server**: {env('MODEL_SERVER')} \n",
45
+ f"**Instance**: {env('INSTANCE_TYPE')} \n",
46
+ f"**Region**: {env('AWS_REGION')}\n"
47
+ ]))
48
+
49
+ # Pip install cell
50
+ cells.append(make_code_cell([
51
+ "%pip install -qU sagemaker boto3"
52
+ ]))
53
+
54
+ # Imports cell
55
+ cells.append(make_code_cell([
56
+ "import json\n",
57
+ "import time\n",
58
+ "import boto3\n",
59
+ "import sagemaker\n",
60
+ "from sagemaker import get_execution_role\n",
61
+ "from sagemaker.session import Session\n",
62
+ "\n",
63
+ "sagemaker_session = Session()\n",
64
+ "role = get_execution_role()\n",
65
+ "account_id = boto3.client('sts').get_caller_identity()['Account']\n",
66
+ "region = sagemaker_session.boto_region_name\n",
67
+ "\n",
68
+ "sm_client = boto3.client('sagemaker', region_name=region)\n",
69
+ "smr_client = boto3.client('sagemaker-runtime', region_name=region)"
70
+ ]))
71
+
72
+ # ── Section 2: Configuration ─────────────────────────────────────────────────
73
+
74
+ # Project variables baked as Python literals
75
+ cells.append(make_code_cell([
76
+ f'PROJECT_NAME = "{env("PROJECT_NAME")}"\n',
77
+ f'AWS_REGION = "{env("AWS_REGION")}"\n',
78
+ <% if (deploymentTarget !== 'hyperpod-eks' && !(typeof existingEndpointName !== 'undefined' && existingEndpointName)) { %>
79
+ f'INSTANCE_TYPE = "{env("INSTANCE_TYPE")}"\n',
80
+ <% } %>
81
+ f'ENDPOINT_NAME = f"{{PROJECT_NAME}}-ep-{{int(time.time())}}"\n',
82
+ <% if (typeof inferenceAmiVersion !== 'undefined' && inferenceAmiVersion) { %>
83
+ f'INFERENCE_AMI_VERSION = "{env("INFERENCE_AMI_VERSION")}"\n',
84
+ <% } else { %>
85
+ f'INFERENCE_AMI_VERSION = "{env("INFERENCE_AMI_VERSION", "")}"\n',
86
+ <% } %>
87
+ f'HEALTH_CHECK_TIMEOUT = 850\n',
88
+ f'IC_GPU_COUNT = {env("IC_GPU_COUNT", "1")}\n',
89
+ f'IC_MIN_MEMORY_MB = {env("IC_MEMORY_SIZE", "1024")}'
90
+ ]))
91
+
92
+ # Environment variables dict cell
93
+ cells.append(make_code_cell([
94
+ "env = {\n",
95
+ <% if (orderedEnvVars && orderedEnvVars.length > 0) { %>
96
+ <% orderedEnvVars.forEach(function(item, index) { %>
97
+ f' "<%= item.key %>": "{env("<%= item.key %>", "<%= item.value %>")}",\n',
98
+ <% }); %>
99
+ <% } %>
100
+ "}"
101
+ ]))
102
+
103
+ # ── Section 2b: Secrets Handling ─────────────────────────────────────────────
104
+
105
+ <% if (typeof hfTokenArn !== 'undefined' && hfTokenArn) { %>
106
+ # HF_TOKEN_ARN is configured — resolve via Secrets Manager
107
+ cells.append(make_code_cell([
108
+ "import boto3 as _boto3_secrets\n",
109
+ "\n",
110
+ f'HF_TOKEN_ARN = "{env("HF_TOKEN_ARN")}"\n',
111
+ "\n",
112
+ "secrets_client = _boto3_secrets.client('secretsmanager', region_name=AWS_REGION)\n",
113
+ "hf_token = secrets_client.get_secret_value(SecretId=HF_TOKEN_ARN)['SecretString']\n",
114
+ 'env["HF_TOKEN"] = hf_token'
115
+ ]))
116
+ <% } else if (hfToken) { %>
117
+ # HF_TOKEN is configured — read from environment variable at notebook runtime
118
+ cells.append(make_markdown_cell([
119
+ "### \u26a0\ufe0f HuggingFace Token Required\n",
120
+ "\n",
121
+ "Set the `HF_TOKEN` environment variable before running the next cell. \n",
122
+ "In SageMaker Studio, use the **Environment** tab or run: \n",
123
+ '`export HF_TOKEN="hf_your_token_here"`'
124
+ ]))
125
+
126
+ cells.append(make_code_cell([
127
+ "import os\n",
128
+ "\n",
129
+ 'env["HF_TOKEN"] = os.environ["HF_TOKEN"]'
130
+ ]))
131
+ <% } %>
132
+ <% if (typeof ngcTokenArn !== 'undefined' && ngcTokenArn) { %>
133
+ # NGC_API_KEY_ARN is configured — resolve via Secrets Manager
134
+ cells.append(make_code_cell([
135
+ "import boto3 as _boto3_secrets\n",
136
+ "\n",
137
+ f'NGC_API_KEY_ARN = "{env("NGC_API_KEY_ARN")}"\n',
138
+ "\n",
139
+ "secrets_client = _boto3_secrets.client('secretsmanager', region_name=AWS_REGION)\n",
140
+ "ngc_key = secrets_client.get_secret_value(SecretId=NGC_API_KEY_ARN)['SecretString']\n",
141
+ 'env["NGC_API_KEY"] = ngc_key'
142
+ ]))
143
+ <% } else if (ngcApiKey) { %>
144
+ # NGC_API_KEY is configured — read from environment variable at notebook runtime
145
+ cells.append(make_markdown_cell([
146
+ "### \u26a0\ufe0f NVIDIA NGC API Key Required\n",
147
+ "\n",
148
+ "Set the `NGC_API_KEY` environment variable before running the next cell. \n",
149
+ "In SageMaker Studio, use the **Environment** tab or run: \n",
150
+ '`export NGC_API_KEY="your_ngc_key_here"`'
151
+ ]))
152
+
153
+ cells.append(make_code_cell([
154
+ "import os\n",
155
+ "\n",
156
+ 'env["NGC_API_KEY"] = os.environ["NGC_API_KEY"]'
157
+ ]))
158
+ <% } %>
159
+
160
+ # ── Section 3: Build & Push ──────────────────────────────────────────────────
161
+
162
+ <% if (modelServer !== 'lmi' && modelServer !== 'djl') { %>
163
+ cells.append(make_markdown_cell([
164
+ "## Build & Push Container\n",
165
+ "\n",
166
+ "Build the container via CodeBuild and push to ECR.\n"
167
+ ]))
168
+
169
+ cells.append(make_code_cell([
170
+ 'CODEBUILD_PROJECT_NAME = f"{PROJECT_NAME}-build"\n',
171
+ 'cb_client = boto3.client("codebuild", region_name=AWS_REGION)\n',
172
+ '\n',
173
+ 'build = cb_client.start_build(\n',
174
+ ' projectName=CODEBUILD_PROJECT_NAME,\n',
175
+ ' sourceVersion="main",\n',
176
+ ')\n',
177
+ 'build_id = build["build"]["id"]\n',
178
+ 'print(f"Build started: {build_id}")\n',
179
+ '\n',
180
+ 'while True:\n',
181
+ ' resp = cb_client.batch_get_builds(ids=[build_id])\n',
182
+ ' status = resp["builds"][0]["buildStatus"]\n',
183
+ ' phase = resp["builds"][0].get("currentPhase", "UNKNOWN")\n',
184
+ ' if status == "SUCCEEDED":\n',
185
+ ' print(f"\\u2705 Build succeeded")\n',
186
+ ' break\n',
187
+ ' elif status in ("FAILED", "FAULT", "TIMED_OUT", "STOPPED"):\n',
188
+ ' print(f"\\u274c Build {status}")\n',
189
+ ' break\n',
190
+ ' print(f" {phase}... ({status})")\n',
191
+ ' time.sleep(30)\n',
192
+ ]))
193
+
194
+ cells.append(make_code_cell([
195
+ f'image_uri = f"{{account_id}}.dkr.ecr.{{region}}.amazonaws.com/{env("PROJECT_NAME")}:{env("PROJECT_NAME")}-latest"\n',
196
+ 'print(f"Image URI: {image_uri}")'
197
+ ]))
198
+ <% } else { %>
199
+ cells.append(make_markdown_cell([
200
+ "## Container Image\n",
201
+ "\n",
202
+ "Using AWS Deep Learning Container (DLC) image.\n"
203
+ ]))
204
+
205
+ cells.append(make_code_cell([
206
+ f'image_uri = sagemaker.image_uris.retrieve(\n',
207
+ f' framework="<%= modelServer %>",\n',
208
+ f' region=region,\n',
209
+ f' version="latest",\n',
210
+ f' instance_type=INSTANCE_TYPE,\n',
211
+ f')\n',
212
+ 'print(f"DLC Image URI: {image_uri}")'
213
+ ]))
214
+ <% } %>
215
+
216
+ # ── Section 4: Model ─────────────────────────────────────────────────────────
217
+
218
+ cells.append(make_markdown_cell([
219
+ "## Create SageMaker Model\n",
220
+ "\n",
221
+ "Define the model container and environment variables.\n"
222
+ ]))
223
+
224
+ cells.append(make_code_cell([
225
+ 'model_name = f"{PROJECT_NAME}-model-{int(time.time())}"\n',
226
+ '\n',
227
+ 'sm_client.create_model(\n',
228
+ ' ModelName=model_name,\n',
229
+ ' ExecutionRoleArn=role,\n',
230
+ ' PrimaryContainer={\n',
231
+ ' "Image": image_uri,\n',
232
+ ' "Environment": env,\n',
233
+ ' },\n',
234
+ ')\n',
235
+ '\n',
236
+ 'print(f"✅ Model created: {model_name}")'
237
+ ]))
238
+
239
+ <% if (deploymentTarget === 'realtime-inference') { %>
240
+ # ── Section 5: Endpoint ──────────────────────────────────────────────────────
241
+
242
+ cells.append(make_markdown_cell([
243
+ "## Create Endpoint\n",
244
+ "\n",
245
+ "Create an endpoint configuration and deploy the endpoint.\n"
246
+ ]))
247
+
248
+ # Create endpoint config
249
+ cells.append(make_code_cell([
250
+ 'endpoint_config_name = f"{PROJECT_NAME}-epc-{int(time.time())}"\n',
251
+ '\n',
252
+ 'production_variant = {\n',
253
+ ' "VariantName": "AllTraffic",\n',
254
+ ' "InstanceType": INSTANCE_TYPE,\n',
255
+ ' "InitialInstanceCount": 1,\n',
256
+ ' "ContainerStartupHealthCheckTimeoutInSeconds": HEALTH_CHECK_TIMEOUT,\n',
257
+ '}\n',
258
+ '\n',
259
+ '# Include InferenceAmiVersion if configured\n',
260
+ 'if INFERENCE_AMI_VERSION:\n',
261
+ ' production_variant["InferenceAmiVersion"] = INFERENCE_AMI_VERSION\n',
262
+ '\n',
263
+ 'sm_client.create_endpoint_config(\n',
264
+ ' EndpointConfigName=endpoint_config_name,\n',
265
+ ' ExecutionRoleArn=role,\n',
266
+ ' ProductionVariants=[production_variant],\n',
267
+ ')\n',
268
+ '\n',
269
+ 'print(f"✅ Endpoint config created: {endpoint_config_name}")'
270
+ ]))
271
+
272
+ # Create endpoint
273
+ cells.append(make_code_cell([
274
+ 'sm_client.create_endpoint(\n',
275
+ ' EndpointName=ENDPOINT_NAME,\n',
276
+ ' EndpointConfigName=endpoint_config_name,\n',
277
+ ')\n',
278
+ '\n',
279
+ 'print(f"Creating endpoint: {ENDPOINT_NAME}...")\n',
280
+ '\n',
281
+ '# Wait for InService\n',
282
+ 'while True:\n',
283
+ ' resp = sm_client.describe_endpoint(EndpointName=ENDPOINT_NAME)\n',
284
+ ' status = resp["EndpointStatus"]\n',
285
+ ' if status == "InService":\n',
286
+ ' print(f"✅ Endpoint InService: {ENDPOINT_NAME}")\n',
287
+ ' break\n',
288
+ ' elif status == "Failed":\n',
289
+ ' print(f"❌ Endpoint failed: {resp.get(\'FailureReason\', \'Unknown\')}")\n',
290
+ ' break\n',
291
+ ' print(f" Status: {status}...")\n',
292
+ ' time.sleep(30)'
293
+ ]))
294
+
295
+ # ── Section 6: Inference Component ───────────────────────────────────────────
296
+
297
+ cells.append(make_markdown_cell([
298
+ "## Create Inference Component\n",
299
+ "\n",
300
+ "Attach the model to the endpoint with compute resource allocation. \n",
301
+ "This is separate from the endpoint to support multi-LoRA adapter extensibility.\n"
302
+ ]))
303
+
304
+ cells.append(make_code_cell([
305
+ 'ic_name = f"{PROJECT_NAME}-ic-{int(time.time())}"\n',
306
+ '\n',
307
+ 'sm_client.create_inference_component(\n',
308
+ ' InferenceComponentName=ic_name,\n',
309
+ ' EndpointName=ENDPOINT_NAME,\n',
310
+ ' VariantName="AllTraffic",\n',
311
+ ' Specification={\n',
312
+ ' "ModelName": model_name,\n',
313
+ ' "Container": {\n',
314
+ ' "Image": image_uri,\n',
315
+ ' "Environment": env,\n',
316
+ ' },\n',
317
+ ' "ComputeResourceRequirements": {\n',
318
+ ' "NumberOfAcceleratorDevicesRequired": IC_GPU_COUNT,\n',
319
+ ' "MinMemoryRequiredInMb": IC_MIN_MEMORY_MB,\n',
320
+ ' },\n',
321
+ ' },\n',
322
+ ' RuntimeConfig={\n',
323
+ ' "CopyCount": 1,\n',
324
+ ' },\n',
325
+ ')\n',
326
+ '\n',
327
+ 'print(f"Creating inference component: {ic_name}...")\n',
328
+ '\n',
329
+ '# Wait for IC InService\n',
330
+ 'while True:\n',
331
+ ' resp = sm_client.describe_inference_component(InferenceComponentName=ic_name)\n',
332
+ ' status = resp["InferenceComponentStatus"]\n',
333
+ ' if status == "InService":\n',
334
+ ' print(f"✅ Inference Component InService: {ic_name}")\n',
335
+ ' break\n',
336
+ ' elif status == "Failed":\n',
337
+ ' print(f"❌ IC failed: {resp.get(\'FailureReason\', \'Unknown\')}")\n',
338
+ ' break\n',
339
+ ' print(f" Status: {status}...")\n',
340
+ ' time.sleep(30)'
341
+ ]))
342
+
343
+ # ── Section 7: Test ──────────────────────────────────────────────────────────
344
+
345
+ cells.append(make_markdown_cell([
346
+ "## Test Inference\n",
347
+ "\n",
348
+ "Send a test request to the deployed model.\n"
349
+ ]))
350
+
351
+ <% if (framework === 'transformers') { %>
352
+ cells.append(make_code_cell([
353
+ 'payload = {\n',
354
+ ' "messages": [{"role": "user", "content": "What is machine learning?"}],\n',
355
+ ' "max_tokens": 100,\n',
356
+ ' "temperature": 0.7,\n',
357
+ '}\n',
358
+ '\n',
359
+ 'response = smr_client.invoke_endpoint(\n',
360
+ ' EndpointName=ENDPOINT_NAME,\n',
361
+ ' InferenceComponentName=ic_name,\n',
362
+ ' ContentType="application/json",\n',
363
+ ' Body=json.dumps(payload),\n',
364
+ ')\n',
365
+ '\n',
366
+ 'result = json.loads(response["Body"].read().decode())\n',
367
+ 'print(result["choices"][0]["message"]["content"])'
368
+ ]))
369
+ <% } else { %>
370
+ cells.append(make_code_cell([
371
+ 'payload = {\n',
372
+ ' "inputs": "What is machine learning?",\n',
373
+ ' "parameters": {"max_new_tokens": 100},\n',
374
+ '}\n',
375
+ '\n',
376
+ 'response = smr_client.invoke_endpoint(\n',
377
+ ' EndpointName=ENDPOINT_NAME,\n',
378
+ ' InferenceComponentName=ic_name,\n',
379
+ ' ContentType="application/json",\n',
380
+ ' Body=json.dumps(payload),\n',
381
+ ')\n',
382
+ '\n',
383
+ 'result = json.loads(response["Body"].read().decode())\n',
384
+ 'print(json.dumps(result, indent=2))'
385
+ ]))
386
+ <% } %>
387
+ <% if (enableLora) { %>
388
+ # ── Section 8: LoRA Adapter ──────────────────────────────────────────────────
389
+
390
+ cells.append(make_markdown_cell([
391
+ "## 🔗 LoRA Adapter\n",
392
+ "\n",
393
+ "LoRA (Low-Rank Adaptation) adapters let you serve multiple fine-tuned \"personalities\"\n",
394
+ "from a single base model. Adapter ICs share the base IC's GPU resources — no additional\n",
395
+ "compute allocation is needed. You can add/remove adapters without redeploying the endpoint.\n"
396
+ ]))
397
+
398
+ cells.append(make_code_cell([
399
+ '# ✏️ Edit these values to configure your adapter\n',
400
+ 'ADAPTER_NAME = "my-adapter"\n',
401
+ f'ADAPTER_WEIGHTS_URI = "{env("TUNE_ADAPTER_PATH_SFT", env("TUNE_ADAPTER_PATH_DPO", env("TUNE_OUTPUT_PATH_LATEST", "s3://your-bucket/adapters/my-adapter/adapter.tar.gz")))}"\n',
402
+ ]))
403
+
404
+ cells.append(make_code_cell([
405
+ 'adapter_ic_name = f"{PROJECT_NAME}-adapter-{ADAPTER_NAME}"\n',
406
+ '\n',
407
+ 'sm_client.create_inference_component(\n',
408
+ ' InferenceComponentName=adapter_ic_name,\n',
409
+ ' EndpointName=ENDPOINT_NAME,\n',
410
+ ' Specification={\n',
411
+ ' "BaseInferenceComponentName": ic_name,\n',
412
+ ' "Container": {\n',
413
+ ' "ArtifactUrl": ADAPTER_WEIGHTS_URI,\n',
414
+ ' },\n',
415
+ ' },\n',
416
+ ')\n',
417
+ '\n',
418
+ 'print(f"Creating adapter IC: {adapter_ic_name}...")\n',
419
+ '\n',
420
+ '# Wait for adapter IC InService\n',
421
+ 'while True:\n',
422
+ ' resp = sm_client.describe_inference_component(InferenceComponentName=adapter_ic_name)\n',
423
+ ' status = resp["InferenceComponentStatus"]\n',
424
+ ' if status == "InService":\n',
425
+ ' print(f"✅ Adapter IC InService: {adapter_ic_name}")\n',
426
+ ' break\n',
427
+ ' elif status == "Failed":\n',
428
+ ' print(f"❌ Adapter IC failed: {resp.get(\'FailureReason\', \'Unknown\')}")\n',
429
+ ' break\n',
430
+ ' print(f" Status: {status}...")\n',
431
+ ' time.sleep(30)'
432
+ ]))
433
+
434
+ <% if (framework === 'transformers') { %>
435
+ cells.append(make_code_cell([
436
+ 'payload = {\n',
437
+ ' "messages": [{"role": "user", "content": "What is machine learning?"}],\n',
438
+ ' "max_tokens": 100,\n',
439
+ ' "temperature": 0.7,\n',
440
+ '}\n',
441
+ '\n',
442
+ 'response = smr_client.invoke_endpoint(\n',
443
+ ' EndpointName=ENDPOINT_NAME,\n',
444
+ ' InferenceComponentName=adapter_ic_name,\n',
445
+ ' ContentType="application/json",\n',
446
+ ' Body=json.dumps(payload),\n',
447
+ ')\n',
448
+ '\n',
449
+ 'result = json.loads(response["Body"].read().decode())\n',
450
+ 'print(f"Adapter \'{ADAPTER_NAME}\' response:")\n',
451
+ 'print(result["choices"][0]["message"]["content"])'
452
+ ]))
453
+ <% } else { %>
454
+ cells.append(make_code_cell([
455
+ 'payload = {\n',
456
+ ' "inputs": "What is machine learning?",\n',
457
+ ' "parameters": {"max_new_tokens": 100},\n',
458
+ '}\n',
459
+ '\n',
460
+ 'response = smr_client.invoke_endpoint(\n',
461
+ ' EndpointName=ENDPOINT_NAME,\n',
462
+ ' InferenceComponentName=adapter_ic_name,\n',
463
+ ' ContentType="application/json",\n',
464
+ ' Body=json.dumps(payload),\n',
465
+ ')\n',
466
+ '\n',
467
+ 'result = json.loads(response["Body"].read().decode())\n',
468
+ 'print(f"Adapter \'{ADAPTER_NAME}\' response:")\n',
469
+ 'print(json.dumps(result, indent=2))'
470
+ ]))
471
+ <% } %>
472
+
473
+ cells.append(make_markdown_cell([
474
+ "### Adding More Adapters\n",
475
+ "\n",
476
+ "To add another adapter, duplicate this section with a different `ADAPTER_NAME` and\n",
477
+ "`ADAPTER_WEIGHTS_URI`. Each adapter shares the base IC's GPU resources.\n",
478
+ "\n",
479
+ "The CLI equivalent is: `./do/adapter add <name> --weights <s3-uri>`\n"
480
+ ]))
481
+ <% } %>
482
+ <% if (tuneSupported) { %>
483
+ # ── Section 9: Fine-Tune ─────────────────────────────────────────────────────
484
+
485
+ cells.append(make_markdown_cell([
486
+ "## 🎯 Managed Fine-Tuning\n",
487
+ "\n",
488
+ "SageMaker Managed Model Customization provides serverless fine-tuning — no instance\n",
489
+ "selection or container management needed. You provide a dataset and technique; SageMaker\n",
490
+ "handles infrastructure and optimization. The output is either LoRA adapter weights or a\n",
491
+ "full merged model, depending on the training type.\n"
492
+ ]))
493
+
494
+ cells.append(make_code_cell([
495
+ '# ✏️ Edit these values to configure your fine-tuning job\n',
496
+ 'TECHNIQUE = "sft" # Options: sft, dpo, rlaif, rlvr\n',
497
+ 'TRAINING_TYPE = "lora" # Options: lora, full-rank\n',
498
+ 'DATASET_S3_URI = "s3://your-bucket/datasets/train.jsonl" # ← Replace with your dataset\n',
499
+ f'TUNE_OUTPUT_BUCKET = "{env("TUNE_S3_BUCKET")}"'
500
+ ]))
501
+
502
+ cells.append(make_code_cell([
503
+ 'from sagemaker.modules.train import ModelTrainer\n',
504
+ '\n',
505
+ f'MODEL_NAME = "{env("MODEL_NAME")}"\n',
506
+ '\n',
507
+ 'job_name = f"{PROJECT_NAME}-tune-{TECHNIQUE}-{int(time.time())}"\n',
508
+ '\n',
509
+ 'trainer = ModelTrainer(\n',
510
+ ' model_id=MODEL_NAME,\n',
511
+ ' training_dataset={"s3Uri": DATASET_S3_URI},\n',
512
+ ' technique=TECHNIQUE,\n',
513
+ ' training_type=TRAINING_TYPE,\n',
514
+ ' output_path=f"s3://{TUNE_OUTPUT_BUCKET}/{PROJECT_NAME}/output/",\n',
515
+ ' role=role,\n',
516
+ ')\n',
517
+ 'trainer.train()\n',
518
+ '\n',
519
+ 'print(f"✅ Training job submitted: {job_name}")'
520
+ ]))
521
+
522
+ cells.append(make_code_cell([
523
+ 'import time as _time\n',
524
+ 'start_time = _time.time()\n',
525
+ '\n',
526
+ 'while True:\n',
527
+ ' status = trainer.describe()["TrainingJobStatus"]\n',
528
+ ' elapsed = int(_time.time() - start_time)\n',
529
+ ' elapsed_str = f"{elapsed // 60}m {elapsed % 60}s"\n',
530
+ '\n',
531
+ ' if status == "Completed":\n',
532
+ ' print(f"✅ Training completed in {elapsed_str}")\n',
533
+ ' break\n',
534
+ ' elif status == "Failed":\n',
535
+ ' reason = trainer.describe().get("FailureReason", "Unknown")\n',
536
+ ' print(f"❌ Training failed after {elapsed_str}: {reason}")\n',
537
+ ' break\n',
538
+ ' print(f" Status: {status} (elapsed: {elapsed_str})...")\n',
539
+ ' _time.sleep(60)'
540
+ ]))
541
+
542
+ cells.append(make_code_cell([
543
+ 'job_desc = trainer.describe()\n',
544
+ 'output_path = job_desc["ModelArtifacts"]["S3ModelArtifacts"]\n',
545
+ 'print(f"Output artifacts: {output_path}")'
546
+ ]))
547
+
548
+ cells.append(make_markdown_cell([
549
+ "### Next Steps\n",
550
+ "\n",
551
+ "**If LoRA output** (TRAINING_TYPE=\"lora\"): \n",
552
+ "- Run the adapter section above with `ADAPTER_WEIGHTS_URI` set to the output path \n",
553
+ "- Or use the CLI: `./do/adapter add tuned-sft --from-tune`\n",
554
+ "\n",
555
+ "**If full-rank output** (TRAINING_TYPE=\"full-rank\"): \n",
556
+ "- Deploy as a new inference component: `./do/add-ic tuned-v1 --from-tune`\n"
557
+ ]))
558
+ <% } %>
559
+
560
+ # ── Section 10: Cleanup ───────────────────────────────────────────────────────
561
+
562
+ cells.append(make_markdown_cell([
563
+ "## ⚠️ Cleanup\n",
564
+ "\n",
565
+ "**Warning**: This will delete all deployed resources. Only run when you're done testing.\n"
566
+ ]))
567
+
568
+ cells.append(make_code_cell([
569
+ '# Delete in reverse dependency order\n',
570
+ '\n',
571
+ '# 1. Delete adapter IC (if created in Section 8)\n',
572
+ 'try:\n',
573
+ ' sm_client.delete_inference_component(InferenceComponentName=adapter_ic_name)\n',
574
+ ' print(f"Deleting adapter IC: {adapter_ic_name}...")\n',
575
+ ' time.sleep(15)\n',
576
+ 'except NameError:\n',
577
+ ' pass # adapter_ic_name not defined — Section 8 was not run\n',
578
+ 'except Exception as e:\n',
579
+ ' print(f"Note: {e}")\n',
580
+ '\n',
581
+ '# 2. Delete base IC\n',
582
+ 'sm_client.delete_inference_component(InferenceComponentName=ic_name)\n',
583
+ 'print(f"Deleting base IC: {ic_name}...")\n',
584
+ 'time.sleep(30)\n',
585
+ '\n',
586
+ '# 3. Delete endpoint\n',
587
+ 'sm_client.delete_endpoint(EndpointName=ENDPOINT_NAME)\n',
588
+ 'print(f"Deleting endpoint: {ENDPOINT_NAME}...")\n',
589
+ '\n',
590
+ '# 4. Delete endpoint config\n',
591
+ 'sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n',
592
+ 'print(f"Deleting endpoint config: {endpoint_config_name}...")\n',
593
+ '\n',
594
+ '# 5. Delete model\n',
595
+ 'sm_client.delete_model(ModelName=model_name)\n',
596
+ 'print(f"Deleting model: {model_name}...")\n',
597
+ '\n',
598
+ 'print("\\n✅ All resources cleaned up")'
599
+ ]))
600
+ <% } else if (deploymentTarget === 'async-inference') { %>
601
+ # ── Section 5: Endpoint (Async) ──────────────────────────────────────────────
602
+
603
+ cells.append(make_markdown_cell([
604
+ "## Create Async Endpoint\n",
605
+ "\n",
606
+ "Create an endpoint configuration with async inference settings and deploy the endpoint. \n",
607
+ "Async inference is ideal for large payloads or long-running predictions — results are\n",
608
+ "written to S3 when ready.\n"
609
+ ]))
610
+
611
+ # Create endpoint config with AsyncInferenceConfig
612
+ cells.append(make_code_cell([
613
+ 'endpoint_config_name = f"{PROJECT_NAME}-epc-{int(time.time())}"\n',
614
+ '\n',
615
+ 'sm_client.create_endpoint_config(\n',
616
+ ' EndpointConfigName=endpoint_config_name,\n',
617
+ ' ExecutionRoleArn=role,\n',
618
+ ' ProductionVariants=[\n',
619
+ ' {\n',
620
+ ' "VariantName": "AllTraffic",\n',
621
+ ' "InstanceType": INSTANCE_TYPE,\n',
622
+ ' "InitialInstanceCount": 1,\n',
623
+ ' "ContainerStartupHealthCheckTimeoutInSeconds": HEALTH_CHECK_TIMEOUT,\n',
624
+ ' }\n',
625
+ ' ],\n',
626
+ ' AsyncInferenceConfig={\n',
627
+ ' "OutputConfig": {\n',
628
+ ' "S3OutputPath": f"s3://{PROJECT_NAME}-async-output/{PROJECT_NAME}/",\n',
629
+ ' },\n',
630
+ ' "ClientConfig": {\n',
631
+ ' "MaxConcurrentInvocationsPerInstance": 4,\n',
632
+ ' },\n',
633
+ ' },\n',
634
+ ')\n',
635
+ '\n',
636
+ 'print(f"✅ Async endpoint config created: {endpoint_config_name}")'
637
+ ]))
638
+
639
+ # Create endpoint
640
+ cells.append(make_code_cell([
641
+ 'sm_client.create_endpoint(\n',
642
+ ' EndpointName=ENDPOINT_NAME,\n',
643
+ ' EndpointConfigName=endpoint_config_name,\n',
644
+ ')\n',
645
+ '\n',
646
+ 'print(f"Creating endpoint: {ENDPOINT_NAME}...")\n',
647
+ '\n',
648
+ '# Wait for InService\n',
649
+ 'while True:\n',
650
+ ' resp = sm_client.describe_endpoint(EndpointName=ENDPOINT_NAME)\n',
651
+ ' status = resp["EndpointStatus"]\n',
652
+ ' if status == "InService":\n',
653
+ ' print(f"✅ Endpoint InService: {ENDPOINT_NAME}")\n',
654
+ ' break\n',
655
+ ' elif status == "Failed":\n',
656
+ ' print(f"❌ Endpoint failed: {resp.get(\'FailureReason\', \'Unknown\')}")\n',
657
+ ' break\n',
658
+ ' print(f" Status: {status}...")\n',
659
+ ' time.sleep(30)'
660
+ ]))
661
+
662
+ # ── Section 6: Inference Component — SKIPPED for async ───────────────────────
663
+ # Async inference does not support inference components.
664
+
665
+ # ── Section 7: Test (Async) ──────────────────────────────────────────────────
666
+
667
+ cells.append(make_markdown_cell([
668
+ "## Test Async Inference\n",
669
+ "\n",
670
+ "Upload input to S3, invoke the async endpoint, then poll S3 for the result.\n"
671
+ ]))
672
+
673
+ cells.append(make_code_cell([
674
+ 'import boto3\n',
675
+ '\n',
676
+ 's3_client = boto3.client("s3", region_name=AWS_REGION)\n',
677
+ 'async_input_bucket = f"{PROJECT_NAME}-async-output"\n',
678
+ 'async_input_key = f"{PROJECT_NAME}/input/test-request.json"\n',
679
+ '\n',
680
+ <% if (framework === 'transformers') { %>
681
+ 'payload = {\n',
682
+ ' "messages": [{"role": "user", "content": "What is machine learning?"}],\n',
683
+ ' "max_tokens": 100,\n',
684
+ ' "temperature": 0.7,\n',
685
+ '}\n',
686
+ <% } else { %>
687
+ 'payload = {\n',
688
+ ' "inputs": "What is machine learning?",\n',
689
+ ' "parameters": {"max_new_tokens": 100},\n',
690
+ '}\n',
691
+ <% } %>
692
+ '\n',
693
+ '# Upload input payload to S3\n',
694
+ 's3_client.put_object(\n',
695
+ ' Bucket=async_input_bucket,\n',
696
+ ' Key=async_input_key,\n',
697
+ ' Body=json.dumps(payload),\n',
698
+ ' ContentType="application/json",\n',
699
+ ')\n',
700
+ 'input_s3_uri = f"s3://{async_input_bucket}/{async_input_key}"\n',
701
+ 'print(f"Input uploaded to: {input_s3_uri}")'
702
+ ]))
703
+
704
+ cells.append(make_code_cell([
705
+ '# Invoke async endpoint\n',
706
+ 'response = smr_client.invoke_endpoint_async(\n',
707
+ ' EndpointName=ENDPOINT_NAME,\n',
708
+ ' ContentType="application/json",\n',
709
+ ' InputLocation=input_s3_uri,\n',
710
+ ')\n',
711
+ 'output_location = response["OutputLocation"]\n',
712
+ 'print(f"Output will be at: {output_location}")\n',
713
+ '\n',
714
+ '# Poll S3 for the result\n',
715
+ 'import urllib.parse\n',
716
+ 'parsed = urllib.parse.urlparse(output_location)\n',
717
+ 'output_bucket = parsed.netloc\n',
718
+ 'output_key = parsed.path.lstrip("/")\n',
719
+ '\n',
720
+ 'print("Waiting for result...")\n',
721
+ 'while True:\n',
722
+ ' try:\n',
723
+ ' result_obj = s3_client.get_object(Bucket=output_bucket, Key=output_key)\n',
724
+ ' result = json.loads(result_obj["Body"].read().decode())\n',
725
+ ' print("\\n✅ Async inference result:")\n',
726
+ ' print(json.dumps(result, indent=2))\n',
727
+ ' break\n',
728
+ ' except s3_client.exceptions.NoSuchKey:\n',
729
+ ' print(" Waiting for output...")\n',
730
+ ' time.sleep(10)'
731
+ ]))
732
+
733
+ # ── Sections 8-9: SKIPPED for async ─────────────────────────────────────────
734
+ # LoRA adapters and fine-tuning require a realtime endpoint with inference components.
735
+
736
+ # ── Section 10: Cleanup (Async) ──────────────────────────────────────────────
737
+
738
+ cells.append(make_markdown_cell([
739
+ "## ⚠️ Cleanup\n",
740
+ "\n",
741
+ "**Warning**: This will delete all deployed resources. Only run when you're done testing. \n",
742
+ "Note: The S3 output bucket and any SNS topics are NOT deleted (user-managed).\n"
743
+ ]))
744
+
745
+ cells.append(make_code_cell([
746
+ '# Delete in reverse dependency order (no IC for async)\n',
747
+ '\n',
748
+ '# 1. Delete endpoint\n',
749
+ 'sm_client.delete_endpoint(EndpointName=ENDPOINT_NAME)\n',
750
+ 'print(f"Deleting endpoint: {ENDPOINT_NAME}...")\n',
751
+ '\n',
752
+ '# 2. Delete endpoint config\n',
753
+ 'sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n',
754
+ 'print(f"Deleting endpoint config: {endpoint_config_name}...")\n',
755
+ '\n',
756
+ '# 3. Delete model\n',
757
+ 'sm_client.delete_model(ModelName=model_name)\n',
758
+ 'print(f"Deleting model: {model_name}...")\n',
759
+ '\n',
760
+ 'print("\\n✅ All resources cleaned up")\n',
761
+ 'print("Note: S3 output bucket and SNS topics are not deleted (user-managed).")'
762
+ ]))
763
+ <% } else if (deploymentTarget === 'batch-transform') { %>
764
+ # ── Section 5: Batch Transform Job ───────────────────────────────────────────
765
+
766
+ cells.append(make_markdown_cell([
767
+ "## Create Batch Transform Job\n",
768
+ "\n",
769
+ "Run batch inference on your input data stored in S3. \n",
770
+ "The transform job reads input from S3, runs inference, and writes output back to S3.\n"
771
+ ]))
772
+
773
+ cells.append(make_code_cell([
774
+ 'INPUT_S3_URI = "s3://your-bucket/input/" # ← Replace with your input data S3 path\n',
775
+ 'OUTPUT_S3_URI = "s3://your-bucket/output/" # ← Replace with your desired output S3 path\n',
776
+ '\n',
777
+ 'transform_job_name = f"{PROJECT_NAME}-transform-{int(time.time())}"\n',
778
+ '\n',
779
+ 'sm_client.create_transform_job(\n',
780
+ ' TransformJobName=transform_job_name,\n',
781
+ ' ModelName=model_name,\n',
782
+ ' TransformInput={\n',
783
+ ' "DataSource": {\n',
784
+ ' "S3DataSource": {\n',
785
+ ' "S3DataType": "S3Prefix",\n',
786
+ ' "S3Uri": INPUT_S3_URI,\n',
787
+ ' }\n',
788
+ ' },\n',
789
+ ' "ContentType": "application/json",\n',
790
+ ' },\n',
791
+ ' TransformOutput={\n',
792
+ ' "S3OutputPath": OUTPUT_S3_URI,\n',
793
+ ' },\n',
794
+ ' TransformResources={\n',
795
+ ' "InstanceType": INSTANCE_TYPE,\n',
796
+ ' "InstanceCount": 1,\n',
797
+ ' },\n',
798
+ ')\n',
799
+ '\n',
800
+ 'print(f"Transform job started: {transform_job_name}")\n',
801
+ '\n',
802
+ '# Wait for transform job to complete\n',
803
+ 'while True:\n',
804
+ ' resp = sm_client.describe_transform_job(TransformJobName=transform_job_name)\n',
805
+ ' status = resp["TransformJobStatus"]\n',
806
+ ' if status == "Completed":\n',
807
+ ' print(f"✅ Transform job completed: {transform_job_name}")\n',
808
+ ' break\n',
809
+ ' elif status == "Failed":\n',
810
+ ' print(f"❌ Transform job failed: {resp.get(\'FailureReason\', \'Unknown\')}")\n',
811
+ ' break\n',
812
+ ' elif status == "Stopped":\n',
813
+ ' print(f"⚠️ Transform job stopped")\n',
814
+ ' break\n',
815
+ ' print(f" Status: {status}...")\n',
816
+ ' time.sleep(30)'
817
+ ]))
818
+
819
+ # ── Section 7: Test (Download Output) ────────────────────────────────────────
820
+
821
+ cells.append(make_markdown_cell([
822
+ "## Review Transform Output\n",
823
+ "\n",
824
+ "Download and display a sample result from the batch transform output.\n"
825
+ ]))
826
+
827
+ cells.append(make_code_cell([
828
+ 'import boto3\n',
829
+ 'from urllib.parse import urlparse\n',
830
+ '\n',
831
+ 's3 = boto3.client("s3", region_name=AWS_REGION)\n',
832
+ '\n',
833
+ '# Parse the output S3 URI\n',
834
+ 'parsed = urlparse(OUTPUT_S3_URI)\n',
835
+ 'bucket = parsed.netloc\n',
836
+ 'prefix = parsed.path.lstrip("/")\n',
837
+ '\n',
838
+ '# List output files\n',
839
+ 'response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)\n',
840
+ 'output_files = [obj["Key"] for obj in response.get("Contents", [])]\n',
841
+ 'print(f"Output files ({len(output_files)}):")\n',
842
+ 'for f in output_files[:10]:\n',
843
+ ' print(f" s3://{bucket}/{f}")\n',
844
+ '\n',
845
+ '# Download and display first result\n',
846
+ 'if output_files:\n',
847
+ ' first_file = output_files[0]\n',
848
+ ' obj = s3.get_object(Bucket=bucket, Key=first_file)\n',
849
+ ' content = obj["Body"].read().decode("utf-8")\n',
850
+ ' print(f"\\nSample output from: {first_file}")\n',
851
+ ' print("-" * 60)\n',
852
+ ' print(content[:2000])\n',
853
+ 'else:\n',
854
+ ' print("No output files found.")'
855
+ ]))
856
+
857
+ # ── Section 10: Cleanup ───────────────────────────────────────────────────────
858
+
859
+ cells.append(make_markdown_cell([
860
+ "## ⚠️ Cleanup\n",
861
+ "\n",
862
+ "**Warning**: This will delete the model resource. Only run when you're done. \n",
863
+ "Note: The transform job is ephemeral and does not need deletion.\n"
864
+ ]))
865
+
866
+ cells.append(make_code_cell([
867
+ 'sm_client.delete_model(ModelName=model_name)\n',
868
+ 'print(f"Deleting model: {model_name}...")\n',
869
+ '\n',
870
+ 'print("\\n✅ All resources cleaned up")'
871
+ ]))
872
+ <% } %>
873
+
874
+ # ── Write notebook ───────────────────────────────────────────────────────────
875
+
876
+ notebook = {
877
+ "nbformat": 4,
878
+ "nbformat_minor": 5,
879
+ "metadata": {
880
+ "kernelspec": {
881
+ "display_name": "Python 3 (ipykernel)",
882
+ "language": "python",
883
+ "name": "python3"
884
+ },
885
+ "language_info": {
886
+ "name": "python",
887
+ "version": "3.10.0"
888
+ }
889
+ },
890
+ "cells": cells
891
+ }
892
+
893
+ output_path = "deploy_notebook.ipynb"
894
+ with open(output_path, "w") as f:
895
+ json.dump(notebook, f, indent=1)
896
+
897
+ print(f"\u2705 Notebook exported: ./{output_path}")
@@ -2,16 +2,33 @@
2
2
  # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
- # Export current configuration as a CLI command or JSON object
6
- # Usage: ./do/export [--json]
5
+ # Export current configuration as a CLI command, JSON object, or Jupyter notebook
6
+ # Usage: ./do/export [--json | --notebook]
7
7
 
8
8
  # Source configuration (suppress the summary output)
9
9
  SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
10
10
  source "${SCRIPT_DIR}/config" > /dev/null 2>&1
11
11
 
12
+ # ── Notebook output mode ──────────────────────────────────────────────────────
13
+
14
+ if [ "${1:-}" = "--notebook" ]; then
15
+ # Ensure not combined with --json
16
+ if [ "${2:-}" = "--json" ]; then
17
+ echo "Error: --notebook and --json are mutually exclusive" >&2
18
+ exit 1
19
+ fi
20
+ python3 "${SCRIPT_DIR}/../deploy_notebook_generator.py"
21
+ exit 0
22
+ fi
23
+
12
24
  # ── JSON output mode ─────────────────────────────────────────────────────────
13
25
 
14
26
  if [ "${1:-}" = "--json" ]; then
27
+ # Ensure not combined with --notebook
28
+ if [ "${2:-}" = "--notebook" ]; then
29
+ echo "Error: --notebook and --json are mutually exclusive" >&2
30
+ exit 1
31
+ fi
15
32
  # Build a JSON object with all configuration parameters.
16
33
  # Uses ConfigManager camelCase keys so the output can be fed directly
17
34
  # back into the generator via --config=<file>.
@@ -152,7 +152,9 @@ create_endpoint_config() {
152
152
  variant_json="${variant_json}}]"
153
153
  else
154
154
  # Standard path: single instance type
155
- variant_json="[{\"VariantName\":\"AllTraffic\",\"InstanceType\":\"${INSTANCE_TYPE}\",\"InitialInstanceCount\":1"
155
+ # RoutingConfig is required for IC-based endpoints — without it the IC scheduler
156
+ # cannot place containers and the IC stays in Creating with no logs.
157
+ variant_json="[{\"VariantName\":\"AllTraffic\",\"InstanceType\":\"${INSTANCE_TYPE}\",\"InitialInstanceCount\":1,\"RoutingConfig\":{\"RoutingStrategy\":\"LEAST_OUTSTANDING_REQUESTS\"}"
156
158
 
157
159
  # Optional: AMI version
158
160
  if [ -n "${INFERENCE_AMI_VERSION:-}" ]; then
@@ -46,10 +46,14 @@ create_inference_component() {
46
46
 
47
47
  # Build container spec JSON
48
48
  local container_spec="{\"Image\":\"${ECR_REPOSITORY}:${IC_IMAGE_TAG:-${PROJECT_NAME}-latest}\""
49
+ # Always inject IC name for CW log forwarder
50
+ local ic_env="\"INFERENCE_COMPONENT_NAME\":\"${ic_name}\""
49
51
  if [ -n "${CONTAINER_ENV_JSON}${IC_CONTAINER_ENV_EXTRA:-}" ]; then
50
52
  local env_json="${CONTAINER_ENV_JSON}"
51
53
  [ -n "${IC_CONTAINER_ENV_EXTRA:-}" ] && env_json="${env_json:+${env_json},}${IC_CONTAINER_ENV_EXTRA}"
52
- container_spec="${container_spec},\"Environment\":{${env_json}}"
54
+ container_spec="${container_spec},\"Environment\":{${ic_env},${env_json}}"
55
+ else
56
+ container_spec="${container_spec},\"Environment\":{${ic_env}}"
53
57
  fi
54
58
  container_spec="${container_spec}}"
55
59