sagemaker-ops-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker_ops/__init__.py +4 -0
- sagemaker_ops/aws.py +502 -0
- sagemaker_ops/cli.py +262 -0
- sagemaker_ops/tui.py +458 -0
- sagemaker_ops_cli-0.1.0.dist-info/METADATA +241 -0
- sagemaker_ops_cli-0.1.0.dist-info/RECORD +9 -0
- sagemaker_ops_cli-0.1.0.dist-info/WHEEL +5 -0
- sagemaker_ops_cli-0.1.0.dist-info/entry_points.txt +2 -0
- sagemaker_ops_cli-0.1.0.dist-info/top_level.txt +1 -0
sagemaker_ops/aws.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import binascii
|
|
5
|
+
import json
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from datetime import datetime, timedelta, timezone
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Iterable
|
|
10
|
+
|
|
11
|
+
import boto3
|
|
12
|
+
from botocore.exceptions import BotoCoreError, ClientError
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
ACTIVE_PROCESSING_STATUSES = ("InProgress", "Stopping")
|
|
16
|
+
ACTIVE_PIPELINE_STATUSES = ("Executing", "Stopping")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AwsCliError(RuntimeError):
|
|
20
|
+
"""User-facing error raised by the CLI layer."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class AwsContext:
|
|
25
|
+
profile: str
|
|
26
|
+
region: str
|
|
27
|
+
sagemaker: Any
|
|
28
|
+
logs: Any
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class ProcessingJobView:
|
|
33
|
+
profile: str
|
|
34
|
+
region: str
|
|
35
|
+
name: str
|
|
36
|
+
status: str
|
|
37
|
+
creation_time: datetime | None
|
|
38
|
+
last_modified_time: datetime | None
|
|
39
|
+
started_time: datetime | None
|
|
40
|
+
ended_time: datetime | None
|
|
41
|
+
instance_type: str
|
|
42
|
+
instance_count: int | None
|
|
43
|
+
role_arn: str
|
|
44
|
+
failure_reason: str
|
|
45
|
+
arn: str
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class ProcessingJobsPage:
|
|
50
|
+
jobs: list[ProcessingJobView]
|
|
51
|
+
next_token: str | None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass(frozen=True)
|
|
55
|
+
class PipelineExecutionView:
|
|
56
|
+
profile: str
|
|
57
|
+
region: str
|
|
58
|
+
pipeline_name: str
|
|
59
|
+
execution_arn: str
|
|
60
|
+
display_name: str
|
|
61
|
+
status: str
|
|
62
|
+
start_time: datetime | None
|
|
63
|
+
last_modified_time: datetime | None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(frozen=True)
|
|
67
|
+
class PipelineExecutionsPage:
|
|
68
|
+
executions: list[PipelineExecutionView]
|
|
69
|
+
next_token: str | None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def load_job_spec(path: Path) -> dict[str, Any]:
|
|
73
|
+
if not path.exists():
|
|
74
|
+
raise AwsCliError(f"配置文件不存在: {path}")
|
|
75
|
+
|
|
76
|
+
text = path.read_text(encoding="utf-8")
|
|
77
|
+
suffix = path.suffix.lower()
|
|
78
|
+
if suffix == ".json":
|
|
79
|
+
return json.loads(text)
|
|
80
|
+
if suffix in {".yaml", ".yml"}:
|
|
81
|
+
try:
|
|
82
|
+
import yaml
|
|
83
|
+
except ImportError as exc:
|
|
84
|
+
raise AwsCliError("读取 YAML 需要安装: pip install 'sagemaker-ops-cli[yaml]'") from exc
|
|
85
|
+
loaded = yaml.safe_load(text)
|
|
86
|
+
if not isinstance(loaded, dict):
|
|
87
|
+
raise AwsCliError("YAML 配置必须是一个对象")
|
|
88
|
+
return loaded
|
|
89
|
+
raise AwsCliError("只支持 .json/.yaml/.yml 配置文件")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def parse_parameters(items: Iterable[str]) -> list[dict[str, str]]:
|
|
93
|
+
parameters: list[dict[str, str]] = []
|
|
94
|
+
for item in items:
|
|
95
|
+
if "=" not in item:
|
|
96
|
+
raise AwsCliError(f"Pipeline 参数必须是 NAME=VALUE 格式: {item}")
|
|
97
|
+
name, value = item.split("=", 1)
|
|
98
|
+
name = name.strip()
|
|
99
|
+
if not name:
|
|
100
|
+
raise AwsCliError(f"Pipeline 参数名不能为空: {item}")
|
|
101
|
+
parameters.append({"Name": name, "Value": value})
|
|
102
|
+
return parameters
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def build_contexts(
|
|
106
|
+
profiles: tuple[str, ...],
|
|
107
|
+
region: str | None,
|
|
108
|
+
all_profiles: bool = False,
|
|
109
|
+
) -> list[AwsContext]:
|
|
110
|
+
if all_profiles:
|
|
111
|
+
session = boto3.Session()
|
|
112
|
+
names = tuple(session.available_profiles)
|
|
113
|
+
if not names:
|
|
114
|
+
raise AwsCliError("没有找到任何 AWS profile")
|
|
115
|
+
else:
|
|
116
|
+
names = profiles or (None,)
|
|
117
|
+
|
|
118
|
+
contexts: list[AwsContext] = []
|
|
119
|
+
for profile in names:
|
|
120
|
+
try:
|
|
121
|
+
session = boto3.Session(profile_name=profile, region_name=region)
|
|
122
|
+
resolved_region = session.region_name or region
|
|
123
|
+
if not resolved_region:
|
|
124
|
+
label = profile or "default/env"
|
|
125
|
+
raise AwsCliError(f"profile {label} 没有配置 region,请传 --region")
|
|
126
|
+
contexts.append(
|
|
127
|
+
AwsContext(
|
|
128
|
+
profile=profile or session.profile_name or "default/env",
|
|
129
|
+
region=resolved_region,
|
|
130
|
+
sagemaker=session.client("sagemaker"),
|
|
131
|
+
logs=session.client("logs"),
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
except (BotoCoreError, ClientError) as exc:
|
|
135
|
+
raise AwsCliError(f"创建 AWS session 失败 profile={profile or 'default/env'}: {exc}") from exc
|
|
136
|
+
return contexts
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def submit_processing_job(ctx: AwsContext, spec: dict[str, Any]) -> dict[str, Any]:
|
|
140
|
+
try:
|
|
141
|
+
return ctx.sagemaker.create_processing_job(**spec)
|
|
142
|
+
except (BotoCoreError, ClientError) as exc:
|
|
143
|
+
raise AwsCliError(f"提交 processing job 失败: {exc}") from exc
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def start_pipeline_execution(
|
|
147
|
+
ctx: AwsContext,
|
|
148
|
+
pipeline_name: str,
|
|
149
|
+
display_name: str | None,
|
|
150
|
+
parameters: list[dict[str, str]],
|
|
151
|
+
client_request_token: str | None,
|
|
152
|
+
) -> dict[str, Any]:
|
|
153
|
+
request: dict[str, Any] = {"PipelineName": pipeline_name}
|
|
154
|
+
if display_name:
|
|
155
|
+
request["PipelineExecutionDisplayName"] = display_name
|
|
156
|
+
if parameters:
|
|
157
|
+
request["PipelineParameters"] = parameters
|
|
158
|
+
if client_request_token:
|
|
159
|
+
request["ClientRequestToken"] = client_request_token
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
return ctx.sagemaker.start_pipeline_execution(**request)
|
|
163
|
+
except (BotoCoreError, ClientError) as exc:
|
|
164
|
+
raise AwsCliError(f"启动 pipeline 失败: {exc}") from exc
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def list_processing_jobs(ctx: AwsContext, max_results: int = 50) -> list[ProcessingJobView]:
|
|
168
|
+
return list_processing_jobs_page(ctx, page_size=max_results).jobs
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def list_processing_jobs_page(
|
|
172
|
+
ctx: AwsContext,
|
|
173
|
+
page_size: int = 20,
|
|
174
|
+
next_token: str | None = None,
|
|
175
|
+
) -> ProcessingJobsPage:
|
|
176
|
+
page_size = max(1, min(page_size, 100))
|
|
177
|
+
status_index, aws_next_token = _decode_processing_jobs_token(next_token)
|
|
178
|
+
summaries: list[dict[str, Any]] = []
|
|
179
|
+
output_next_token: str | None = None
|
|
180
|
+
|
|
181
|
+
while len(summaries) < page_size and status_index < len(ACTIVE_PROCESSING_STATUSES):
|
|
182
|
+
status = ACTIVE_PROCESSING_STATUSES[status_index]
|
|
183
|
+
request: dict[str, Any] = {
|
|
184
|
+
"StatusEquals": status,
|
|
185
|
+
"SortBy": "CreationTime",
|
|
186
|
+
"SortOrder": "Descending",
|
|
187
|
+
"MaxResults": min(100, page_size - len(summaries)),
|
|
188
|
+
}
|
|
189
|
+
if aws_next_token:
|
|
190
|
+
request["NextToken"] = aws_next_token
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
response = ctx.sagemaker.list_processing_jobs(**request)
|
|
194
|
+
except (BotoCoreError, ClientError) as exc:
|
|
195
|
+
raise AwsCliError(f"读取 processing jobs 失败: {exc}") from exc
|
|
196
|
+
|
|
197
|
+
summaries.extend(response.get("ProcessingJobSummaries", []))
|
|
198
|
+
aws_next_token = response.get("NextToken")
|
|
199
|
+
if aws_next_token:
|
|
200
|
+
output_next_token = _encode_processing_jobs_token(status_index, aws_next_token)
|
|
201
|
+
break
|
|
202
|
+
status_index += 1
|
|
203
|
+
|
|
204
|
+
jobs = [_processing_job_view_from_summary(ctx, summary) for summary in summaries]
|
|
205
|
+
return ProcessingJobsPage(
|
|
206
|
+
jobs=sorted(jobs, key=lambda job: job.creation_time or datetime.min.replace(tzinfo=timezone.utc), reverse=True),
|
|
207
|
+
next_token=output_next_token,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _processing_job_view_from_summary(ctx: AwsContext, summary: dict[str, Any]) -> ProcessingJobView:
|
|
212
|
+
name = summary["ProcessingJobName"]
|
|
213
|
+
try:
|
|
214
|
+
detail = ctx.sagemaker.describe_processing_job(ProcessingJobName=name)
|
|
215
|
+
except (BotoCoreError, ClientError):
|
|
216
|
+
detail = summary
|
|
217
|
+
cluster = detail.get("ProcessingResources", {}).get("ClusterConfig", {})
|
|
218
|
+
return ProcessingJobView(
|
|
219
|
+
profile=ctx.profile,
|
|
220
|
+
region=ctx.region,
|
|
221
|
+
name=name,
|
|
222
|
+
status=detail.get("ProcessingJobStatus", summary.get("ProcessingJobStatus", "")),
|
|
223
|
+
creation_time=detail.get("CreationTime", summary.get("CreationTime")),
|
|
224
|
+
last_modified_time=detail.get("LastModifiedTime"),
|
|
225
|
+
started_time=detail.get("ProcessingStartTime"),
|
|
226
|
+
ended_time=detail.get("ProcessingEndTime"),
|
|
227
|
+
instance_type=cluster.get("InstanceType", ""),
|
|
228
|
+
instance_count=cluster.get("InstanceCount"),
|
|
229
|
+
role_arn=detail.get("RoleArn", ""),
|
|
230
|
+
failure_reason=detail.get("FailureReason", ""),
|
|
231
|
+
arn=detail.get("ProcessingJobArn", summary.get("ProcessingJobArn", "")),
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _encode_processing_jobs_token(status_index: int, aws_next_token: str) -> str:
|
|
236
|
+
payload = json.dumps({"status_index": status_index, "aws_next_token": aws_next_token}).encode("utf-8")
|
|
237
|
+
return base64.urlsafe_b64encode(payload).decode("ascii")
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _decode_processing_jobs_token(next_token: str | None) -> tuple[int, str | None]:
|
|
241
|
+
if not next_token:
|
|
242
|
+
return 0, None
|
|
243
|
+
try:
|
|
244
|
+
decoded = base64.urlsafe_b64decode(next_token.encode("ascii"))
|
|
245
|
+
payload = json.loads(decoded.decode("utf-8"))
|
|
246
|
+
status_index = int(payload.get("status_index", 0))
|
|
247
|
+
aws_next_token = payload.get("aws_next_token")
|
|
248
|
+
except (binascii.Error, json.JSONDecodeError, TypeError, ValueError) as exc:
|
|
249
|
+
raise AwsCliError("processing jobs next token 无效") from exc
|
|
250
|
+
if status_index < 0 or status_index >= len(ACTIVE_PROCESSING_STATUSES) or not isinstance(aws_next_token, str):
|
|
251
|
+
raise AwsCliError("processing jobs next token 无效")
|
|
252
|
+
return status_index, aws_next_token
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def list_active_pipeline_executions(
|
|
256
|
+
ctx: AwsContext,
|
|
257
|
+
pipeline_name: str | None = None,
|
|
258
|
+
per_pipeline: int = 10,
|
|
259
|
+
recent_hours: int = 3,
|
|
260
|
+
) -> list[PipelineExecutionView]:
|
|
261
|
+
return list_pipeline_executions_page(
|
|
262
|
+
ctx,
|
|
263
|
+
pipeline_name=pipeline_name,
|
|
264
|
+
per_pipeline=per_pipeline,
|
|
265
|
+
recent_hours=recent_hours,
|
|
266
|
+
).executions
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def list_pipeline_executions_page(
|
|
270
|
+
ctx: AwsContext,
|
|
271
|
+
pipeline_name: str | None = None,
|
|
272
|
+
per_pipeline: int = 10,
|
|
273
|
+
recent_hours: int = 3,
|
|
274
|
+
pipeline_page_size: int = 10,
|
|
275
|
+
next_token: str | None = None,
|
|
276
|
+
) -> PipelineExecutionsPage:
|
|
277
|
+
names, output_next_token = _list_pipeline_names_page(
|
|
278
|
+
ctx,
|
|
279
|
+
pipeline_name=pipeline_name,
|
|
280
|
+
page_size=pipeline_page_size,
|
|
281
|
+
next_token=next_token,
|
|
282
|
+
)
|
|
283
|
+
cutoff = datetime.now(timezone.utc) - timedelta(hours=recent_hours)
|
|
284
|
+
executions: list[PipelineExecutionView] = []
|
|
285
|
+
|
|
286
|
+
for name in names:
|
|
287
|
+
executions.extend(_list_recent_pipeline_executions_for_name(ctx, name, per_pipeline, cutoff))
|
|
288
|
+
|
|
289
|
+
return PipelineExecutionsPage(
|
|
290
|
+
executions=sorted(
|
|
291
|
+
executions,
|
|
292
|
+
key=lambda item: item.last_modified_time or item.start_time or datetime.min.replace(tzinfo=timezone.utc),
|
|
293
|
+
reverse=True,
|
|
294
|
+
),
|
|
295
|
+
next_token=output_next_token,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _list_recent_pipeline_executions_for_name(
|
|
300
|
+
ctx: AwsContext,
|
|
301
|
+
pipeline_name: str,
|
|
302
|
+
per_pipeline: int,
|
|
303
|
+
cutoff: datetime,
|
|
304
|
+
) -> list[PipelineExecutionView]:
|
|
305
|
+
request = {
|
|
306
|
+
"PipelineName": pipeline_name,
|
|
307
|
+
"SortBy": "CreationTime",
|
|
308
|
+
"SortOrder": "Descending",
|
|
309
|
+
"MaxResults": max(1, min(per_pipeline, 100)),
|
|
310
|
+
}
|
|
311
|
+
try:
|
|
312
|
+
response = ctx.sagemaker.list_pipeline_executions(**request)
|
|
313
|
+
except (BotoCoreError, ClientError) as exc:
|
|
314
|
+
raise AwsCliError(f"读取 pipeline executions 失败 pipeline={pipeline_name}: {exc}") from exc
|
|
315
|
+
|
|
316
|
+
executions: list[PipelineExecutionView] = []
|
|
317
|
+
for summary in response.get("PipelineExecutionSummaries", []):
|
|
318
|
+
status = summary.get("PipelineExecutionStatus", "")
|
|
319
|
+
execution_arn = summary.get("PipelineExecutionArn", "")
|
|
320
|
+
detail = _describe_pipeline_execution_safely(ctx, execution_arn) if execution_arn else {}
|
|
321
|
+
start_time = detail.get("StartTime", summary.get("StartTime"))
|
|
322
|
+
last_modified_time = detail.get("LastModifiedTime", summary.get("LastModifiedTime"))
|
|
323
|
+
if not _should_show_pipeline_execution(status, start_time, last_modified_time, cutoff):
|
|
324
|
+
continue
|
|
325
|
+
executions.append(
|
|
326
|
+
PipelineExecutionView(
|
|
327
|
+
profile=ctx.profile,
|
|
328
|
+
region=ctx.region,
|
|
329
|
+
pipeline_name=detail.get("PipelineName", pipeline_name),
|
|
330
|
+
execution_arn=execution_arn,
|
|
331
|
+
display_name=summary.get("PipelineExecutionDisplayName", detail.get("PipelineExecutionDisplayName", "")),
|
|
332
|
+
status=status,
|
|
333
|
+
start_time=start_time,
|
|
334
|
+
last_modified_time=last_modified_time,
|
|
335
|
+
)
|
|
336
|
+
)
|
|
337
|
+
return executions
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _list_pipeline_names_page(
|
|
341
|
+
ctx: AwsContext,
|
|
342
|
+
pipeline_name: str | None,
|
|
343
|
+
page_size: int,
|
|
344
|
+
next_token: str | None,
|
|
345
|
+
) -> tuple[list[str], str | None]:
|
|
346
|
+
if pipeline_name:
|
|
347
|
+
if next_token:
|
|
348
|
+
raise AwsCliError("指定 --name 时不支持 pipeline next token")
|
|
349
|
+
return [pipeline_name], None
|
|
350
|
+
|
|
351
|
+
request: dict[str, Any] = {
|
|
352
|
+
"SortBy": "CreationTime",
|
|
353
|
+
"SortOrder": "Descending",
|
|
354
|
+
"MaxResults": max(1, min(page_size, 100)),
|
|
355
|
+
}
|
|
356
|
+
if next_token:
|
|
357
|
+
request["NextToken"] = next_token
|
|
358
|
+
try:
|
|
359
|
+
response = ctx.sagemaker.list_pipelines(**request)
|
|
360
|
+
except (BotoCoreError, ClientError) as exc:
|
|
361
|
+
raise AwsCliError(f"读取 pipelines 失败: {exc}") from exc
|
|
362
|
+
|
|
363
|
+
names = [item["PipelineName"] for item in response.get("PipelineSummaries", [])]
|
|
364
|
+
return names, response.get("NextToken")
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _describe_pipeline_execution_safely(ctx: AwsContext, execution_arn: str) -> dict[str, Any]:
|
|
368
|
+
try:
|
|
369
|
+
return ctx.sagemaker.describe_pipeline_execution(PipelineExecutionArn=execution_arn)
|
|
370
|
+
except (BotoCoreError, ClientError):
|
|
371
|
+
return {}
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def _should_show_pipeline_execution(
|
|
375
|
+
status: str,
|
|
376
|
+
start_time: datetime | None,
|
|
377
|
+
last_modified_time: datetime | None,
|
|
378
|
+
cutoff: datetime,
|
|
379
|
+
) -> bool:
|
|
380
|
+
if status in ACTIVE_PIPELINE_STATUSES:
|
|
381
|
+
return True
|
|
382
|
+
recent_at = last_modified_time or start_time
|
|
383
|
+
if recent_at is None:
|
|
384
|
+
return False
|
|
385
|
+
if recent_at.tzinfo is None:
|
|
386
|
+
recent_at = recent_at.replace(tzinfo=timezone.utc)
|
|
387
|
+
return recent_at >= cutoff
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def list_pipeline_steps(ctx: AwsContext, execution_arn: str) -> list[dict[str, Any]]:
|
|
391
|
+
paginator = ctx.sagemaker.get_paginator("list_pipeline_execution_steps")
|
|
392
|
+
steps: list[dict[str, Any]] = []
|
|
393
|
+
try:
|
|
394
|
+
for page in paginator.paginate(PipelineExecutionArn=execution_arn):
|
|
395
|
+
steps.extend(page.get("PipelineExecutionSteps", []))
|
|
396
|
+
except (BotoCoreError, ClientError) as exc:
|
|
397
|
+
raise AwsCliError(f"读取 pipeline steps 失败: {exc}") from exc
|
|
398
|
+
return sorted(steps, key=lambda step: step.get("StartTime") or datetime.min.replace(tzinfo=timezone.utc))
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def describe_pipeline_execution(ctx: AwsContext, execution_arn: str) -> dict[str, Any]:
|
|
402
|
+
try:
|
|
403
|
+
return ctx.sagemaker.describe_pipeline_execution(PipelineExecutionArn=execution_arn)
|
|
404
|
+
except (BotoCoreError, ClientError) as exc:
|
|
405
|
+
raise AwsCliError(f"读取 pipeline execution 失败: {exc}") from exc
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def tail_step_logs(ctx: AwsContext, step: dict[str, Any], limit: int = 80) -> list[str]:
|
|
409
|
+
source = infer_log_source(step)
|
|
410
|
+
if source is None:
|
|
411
|
+
return []
|
|
412
|
+
log_group, stream_prefix = source
|
|
413
|
+
return tail_cloudwatch_logs(ctx, log_group, stream_prefix, limit=limit)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def tail_cloudwatch_logs(
|
|
417
|
+
ctx: AwsContext,
|
|
418
|
+
log_group: str,
|
|
419
|
+
stream_prefix: str,
|
|
420
|
+
limit: int = 80,
|
|
421
|
+
) -> list[str]:
|
|
422
|
+
try:
|
|
423
|
+
streams = ctx.logs.describe_log_streams(
|
|
424
|
+
logGroupName=log_group,
|
|
425
|
+
logStreamNamePrefix=stream_prefix,
|
|
426
|
+
limit=5,
|
|
427
|
+
).get("logStreams", [])
|
|
428
|
+
except ctx.logs.exceptions.ResourceNotFoundException:
|
|
429
|
+
return [f"没有找到日志组: {log_group}"]
|
|
430
|
+
except (BotoCoreError, ClientError) as exc:
|
|
431
|
+
return [f"读取日志流失败: {exc}"]
|
|
432
|
+
|
|
433
|
+
streams = sorted(streams, key=lambda stream: stream.get("lastEventTimestamp", 0), reverse=True)
|
|
434
|
+
lines: list[str] = []
|
|
435
|
+
for stream in streams:
|
|
436
|
+
stream_name = stream["logStreamName"]
|
|
437
|
+
try:
|
|
438
|
+
events = ctx.logs.get_log_events(
|
|
439
|
+
logGroupName=log_group,
|
|
440
|
+
logStreamName=stream_name,
|
|
441
|
+
limit=limit,
|
|
442
|
+
startFromHead=False,
|
|
443
|
+
).get("events", [])
|
|
444
|
+
except (BotoCoreError, ClientError) as exc:
|
|
445
|
+
lines.append(f"[{stream_name}] 读取失败: {exc}")
|
|
446
|
+
continue
|
|
447
|
+
for event in events[-limit:]:
|
|
448
|
+
timestamp = datetime.fromtimestamp(event["timestamp"] / 1000, tz=timezone.utc)
|
|
449
|
+
lines.append(f"{timestamp:%Y-%m-%d %H:%M:%S}Z {event.get('message', '').rstrip()}")
|
|
450
|
+
return lines[-limit:]
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def infer_log_source(step: dict[str, Any]) -> tuple[str, str] | None:
|
|
454
|
+
metadata = step.get("Metadata") or {}
|
|
455
|
+
sources = (
|
|
456
|
+
("ProcessingJob", "/aws/sagemaker/ProcessingJobs"),
|
|
457
|
+
("TrainingJob", "/aws/sagemaker/TrainingJobs"),
|
|
458
|
+
("TransformJob", "/aws/sagemaker/TransformJobs"),
|
|
459
|
+
)
|
|
460
|
+
for key, log_group in sources:
|
|
461
|
+
payload = metadata.get(key)
|
|
462
|
+
if not isinstance(payload, dict):
|
|
463
|
+
continue
|
|
464
|
+
arn = payload.get("Arn")
|
|
465
|
+
if arn:
|
|
466
|
+
return log_group, arn.rsplit("/", 1)[-1]
|
|
467
|
+
return None
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def format_dt(value: datetime | None) -> str:
|
|
471
|
+
if value is None:
|
|
472
|
+
return ""
|
|
473
|
+
if value.tzinfo is None:
|
|
474
|
+
value = value.replace(tzinfo=timezone.utc)
|
|
475
|
+
return value.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%SZ")
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def format_duration(start: datetime | None, end: datetime | None = None) -> str:
|
|
479
|
+
if start is None:
|
|
480
|
+
return ""
|
|
481
|
+
if start.tzinfo is None:
|
|
482
|
+
start = start.replace(tzinfo=timezone.utc)
|
|
483
|
+
finish = end or datetime.now(timezone.utc)
|
|
484
|
+
if finish.tzinfo is None:
|
|
485
|
+
finish = finish.replace(tzinfo=timezone.utc)
|
|
486
|
+
seconds = max(0, int((finish - start).total_seconds()))
|
|
487
|
+
hours, remainder = divmod(seconds, 3600)
|
|
488
|
+
minutes, seconds = divmod(remainder, 60)
|
|
489
|
+
if hours:
|
|
490
|
+
return f"{hours}h{minutes:02d}m"
|
|
491
|
+
if minutes:
|
|
492
|
+
return f"{minutes}m{seconds:02d}s"
|
|
493
|
+
return f"{seconds}s"
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def _list_pipeline_names(ctx: AwsContext) -> list[str]:
|
|
497
|
+
paginator = ctx.sagemaker.get_paginator("list_pipelines")
|
|
498
|
+
names: list[str] = []
|
|
499
|
+
for page in paginator.paginate(SortBy="CreationTime", SortOrder="Descending"):
|
|
500
|
+
names.extend(item["PipelineName"] for item in page.get("PipelineSummaries", []))
|
|
501
|
+
return names
|
|
502
|
+
|