apache-airflow-providers-amazon 9.4.0rc1__py3-none-any.whl → 9.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +80 -110
- airflow/providers/amazon/aws/auth_manager/router/login.py +11 -4
- airflow/providers/amazon/aws/auth_manager/user.py +7 -4
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
- airflow/providers/amazon/aws/hooks/appflow.py +5 -15
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
- airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
- airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
- airflow/providers/amazon/aws/hooks/dms.py +3 -1
- airflow/providers/amazon/aws/hooks/ec2.py +1 -1
- airflow/providers/amazon/aws/hooks/eks.py +3 -6
- airflow/providers/amazon/aws/hooks/glue.py +6 -2
- airflow/providers/amazon/aws/hooks/logs.py +2 -2
- airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +10 -10
- airflow/providers/amazon/aws/hooks/redshift_data.py +3 -4
- airflow/providers/amazon/aws/hooks/s3.py +3 -1
- airflow/providers/amazon/aws/hooks/sagemaker.py +2 -2
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
- airflow/providers/amazon/aws/links/athena.py +1 -2
- airflow/providers/amazon/aws/links/base_aws.py +8 -1
- airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
- airflow/providers/amazon/aws/log/s3_task_handler.py +136 -84
- airflow/providers/amazon/aws/notifications/chime.py +1 -2
- airflow/providers/amazon/aws/notifications/sns.py +1 -1
- airflow/providers/amazon/aws/notifications/sqs.py +1 -1
- airflow/providers/amazon/aws/operators/ec2.py +91 -83
- airflow/providers/amazon/aws/operators/eks.py +3 -3
- airflow/providers/amazon/aws/operators/mwaa.py +73 -2
- airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
- airflow/providers/amazon/aws/operators/s3.py +147 -157
- airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
- airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
- airflow/providers/amazon/aws/sensors/ec2.py +5 -12
- airflow/providers/amazon/aws/sensors/emr.py +1 -1
- airflow/providers/amazon/aws/sensors/glacier.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +161 -0
- airflow/providers/amazon/aws/sensors/rds.py +10 -5
- airflow/providers/amazon/aws/sensors/s3.py +32 -43
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
- airflow/providers/amazon/aws/sensors/step_function.py +2 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
- airflow/providers/amazon/aws/triggers/README.md +4 -4
- airflow/providers/amazon/aws/triggers/base.py +11 -2
- airflow/providers/amazon/aws/triggers/ecs.py +6 -2
- airflow/providers/amazon/aws/triggers/eks.py +2 -2
- airflow/providers/amazon/aws/triggers/glue.py +1 -1
- airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
- airflow/providers/amazon/aws/triggers/s3.py +31 -6
- airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
- airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
- airflow/providers/amazon/aws/triggers/sqs.py +11 -3
- airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
- airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
- airflow/providers/amazon/get_provider_info.py +46 -5
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/METADATA +40 -33
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/RECORD +68 -61
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/WHEEL +1 -1
- airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/entry_points.txt +0 -0
@@ -211,7 +211,7 @@ class GlueJobHook(AwsBaseHook):
|
|
211
211
|
|
212
212
|
The async version of get_job_state.
|
213
213
|
"""
|
214
|
-
async with self.
|
214
|
+
async with await self.get_async_conn() as client:
|
215
215
|
job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
|
216
216
|
return job_run["JobRun"]["JobRunState"]
|
217
217
|
|
@@ -236,6 +236,9 @@ class GlueJobHook(AwsBaseHook):
|
|
236
236
|
"""
|
237
237
|
log_client = self.logs_hook.get_conn()
|
238
238
|
paginator = log_client.get_paginator("filter_log_events")
|
239
|
+
job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]
|
240
|
+
# StartTime needs to be an int and is Epoch time in milliseconds
|
241
|
+
start_time = int(job_run["StartedOn"].timestamp() * 1000)
|
239
242
|
|
240
243
|
def display_logs_from(log_group: str, continuation_token: str | None) -> str | None:
|
241
244
|
"""Mutualize iteration over the 2 different log streams glue jobs write to."""
|
@@ -245,6 +248,7 @@ class GlueJobHook(AwsBaseHook):
|
|
245
248
|
for response in paginator.paginate(
|
246
249
|
logGroupName=log_group,
|
247
250
|
logStreamNames=[run_id],
|
251
|
+
startTime=start_time,
|
248
252
|
PaginationConfig={"StartingToken": continuation_token},
|
249
253
|
):
|
250
254
|
fetched_logs.extend([event["message"] for event in response["events"]])
|
@@ -270,7 +274,7 @@ class GlueJobHook(AwsBaseHook):
|
|
270
274
|
self.log.info("No new log from the Glue Job in %s", log_group)
|
271
275
|
return next_token
|
272
276
|
|
273
|
-
log_group_prefix =
|
277
|
+
log_group_prefix = job_run["LogGroupName"]
|
274
278
|
log_group_default = f"{log_group_prefix}/{DEFAULT_LOG_SUFFIX}"
|
275
279
|
log_group_error = f"{log_group_prefix}/{ERROR_LOG_SUFFIX}"
|
276
280
|
# one would think that the error log group would contain only errors, but it actually contains
|
@@ -152,7 +152,7 @@ class AwsLogsHook(AwsBaseHook):
|
|
152
152
|
If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName.
|
153
153
|
:param count: The maximum number of items returned
|
154
154
|
"""
|
155
|
-
async with self.
|
155
|
+
async with await self.get_async_conn() as client:
|
156
156
|
try:
|
157
157
|
response: dict[str, Any] = await client.describe_log_streams(
|
158
158
|
logGroupName=log_group,
|
@@ -194,7 +194,7 @@ class AwsLogsHook(AwsBaseHook):
|
|
194
194
|
else:
|
195
195
|
token_arg = {}
|
196
196
|
|
197
|
-
async with self.
|
197
|
+
async with await self.get_async_conn() as client:
|
198
198
|
response = await client.get_log_events(
|
199
199
|
logGroupName=log_group,
|
200
200
|
logStreamName=log_stream_name,
|
@@ -18,6 +18,7 @@
|
|
18
18
|
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
|
+
import requests
|
21
22
|
from botocore.exceptions import ClientError
|
22
23
|
|
23
24
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
@@ -29,6 +30,12 @@ class MwaaHook(AwsBaseHook):
|
|
29
30
|
|
30
31
|
Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`
|
31
32
|
|
33
|
+
If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the hook will use a fallback method
|
34
|
+
that uses the AWS credential to generate a local web login token for the Airflow Web UI and then directly
|
35
|
+
make requests to the Airflow API. This fallback method can be set as the default (and only) method used by
|
36
|
+
setting `generate_local_token` to True. Learn more here:
|
37
|
+
https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API
|
38
|
+
|
32
39
|
Additional arguments (such as ``aws_conn_id``) may be specified and
|
33
40
|
are passed down to the underlying AwsBaseHook.
|
34
41
|
|
@@ -47,6 +54,7 @@ class MwaaHook(AwsBaseHook):
|
|
47
54
|
method: str,
|
48
55
|
body: dict | None = None,
|
49
56
|
query_params: dict | None = None,
|
57
|
+
generate_local_token: bool = False,
|
50
58
|
) -> dict:
|
51
59
|
"""
|
52
60
|
Invoke the REST API on the Airflow webserver with the specified inputs.
|
@@ -56,30 +64,86 @@ class MwaaHook(AwsBaseHook):
|
|
56
64
|
|
57
65
|
:param env_name: name of the MWAA environment
|
58
66
|
:param path: Apache Airflow REST API endpoint path to be called
|
59
|
-
:param method: HTTP method used for making Airflow REST API calls
|
67
|
+
:param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE'
|
60
68
|
:param body: Request body for the Apache Airflow REST API call
|
61
69
|
:param query_params: Query parameters to be included in the Apache Airflow REST API call
|
70
|
+
:param generate_local_token: If True, only the local web token method is used without trying boto's
|
71
|
+
`invoke_rest_api` first. If False, the local web token method is used as a fallback after trying
|
72
|
+
boto's `invoke_rest_api`
|
62
73
|
"""
|
63
|
-
|
74
|
+
# Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
|
75
|
+
body = {k: v for k, v in body.items() if v is not None} if body else {}
|
76
|
+
query_params = query_params or {}
|
64
77
|
api_kwargs = {
|
65
78
|
"Name": env_name,
|
66
79
|
"Path": path,
|
67
80
|
"Method": method,
|
68
|
-
|
69
|
-
"
|
70
|
-
"QueryParameters": query_params if query_params else {},
|
81
|
+
"Body": body,
|
82
|
+
"QueryParameters": query_params,
|
71
83
|
}
|
84
|
+
|
85
|
+
if generate_local_token:
|
86
|
+
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
|
87
|
+
|
72
88
|
try:
|
73
|
-
|
89
|
+
response = self.conn.invoke_rest_api(**api_kwargs)
|
74
90
|
# ResponseMetadata is removed because it contains data that is either very unlikely to be useful
|
75
91
|
# in XComs and logs, or redundant given the data already included in the response
|
76
|
-
|
77
|
-
return
|
92
|
+
response.pop("ResponseMetadata", None)
|
93
|
+
return response
|
94
|
+
|
78
95
|
except ClientError as e:
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
96
|
+
if (
|
97
|
+
e.response["Error"]["Code"] == "AccessDeniedException"
|
98
|
+
and "Airflow role" in e.response["Error"]["Message"]
|
99
|
+
):
|
100
|
+
self.log.info(
|
101
|
+
"Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
|
102
|
+
)
|
103
|
+
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
|
104
|
+
else:
|
105
|
+
to_log = e.response
|
106
|
+
# ResponseMetadata is removed because it contains data that is either very unlikely to be
|
107
|
+
# useful in XComs and logs, or redundant given the data already included in the response
|
108
|
+
to_log.pop("ResponseMetadata", None)
|
109
|
+
self.log.error(to_log)
|
110
|
+
raise
|
111
|
+
|
112
|
+
def _invoke_rest_api_using_local_session_token(
|
113
|
+
self,
|
114
|
+
**api_kwargs,
|
115
|
+
) -> dict:
|
116
|
+
try:
|
117
|
+
session, hostname = self._get_session_conn(api_kwargs["Name"])
|
118
|
+
|
119
|
+
response = session.request(
|
120
|
+
method=api_kwargs["Method"],
|
121
|
+
url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
|
122
|
+
params=api_kwargs["QueryParameters"],
|
123
|
+
json=api_kwargs["Body"],
|
124
|
+
timeout=10,
|
125
|
+
)
|
126
|
+
response.raise_for_status()
|
127
|
+
|
128
|
+
except requests.HTTPError as e:
|
129
|
+
self.log.error(e.response.json())
|
130
|
+
raise
|
131
|
+
|
132
|
+
return {
|
133
|
+
"RestApiStatusCode": response.status_code,
|
134
|
+
"RestApiResponse": response.json(),
|
135
|
+
}
|
136
|
+
|
137
|
+
# Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
|
138
|
+
def _get_session_conn(self, env_name: str) -> tuple:
|
139
|
+
create_token_response = self.conn.create_web_login_token(Name=env_name)
|
140
|
+
web_server_hostname = create_token_response["WebServerHostname"]
|
141
|
+
web_token = create_token_response["WebToken"]
|
142
|
+
|
143
|
+
login_url = f"https://{web_server_hostname}/aws_mwaa/login"
|
144
|
+
login_payload = {"token": web_token}
|
145
|
+
session = requests.Session()
|
146
|
+
login_response = session.post(login_url, data=login_payload, timeout=10)
|
147
|
+
login_response.raise_for_status()
|
148
|
+
|
149
|
+
return session, web_server_hostname
|
@@ -67,7 +67,7 @@ class RedshiftHook(AwsBaseHook):
|
|
67
67
|
for the cluster that is being created.
|
68
68
|
:param params: Remaining AWS Create cluster API params.
|
69
69
|
"""
|
70
|
-
response = self.
|
70
|
+
response = self.conn.create_cluster(
|
71
71
|
ClusterIdentifier=cluster_identifier,
|
72
72
|
NodeType=node_type,
|
73
73
|
MasterUsername=master_username,
|
@@ -87,13 +87,13 @@ class RedshiftHook(AwsBaseHook):
|
|
87
87
|
:param cluster_identifier: unique identifier of a cluster
|
88
88
|
"""
|
89
89
|
try:
|
90
|
-
response = self.
|
90
|
+
response = self.conn.describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"]
|
91
91
|
return response[0]["ClusterStatus"] if response else None
|
92
|
-
except self.
|
92
|
+
except self.conn.exceptions.ClusterNotFoundFault:
|
93
93
|
return "cluster_not_found"
|
94
94
|
|
95
95
|
async def cluster_status_async(self, cluster_identifier: str) -> str:
|
96
|
-
async with self.
|
96
|
+
async with await self.get_async_conn() as client:
|
97
97
|
response = await client.describe_clusters(ClusterIdentifier=cluster_identifier)
|
98
98
|
return response["Clusters"][0]["ClusterStatus"] if response else None
|
99
99
|
|
@@ -115,7 +115,7 @@ class RedshiftHook(AwsBaseHook):
|
|
115
115
|
"""
|
116
116
|
final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or ""
|
117
117
|
|
118
|
-
response = self.
|
118
|
+
response = self.conn.delete_cluster(
|
119
119
|
ClusterIdentifier=cluster_identifier,
|
120
120
|
SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
|
121
121
|
FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier,
|
@@ -131,7 +131,7 @@ class RedshiftHook(AwsBaseHook):
|
|
131
131
|
|
132
132
|
:param cluster_identifier: unique identifier of a cluster
|
133
133
|
"""
|
134
|
-
response = self.
|
134
|
+
response = self.conn.describe_cluster_snapshots(ClusterIdentifier=cluster_identifier)
|
135
135
|
if "Snapshots" not in response:
|
136
136
|
return None
|
137
137
|
snapshots = response["Snapshots"]
|
@@ -149,7 +149,7 @@ class RedshiftHook(AwsBaseHook):
|
|
149
149
|
:param cluster_identifier: unique identifier of a cluster
|
150
150
|
:param snapshot_identifier: unique identifier for a snapshot of a cluster
|
151
151
|
"""
|
152
|
-
response = self.
|
152
|
+
response = self.conn.restore_from_cluster_snapshot(
|
153
153
|
ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier
|
154
154
|
)
|
155
155
|
return response["Cluster"] if response["Cluster"] else None
|
@@ -175,7 +175,7 @@ class RedshiftHook(AwsBaseHook):
|
|
175
175
|
"""
|
176
176
|
if tags is None:
|
177
177
|
tags = []
|
178
|
-
response = self.
|
178
|
+
response = self.conn.create_cluster_snapshot(
|
179
179
|
SnapshotIdentifier=snapshot_identifier,
|
180
180
|
ClusterIdentifier=cluster_identifier,
|
181
181
|
ManualSnapshotRetentionPeriod=retention_period,
|
@@ -192,11 +192,11 @@ class RedshiftHook(AwsBaseHook):
|
|
192
192
|
:param snapshot_identifier: A unique identifier for the snapshot that you are requesting
|
193
193
|
"""
|
194
194
|
try:
|
195
|
-
response = self.
|
195
|
+
response = self.conn.describe_cluster_snapshots(
|
196
196
|
SnapshotIdentifier=snapshot_identifier,
|
197
197
|
)
|
198
198
|
snapshot = response.get("Snapshots")[0]
|
199
199
|
snapshot_status: str = snapshot.get("Status")
|
200
200
|
return snapshot_status
|
201
|
-
except self.
|
201
|
+
except self.conn.exceptions.ClusterSnapshotNotFoundFault:
|
202
202
|
return None
|
@@ -186,8 +186,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
186
186
|
RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError
|
187
187
|
)
|
188
188
|
raise exception_cls(
|
189
|
-
f"Statement {resp['Id']} terminated with status {status}. "
|
190
|
-
f"Response details: {pformat(resp)}"
|
189
|
+
f"Statement {resp['Id']} terminated with status {status}. Response details: {pformat(resp)}"
|
191
190
|
)
|
192
191
|
|
193
192
|
self.log.info("Query status: %s", status)
|
@@ -275,7 +274,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
275
274
|
|
276
275
|
:param statement_id: the UUID of the statement
|
277
276
|
"""
|
278
|
-
async with self.
|
277
|
+
async with await self.get_async_conn() as client:
|
279
278
|
desc = await client.describe_statement(Id=statement_id)
|
280
279
|
return desc["Status"] in RUNNING_STATES
|
281
280
|
|
@@ -288,6 +287,6 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
288
287
|
|
289
288
|
:param statement_id: the UUID of the statement
|
290
289
|
"""
|
291
|
-
async with self.
|
290
|
+
async with await self.get_async_conn() as client:
|
292
291
|
resp = await client.describe_statement(Id=statement_id)
|
293
292
|
return self.parse_statement_response(resp)
|
@@ -1494,7 +1494,9 @@ class S3Hook(AwsBaseHook):
|
|
1494
1494
|
get_hook_lineage_collector().add_output_asset(
|
1495
1495
|
context=self,
|
1496
1496
|
scheme="file",
|
1497
|
-
asset_kwargs={
|
1497
|
+
asset_kwargs={
|
1498
|
+
"path": str(file_path) if file_path.is_absolute() else str(file_path.absolute())
|
1499
|
+
},
|
1498
1500
|
)
|
1499
1501
|
file = open(file_path, "wb")
|
1500
1502
|
else:
|
@@ -131,7 +131,7 @@ def secondary_training_status_message(
|
|
131
131
|
status_strs = []
|
132
132
|
for transition in transitions_to_print:
|
133
133
|
message = transition["StatusMessage"]
|
134
|
-
time_utc = timezone.convert_to_utc(cast(datetime, job_description["LastModifiedTime"]))
|
134
|
+
time_utc = timezone.convert_to_utc(cast("datetime", job_description["LastModifiedTime"]))
|
135
135
|
status_strs.append(f"{time_utc:%Y-%m-%d %H:%M:%S} {transition['Status']} - {message}")
|
136
136
|
|
137
137
|
return "\n".join(status_strs)
|
@@ -1318,7 +1318,7 @@ class SageMakerHook(AwsBaseHook):
|
|
1318
1318
|
|
1319
1319
|
:param job_name: the name of the training job
|
1320
1320
|
"""
|
1321
|
-
async with self.
|
1321
|
+
async with await self.get_async_conn() as client:
|
1322
1322
|
response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
|
1323
1323
|
return response
|
1324
1324
|
|
@@ -0,0 +1,188 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
|
18
|
+
"""This module contains the Amazon SageMaker Unified Studio Notebook hook."""
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
import time
|
23
|
+
|
24
|
+
from sagemaker_studio import ClientConfig
|
25
|
+
from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI
|
26
|
+
|
27
|
+
from airflow.exceptions import AirflowException
|
28
|
+
from airflow.hooks.base import BaseHook
|
29
|
+
from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner
|
30
|
+
|
31
|
+
|
32
|
+
class SageMakerNotebookHook(BaseHook):
|
33
|
+
"""
|
34
|
+
Interact with Sagemaker Unified Studio Workflows.
|
35
|
+
|
36
|
+
This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API.
|
37
|
+
|
38
|
+
Examples:
|
39
|
+
.. code-block:: python
|
40
|
+
|
41
|
+
from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import SageMakerNotebookHook
|
42
|
+
|
43
|
+
notebook_hook = SageMakerNotebookHook(
|
44
|
+
input_config={"input_path": "path/to/notebook.ipynb", "input_params": {"param1": "value1"}},
|
45
|
+
output_config={"output_uri": "folder/output/location/prefix", "output_formats": "NOTEBOOK"},
|
46
|
+
execution_name="notebook_execution",
|
47
|
+
waiter_delay=10,
|
48
|
+
waiter_max_attempts=1440,
|
49
|
+
)
|
50
|
+
|
51
|
+
:param execution_name: The name of the notebook job to be executed, this is same as task_id.
|
52
|
+
:param input_config: Configuration for the input file.
|
53
|
+
Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}}
|
54
|
+
:param output_config: Configuration for the output format. It should include an output_formats parameter to specify the output format.
|
55
|
+
Example: {'output_formats': ['NOTEBOOK']}
|
56
|
+
:param compute: compute configuration to use for the notebook execution. This is a required attribute
|
57
|
+
if the execution is on a remote compute.
|
58
|
+
Example: { "instance_type": "ml.m5.large", "volume_size_in_gb": 30, "volume_kms_key_id": "", "image_uri": "string", "container_entrypoint": [ "string" ]}
|
59
|
+
:param termination_condition: conditions to match to terminate the remote execution.
|
60
|
+
Example: { "MaxRuntimeInSeconds": 3600 }
|
61
|
+
:param tags: tags to be associated with the remote execution runs.
|
62
|
+
Example: { "md_analytics": "logs" }
|
63
|
+
:param waiter_delay: Interval in seconds to check the task execution status.
|
64
|
+
:param waiter_max_attempts: Number of attempts to wait before returning FAILED.
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
execution_name: str,
|
70
|
+
input_config: dict | None = None,
|
71
|
+
output_config: dict | None = None,
|
72
|
+
compute: dict | None = None,
|
73
|
+
termination_condition: dict | None = None,
|
74
|
+
tags: dict | None = None,
|
75
|
+
waiter_delay: int = 10,
|
76
|
+
waiter_max_attempts: int = 1440,
|
77
|
+
*args,
|
78
|
+
**kwargs,
|
79
|
+
):
|
80
|
+
super().__init__(*args, **kwargs)
|
81
|
+
self._sagemaker_studio = SageMakerStudioAPI(self._get_sagemaker_studio_config())
|
82
|
+
self.execution_name = execution_name
|
83
|
+
self.input_config = input_config or {}
|
84
|
+
self.output_config = output_config or {"output_formats": ["NOTEBOOK"]}
|
85
|
+
self.compute = compute
|
86
|
+
self.termination_condition = termination_condition or {}
|
87
|
+
self.tags = tags or {}
|
88
|
+
self.waiter_delay = waiter_delay
|
89
|
+
self.waiter_max_attempts = waiter_max_attempts
|
90
|
+
|
91
|
+
def _get_sagemaker_studio_config(self):
|
92
|
+
config = ClientConfig()
|
93
|
+
config.overrides["execution"] = {"local": is_local_runner()}
|
94
|
+
return config
|
95
|
+
|
96
|
+
def _format_start_execution_input_config(self):
|
97
|
+
config = {
|
98
|
+
"notebook_config": {
|
99
|
+
"input_path": self.input_config.get("input_path"),
|
100
|
+
"input_parameters": self.input_config.get("input_params"),
|
101
|
+
},
|
102
|
+
}
|
103
|
+
|
104
|
+
return config
|
105
|
+
|
106
|
+
def _format_start_execution_output_config(self):
|
107
|
+
output_formats = self.output_config.get("output_formats")
|
108
|
+
config = {
|
109
|
+
"notebook_config": {
|
110
|
+
"output_formats": output_formats,
|
111
|
+
}
|
112
|
+
}
|
113
|
+
return config
|
114
|
+
|
115
|
+
def start_notebook_execution(self):
|
116
|
+
start_execution_params = {
|
117
|
+
"execution_name": self.execution_name,
|
118
|
+
"execution_type": "NOTEBOOK",
|
119
|
+
"input_config": self._format_start_execution_input_config(),
|
120
|
+
"output_config": self._format_start_execution_output_config(),
|
121
|
+
"termination_condition": self.termination_condition,
|
122
|
+
"tags": self.tags,
|
123
|
+
}
|
124
|
+
if self.compute:
|
125
|
+
start_execution_params["compute"] = self.compute
|
126
|
+
else:
|
127
|
+
start_execution_params["compute"] = {"instance_type": "ml.m4.xlarge"}
|
128
|
+
|
129
|
+
print(start_execution_params)
|
130
|
+
return self._sagemaker_studio.execution_client.start_execution(**start_execution_params)
|
131
|
+
|
132
|
+
def wait_for_execution_completion(self, execution_id, context):
|
133
|
+
wait_attempts = 0
|
134
|
+
while wait_attempts < self.waiter_max_attempts:
|
135
|
+
wait_attempts += 1
|
136
|
+
time.sleep(self.waiter_delay)
|
137
|
+
response = self._sagemaker_studio.execution_client.get_execution(execution_id=execution_id)
|
138
|
+
error_message = response.get("error_details", {}).get("error_message")
|
139
|
+
status = response["status"]
|
140
|
+
if "files" in response:
|
141
|
+
self._set_xcom_files(response["files"], context)
|
142
|
+
if "s3_path" in response:
|
143
|
+
self._set_xcom_s3_path(response["s3_path"], context)
|
144
|
+
|
145
|
+
ret = self._handle_state(execution_id, status, error_message)
|
146
|
+
if ret:
|
147
|
+
return ret
|
148
|
+
|
149
|
+
# If timeout, handle state FAILED with timeout message
|
150
|
+
return self._handle_state(execution_id, "FAILED", "Execution timed out")
|
151
|
+
|
152
|
+
def _set_xcom_files(self, files, context):
|
153
|
+
if not context:
|
154
|
+
error_message = "context is required"
|
155
|
+
raise AirflowException(error_message)
|
156
|
+
for file in files:
|
157
|
+
context["ti"].xcom_push(
|
158
|
+
key=f"{file['display_name']}.{file['file_format']}",
|
159
|
+
value=file["file_path"],
|
160
|
+
)
|
161
|
+
|
162
|
+
def _set_xcom_s3_path(self, s3_path, context):
|
163
|
+
if not context:
|
164
|
+
error_message = "context is required"
|
165
|
+
raise AirflowException(error_message)
|
166
|
+
context["ti"].xcom_push(
|
167
|
+
key="s3_path",
|
168
|
+
value=s3_path,
|
169
|
+
)
|
170
|
+
|
171
|
+
def _handle_state(self, execution_id, status, error_message):
|
172
|
+
finished_states = ["COMPLETED"]
|
173
|
+
in_progress_states = ["IN_PROGRESS", "STOPPING"]
|
174
|
+
|
175
|
+
if status in in_progress_states:
|
176
|
+
info_message = f"Execution {execution_id} is still in progress with state:{status}, will check for a terminal status again in {self.waiter_delay}"
|
177
|
+
self.log.info(info_message)
|
178
|
+
return None
|
179
|
+
execution_message = f"Exiting Execution {execution_id} State: {status}"
|
180
|
+
if status in finished_states:
|
181
|
+
self.log.info(execution_message)
|
182
|
+
return {"Status": status, "ExecutionId": execution_id}
|
183
|
+
else:
|
184
|
+
log_error_message = f"Execution {execution_id} failed with error: {error_message}"
|
185
|
+
self.log.error(log_error_message)
|
186
|
+
if error_message == "":
|
187
|
+
error_message = execution_message
|
188
|
+
raise AirflowException(error_message)
|
@@ -25,6 +25,5 @@ class AthenaQueryResultsLink(BaseAwsLink):
|
|
25
25
|
name = "Query Results"
|
26
26
|
key = "_athena_query_results"
|
27
27
|
format_str = (
|
28
|
-
BASE_AWS_CONSOLE_LINK + "/athena/home?region={region_name}
|
29
|
-
"/query-editor/history/{query_execution_id}"
|
28
|
+
BASE_AWS_CONSOLE_LINK + "/athena/home?region={region_name}#/query-editor/history/{query_execution_id}"
|
30
29
|
)
|
@@ -19,14 +19,21 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
from typing import TYPE_CHECKING, ClassVar
|
21
21
|
|
22
|
-
from airflow.models import BaseOperatorLink, XCom
|
23
22
|
from airflow.providers.amazon.aws.utils.suppress import return_on_error
|
23
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
24
24
|
|
25
25
|
if TYPE_CHECKING:
|
26
26
|
from airflow.models import BaseOperator
|
27
27
|
from airflow.models.taskinstancekey import TaskInstanceKey
|
28
28
|
from airflow.utils.context import Context
|
29
29
|
|
30
|
+
if AIRFLOW_V_3_0_PLUS:
|
31
|
+
from airflow.sdk import BaseOperatorLink
|
32
|
+
from airflow.sdk.execution_time.xcom import XCom
|
33
|
+
else:
|
34
|
+
from airflow.models import XCom # type: ignore[no-redef]
|
35
|
+
from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
|
36
|
+
|
30
37
|
|
31
38
|
BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}"
|
32
39
|
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
|
20
|
+
|
21
|
+
|
22
|
+
class SageMakerUnifiedStudioLink(BaseAwsLink):
|
23
|
+
"""Helper class for constructing Amazon SageMaker Unified Studio Links."""
|
24
|
+
|
25
|
+
name = "Amazon SageMaker Unified Studio"
|
26
|
+
key = "sagemaker_unified_studio"
|
27
|
+
format_str = BASE_AWS_CONSOLE_LINK + "/datazone/home?region={region_name}"
|