apache-airflow-providers-amazon 9.4.0rc1__py3-none-any.whl → 9.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -1
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +21 -100
  5. airflow/providers/amazon/aws/auth_manager/router/login.py +3 -2
  6. airflow/providers/amazon/aws/auth_manager/user.py +7 -4
  7. airflow/providers/amazon/aws/hooks/base_aws.py +25 -0
  8. airflow/providers/amazon/aws/hooks/ec2.py +1 -1
  9. airflow/providers/amazon/aws/hooks/glue.py +6 -2
  10. airflow/providers/amazon/aws/hooks/logs.py +2 -2
  11. airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
  12. airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -1
  13. airflow/providers/amazon/aws/hooks/redshift_data.py +2 -2
  14. airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
  15. airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
  16. airflow/providers/amazon/aws/links/base_aws.py +7 -1
  17. airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
  18. airflow/providers/amazon/aws/log/s3_task_handler.py +22 -7
  19. airflow/providers/amazon/aws/operators/s3.py +147 -157
  20. airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
  21. airflow/providers/amazon/aws/sensors/emr.py +1 -1
  22. airflow/providers/amazon/aws/sensors/mwaa.py +113 -0
  23. airflow/providers/amazon/aws/sensors/rds.py +10 -5
  24. airflow/providers/amazon/aws/sensors/s3.py +31 -42
  25. airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
  26. airflow/providers/amazon/aws/triggers/README.md +4 -4
  27. airflow/providers/amazon/aws/triggers/base.py +1 -1
  28. airflow/providers/amazon/aws/triggers/ecs.py +6 -2
  29. airflow/providers/amazon/aws/triggers/eks.py +2 -2
  30. airflow/providers/amazon/aws/triggers/glue.py +1 -1
  31. airflow/providers/amazon/aws/triggers/s3.py +31 -6
  32. airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
  33. airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
  34. airflow/providers/amazon/aws/triggers/sqs.py +11 -3
  35. airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
  36. airflow/providers/amazon/get_provider_info.py +36 -1
  37. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/METADATA +30 -25
  38. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/RECORD +40 -35
  39. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/WHEEL +1 -1
  40. airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
  41. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,188 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ """This module contains the Amazon SageMaker Unified Studio Notebook hook."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import time
23
+
24
+ from sagemaker_studio import ClientConfig
25
+ from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI
26
+
27
+ from airflow.exceptions import AirflowException
28
+ from airflow.hooks.base import BaseHook
29
+ from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner
30
+
31
+
32
+ class SageMakerNotebookHook(BaseHook):
33
+ """
34
+ Interact with Sagemaker Unified Studio Workflows.
35
+
36
+ This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API.
37
+
38
+ Examples:
39
+ .. code-block:: python
40
+
41
+ from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import SageMakerNotebookHook
42
+
43
+ notebook_hook = SageMakerNotebookHook(
44
+ input_config={"input_path": "path/to/notebook.ipynb", "input_params": {"param1": "value1"}},
45
+ output_config={"output_uri": "folder/output/location/prefix", "output_formats": "NOTEBOOK"},
46
+ execution_name="notebook_execution",
47
+ waiter_delay=10,
48
+ waiter_max_attempts=1440,
49
+ )
50
+
51
+ :param execution_name: The name of the notebook job to be executed, this is same as task_id.
52
+ :param input_config: Configuration for the input file.
53
+ Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}}
54
+ :param output_config: Configuration for the output format. It should include an output_formats parameter to specify the output format.
55
+ Example: {'output_formats': ['NOTEBOOK']}
56
+ :param compute: compute configuration to use for the notebook execution. This is a required attribute
57
+ if the execution is on a remote compute.
58
+ Example: { "instance_type": "ml.m5.large", "volume_size_in_gb": 30, "volume_kms_key_id": "", "image_uri": "string", "container_entrypoint": [ "string" ]}
59
+ :param termination_condition: conditions to match to terminate the remote execution.
60
+ Example: { "MaxRuntimeInSeconds": 3600 }
61
+ :param tags: tags to be associated with the remote execution runs.
62
+ Example: { "md_analytics": "logs" }
63
+ :param waiter_delay: Interval in seconds to check the task execution status.
64
+ :param waiter_max_attempts: Number of attempts to wait before returning FAILED.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ execution_name: str,
70
+ input_config: dict | None = None,
71
+ output_config: dict | None = None,
72
+ compute: dict | None = None,
73
+ termination_condition: dict | None = None,
74
+ tags: dict | None = None,
75
+ waiter_delay: int = 10,
76
+ waiter_max_attempts: int = 1440,
77
+ *args,
78
+ **kwargs,
79
+ ):
80
+ super().__init__(*args, **kwargs)
81
+ self._sagemaker_studio = SageMakerStudioAPI(self._get_sagemaker_studio_config())
82
+ self.execution_name = execution_name
83
+ self.input_config = input_config or {}
84
+ self.output_config = output_config or {"output_formats": ["NOTEBOOK"]}
85
+ self.compute = compute
86
+ self.termination_condition = termination_condition or {}
87
+ self.tags = tags or {}
88
+ self.waiter_delay = waiter_delay
89
+ self.waiter_max_attempts = waiter_max_attempts
90
+
91
+ def _get_sagemaker_studio_config(self):
92
+ config = ClientConfig()
93
+ config.overrides["execution"] = {"local": is_local_runner()}
94
+ return config
95
+
96
+ def _format_start_execution_input_config(self):
97
+ config = {
98
+ "notebook_config": {
99
+ "input_path": self.input_config.get("input_path"),
100
+ "input_parameters": self.input_config.get("input_params"),
101
+ },
102
+ }
103
+
104
+ return config
105
+
106
+ def _format_start_execution_output_config(self):
107
+ output_formats = self.output_config.get("output_formats")
108
+ config = {
109
+ "notebook_config": {
110
+ "output_formats": output_formats,
111
+ }
112
+ }
113
+ return config
114
+
115
+ def start_notebook_execution(self):
116
+ start_execution_params = {
117
+ "execution_name": self.execution_name,
118
+ "execution_type": "NOTEBOOK",
119
+ "input_config": self._format_start_execution_input_config(),
120
+ "output_config": self._format_start_execution_output_config(),
121
+ "termination_condition": self.termination_condition,
122
+ "tags": self.tags,
123
+ }
124
+ if self.compute:
125
+ start_execution_params["compute"] = self.compute
126
+ else:
127
+ start_execution_params["compute"] = {"instance_type": "ml.m4.xlarge"}
128
+
129
+ print(start_execution_params)
130
+ return self._sagemaker_studio.execution_client.start_execution(**start_execution_params)
131
+
132
+ def wait_for_execution_completion(self, execution_id, context):
133
+ wait_attempts = 0
134
+ while wait_attempts < self.waiter_max_attempts:
135
+ wait_attempts += 1
136
+ time.sleep(self.waiter_delay)
137
+ response = self._sagemaker_studio.execution_client.get_execution(execution_id=execution_id)
138
+ error_message = response.get("error_details", {}).get("error_message")
139
+ status = response["status"]
140
+ if "files" in response:
141
+ self._set_xcom_files(response["files"], context)
142
+ if "s3_path" in response:
143
+ self._set_xcom_s3_path(response["s3_path"], context)
144
+
145
+ ret = self._handle_state(execution_id, status, error_message)
146
+ if ret:
147
+ return ret
148
+
149
+ # If timeout, handle state FAILED with timeout message
150
+ return self._handle_state(execution_id, "FAILED", "Execution timed out")
151
+
152
+ def _set_xcom_files(self, files, context):
153
+ if not context:
154
+ error_message = "context is required"
155
+ raise AirflowException(error_message)
156
+ for file in files:
157
+ context["ti"].xcom_push(
158
+ key=f"{file['display_name']}.{file['file_format']}",
159
+ value=file["file_path"],
160
+ )
161
+
162
+ def _set_xcom_s3_path(self, s3_path, context):
163
+ if not context:
164
+ error_message = "context is required"
165
+ raise AirflowException(error_message)
166
+ context["ti"].xcom_push(
167
+ key="s3_path",
168
+ value=s3_path,
169
+ )
170
+
171
+ def _handle_state(self, execution_id, status, error_message):
172
+ finished_states = ["COMPLETED"]
173
+ in_progress_states = ["IN_PROGRESS", "STOPPING"]
174
+
175
+ if status in in_progress_states:
176
+ info_message = f"Execution {execution_id} is still in progress with state:{status}, will check for a terminal status again in {self.waiter_delay}"
177
+ self.log.info(info_message)
178
+ return None
179
+ execution_message = f"Exiting Execution {execution_id} State: {status}"
180
+ if status in finished_states:
181
+ self.log.info(execution_message)
182
+ return {"Status": status, "ExecutionId": execution_id}
183
+ else:
184
+ log_error_message = f"Execution {execution_id} failed with error: {error_message}"
185
+ self.log.error(log_error_message)
186
+ if error_message == "":
187
+ error_message = execution_message
188
+ raise AirflowException(error_message)
@@ -19,14 +19,20 @@ from __future__ import annotations
19
19
 
20
20
  from typing import TYPE_CHECKING, ClassVar
21
21
 
22
- from airflow.models import BaseOperatorLink, XCom
22
+ from airflow.models import XCom
23
23
  from airflow.providers.amazon.aws.utils.suppress import return_on_error
24
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
24
25
 
25
26
  if TYPE_CHECKING:
26
27
  from airflow.models import BaseOperator
27
28
  from airflow.models.taskinstancekey import TaskInstanceKey
28
29
  from airflow.utils.context import Context
29
30
 
31
+ if AIRFLOW_V_3_0_PLUS:
32
+ from airflow.sdk import BaseOperatorLink
33
+ else:
34
+ from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
35
+
30
36
 
31
37
  BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}"
32
38
 
@@ -0,0 +1,27 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
20
+
21
+
22
+ class SageMakerUnifiedStudioLink(BaseAwsLink):
23
+ """Helper class for constructing Amazon SageMaker Unified Studio Links."""
24
+
25
+ name = "Amazon SageMaker Unified Studio"
26
+ key = "sagemaker_unified_studio"
27
+ format_str = BASE_AWS_CONSOLE_LINK + "/datazone/home?region={region_name}"
@@ -26,6 +26,7 @@ from typing import TYPE_CHECKING
26
26
 
27
27
  from airflow.configuration import conf
28
28
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
29
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
29
30
  from airflow.utils.log.file_task_handler import FileTaskHandler
30
31
  from airflow.utils.log.logging_mixin import LoggingMixin
31
32
 
@@ -58,7 +59,8 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
58
59
  def hook(self):
59
60
  """Returns S3Hook."""
60
61
  return S3Hook(
61
- aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"), transfer_config_args={"use_threads": False}
62
+ aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"),
63
+ transfer_config_args={"use_threads": False},
62
64
  )
63
65
 
64
66
  def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None:
@@ -73,7 +75,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
73
75
  is_trigger_log_context = getattr(ti, "is_trigger_log_context", False)
74
76
  self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None)
75
77
  # Clear the file first so that duplicate data is not uploaded
76
- # when re-using the same path (e.g. with rescheduled sensors)
78
+ # when reusing the same path (e.g. with rescheduled sensors)
77
79
  if self.upload_on_close:
78
80
  with open(self.handler.baseFilename, "w"):
79
81
  pass
@@ -116,12 +118,16 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
116
118
  keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix)
117
119
  if keys:
118
120
  keys = sorted(f"s3://{bucket}/{key}" for key in keys)
119
- messages.append("Found logs in s3:")
120
- messages.extend(f" * {key}" for key in keys)
121
+ if AIRFLOW_V_3_0_PLUS:
122
+ messages = keys
123
+ else:
124
+ messages.append("Found logs in s3:")
125
+ messages.extend(f" * {key}" for key in keys)
121
126
  for key in keys:
122
127
  logs.append(self.s3_read(key, return_error=True))
123
128
  else:
124
- messages.append(f"No logs found on s3 for ti={ti}")
129
+ if not AIRFLOW_V_3_0_PLUS:
130
+ messages.append(f"No logs found on s3 for ti={ti}")
125
131
  return messages, logs
126
132
 
127
133
  def s3_log_exists(self, remote_log_location: str) -> bool:
@@ -152,7 +158,13 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
152
158
  return msg
153
159
  return ""
154
160
 
155
- def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_retry: int = 1) -> bool:
161
+ def s3_write(
162
+ self,
163
+ log: str,
164
+ remote_log_location: str,
165
+ append: bool = True,
166
+ max_retry: int = 1,
167
+ ) -> bool:
156
168
  """
157
169
  Write the log to the remote_log_location; return `True` or fails silently and return `False`.
158
170
 
@@ -185,7 +197,10 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
185
197
  break
186
198
  except Exception:
187
199
  if try_num < max_retry:
188
- self.log.warning("Failed attempt to write logs to %s, will retry", remote_log_location)
200
+ self.log.warning(
201
+ "Failed attempt to write logs to %s, will retry",
202
+ remote_log_location,
203
+ )
189
204
  else:
190
205
  self.log.exception("Could not write logs to %s", remote_log_location)
191
206
  return False