@aws/ml-container-creator 0.6.0 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,897 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Generate deploy_notebook.ipynb from environment variables."""
6
+ import json
7
+ import os
8
+ import sys
9
+
10
+
11
+ def env(name, default=""):
12
+ """Read an environment variable with a default."""
13
+ return os.environ.get(name, default)
14
+
15
+
16
+ def make_markdown_cell(source_lines):
17
+ """Create a markdown cell dict."""
18
+ return {
19
+ "cell_type": "markdown",
20
+ "metadata": {},
21
+ "source": source_lines
22
+ }
23
+
24
+
25
+ def make_code_cell(source_lines):
26
+ """Create a code cell dict."""
27
+ return {
28
+ "cell_type": "code",
29
+ "metadata": {},
30
+ "source": source_lines,
31
+ "outputs": [],
32
+ "execution_count": None
33
+ }
34
+
35
+
36
+ cells = []
37
+
38
+ # ── Section 1: Setup ─────────────────────────────────────────────────────────
39
+
40
+ # Title markdown cell
41
+ cells.append(make_markdown_cell([
42
+ f"# Deploy {env('PROJECT_NAME')} on SageMaker\n",
43
+ "\n",
44
+ f"**Model Server**: {env('MODEL_SERVER')} \n",
45
+ f"**Instance**: {env('INSTANCE_TYPE')} \n",
46
+ f"**Region**: {env('AWS_REGION')}\n"
47
+ ]))
48
+
49
+ # Pip install cell
50
+ cells.append(make_code_cell([
51
+ "%pip install -qU sagemaker boto3"
52
+ ]))
53
+
54
+ # Imports cell
55
+ cells.append(make_code_cell([
56
+ "import json\n",
57
+ "import time\n",
58
+ "import boto3\n",
59
+ "import sagemaker\n",
60
+ "from sagemaker import get_execution_role\n",
61
+ "from sagemaker.session import Session\n",
62
+ "\n",
63
+ "sagemaker_session = Session()\n",
64
+ "role = get_execution_role()\n",
65
+ "account_id = boto3.client('sts').get_caller_identity()['Account']\n",
66
+ "region = sagemaker_session.boto_region_name\n",
67
+ "\n",
68
+ "sm_client = boto3.client('sagemaker', region_name=region)\n",
69
+ "smr_client = boto3.client('sagemaker-runtime', region_name=region)"
70
+ ]))
71
+
72
+ # ── Section 2: Configuration ─────────────────────────────────────────────────
73
+
74
+ # Project variables baked as Python literals
75
+ cells.append(make_code_cell([
76
+ f'PROJECT_NAME = "{env("PROJECT_NAME")}"\n',
77
+ f'AWS_REGION = "{env("AWS_REGION")}"\n',
78
+ <% if (deploymentTarget !== 'hyperpod-eks' && !(typeof existingEndpointName !== 'undefined' && existingEndpointName)) { %>
79
+ f'INSTANCE_TYPE = "{env("INSTANCE_TYPE")}"\n',
80
+ <% } %>
81
+ f'ENDPOINT_NAME = f"{{PROJECT_NAME}}-ep-{{int(time.time())}}"\n',
82
+ <% if (typeof inferenceAmiVersion !== 'undefined' && inferenceAmiVersion) { %>
83
+ f'INFERENCE_AMI_VERSION = "{env("INFERENCE_AMI_VERSION")}"\n',
84
+ <% } else { %>
85
+ f'INFERENCE_AMI_VERSION = "{env("INFERENCE_AMI_VERSION", "")}"\n',
86
+ <% } %>
87
+ f'HEALTH_CHECK_TIMEOUT = 850\n',
88
+ f'IC_GPU_COUNT = {env("IC_GPU_COUNT", "1")}\n',
89
+ f'IC_MIN_MEMORY_MB = {env("IC_MEMORY_SIZE", "1024")}'
90
+ ]))
91
+
92
+ # Environment variables dict cell
93
+ cells.append(make_code_cell([
94
+ "env = {\n",
95
+ <% if (orderedEnvVars && orderedEnvVars.length > 0) { %>
96
+ <% orderedEnvVars.forEach(function(item, index) { %>
97
+ f' "<%= item.key %>": "{env("<%= item.key %>", "<%= item.value %>")}",\n',
98
+ <% }); %>
99
+ <% } %>
100
+ "}"
101
+ ]))
102
+
103
+ # ── Section 2b: Secrets Handling ─────────────────────────────────────────────
104
+
105
+ <% if (typeof hfTokenArn !== 'undefined' && hfTokenArn) { %>
106
+ # HF_TOKEN_ARN is configured — resolve via Secrets Manager
107
+ cells.append(make_code_cell([
108
+ "import boto3 as _boto3_secrets\n",
109
+ "\n",
110
+ f'HF_TOKEN_ARN = "{env("HF_TOKEN_ARN")}"\n',
111
+ "\n",
112
+ "secrets_client = _boto3_secrets.client('secretsmanager', region_name=AWS_REGION)\n",
113
+ "hf_token = secrets_client.get_secret_value(SecretId=HF_TOKEN_ARN)['SecretString']\n",
114
+ 'env["HF_TOKEN"] = hf_token'
115
+ ]))
116
+ <% } else if (hfToken) { %>
117
+ # HF_TOKEN is configured — read from environment variable at notebook runtime
118
+ cells.append(make_markdown_cell([
119
+ "### \u26a0\ufe0f HuggingFace Token Required\n",
120
+ "\n",
121
+ "Set the `HF_TOKEN` environment variable before running the next cell. \n",
122
+ "In SageMaker Studio, use the **Environment** tab or run: \n",
123
+ '`export HF_TOKEN="hf_your_token_here"`'
124
+ ]))
125
+
126
+ cells.append(make_code_cell([
127
+ "import os\n",
128
+ "\n",
129
+ 'env["HF_TOKEN"] = os.environ["HF_TOKEN"]'
130
+ ]))
131
+ <% } %>
132
+ <% if (typeof ngcTokenArn !== 'undefined' && ngcTokenArn) { %>
133
+ # NGC_API_KEY_ARN is configured — resolve via Secrets Manager
134
+ cells.append(make_code_cell([
135
+ "import boto3 as _boto3_secrets\n",
136
+ "\n",
137
+ f'NGC_API_KEY_ARN = "{env("NGC_API_KEY_ARN")}"\n',
138
+ "\n",
139
+ "secrets_client = _boto3_secrets.client('secretsmanager', region_name=AWS_REGION)\n",
140
+ "ngc_key = secrets_client.get_secret_value(SecretId=NGC_API_KEY_ARN)['SecretString']\n",
141
+ 'env["NGC_API_KEY"] = ngc_key'
142
+ ]))
143
+ <% } else if (ngcApiKey) { %>
144
+ # NGC_API_KEY is configured — read from environment variable at notebook runtime
145
+ cells.append(make_markdown_cell([
146
+ "### \u26a0\ufe0f NVIDIA NGC API Key Required\n",
147
+ "\n",
148
+ "Set the `NGC_API_KEY` environment variable before running the next cell. \n",
149
+ "In SageMaker Studio, use the **Environment** tab or run: \n",
150
+ '`export NGC_API_KEY="your_ngc_key_here"`'
151
+ ]))
152
+
153
+ cells.append(make_code_cell([
154
+ "import os\n",
155
+ "\n",
156
+ 'env["NGC_API_KEY"] = os.environ["NGC_API_KEY"]'
157
+ ]))
158
+ <% } %>
159
+
160
+ # ── Section 3: Build & Push ──────────────────────────────────────────────────
161
+
162
+ <% if (modelServer !== 'lmi' && modelServer !== 'djl') { %>
163
+ cells.append(make_markdown_cell([
164
+ "## Build & Push Container\n",
165
+ "\n",
166
+ "Build the container via CodeBuild and push to ECR.\n"
167
+ ]))
168
+
169
+ cells.append(make_code_cell([
170
+ 'CODEBUILD_PROJECT_NAME = f"{PROJECT_NAME}-build"\n',
171
+ 'cb_client = boto3.client("codebuild", region_name=AWS_REGION)\n',
172
+ '\n',
173
+ 'build = cb_client.start_build(\n',
174
+ ' projectName=CODEBUILD_PROJECT_NAME,\n',
175
+ ' sourceVersion="main",\n',
176
+ ')\n',
177
+ 'build_id = build["build"]["id"]\n',
178
+ 'print(f"Build started: {build_id}")\n',
179
+ '\n',
180
+ 'while True:\n',
181
+ ' resp = cb_client.batch_get_builds(ids=[build_id])\n',
182
+ ' status = resp["builds"][0]["buildStatus"]\n',
183
+ ' phase = resp["builds"][0].get("currentPhase", "UNKNOWN")\n',
184
+ ' if status == "SUCCEEDED":\n',
185
+ ' print(f"\\u2705 Build succeeded")\n',
186
+ ' break\n',
187
+ ' elif status in ("FAILED", "FAULT", "TIMED_OUT", "STOPPED"):\n',
188
+ ' print(f"\\u274c Build {status}")\n',
189
+ ' break\n',
190
+ ' print(f" {phase}... ({status})")\n',
191
+ ' time.sleep(30)\n',
192
+ ]))
193
+
194
+ cells.append(make_code_cell([
195
+ f'image_uri = f"{{account_id}}.dkr.ecr.{{region}}.amazonaws.com/{env("PROJECT_NAME")}:{env("PROJECT_NAME")}-latest"\n',
196
+ 'print(f"Image URI: {image_uri}")'
197
+ ]))
198
+ <% } else { %>
199
+ cells.append(make_markdown_cell([
200
+ "## Container Image\n",
201
+ "\n",
202
+ "Using AWS Deep Learning Container (DLC) image.\n"
203
+ ]))
204
+
205
+ cells.append(make_code_cell([
206
+ f'image_uri = sagemaker.image_uris.retrieve(\n',
207
+ f' framework="<%= modelServer %>",\n',
208
+ f' region=region,\n',
209
+ f' version="latest",\n',
210
+ f' instance_type=INSTANCE_TYPE,\n',
211
+ f')\n',
212
+ 'print(f"DLC Image URI: {image_uri}")'
213
+ ]))
214
+ <% } %>
215
+
216
+ # ── Section 4: Model ─────────────────────────────────────────────────────────
217
+
218
+ cells.append(make_markdown_cell([
219
+ "## Create SageMaker Model\n",
220
+ "\n",
221
+ "Define the model container and environment variables.\n"
222
+ ]))
223
+
224
+ cells.append(make_code_cell([
225
+ 'model_name = f"{PROJECT_NAME}-model-{int(time.time())}"\n',
226
+ '\n',
227
+ 'sm_client.create_model(\n',
228
+ ' ModelName=model_name,\n',
229
+ ' ExecutionRoleArn=role,\n',
230
+ ' PrimaryContainer={\n',
231
+ ' "Image": image_uri,\n',
232
+ ' "Environment": env,\n',
233
+ ' },\n',
234
+ ')\n',
235
+ '\n',
236
+ 'print(f"✅ Model created: {model_name}")'
237
+ ]))
238
+
239
+ <% if (deploymentTarget === 'realtime-inference') { %>
240
+ # ── Section 5: Endpoint ──────────────────────────────────────────────────────
241
+
242
+ cells.append(make_markdown_cell([
243
+ "## Create Endpoint\n",
244
+ "\n",
245
+ "Create an endpoint configuration and deploy the endpoint.\n"
246
+ ]))
247
+
248
+ # Create endpoint config
249
+ cells.append(make_code_cell([
250
+ 'endpoint_config_name = f"{PROJECT_NAME}-epc-{int(time.time())}"\n',
251
+ '\n',
252
+ 'production_variant = {\n',
253
+ ' "VariantName": "AllTraffic",\n',
254
+ ' "InstanceType": INSTANCE_TYPE,\n',
255
+ ' "InitialInstanceCount": 1,\n',
256
+ ' "ContainerStartupHealthCheckTimeoutInSeconds": HEALTH_CHECK_TIMEOUT,\n',
257
+ '}\n',
258
+ '\n',
259
+ '# Include InferenceAmiVersion if configured\n',
260
+ 'if INFERENCE_AMI_VERSION:\n',
261
+ ' production_variant["InferenceAmiVersion"] = INFERENCE_AMI_VERSION\n',
262
+ '\n',
263
+ 'sm_client.create_endpoint_config(\n',
264
+ ' EndpointConfigName=endpoint_config_name,\n',
265
+ ' ExecutionRoleArn=role,\n',
266
+ ' ProductionVariants=[production_variant],\n',
267
+ ')\n',
268
+ '\n',
269
+ 'print(f"✅ Endpoint config created: {endpoint_config_name}")'
270
+ ]))
271
+
272
+ # Create endpoint
273
+ cells.append(make_code_cell([
274
+ 'sm_client.create_endpoint(\n',
275
+ ' EndpointName=ENDPOINT_NAME,\n',
276
+ ' EndpointConfigName=endpoint_config_name,\n',
277
+ ')\n',
278
+ '\n',
279
+ 'print(f"Creating endpoint: {ENDPOINT_NAME}...")\n',
280
+ '\n',
281
+ '# Wait for InService\n',
282
+ 'while True:\n',
283
+ ' resp = sm_client.describe_endpoint(EndpointName=ENDPOINT_NAME)\n',
284
+ ' status = resp["EndpointStatus"]\n',
285
+ ' if status == "InService":\n',
286
+ ' print(f"✅ Endpoint InService: {ENDPOINT_NAME}")\n',
287
+ ' break\n',
288
+ ' elif status == "Failed":\n',
289
+ ' print(f"❌ Endpoint failed: {resp.get(\'FailureReason\', \'Unknown\')}")\n',
290
+ ' break\n',
291
+ ' print(f" Status: {status}...")\n',
292
+ ' time.sleep(30)'
293
+ ]))
294
+
295
+ # ── Section 6: Inference Component ───────────────────────────────────────────
296
+
297
+ cells.append(make_markdown_cell([
298
+ "## Create Inference Component\n",
299
+ "\n",
300
+ "Attach the model to the endpoint with compute resource allocation. \n",
301
+ "This is separate from the endpoint to support multi-LoRA adapter extensibility.\n"
302
+ ]))
303
+
304
+ cells.append(make_code_cell([
305
+ 'ic_name = f"{PROJECT_NAME}-ic-{int(time.time())}"\n',
306
+ '\n',
307
+ 'sm_client.create_inference_component(\n',
308
+ ' InferenceComponentName=ic_name,\n',
309
+ ' EndpointName=ENDPOINT_NAME,\n',
310
+ ' VariantName="AllTraffic",\n',
311
+ ' Specification={\n',
312
+ ' "ModelName": model_name,\n',
313
+ ' "Container": {\n',
314
+ ' "Image": image_uri,\n',
315
+ ' "Environment": env,\n',
316
+ ' },\n',
317
+ ' "ComputeResourceRequirements": {\n',
318
+ ' "NumberOfAcceleratorDevicesRequired": IC_GPU_COUNT,\n',
319
+ ' "MinMemoryRequiredInMb": IC_MIN_MEMORY_MB,\n',
320
+ ' },\n',
321
+ ' },\n',
322
+ ' RuntimeConfig={\n',
323
+ ' "CopyCount": 1,\n',
324
+ ' },\n',
325
+ ')\n',
326
+ '\n',
327
+ 'print(f"Creating inference component: {ic_name}...")\n',
328
+ '\n',
329
+ '# Wait for IC InService\n',
330
+ 'while True:\n',
331
+ ' resp = sm_client.describe_inference_component(InferenceComponentName=ic_name)\n',
332
+ ' status = resp["InferenceComponentStatus"]\n',
333
+ ' if status == "InService":\n',
334
+ ' print(f"✅ Inference Component InService: {ic_name}")\n',
335
+ ' break\n',
336
+ ' elif status == "Failed":\n',
337
+ ' print(f"❌ IC failed: {resp.get(\'FailureReason\', \'Unknown\')}")\n',
338
+ ' break\n',
339
+ ' print(f" Status: {status}...")\n',
340
+ ' time.sleep(30)'
341
+ ]))
342
+
343
+ # ── Section 7: Test ──────────────────────────────────────────────────────────
344
+
345
+ cells.append(make_markdown_cell([
346
+ "## Test Inference\n",
347
+ "\n",
348
+ "Send a test request to the deployed model.\n"
349
+ ]))
350
+
351
+ <% if (framework === 'transformers') { %>
352
+ cells.append(make_code_cell([
353
+ 'payload = {\n',
354
+ ' "messages": [{"role": "user", "content": "What is machine learning?"}],\n',
355
+ ' "max_tokens": 100,\n',
356
+ ' "temperature": 0.7,\n',
357
+ '}\n',
358
+ '\n',
359
+ 'response = smr_client.invoke_endpoint(\n',
360
+ ' EndpointName=ENDPOINT_NAME,\n',
361
+ ' InferenceComponentName=ic_name,\n',
362
+ ' ContentType="application/json",\n',
363
+ ' Body=json.dumps(payload),\n',
364
+ ')\n',
365
+ '\n',
366
+ 'result = json.loads(response["Body"].read().decode())\n',
367
+ 'print(result["choices"][0]["message"]["content"])'
368
+ ]))
369
+ <% } else { %>
370
+ cells.append(make_code_cell([
371
+ 'payload = {\n',
372
+ ' "inputs": "What is machine learning?",\n',
373
+ ' "parameters": {"max_new_tokens": 100},\n',
374
+ '}\n',
375
+ '\n',
376
+ 'response = smr_client.invoke_endpoint(\n',
377
+ ' EndpointName=ENDPOINT_NAME,\n',
378
+ ' InferenceComponentName=ic_name,\n',
379
+ ' ContentType="application/json",\n',
380
+ ' Body=json.dumps(payload),\n',
381
+ ')\n',
382
+ '\n',
383
+ 'result = json.loads(response["Body"].read().decode())\n',
384
+ 'print(json.dumps(result, indent=2))'
385
+ ]))
386
+ <% } %>
387
+ <% if (enableLora) { %>
388
+ # ── Section 8: LoRA Adapter ──────────────────────────────────────────────────
389
+
390
+ cells.append(make_markdown_cell([
391
+ "## 🔗 LoRA Adapter\n",
392
+ "\n",
393
+ "LoRA (Low-Rank Adaptation) adapters let you serve multiple fine-tuned \"personalities\"\n",
394
+ "from a single base model. Adapter ICs share the base IC's GPU resources — no additional\n",
395
+ "compute allocation is needed. You can add/remove adapters without redeploying the endpoint.\n"
396
+ ]))
397
+
398
+ cells.append(make_code_cell([
399
+ '# ✏️ Edit these values to configure your adapter\n',
400
+ 'ADAPTER_NAME = "my-adapter"\n',
401
+ f'ADAPTER_WEIGHTS_URI = "{env("TUNE_ADAPTER_PATH_SFT", env("TUNE_ADAPTER_PATH_DPO", env("TUNE_OUTPUT_PATH_LATEST", "s3://your-bucket/adapters/my-adapter/adapter.tar.gz")))}"\n',
402
+ ]))
403
+
404
+ cells.append(make_code_cell([
405
+ 'adapter_ic_name = f"{PROJECT_NAME}-adapter-{ADAPTER_NAME}"\n',
406
+ '\n',
407
+ 'sm_client.create_inference_component(\n',
408
+ ' InferenceComponentName=adapter_ic_name,\n',
409
+ ' EndpointName=ENDPOINT_NAME,\n',
410
+ ' Specification={\n',
411
+ ' "BaseInferenceComponentName": ic_name,\n',
412
+ ' "Container": {\n',
413
+ ' "ArtifactUrl": ADAPTER_WEIGHTS_URI,\n',
414
+ ' },\n',
415
+ ' },\n',
416
+ ')\n',
417
+ '\n',
418
+ 'print(f"Creating adapter IC: {adapter_ic_name}...")\n',
419
+ '\n',
420
+ '# Wait for adapter IC InService\n',
421
+ 'while True:\n',
422
+ ' resp = sm_client.describe_inference_component(InferenceComponentName=adapter_ic_name)\n',
423
+ ' status = resp["InferenceComponentStatus"]\n',
424
+ ' if status == "InService":\n',
425
+ ' print(f"✅ Adapter IC InService: {adapter_ic_name}")\n',
426
+ ' break\n',
427
+ ' elif status == "Failed":\n',
428
+ ' print(f"❌ Adapter IC failed: {resp.get(\'FailureReason\', \'Unknown\')}")\n',
429
+ ' break\n',
430
+ ' print(f" Status: {status}...")\n',
431
+ ' time.sleep(30)'
432
+ ]))
433
+
434
+ <% if (framework === 'transformers') { %>
435
+ cells.append(make_code_cell([
436
+ 'payload = {\n',
437
+ ' "messages": [{"role": "user", "content": "What is machine learning?"}],\n',
438
+ ' "max_tokens": 100,\n',
439
+ ' "temperature": 0.7,\n',
440
+ '}\n',
441
+ '\n',
442
+ 'response = smr_client.invoke_endpoint(\n',
443
+ ' EndpointName=ENDPOINT_NAME,\n',
444
+ ' InferenceComponentName=adapter_ic_name,\n',
445
+ ' ContentType="application/json",\n',
446
+ ' Body=json.dumps(payload),\n',
447
+ ')\n',
448
+ '\n',
449
+ 'result = json.loads(response["Body"].read().decode())\n',
450
+ 'print(f"Adapter \'{ADAPTER_NAME}\' response:")\n',
451
+ 'print(result["choices"][0]["message"]["content"])'
452
+ ]))
453
+ <% } else { %>
454
+ cells.append(make_code_cell([
455
+ 'payload = {\n',
456
+ ' "inputs": "What is machine learning?",\n',
457
+ ' "parameters": {"max_new_tokens": 100},\n',
458
+ '}\n',
459
+ '\n',
460
+ 'response = smr_client.invoke_endpoint(\n',
461
+ ' EndpointName=ENDPOINT_NAME,\n',
462
+ ' InferenceComponentName=adapter_ic_name,\n',
463
+ ' ContentType="application/json",\n',
464
+ ' Body=json.dumps(payload),\n',
465
+ ')\n',
466
+ '\n',
467
+ 'result = json.loads(response["Body"].read().decode())\n',
468
+ 'print(f"Adapter \'{ADAPTER_NAME}\' response:")\n',
469
+ 'print(json.dumps(result, indent=2))'
470
+ ]))
471
+ <% } %>
472
+
473
+ cells.append(make_markdown_cell([
474
+ "### Adding More Adapters\n",
475
+ "\n",
476
+ "To add another adapter, duplicate this section with a different `ADAPTER_NAME` and\n",
477
+ "`ADAPTER_WEIGHTS_URI`. Each adapter shares the base IC's GPU resources.\n",
478
+ "\n",
479
+ "The CLI equivalent is: `./do/adapter add <name> --weights <s3-uri>`\n"
480
+ ]))
481
+ <% } %>
482
+ <% if (tuneSupported) { %>
483
+ # ── Section 9: Fine-Tune ─────────────────────────────────────────────────────
484
+
485
+ cells.append(make_markdown_cell([
486
+ "## 🎯 Managed Fine-Tuning\n",
487
+ "\n",
488
+ "SageMaker Managed Model Customization provides serverless fine-tuning — no instance\n",
489
+ "selection or container management needed. You provide a dataset and technique; SageMaker\n",
490
+ "handles infrastructure and optimization. The output is either LoRA adapter weights or a\n",
491
+ "full merged model, depending on the training type.\n"
492
+ ]))
493
+
494
+ cells.append(make_code_cell([
495
+ '# ✏️ Edit these values to configure your fine-tuning job\n',
496
+ 'TECHNIQUE = "sft" # Options: sft, dpo, rlaif, rlvr\n',
497
+ 'TRAINING_TYPE = "lora" # Options: lora, full-rank\n',
498
+ 'DATASET_S3_URI = "s3://your-bucket/datasets/train.jsonl" # ← Replace with your dataset\n',
499
+ f'TUNE_OUTPUT_BUCKET = "{env("TUNE_S3_BUCKET")}"'
500
+ ]))
501
+
502
+ cells.append(make_code_cell([
503
+ 'from sagemaker.modules.train import ModelTrainer\n',
504
+ '\n',
505
+ f'MODEL_NAME = "{env("MODEL_NAME")}"\n',
506
+ '\n',
507
+ 'job_name = f"{PROJECT_NAME}-tune-{TECHNIQUE}-{int(time.time())}"\n',
508
+ '\n',
509
+ 'trainer = ModelTrainer(\n',
510
+ ' model_id=MODEL_NAME,\n',
511
+ ' training_dataset={"s3Uri": DATASET_S3_URI},\n',
512
+ ' technique=TECHNIQUE,\n',
513
+ ' training_type=TRAINING_TYPE,\n',
514
+ ' output_path=f"s3://{TUNE_OUTPUT_BUCKET}/{PROJECT_NAME}/output/",\n',
515
+ ' role=role,\n',
516
+ ')\n',
517
+ 'trainer.train()\n',
518
+ '\n',
519
+ 'print(f"✅ Training job submitted: {job_name}")'
520
+ ]))
521
+
522
+ cells.append(make_code_cell([
523
+ 'import time as _time\n',
524
+ 'start_time = _time.time()\n',
525
+ '\n',
526
+ 'while True:\n',
527
+ ' status = trainer.describe()["TrainingJobStatus"]\n',
528
+ ' elapsed = int(_time.time() - start_time)\n',
529
+ ' elapsed_str = f"{elapsed // 60}m {elapsed % 60}s"\n',
530
+ '\n',
531
+ ' if status == "Completed":\n',
532
+ ' print(f"✅ Training completed in {elapsed_str}")\n',
533
+ ' break\n',
534
+ ' elif status == "Failed":\n',
535
+ ' reason = trainer.describe().get("FailureReason", "Unknown")\n',
536
+ ' print(f"❌ Training failed after {elapsed_str}: {reason}")\n',
537
+ ' break\n',
538
+ ' print(f" Status: {status} (elapsed: {elapsed_str})...")\n',
539
+ ' _time.sleep(60)'
540
+ ]))
541
+
542
+ cells.append(make_code_cell([
543
+ 'job_desc = trainer.describe()\n',
544
+ 'output_path = job_desc["ModelArtifacts"]["S3ModelArtifacts"]\n',
545
+ 'print(f"Output artifacts: {output_path}")'
546
+ ]))
547
+
548
+ cells.append(make_markdown_cell([
549
+ "### Next Steps\n",
550
+ "\n",
551
+ "**If LoRA output** (TRAINING_TYPE=\"lora\"): \n",
552
+ "- Run the adapter section above with `ADAPTER_WEIGHTS_URI` set to the output path \n",
553
+ "- Or use the CLI: `./do/adapter add tuned-sft --from-tune`\n",
554
+ "\n",
555
+ "**If full-rank output** (TRAINING_TYPE=\"full-rank\"): \n",
556
+ "- Deploy as a new inference component: `./do/add-ic tuned-v1 --from-tune`\n"
557
+ ]))
558
+ <% } %>
559
+
560
+ # ── Section 10: Cleanup ───────────────────────────────────────────────────────
561
+
562
+ cells.append(make_markdown_cell([
563
+ "## ⚠️ Cleanup\n",
564
+ "\n",
565
+ "**Warning**: This will delete all deployed resources. Only run when you're done testing.\n"
566
+ ]))
567
+
568
+ cells.append(make_code_cell([
569
+ '# Delete in reverse dependency order\n',
570
+ '\n',
571
+ '# 1. Delete adapter IC (if created in Section 8)\n',
572
+ 'try:\n',
573
+ ' sm_client.delete_inference_component(InferenceComponentName=adapter_ic_name)\n',
574
+ ' print(f"Deleting adapter IC: {adapter_ic_name}...")\n',
575
+ ' time.sleep(15)\n',
576
+ 'except NameError:\n',
577
+ ' pass # adapter_ic_name not defined — Section 8 was not run\n',
578
+ 'except Exception as e:\n',
579
+ ' print(f"Note: {e}")\n',
580
+ '\n',
581
+ '# 2. Delete base IC\n',
582
+ 'sm_client.delete_inference_component(InferenceComponentName=ic_name)\n',
583
+ 'print(f"Deleting base IC: {ic_name}...")\n',
584
+ 'time.sleep(30)\n',
585
+ '\n',
586
+ '# 3. Delete endpoint\n',
587
+ 'sm_client.delete_endpoint(EndpointName=ENDPOINT_NAME)\n',
588
+ 'print(f"Deleting endpoint: {ENDPOINT_NAME}...")\n',
589
+ '\n',
590
+ '# 4. Delete endpoint config\n',
591
+ 'sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n',
592
+ 'print(f"Deleting endpoint config: {endpoint_config_name}...")\n',
593
+ '\n',
594
+ '# 5. Delete model\n',
595
+ 'sm_client.delete_model(ModelName=model_name)\n',
596
+ 'print(f"Deleting model: {model_name}...")\n',
597
+ '\n',
598
+ 'print("\\n✅ All resources cleaned up")'
599
+ ]))
600
+ <% } else if (deploymentTarget === 'async-inference') { %>
601
+ # ── Section 5: Endpoint (Async) ──────────────────────────────────────────────
602
+
603
+ cells.append(make_markdown_cell([
604
+ "## Create Async Endpoint\n",
605
+ "\n",
606
+ "Create an endpoint configuration with async inference settings and deploy the endpoint. \n",
607
+ "Async inference is ideal for large payloads or long-running predictions — results are\n",
608
+ "written to S3 when ready.\n"
609
+ ]))
610
+
611
+ # Create endpoint config with AsyncInferenceConfig
612
+ cells.append(make_code_cell([
613
+ 'endpoint_config_name = f"{PROJECT_NAME}-epc-{int(time.time())}"\n',
614
+ '\n',
615
+ 'sm_client.create_endpoint_config(\n',
616
+ ' EndpointConfigName=endpoint_config_name,\n',
617
+ ' ExecutionRoleArn=role,\n',
618
+ ' ProductionVariants=[\n',
619
+ ' {\n',
620
+ ' "VariantName": "AllTraffic",\n',
621
+ ' "InstanceType": INSTANCE_TYPE,\n',
622
+ ' "InitialInstanceCount": 1,\n',
623
+ ' "ContainerStartupHealthCheckTimeoutInSeconds": HEALTH_CHECK_TIMEOUT,\n',
624
+ ' }\n',
625
+ ' ],\n',
626
+ ' AsyncInferenceConfig={\n',
627
+ ' "OutputConfig": {\n',
628
+ ' "S3OutputPath": f"s3://{PROJECT_NAME}-async-output/{PROJECT_NAME}/",\n',
629
+ ' },\n',
630
+ ' "ClientConfig": {\n',
631
+ ' "MaxConcurrentInvocationsPerInstance": 4,\n',
632
+ ' },\n',
633
+ ' },\n',
634
+ ')\n',
635
+ '\n',
636
+ 'print(f"✅ Async endpoint config created: {endpoint_config_name}")'
637
+ ]))
638
+
639
+ # Create endpoint
640
+ cells.append(make_code_cell([
641
+ 'sm_client.create_endpoint(\n',
642
+ ' EndpointName=ENDPOINT_NAME,\n',
643
+ ' EndpointConfigName=endpoint_config_name,\n',
644
+ ')\n',
645
+ '\n',
646
+ 'print(f"Creating endpoint: {ENDPOINT_NAME}...")\n',
647
+ '\n',
648
+ '# Wait for InService\n',
649
+ 'while True:\n',
650
+ ' resp = sm_client.describe_endpoint(EndpointName=ENDPOINT_NAME)\n',
651
+ ' status = resp["EndpointStatus"]\n',
652
+ ' if status == "InService":\n',
653
+ ' print(f"✅ Endpoint InService: {ENDPOINT_NAME}")\n',
654
+ ' break\n',
655
+ ' elif status == "Failed":\n',
656
+ ' print(f"❌ Endpoint failed: {resp.get(\'FailureReason\', \'Unknown\')}")\n',
657
+ ' break\n',
658
+ ' print(f" Status: {status}...")\n',
659
+ ' time.sleep(30)'
660
+ ]))
661
+
662
+ # ── Section 6: Inference Component — SKIPPED for async ───────────────────────
663
+ # Async inference does not support inference components.
664
+
665
+ # ── Section 7: Test (Async) ──────────────────────────────────────────────────
666
+
667
+ cells.append(make_markdown_cell([
668
+ "## Test Async Inference\n",
669
+ "\n",
670
+ "Upload input to S3, invoke the async endpoint, then poll S3 for the result.\n"
671
+ ]))
672
+
673
+ cells.append(make_code_cell([
674
+ 'import boto3\n',
675
+ '\n',
676
+ 's3_client = boto3.client("s3", region_name=AWS_REGION)\n',
677
+ 'async_input_bucket = f"{PROJECT_NAME}-async-output"\n',
678
+ 'async_input_key = f"{PROJECT_NAME}/input/test-request.json"\n',
679
+ '\n',
680
+ <% if (framework === 'transformers') { %>
681
+ 'payload = {\n',
682
+ ' "messages": [{"role": "user", "content": "What is machine learning?"}],\n',
683
+ ' "max_tokens": 100,\n',
684
+ ' "temperature": 0.7,\n',
685
+ '}\n',
686
+ <% } else { %>
687
+ 'payload = {\n',
688
+ ' "inputs": "What is machine learning?",\n',
689
+ ' "parameters": {"max_new_tokens": 100},\n',
690
+ '}\n',
691
+ <% } %>
692
+ '\n',
693
+ '# Upload input payload to S3\n',
694
+ 's3_client.put_object(\n',
695
+ ' Bucket=async_input_bucket,\n',
696
+ ' Key=async_input_key,\n',
697
+ ' Body=json.dumps(payload),\n',
698
+ ' ContentType="application/json",\n',
699
+ ')\n',
700
+ 'input_s3_uri = f"s3://{async_input_bucket}/{async_input_key}"\n',
701
+ 'print(f"Input uploaded to: {input_s3_uri}")'
702
+ ]))
703
+
704
+ cells.append(make_code_cell([
705
+ '# Invoke async endpoint\n',
706
+ 'response = smr_client.invoke_endpoint_async(\n',
707
+ ' EndpointName=ENDPOINT_NAME,\n',
708
+ ' ContentType="application/json",\n',
709
+ ' InputLocation=input_s3_uri,\n',
710
+ ')\n',
711
+ 'output_location = response["OutputLocation"]\n',
712
+ 'print(f"Output will be at: {output_location}")\n',
713
+ '\n',
714
+ '# Poll S3 for the result\n',
715
+ 'import urllib.parse\n',
716
+ 'parsed = urllib.parse.urlparse(output_location)\n',
717
+ 'output_bucket = parsed.netloc\n',
718
+ 'output_key = parsed.path.lstrip("/")\n',
719
+ '\n',
720
+ 'print("Waiting for result...")\n',
721
+ 'while True:\n',
722
+ ' try:\n',
723
+ ' result_obj = s3_client.get_object(Bucket=output_bucket, Key=output_key)\n',
724
+ ' result = json.loads(result_obj["Body"].read().decode())\n',
725
+ ' print("\\n✅ Async inference result:")\n',
726
+ ' print(json.dumps(result, indent=2))\n',
727
+ ' break\n',
728
+ ' except s3_client.exceptions.NoSuchKey:\n',
729
+ ' print(" Waiting for output...")\n',
730
+ ' time.sleep(10)'
731
+ ]))
732
+
733
+ # ── Sections 8-9: SKIPPED for async ─────────────────────────────────────────
734
+ # LoRA adapters and fine-tuning require a realtime endpoint with inference components.
735
+
736
+ # ── Section 10: Cleanup (Async) ──────────────────────────────────────────────
737
+
738
+ cells.append(make_markdown_cell([
739
+ "## ⚠️ Cleanup\n",
740
+ "\n",
741
+ "**Warning**: This will delete all deployed resources. Only run when you're done testing. \n",
742
+ "Note: The S3 output bucket and any SNS topics are NOT deleted (user-managed).\n"
743
+ ]))
744
+
745
+ cells.append(make_code_cell([
746
+ '# Delete in reverse dependency order (no IC for async)\n',
747
+ '\n',
748
+ '# 1. Delete endpoint\n',
749
+ 'sm_client.delete_endpoint(EndpointName=ENDPOINT_NAME)\n',
750
+ 'print(f"Deleting endpoint: {ENDPOINT_NAME}...")\n',
751
+ '\n',
752
+ '# 2. Delete endpoint config\n',
753
+ 'sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n',
754
+ 'print(f"Deleting endpoint config: {endpoint_config_name}...")\n',
755
+ '\n',
756
+ '# 3. Delete model\n',
757
+ 'sm_client.delete_model(ModelName=model_name)\n',
758
+ 'print(f"Deleting model: {model_name}...")\n',
759
+ '\n',
760
+ 'print("\\n✅ All resources cleaned up")\n',
761
+ 'print("Note: S3 output bucket and SNS topics are not deleted (user-managed).")'
762
+ ]))
763
+ <% } else if (deploymentTarget === 'batch-transform') { %>
764
+ # ── Section 5: Batch Transform Job ───────────────────────────────────────────
765
+
766
+ cells.append(make_markdown_cell([
767
+ "## Create Batch Transform Job\n",
768
+ "\n",
769
+ "Run batch inference on your input data stored in S3. \n",
770
+ "The transform job reads input from S3, runs inference, and writes output back to S3.\n"
771
+ ]))
772
+
773
+ cells.append(make_code_cell([
774
+ 'INPUT_S3_URI = "s3://your-bucket/input/" # ← Replace with your input data S3 path\n',
775
+ 'OUTPUT_S3_URI = "s3://your-bucket/output/" # ← Replace with your desired output S3 path\n',
776
+ '\n',
777
+ 'transform_job_name = f"{PROJECT_NAME}-transform-{int(time.time())}"\n',
778
+ '\n',
779
+ 'sm_client.create_transform_job(\n',
780
+ ' TransformJobName=transform_job_name,\n',
781
+ ' ModelName=model_name,\n',
782
+ ' TransformInput={\n',
783
+ ' "DataSource": {\n',
784
+ ' "S3DataSource": {\n',
785
+ ' "S3DataType": "S3Prefix",\n',
786
+ ' "S3Uri": INPUT_S3_URI,\n',
787
+ ' }\n',
788
+ ' },\n',
789
+ ' "ContentType": "application/json",\n',
790
+ ' },\n',
791
+ ' TransformOutput={\n',
792
+ ' "S3OutputPath": OUTPUT_S3_URI,\n',
793
+ ' },\n',
794
+ ' TransformResources={\n',
795
+ ' "InstanceType": INSTANCE_TYPE,\n',
796
+ ' "InstanceCount": 1,\n',
797
+ ' },\n',
798
+ ')\n',
799
+ '\n',
800
+ 'print(f"Transform job started: {transform_job_name}")\n',
801
+ '\n',
802
+ '# Wait for transform job to complete\n',
803
+ 'while True:\n',
804
+ ' resp = sm_client.describe_transform_job(TransformJobName=transform_job_name)\n',
805
+ ' status = resp["TransformJobStatus"]\n',
806
+ ' if status == "Completed":\n',
807
+ ' print(f"✅ Transform job completed: {transform_job_name}")\n',
808
+ ' break\n',
809
+ ' elif status == "Failed":\n',
810
+ ' print(f"❌ Transform job failed: {resp.get(\'FailureReason\', \'Unknown\')}")\n',
811
+ ' break\n',
812
+ ' elif status == "Stopped":\n',
813
+ ' print(f"⚠️ Transform job stopped")\n',
814
+ ' break\n',
815
+ ' print(f" Status: {status}...")\n',
816
+ ' time.sleep(30)'
817
+ ]))
818
+
819
+ # ── Section 7: Test (Download Output) ────────────────────────────────────────
820
+
821
+ cells.append(make_markdown_cell([
822
+ "## Review Transform Output\n",
823
+ "\n",
824
+ "Download and display a sample result from the batch transform output.\n"
825
+ ]))
826
+
827
+ cells.append(make_code_cell([
828
+ 'import boto3\n',
829
+ 'from urllib.parse import urlparse\n',
830
+ '\n',
831
+ 's3 = boto3.client("s3", region_name=AWS_REGION)\n',
832
+ '\n',
833
+ '# Parse the output S3 URI\n',
834
+ 'parsed = urlparse(OUTPUT_S3_URI)\n',
835
+ 'bucket = parsed.netloc\n',
836
+ 'prefix = parsed.path.lstrip("/")\n',
837
+ '\n',
838
+ '# List output files\n',
839
+ 'response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)\n',
840
+ 'output_files = [obj["Key"] for obj in response.get("Contents", [])]\n',
841
+ 'print(f"Output files ({len(output_files)}):")\n',
842
+ 'for f in output_files[:10]:\n',
843
+ ' print(f" s3://{bucket}/{f}")\n',
844
+ '\n',
845
+ '# Download and display first result\n',
846
+ 'if output_files:\n',
847
+ ' first_file = output_files[0]\n',
848
+ ' obj = s3.get_object(Bucket=bucket, Key=first_file)\n',
849
+ ' content = obj["Body"].read().decode("utf-8")\n',
850
+ ' print(f"\\nSample output from: {first_file}")\n',
851
+ ' print("-" * 60)\n',
852
+ ' print(content[:2000])\n',
853
+ 'else:\n',
854
+ ' print("No output files found.")'
855
+ ]))
856
+
857
+ # ── Section 10: Cleanup ───────────────────────────────────────────────────────
858
+
859
+ cells.append(make_markdown_cell([
860
+ "## ⚠️ Cleanup\n",
861
+ "\n",
862
+ "**Warning**: This will delete the model resource. Only run when you're done. \n",
863
+ "Note: The transform job is ephemeral and does not need deletion.\n"
864
+ ]))
865
+
866
+ cells.append(make_code_cell([
867
+ 'sm_client.delete_model(ModelName=model_name)\n',
868
+ 'print(f"Deleting model: {model_name}...")\n',
869
+ '\n',
870
+ 'print("\\n✅ All resources cleaned up")'
871
+ ]))
872
+ <% } %>
873
+
874
+ # ── Write notebook ───────────────────────────────────────────────────────────
875
+
876
+ notebook = {
877
+ "nbformat": 4,
878
+ "nbformat_minor": 5,
879
+ "metadata": {
880
+ "kernelspec": {
881
+ "display_name": "Python 3 (ipykernel)",
882
+ "language": "python",
883
+ "name": "python3"
884
+ },
885
+ "language_info": {
886
+ "name": "python",
887
+ "version": "3.10.0"
888
+ }
889
+ },
890
+ "cells": cells
891
+ }
892
+
893
+ output_path = "deploy_notebook.ipynb"
894
+ with open(output_path, "w") as f:
895
+ json.dump(notebook, f, indent=1)
896
+
897
+ print(f"\u2705 Notebook exported: ./{output_path}")