apache-airflow-providers-amazon 9.4.0rc1__py3-none-any.whl → 9.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +80 -110
  5. airflow/providers/amazon/aws/auth_manager/router/login.py +11 -4
  6. airflow/providers/amazon/aws/auth_manager/user.py +7 -4
  7. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
  8. airflow/providers/amazon/aws/hooks/appflow.py +5 -15
  9. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
  10. airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
  11. airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
  12. airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
  13. airflow/providers/amazon/aws/hooks/dms.py +3 -1
  14. airflow/providers/amazon/aws/hooks/ec2.py +1 -1
  15. airflow/providers/amazon/aws/hooks/eks.py +3 -6
  16. airflow/providers/amazon/aws/hooks/glue.py +6 -2
  17. airflow/providers/amazon/aws/hooks/logs.py +2 -2
  18. airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
  19. airflow/providers/amazon/aws/hooks/redshift_cluster.py +10 -10
  20. airflow/providers/amazon/aws/hooks/redshift_data.py +3 -4
  21. airflow/providers/amazon/aws/hooks/s3.py +3 -1
  22. airflow/providers/amazon/aws/hooks/sagemaker.py +2 -2
  23. airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
  24. airflow/providers/amazon/aws/links/athena.py +1 -2
  25. airflow/providers/amazon/aws/links/base_aws.py +8 -1
  26. airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
  27. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
  28. airflow/providers/amazon/aws/log/s3_task_handler.py +136 -84
  29. airflow/providers/amazon/aws/notifications/chime.py +1 -2
  30. airflow/providers/amazon/aws/notifications/sns.py +1 -1
  31. airflow/providers/amazon/aws/notifications/sqs.py +1 -1
  32. airflow/providers/amazon/aws/operators/ec2.py +91 -83
  33. airflow/providers/amazon/aws/operators/eks.py +3 -3
  34. airflow/providers/amazon/aws/operators/mwaa.py +73 -2
  35. airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
  36. airflow/providers/amazon/aws/operators/s3.py +147 -157
  37. airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
  38. airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
  39. airflow/providers/amazon/aws/sensors/ec2.py +5 -12
  40. airflow/providers/amazon/aws/sensors/emr.py +1 -1
  41. airflow/providers/amazon/aws/sensors/glacier.py +1 -1
  42. airflow/providers/amazon/aws/sensors/mwaa.py +161 -0
  43. airflow/providers/amazon/aws/sensors/rds.py +10 -5
  44. airflow/providers/amazon/aws/sensors/s3.py +32 -43
  45. airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
  46. airflow/providers/amazon/aws/sensors/step_function.py +2 -1
  47. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
  48. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
  49. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
  50. airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
  51. airflow/providers/amazon/aws/triggers/README.md +4 -4
  52. airflow/providers/amazon/aws/triggers/base.py +11 -2
  53. airflow/providers/amazon/aws/triggers/ecs.py +6 -2
  54. airflow/providers/amazon/aws/triggers/eks.py +2 -2
  55. airflow/providers/amazon/aws/triggers/glue.py +1 -1
  56. airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
  57. airflow/providers/amazon/aws/triggers/s3.py +31 -6
  58. airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
  59. airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
  60. airflow/providers/amazon/aws/triggers/sqs.py +11 -3
  61. airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
  62. airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
  63. airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
  64. airflow/providers/amazon/get_provider_info.py +46 -5
  65. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/METADATA +40 -33
  66. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/RECORD +68 -61
  67. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/WHEEL +1 -1
  68. airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
  69. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/entry_points.txt +0 -0
@@ -211,7 +211,7 @@ class GlueJobHook(AwsBaseHook):
211
211
 
212
212
  The async version of get_job_state.
213
213
  """
214
- async with self.async_conn as client:
214
+ async with await self.get_async_conn() as client:
215
215
  job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
216
216
  return job_run["JobRun"]["JobRunState"]
217
217
 
@@ -236,6 +236,9 @@ class GlueJobHook(AwsBaseHook):
236
236
  """
237
237
  log_client = self.logs_hook.get_conn()
238
238
  paginator = log_client.get_paginator("filter_log_events")
239
+ job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]
240
+ # StartTime needs to be an int and is Epoch time in milliseconds
241
+ start_time = int(job_run["StartedOn"].timestamp() * 1000)
239
242
 
240
243
  def display_logs_from(log_group: str, continuation_token: str | None) -> str | None:
241
244
  """Mutualize iteration over the 2 different log streams glue jobs write to."""
@@ -245,6 +248,7 @@ class GlueJobHook(AwsBaseHook):
245
248
  for response in paginator.paginate(
246
249
  logGroupName=log_group,
247
250
  logStreamNames=[run_id],
251
+ startTime=start_time,
248
252
  PaginationConfig={"StartingToken": continuation_token},
249
253
  ):
250
254
  fetched_logs.extend([event["message"] for event in response["events"]])
@@ -270,7 +274,7 @@ class GlueJobHook(AwsBaseHook):
270
274
  self.log.info("No new log from the Glue Job in %s", log_group)
271
275
  return next_token
272
276
 
273
- log_group_prefix = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]["LogGroupName"]
277
+ log_group_prefix = job_run["LogGroupName"]
274
278
  log_group_default = f"{log_group_prefix}/{DEFAULT_LOG_SUFFIX}"
275
279
  log_group_error = f"{log_group_prefix}/{ERROR_LOG_SUFFIX}"
276
280
  # one would think that the error log group would contain only errors, but it actually contains
@@ -152,7 +152,7 @@ class AwsLogsHook(AwsBaseHook):
152
152
  If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName.
153
153
  :param count: The maximum number of items returned
154
154
  """
155
- async with self.async_conn as client:
155
+ async with await self.get_async_conn() as client:
156
156
  try:
157
157
  response: dict[str, Any] = await client.describe_log_streams(
158
158
  logGroupName=log_group,
@@ -194,7 +194,7 @@ class AwsLogsHook(AwsBaseHook):
194
194
  else:
195
195
  token_arg = {}
196
196
 
197
- async with self.async_conn as client:
197
+ async with await self.get_async_conn() as client:
198
198
  response = await client.get_log_events(
199
199
  logGroupName=log_group,
200
200
  logStreamName=log_stream_name,
@@ -18,6 +18,7 @@
18
18
 
19
19
  from __future__ import annotations
20
20
 
21
+ import requests
21
22
  from botocore.exceptions import ClientError
22
23
 
23
24
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -29,6 +30,12 @@ class MwaaHook(AwsBaseHook):
29
30
 
30
31
  Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`
31
32
 
33
+ If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the hook will use a fallback method
34
+ that uses the AWS credential to generate a local web login token for the Airflow Web UI and then directly
35
+ make requests to the Airflow API. This fallback method can be set as the default (and only) method used by
36
+ setting `generate_local_token` to True. Learn more here:
37
+ https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API
38
+
32
39
  Additional arguments (such as ``aws_conn_id``) may be specified and
33
40
  are passed down to the underlying AwsBaseHook.
34
41
 
@@ -47,6 +54,7 @@ class MwaaHook(AwsBaseHook):
47
54
  method: str,
48
55
  body: dict | None = None,
49
56
  query_params: dict | None = None,
57
+ generate_local_token: bool = False,
50
58
  ) -> dict:
51
59
  """
52
60
  Invoke the REST API on the Airflow webserver with the specified inputs.
@@ -56,30 +64,86 @@ class MwaaHook(AwsBaseHook):
56
64
 
57
65
  :param env_name: name of the MWAA environment
58
66
  :param path: Apache Airflow REST API endpoint path to be called
59
- :param method: HTTP method used for making Airflow REST API calls
67
+ :param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE'
60
68
  :param body: Request body for the Apache Airflow REST API call
61
69
  :param query_params: Query parameters to be included in the Apache Airflow REST API call
70
+ :param generate_local_token: If True, only the local web token method is used without trying boto's
71
+ `invoke_rest_api` first. If False, the local web token method is used as a fallback after trying
72
+ boto's `invoke_rest_api`
62
73
  """
63
- body = body or {}
74
+ # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
75
+ body = {k: v for k, v in body.items() if v is not None} if body else {}
76
+ query_params = query_params or {}
64
77
  api_kwargs = {
65
78
  "Name": env_name,
66
79
  "Path": path,
67
80
  "Method": method,
68
- # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
69
- "Body": {k: v for k, v in body.items() if v is not None},
70
- "QueryParameters": query_params if query_params else {},
81
+ "Body": body,
82
+ "QueryParameters": query_params,
71
83
  }
84
+
85
+ if generate_local_token:
86
+ return self._invoke_rest_api_using_local_session_token(**api_kwargs)
87
+
72
88
  try:
73
- result = self.conn.invoke_rest_api(**api_kwargs)
89
+ response = self.conn.invoke_rest_api(**api_kwargs)
74
90
  # ResponseMetadata is removed because it contains data that is either very unlikely to be useful
75
91
  # in XComs and logs, or redundant given the data already included in the response
76
- result.pop("ResponseMetadata", None)
77
- return result
92
+ response.pop("ResponseMetadata", None)
93
+ return response
94
+
78
95
  except ClientError as e:
79
- to_log = e.response
80
- # ResponseMetadata and Error are removed because they contain data that is either very unlikely to
81
- # be useful in XComs and logs, or redundant given the data already included in the response
82
- to_log.pop("ResponseMetadata", None)
83
- to_log.pop("Error", None)
84
- self.log.error(to_log)
85
- raise e
96
+ if (
97
+ e.response["Error"]["Code"] == "AccessDeniedException"
98
+ and "Airflow role" in e.response["Error"]["Message"]
99
+ ):
100
+ self.log.info(
101
+ "Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
102
+ )
103
+ return self._invoke_rest_api_using_local_session_token(**api_kwargs)
104
+ else:
105
+ to_log = e.response
106
+ # ResponseMetadata is removed because it contains data that is either very unlikely to be
107
+ # useful in XComs and logs, or redundant given the data already included in the response
108
+ to_log.pop("ResponseMetadata", None)
109
+ self.log.error(to_log)
110
+ raise
111
+
112
+ def _invoke_rest_api_using_local_session_token(
113
+ self,
114
+ **api_kwargs,
115
+ ) -> dict:
116
+ try:
117
+ session, hostname = self._get_session_conn(api_kwargs["Name"])
118
+
119
+ response = session.request(
120
+ method=api_kwargs["Method"],
121
+ url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
122
+ params=api_kwargs["QueryParameters"],
123
+ json=api_kwargs["Body"],
124
+ timeout=10,
125
+ )
126
+ response.raise_for_status()
127
+
128
+ except requests.HTTPError as e:
129
+ self.log.error(e.response.json())
130
+ raise
131
+
132
+ return {
133
+ "RestApiStatusCode": response.status_code,
134
+ "RestApiResponse": response.json(),
135
+ }
136
+
137
+ # Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
138
+ def _get_session_conn(self, env_name: str) -> tuple:
139
+ create_token_response = self.conn.create_web_login_token(Name=env_name)
140
+ web_server_hostname = create_token_response["WebServerHostname"]
141
+ web_token = create_token_response["WebToken"]
142
+
143
+ login_url = f"https://{web_server_hostname}/aws_mwaa/login"
144
+ login_payload = {"token": web_token}
145
+ session = requests.Session()
146
+ login_response = session.post(login_url, data=login_payload, timeout=10)
147
+ login_response.raise_for_status()
148
+
149
+ return session, web_server_hostname
@@ -67,7 +67,7 @@ class RedshiftHook(AwsBaseHook):
67
67
  for the cluster that is being created.
68
68
  :param params: Remaining AWS Create cluster API params.
69
69
  """
70
- response = self.get_conn().create_cluster(
70
+ response = self.conn.create_cluster(
71
71
  ClusterIdentifier=cluster_identifier,
72
72
  NodeType=node_type,
73
73
  MasterUsername=master_username,
@@ -87,13 +87,13 @@ class RedshiftHook(AwsBaseHook):
87
87
  :param cluster_identifier: unique identifier of a cluster
88
88
  """
89
89
  try:
90
- response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"]
90
+ response = self.conn.describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"]
91
91
  return response[0]["ClusterStatus"] if response else None
92
- except self.get_conn().exceptions.ClusterNotFoundFault:
92
+ except self.conn.exceptions.ClusterNotFoundFault:
93
93
  return "cluster_not_found"
94
94
 
95
95
  async def cluster_status_async(self, cluster_identifier: str) -> str:
96
- async with self.async_conn as client:
96
+ async with await self.get_async_conn() as client:
97
97
  response = await client.describe_clusters(ClusterIdentifier=cluster_identifier)
98
98
  return response["Clusters"][0]["ClusterStatus"] if response else None
99
99
 
@@ -115,7 +115,7 @@ class RedshiftHook(AwsBaseHook):
115
115
  """
116
116
  final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or ""
117
117
 
118
- response = self.get_conn().delete_cluster(
118
+ response = self.conn.delete_cluster(
119
119
  ClusterIdentifier=cluster_identifier,
120
120
  SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
121
121
  FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier,
@@ -131,7 +131,7 @@ class RedshiftHook(AwsBaseHook):
131
131
 
132
132
  :param cluster_identifier: unique identifier of a cluster
133
133
  """
134
- response = self.get_conn().describe_cluster_snapshots(ClusterIdentifier=cluster_identifier)
134
+ response = self.conn.describe_cluster_snapshots(ClusterIdentifier=cluster_identifier)
135
135
  if "Snapshots" not in response:
136
136
  return None
137
137
  snapshots = response["Snapshots"]
@@ -149,7 +149,7 @@ class RedshiftHook(AwsBaseHook):
149
149
  :param cluster_identifier: unique identifier of a cluster
150
150
  :param snapshot_identifier: unique identifier for a snapshot of a cluster
151
151
  """
152
- response = self.get_conn().restore_from_cluster_snapshot(
152
+ response = self.conn.restore_from_cluster_snapshot(
153
153
  ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier
154
154
  )
155
155
  return response["Cluster"] if response["Cluster"] else None
@@ -175,7 +175,7 @@ class RedshiftHook(AwsBaseHook):
175
175
  """
176
176
  if tags is None:
177
177
  tags = []
178
- response = self.get_conn().create_cluster_snapshot(
178
+ response = self.conn.create_cluster_snapshot(
179
179
  SnapshotIdentifier=snapshot_identifier,
180
180
  ClusterIdentifier=cluster_identifier,
181
181
  ManualSnapshotRetentionPeriod=retention_period,
@@ -192,11 +192,11 @@ class RedshiftHook(AwsBaseHook):
192
192
  :param snapshot_identifier: A unique identifier for the snapshot that you are requesting
193
193
  """
194
194
  try:
195
- response = self.get_conn().describe_cluster_snapshots(
195
+ response = self.conn.describe_cluster_snapshots(
196
196
  SnapshotIdentifier=snapshot_identifier,
197
197
  )
198
198
  snapshot = response.get("Snapshots")[0]
199
199
  snapshot_status: str = snapshot.get("Status")
200
200
  return snapshot_status
201
- except self.get_conn().exceptions.ClusterSnapshotNotFoundFault:
201
+ except self.conn.exceptions.ClusterSnapshotNotFoundFault:
202
202
  return None
@@ -186,8 +186,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
186
186
  RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError
187
187
  )
188
188
  raise exception_cls(
189
- f"Statement {resp['Id']} terminated with status {status}. "
190
- f"Response details: {pformat(resp)}"
189
+ f"Statement {resp['Id']} terminated with status {status}. Response details: {pformat(resp)}"
191
190
  )
192
191
 
193
192
  self.log.info("Query status: %s", status)
@@ -275,7 +274,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
275
274
 
276
275
  :param statement_id: the UUID of the statement
277
276
  """
278
- async with self.async_conn as client:
277
+ async with await self.get_async_conn() as client:
279
278
  desc = await client.describe_statement(Id=statement_id)
280
279
  return desc["Status"] in RUNNING_STATES
281
280
 
@@ -288,6 +287,6 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
288
287
 
289
288
  :param statement_id: the UUID of the statement
290
289
  """
291
- async with self.async_conn as client:
290
+ async with await self.get_async_conn() as client:
292
291
  resp = await client.describe_statement(Id=statement_id)
293
292
  return self.parse_statement_response(resp)
@@ -1494,7 +1494,9 @@ class S3Hook(AwsBaseHook):
1494
1494
  get_hook_lineage_collector().add_output_asset(
1495
1495
  context=self,
1496
1496
  scheme="file",
1497
- asset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()},
1497
+ asset_kwargs={
1498
+ "path": str(file_path) if file_path.is_absolute() else str(file_path.absolute())
1499
+ },
1498
1500
  )
1499
1501
  file = open(file_path, "wb")
1500
1502
  else:
@@ -131,7 +131,7 @@ def secondary_training_status_message(
131
131
  status_strs = []
132
132
  for transition in transitions_to_print:
133
133
  message = transition["StatusMessage"]
134
- time_utc = timezone.convert_to_utc(cast(datetime, job_description["LastModifiedTime"]))
134
+ time_utc = timezone.convert_to_utc(cast("datetime", job_description["LastModifiedTime"]))
135
135
  status_strs.append(f"{time_utc:%Y-%m-%d %H:%M:%S} {transition['Status']} - {message}")
136
136
 
137
137
  return "\n".join(status_strs)
@@ -1318,7 +1318,7 @@ class SageMakerHook(AwsBaseHook):
1318
1318
 
1319
1319
  :param job_name: the name of the training job
1320
1320
  """
1321
- async with self.async_conn as client:
1321
+ async with await self.get_async_conn() as client:
1322
1322
  response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
1323
1323
  return response
1324
1324
 
@@ -0,0 +1,188 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ """This module contains the Amazon SageMaker Unified Studio Notebook hook."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import time
23
+
24
+ from sagemaker_studio import ClientConfig
25
+ from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI
26
+
27
+ from airflow.exceptions import AirflowException
28
+ from airflow.hooks.base import BaseHook
29
+ from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner
30
+
31
+
32
+ class SageMakerNotebookHook(BaseHook):
33
+ """
34
+ Interact with Sagemaker Unified Studio Workflows.
35
+
36
+ This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API.
37
+
38
+ Examples:
39
+ .. code-block:: python
40
+
41
+ from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import SageMakerNotebookHook
42
+
43
+ notebook_hook = SageMakerNotebookHook(
44
+ input_config={"input_path": "path/to/notebook.ipynb", "input_params": {"param1": "value1"}},
45
+ output_config={"output_uri": "folder/output/location/prefix", "output_formats": "NOTEBOOK"},
46
+ execution_name="notebook_execution",
47
+ waiter_delay=10,
48
+ waiter_max_attempts=1440,
49
+ )
50
+
51
+ :param execution_name: The name of the notebook job to be executed, this is same as task_id.
52
+ :param input_config: Configuration for the input file.
53
+ Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}}
54
+ :param output_config: Configuration for the output format. It should include an output_formats parameter to specify the output format.
55
+ Example: {'output_formats': ['NOTEBOOK']}
56
+ :param compute: compute configuration to use for the notebook execution. This is a required attribute
57
+ if the execution is on a remote compute.
58
+ Example: { "instance_type": "ml.m5.large", "volume_size_in_gb": 30, "volume_kms_key_id": "", "image_uri": "string", "container_entrypoint": [ "string" ]}
59
+ :param termination_condition: conditions to match to terminate the remote execution.
60
+ Example: { "MaxRuntimeInSeconds": 3600 }
61
+ :param tags: tags to be associated with the remote execution runs.
62
+ Example: { "md_analytics": "logs" }
63
+ :param waiter_delay: Interval in seconds to check the task execution status.
64
+ :param waiter_max_attempts: Number of attempts to wait before returning FAILED.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ execution_name: str,
70
+ input_config: dict | None = None,
71
+ output_config: dict | None = None,
72
+ compute: dict | None = None,
73
+ termination_condition: dict | None = None,
74
+ tags: dict | None = None,
75
+ waiter_delay: int = 10,
76
+ waiter_max_attempts: int = 1440,
77
+ *args,
78
+ **kwargs,
79
+ ):
80
+ super().__init__(*args, **kwargs)
81
+ self._sagemaker_studio = SageMakerStudioAPI(self._get_sagemaker_studio_config())
82
+ self.execution_name = execution_name
83
+ self.input_config = input_config or {}
84
+ self.output_config = output_config or {"output_formats": ["NOTEBOOK"]}
85
+ self.compute = compute
86
+ self.termination_condition = termination_condition or {}
87
+ self.tags = tags or {}
88
+ self.waiter_delay = waiter_delay
89
+ self.waiter_max_attempts = waiter_max_attempts
90
+
91
+ def _get_sagemaker_studio_config(self):
92
+ config = ClientConfig()
93
+ config.overrides["execution"] = {"local": is_local_runner()}
94
+ return config
95
+
96
+ def _format_start_execution_input_config(self):
97
+ config = {
98
+ "notebook_config": {
99
+ "input_path": self.input_config.get("input_path"),
100
+ "input_parameters": self.input_config.get("input_params"),
101
+ },
102
+ }
103
+
104
+ return config
105
+
106
+ def _format_start_execution_output_config(self):
107
+ output_formats = self.output_config.get("output_formats")
108
+ config = {
109
+ "notebook_config": {
110
+ "output_formats": output_formats,
111
+ }
112
+ }
113
+ return config
114
+
115
+ def start_notebook_execution(self):
116
+ start_execution_params = {
117
+ "execution_name": self.execution_name,
118
+ "execution_type": "NOTEBOOK",
119
+ "input_config": self._format_start_execution_input_config(),
120
+ "output_config": self._format_start_execution_output_config(),
121
+ "termination_condition": self.termination_condition,
122
+ "tags": self.tags,
123
+ }
124
+ if self.compute:
125
+ start_execution_params["compute"] = self.compute
126
+ else:
127
+ start_execution_params["compute"] = {"instance_type": "ml.m4.xlarge"}
128
+
129
+ print(start_execution_params)
130
+ return self._sagemaker_studio.execution_client.start_execution(**start_execution_params)
131
+
132
+ def wait_for_execution_completion(self, execution_id, context):
133
+ wait_attempts = 0
134
+ while wait_attempts < self.waiter_max_attempts:
135
+ wait_attempts += 1
136
+ time.sleep(self.waiter_delay)
137
+ response = self._sagemaker_studio.execution_client.get_execution(execution_id=execution_id)
138
+ error_message = response.get("error_details", {}).get("error_message")
139
+ status = response["status"]
140
+ if "files" in response:
141
+ self._set_xcom_files(response["files"], context)
142
+ if "s3_path" in response:
143
+ self._set_xcom_s3_path(response["s3_path"], context)
144
+
145
+ ret = self._handle_state(execution_id, status, error_message)
146
+ if ret:
147
+ return ret
148
+
149
+ # If timeout, handle state FAILED with timeout message
150
+ return self._handle_state(execution_id, "FAILED", "Execution timed out")
151
+
152
+ def _set_xcom_files(self, files, context):
153
+ if not context:
154
+ error_message = "context is required"
155
+ raise AirflowException(error_message)
156
+ for file in files:
157
+ context["ti"].xcom_push(
158
+ key=f"{file['display_name']}.{file['file_format']}",
159
+ value=file["file_path"],
160
+ )
161
+
162
+ def _set_xcom_s3_path(self, s3_path, context):
163
+ if not context:
164
+ error_message = "context is required"
165
+ raise AirflowException(error_message)
166
+ context["ti"].xcom_push(
167
+ key="s3_path",
168
+ value=s3_path,
169
+ )
170
+
171
+ def _handle_state(self, execution_id, status, error_message):
172
+ finished_states = ["COMPLETED"]
173
+ in_progress_states = ["IN_PROGRESS", "STOPPING"]
174
+
175
+ if status in in_progress_states:
176
+ info_message = f"Execution {execution_id} is still in progress with state:{status}, will check for a terminal status again in {self.waiter_delay}"
177
+ self.log.info(info_message)
178
+ return None
179
+ execution_message = f"Exiting Execution {execution_id} State: {status}"
180
+ if status in finished_states:
181
+ self.log.info(execution_message)
182
+ return {"Status": status, "ExecutionId": execution_id}
183
+ else:
184
+ log_error_message = f"Execution {execution_id} failed with error: {error_message}"
185
+ self.log.error(log_error_message)
186
+ if error_message == "":
187
+ error_message = execution_message
188
+ raise AirflowException(error_message)
@@ -25,6 +25,5 @@ class AthenaQueryResultsLink(BaseAwsLink):
25
25
  name = "Query Results"
26
26
  key = "_athena_query_results"
27
27
  format_str = (
28
- BASE_AWS_CONSOLE_LINK + "/athena/home?region={region_name}#"
29
- "/query-editor/history/{query_execution_id}"
28
+ BASE_AWS_CONSOLE_LINK + "/athena/home?region={region_name}#/query-editor/history/{query_execution_id}"
30
29
  )
@@ -19,14 +19,21 @@ from __future__ import annotations
19
19
 
20
20
  from typing import TYPE_CHECKING, ClassVar
21
21
 
22
- from airflow.models import BaseOperatorLink, XCom
23
22
  from airflow.providers.amazon.aws.utils.suppress import return_on_error
23
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from airflow.models import BaseOperator
27
27
  from airflow.models.taskinstancekey import TaskInstanceKey
28
28
  from airflow.utils.context import Context
29
29
 
30
+ if AIRFLOW_V_3_0_PLUS:
31
+ from airflow.sdk import BaseOperatorLink
32
+ from airflow.sdk.execution_time.xcom import XCom
33
+ else:
34
+ from airflow.models import XCom # type: ignore[no-redef]
35
+ from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
36
+
30
37
 
31
38
  BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}"
32
39
 
@@ -0,0 +1,27 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
20
+
21
+
22
+ class SageMakerUnifiedStudioLink(BaseAwsLink):
23
+ """Helper class for constructing Amazon SageMaker Unified Studio Links."""
24
+
25
+ name = "Amazon SageMaker Unified Studio"
26
+ key = "sagemaker_unified_studio"
27
+ format_str = BASE_AWS_CONSOLE_LINK + "/datazone/home?region={region_name}"