sagemaker-ops-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ """SageMaker operations CLI."""
2
+
3
+ __version__ = "0.1.0"
4
+
sagemaker_ops/aws.py ADDED
@@ -0,0 +1,502 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import binascii
5
+ import json
6
+ from dataclasses import dataclass
7
+ from datetime import datetime, timedelta, timezone
8
+ from pathlib import Path
9
+ from typing import Any, Iterable
10
+
11
+ import boto3
12
+ from botocore.exceptions import BotoCoreError, ClientError
13
+
14
+
15
+ ACTIVE_PROCESSING_STATUSES = ("InProgress", "Stopping")
16
+ ACTIVE_PIPELINE_STATUSES = ("Executing", "Stopping")
17
+
18
+
19
+ class AwsCliError(RuntimeError):
20
+ """User-facing error raised by the CLI layer."""
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class AwsContext:
25
+ profile: str
26
+ region: str
27
+ sagemaker: Any
28
+ logs: Any
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class ProcessingJobView:
33
+ profile: str
34
+ region: str
35
+ name: str
36
+ status: str
37
+ creation_time: datetime | None
38
+ last_modified_time: datetime | None
39
+ started_time: datetime | None
40
+ ended_time: datetime | None
41
+ instance_type: str
42
+ instance_count: int | None
43
+ role_arn: str
44
+ failure_reason: str
45
+ arn: str
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class ProcessingJobsPage:
50
+ jobs: list[ProcessingJobView]
51
+ next_token: str | None
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class PipelineExecutionView:
56
+ profile: str
57
+ region: str
58
+ pipeline_name: str
59
+ execution_arn: str
60
+ display_name: str
61
+ status: str
62
+ start_time: datetime | None
63
+ last_modified_time: datetime | None
64
+
65
+
66
+ @dataclass(frozen=True)
67
+ class PipelineExecutionsPage:
68
+ executions: list[PipelineExecutionView]
69
+ next_token: str | None
70
+
71
+
72
+ def load_job_spec(path: Path) -> dict[str, Any]:
73
+ if not path.exists():
74
+ raise AwsCliError(f"配置文件不存在: {path}")
75
+
76
+ text = path.read_text(encoding="utf-8")
77
+ suffix = path.suffix.lower()
78
+ if suffix == ".json":
79
+ return json.loads(text)
80
+ if suffix in {".yaml", ".yml"}:
81
+ try:
82
+ import yaml
83
+ except ImportError as exc:
84
+ raise AwsCliError("读取 YAML 需要安装: pip install 'sagemaker-ops-cli[yaml]'") from exc
85
+ loaded = yaml.safe_load(text)
86
+ if not isinstance(loaded, dict):
87
+ raise AwsCliError("YAML 配置必须是一个对象")
88
+ return loaded
89
+ raise AwsCliError("只支持 .json/.yaml/.yml 配置文件")
90
+
91
+
92
+ def parse_parameters(items: Iterable[str]) -> list[dict[str, str]]:
93
+ parameters: list[dict[str, str]] = []
94
+ for item in items:
95
+ if "=" not in item:
96
+ raise AwsCliError(f"Pipeline 参数必须是 NAME=VALUE 格式: {item}")
97
+ name, value = item.split("=", 1)
98
+ name = name.strip()
99
+ if not name:
100
+ raise AwsCliError(f"Pipeline 参数名不能为空: {item}")
101
+ parameters.append({"Name": name, "Value": value})
102
+ return parameters
103
+
104
+
105
+ def build_contexts(
106
+ profiles: tuple[str, ...],
107
+ region: str | None,
108
+ all_profiles: bool = False,
109
+ ) -> list[AwsContext]:
110
+ if all_profiles:
111
+ session = boto3.Session()
112
+ names = tuple(session.available_profiles)
113
+ if not names:
114
+ raise AwsCliError("没有找到任何 AWS profile")
115
+ else:
116
+ names = profiles or (None,)
117
+
118
+ contexts: list[AwsContext] = []
119
+ for profile in names:
120
+ try:
121
+ session = boto3.Session(profile_name=profile, region_name=region)
122
+ resolved_region = session.region_name or region
123
+ if not resolved_region:
124
+ label = profile or "default/env"
125
+ raise AwsCliError(f"profile {label} 没有配置 region,请传 --region")
126
+ contexts.append(
127
+ AwsContext(
128
+ profile=profile or session.profile_name or "default/env",
129
+ region=resolved_region,
130
+ sagemaker=session.client("sagemaker"),
131
+ logs=session.client("logs"),
132
+ )
133
+ )
134
+ except (BotoCoreError, ClientError) as exc:
135
+ raise AwsCliError(f"创建 AWS session 失败 profile={profile or 'default/env'}: {exc}") from exc
136
+ return contexts
137
+
138
+
139
+ def submit_processing_job(ctx: AwsContext, spec: dict[str, Any]) -> dict[str, Any]:
140
+ try:
141
+ return ctx.sagemaker.create_processing_job(**spec)
142
+ except (BotoCoreError, ClientError) as exc:
143
+ raise AwsCliError(f"提交 processing job 失败: {exc}") from exc
144
+
145
+
146
+ def start_pipeline_execution(
147
+ ctx: AwsContext,
148
+ pipeline_name: str,
149
+ display_name: str | None,
150
+ parameters: list[dict[str, str]],
151
+ client_request_token: str | None,
152
+ ) -> dict[str, Any]:
153
+ request: dict[str, Any] = {"PipelineName": pipeline_name}
154
+ if display_name:
155
+ request["PipelineExecutionDisplayName"] = display_name
156
+ if parameters:
157
+ request["PipelineParameters"] = parameters
158
+ if client_request_token:
159
+ request["ClientRequestToken"] = client_request_token
160
+
161
+ try:
162
+ return ctx.sagemaker.start_pipeline_execution(**request)
163
+ except (BotoCoreError, ClientError) as exc:
164
+ raise AwsCliError(f"启动 pipeline 失败: {exc}") from exc
165
+
166
+
167
+ def list_processing_jobs(ctx: AwsContext, max_results: int = 50) -> list[ProcessingJobView]:
168
+ return list_processing_jobs_page(ctx, page_size=max_results).jobs
169
+
170
+
171
+ def list_processing_jobs_page(
172
+ ctx: AwsContext,
173
+ page_size: int = 20,
174
+ next_token: str | None = None,
175
+ ) -> ProcessingJobsPage:
176
+ page_size = max(1, min(page_size, 100))
177
+ status_index, aws_next_token = _decode_processing_jobs_token(next_token)
178
+ summaries: list[dict[str, Any]] = []
179
+ output_next_token: str | None = None
180
+
181
+ while len(summaries) < page_size and status_index < len(ACTIVE_PROCESSING_STATUSES):
182
+ status = ACTIVE_PROCESSING_STATUSES[status_index]
183
+ request: dict[str, Any] = {
184
+ "StatusEquals": status,
185
+ "SortBy": "CreationTime",
186
+ "SortOrder": "Descending",
187
+ "MaxResults": min(100, page_size - len(summaries)),
188
+ }
189
+ if aws_next_token:
190
+ request["NextToken"] = aws_next_token
191
+
192
+ try:
193
+ response = ctx.sagemaker.list_processing_jobs(**request)
194
+ except (BotoCoreError, ClientError) as exc:
195
+ raise AwsCliError(f"读取 processing jobs 失败: {exc}") from exc
196
+
197
+ summaries.extend(response.get("ProcessingJobSummaries", []))
198
+ aws_next_token = response.get("NextToken")
199
+ if aws_next_token:
200
+ output_next_token = _encode_processing_jobs_token(status_index, aws_next_token)
201
+ break
202
+ status_index += 1
203
+
204
+ jobs = [_processing_job_view_from_summary(ctx, summary) for summary in summaries]
205
+ return ProcessingJobsPage(
206
+ jobs=sorted(jobs, key=lambda job: job.creation_time or datetime.min.replace(tzinfo=timezone.utc), reverse=True),
207
+ next_token=output_next_token,
208
+ )
209
+
210
+
211
+ def _processing_job_view_from_summary(ctx: AwsContext, summary: dict[str, Any]) -> ProcessingJobView:
212
+ name = summary["ProcessingJobName"]
213
+ try:
214
+ detail = ctx.sagemaker.describe_processing_job(ProcessingJobName=name)
215
+ except (BotoCoreError, ClientError):
216
+ detail = summary
217
+ cluster = detail.get("ProcessingResources", {}).get("ClusterConfig", {})
218
+ return ProcessingJobView(
219
+ profile=ctx.profile,
220
+ region=ctx.region,
221
+ name=name,
222
+ status=detail.get("ProcessingJobStatus", summary.get("ProcessingJobStatus", "")),
223
+ creation_time=detail.get("CreationTime", summary.get("CreationTime")),
224
+ last_modified_time=detail.get("LastModifiedTime"),
225
+ started_time=detail.get("ProcessingStartTime"),
226
+ ended_time=detail.get("ProcessingEndTime"),
227
+ instance_type=cluster.get("InstanceType", ""),
228
+ instance_count=cluster.get("InstanceCount"),
229
+ role_arn=detail.get("RoleArn", ""),
230
+ failure_reason=detail.get("FailureReason", ""),
231
+ arn=detail.get("ProcessingJobArn", summary.get("ProcessingJobArn", "")),
232
+ )
233
+
234
+
235
+ def _encode_processing_jobs_token(status_index: int, aws_next_token: str) -> str:
236
+ payload = json.dumps({"status_index": status_index, "aws_next_token": aws_next_token}).encode("utf-8")
237
+ return base64.urlsafe_b64encode(payload).decode("ascii")
238
+
239
+
240
+ def _decode_processing_jobs_token(next_token: str | None) -> tuple[int, str | None]:
241
+ if not next_token:
242
+ return 0, None
243
+ try:
244
+ decoded = base64.urlsafe_b64decode(next_token.encode("ascii"))
245
+ payload = json.loads(decoded.decode("utf-8"))
246
+ status_index = int(payload.get("status_index", 0))
247
+ aws_next_token = payload.get("aws_next_token")
248
+ except (binascii.Error, json.JSONDecodeError, TypeError, ValueError) as exc:
249
+ raise AwsCliError("processing jobs next token 无效") from exc
250
+ if status_index < 0 or status_index >= len(ACTIVE_PROCESSING_STATUSES) or not isinstance(aws_next_token, str):
251
+ raise AwsCliError("processing jobs next token 无效")
252
+ return status_index, aws_next_token
253
+
254
+
255
+ def list_active_pipeline_executions(
256
+ ctx: AwsContext,
257
+ pipeline_name: str | None = None,
258
+ per_pipeline: int = 10,
259
+ recent_hours: int = 3,
260
+ ) -> list[PipelineExecutionView]:
261
+ return list_pipeline_executions_page(
262
+ ctx,
263
+ pipeline_name=pipeline_name,
264
+ per_pipeline=per_pipeline,
265
+ recent_hours=recent_hours,
266
+ ).executions
267
+
268
+
269
+ def list_pipeline_executions_page(
270
+ ctx: AwsContext,
271
+ pipeline_name: str | None = None,
272
+ per_pipeline: int = 10,
273
+ recent_hours: int = 3,
274
+ pipeline_page_size: int = 10,
275
+ next_token: str | None = None,
276
+ ) -> PipelineExecutionsPage:
277
+ names, output_next_token = _list_pipeline_names_page(
278
+ ctx,
279
+ pipeline_name=pipeline_name,
280
+ page_size=pipeline_page_size,
281
+ next_token=next_token,
282
+ )
283
+ cutoff = datetime.now(timezone.utc) - timedelta(hours=recent_hours)
284
+ executions: list[PipelineExecutionView] = []
285
+
286
+ for name in names:
287
+ executions.extend(_list_recent_pipeline_executions_for_name(ctx, name, per_pipeline, cutoff))
288
+
289
+ return PipelineExecutionsPage(
290
+ executions=sorted(
291
+ executions,
292
+ key=lambda item: item.last_modified_time or item.start_time or datetime.min.replace(tzinfo=timezone.utc),
293
+ reverse=True,
294
+ ),
295
+ next_token=output_next_token,
296
+ )
297
+
298
+
299
+ def _list_recent_pipeline_executions_for_name(
300
+ ctx: AwsContext,
301
+ pipeline_name: str,
302
+ per_pipeline: int,
303
+ cutoff: datetime,
304
+ ) -> list[PipelineExecutionView]:
305
+ request = {
306
+ "PipelineName": pipeline_name,
307
+ "SortBy": "CreationTime",
308
+ "SortOrder": "Descending",
309
+ "MaxResults": max(1, min(per_pipeline, 100)),
310
+ }
311
+ try:
312
+ response = ctx.sagemaker.list_pipeline_executions(**request)
313
+ except (BotoCoreError, ClientError) as exc:
314
+ raise AwsCliError(f"读取 pipeline executions 失败 pipeline={pipeline_name}: {exc}") from exc
315
+
316
+ executions: list[PipelineExecutionView] = []
317
+ for summary in response.get("PipelineExecutionSummaries", []):
318
+ status = summary.get("PipelineExecutionStatus", "")
319
+ execution_arn = summary.get("PipelineExecutionArn", "")
320
+ detail = _describe_pipeline_execution_safely(ctx, execution_arn) if execution_arn else {}
321
+ start_time = detail.get("StartTime", summary.get("StartTime"))
322
+ last_modified_time = detail.get("LastModifiedTime", summary.get("LastModifiedTime"))
323
+ if not _should_show_pipeline_execution(status, start_time, last_modified_time, cutoff):
324
+ continue
325
+ executions.append(
326
+ PipelineExecutionView(
327
+ profile=ctx.profile,
328
+ region=ctx.region,
329
+ pipeline_name=detail.get("PipelineName", pipeline_name),
330
+ execution_arn=execution_arn,
331
+ display_name=summary.get("PipelineExecutionDisplayName", detail.get("PipelineExecutionDisplayName", "")),
332
+ status=status,
333
+ start_time=start_time,
334
+ last_modified_time=last_modified_time,
335
+ )
336
+ )
337
+ return executions
338
+
339
+
340
+ def _list_pipeline_names_page(
341
+ ctx: AwsContext,
342
+ pipeline_name: str | None,
343
+ page_size: int,
344
+ next_token: str | None,
345
+ ) -> tuple[list[str], str | None]:
346
+ if pipeline_name:
347
+ if next_token:
348
+ raise AwsCliError("指定 --name 时不支持 pipeline next token")
349
+ return [pipeline_name], None
350
+
351
+ request: dict[str, Any] = {
352
+ "SortBy": "CreationTime",
353
+ "SortOrder": "Descending",
354
+ "MaxResults": max(1, min(page_size, 100)),
355
+ }
356
+ if next_token:
357
+ request["NextToken"] = next_token
358
+ try:
359
+ response = ctx.sagemaker.list_pipelines(**request)
360
+ except (BotoCoreError, ClientError) as exc:
361
+ raise AwsCliError(f"读取 pipelines 失败: {exc}") from exc
362
+
363
+ names = [item["PipelineName"] for item in response.get("PipelineSummaries", [])]
364
+ return names, response.get("NextToken")
365
+
366
+
367
+ def _describe_pipeline_execution_safely(ctx: AwsContext, execution_arn: str) -> dict[str, Any]:
368
+ try:
369
+ return ctx.sagemaker.describe_pipeline_execution(PipelineExecutionArn=execution_arn)
370
+ except (BotoCoreError, ClientError):
371
+ return {}
372
+
373
+
374
+ def _should_show_pipeline_execution(
375
+ status: str,
376
+ start_time: datetime | None,
377
+ last_modified_time: datetime | None,
378
+ cutoff: datetime,
379
+ ) -> bool:
380
+ if status in ACTIVE_PIPELINE_STATUSES:
381
+ return True
382
+ recent_at = last_modified_time or start_time
383
+ if recent_at is None:
384
+ return False
385
+ if recent_at.tzinfo is None:
386
+ recent_at = recent_at.replace(tzinfo=timezone.utc)
387
+ return recent_at >= cutoff
388
+
389
+
390
+ def list_pipeline_steps(ctx: AwsContext, execution_arn: str) -> list[dict[str, Any]]:
391
+ paginator = ctx.sagemaker.get_paginator("list_pipeline_execution_steps")
392
+ steps: list[dict[str, Any]] = []
393
+ try:
394
+ for page in paginator.paginate(PipelineExecutionArn=execution_arn):
395
+ steps.extend(page.get("PipelineExecutionSteps", []))
396
+ except (BotoCoreError, ClientError) as exc:
397
+ raise AwsCliError(f"读取 pipeline steps 失败: {exc}") from exc
398
+ return sorted(steps, key=lambda step: step.get("StartTime") or datetime.min.replace(tzinfo=timezone.utc))
399
+
400
+
401
+ def describe_pipeline_execution(ctx: AwsContext, execution_arn: str) -> dict[str, Any]:
402
+ try:
403
+ return ctx.sagemaker.describe_pipeline_execution(PipelineExecutionArn=execution_arn)
404
+ except (BotoCoreError, ClientError) as exc:
405
+ raise AwsCliError(f"读取 pipeline execution 失败: {exc}") from exc
406
+
407
+
408
+ def tail_step_logs(ctx: AwsContext, step: dict[str, Any], limit: int = 80) -> list[str]:
409
+ source = infer_log_source(step)
410
+ if source is None:
411
+ return []
412
+ log_group, stream_prefix = source
413
+ return tail_cloudwatch_logs(ctx, log_group, stream_prefix, limit=limit)
414
+
415
+
416
+ def tail_cloudwatch_logs(
417
+ ctx: AwsContext,
418
+ log_group: str,
419
+ stream_prefix: str,
420
+ limit: int = 80,
421
+ ) -> list[str]:
422
+ try:
423
+ streams = ctx.logs.describe_log_streams(
424
+ logGroupName=log_group,
425
+ logStreamNamePrefix=stream_prefix,
426
+ limit=5,
427
+ ).get("logStreams", [])
428
+ except ctx.logs.exceptions.ResourceNotFoundException:
429
+ return [f"没有找到日志组: {log_group}"]
430
+ except (BotoCoreError, ClientError) as exc:
431
+ return [f"读取日志流失败: {exc}"]
432
+
433
+ streams = sorted(streams, key=lambda stream: stream.get("lastEventTimestamp", 0), reverse=True)
434
+ lines: list[str] = []
435
+ for stream in streams:
436
+ stream_name = stream["logStreamName"]
437
+ try:
438
+ events = ctx.logs.get_log_events(
439
+ logGroupName=log_group,
440
+ logStreamName=stream_name,
441
+ limit=limit,
442
+ startFromHead=False,
443
+ ).get("events", [])
444
+ except (BotoCoreError, ClientError) as exc:
445
+ lines.append(f"[{stream_name}] 读取失败: {exc}")
446
+ continue
447
+ for event in events[-limit:]:
448
+ timestamp = datetime.fromtimestamp(event["timestamp"] / 1000, tz=timezone.utc)
449
+ lines.append(f"{timestamp:%Y-%m-%d %H:%M:%S}Z {event.get('message', '').rstrip()}")
450
+ return lines[-limit:]
451
+
452
+
453
+ def infer_log_source(step: dict[str, Any]) -> tuple[str, str] | None:
454
+ metadata = step.get("Metadata") or {}
455
+ sources = (
456
+ ("ProcessingJob", "/aws/sagemaker/ProcessingJobs"),
457
+ ("TrainingJob", "/aws/sagemaker/TrainingJobs"),
458
+ ("TransformJob", "/aws/sagemaker/TransformJobs"),
459
+ )
460
+ for key, log_group in sources:
461
+ payload = metadata.get(key)
462
+ if not isinstance(payload, dict):
463
+ continue
464
+ arn = payload.get("Arn")
465
+ if arn:
466
+ return log_group, arn.rsplit("/", 1)[-1]
467
+ return None
468
+
469
+
470
+ def format_dt(value: datetime | None) -> str:
471
+ if value is None:
472
+ return ""
473
+ if value.tzinfo is None:
474
+ value = value.replace(tzinfo=timezone.utc)
475
+ return value.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%SZ")
476
+
477
+
478
+ def format_duration(start: datetime | None, end: datetime | None = None) -> str:
479
+ if start is None:
480
+ return ""
481
+ if start.tzinfo is None:
482
+ start = start.replace(tzinfo=timezone.utc)
483
+ finish = end or datetime.now(timezone.utc)
484
+ if finish.tzinfo is None:
485
+ finish = finish.replace(tzinfo=timezone.utc)
486
+ seconds = max(0, int((finish - start).total_seconds()))
487
+ hours, remainder = divmod(seconds, 3600)
488
+ minutes, seconds = divmod(remainder, 60)
489
+ if hours:
490
+ return f"{hours}h{minutes:02d}m"
491
+ if minutes:
492
+ return f"{minutes}m{seconds:02d}s"
493
+ return f"{seconds}s"
494
+
495
+
496
+ def _list_pipeline_names(ctx: AwsContext) -> list[str]:
497
+ paginator = ctx.sagemaker.get_paginator("list_pipelines")
498
+ names: list[str] = []
499
+ for page in paginator.paginate(SortBy="CreationTime", SortOrder="Descending"):
500
+ names.extend(item["PipelineName"] for item in page.get("PipelineSummaries", []))
501
+ return names
502
+