apache-airflow-providers-amazon 9.4.0rc1__py3-none-any.whl → 9.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +69 -97
  5. airflow/providers/amazon/aws/auth_manager/router/login.py +9 -4
  6. airflow/providers/amazon/aws/auth_manager/user.py +7 -4
  7. airflow/providers/amazon/aws/hooks/appflow.py +5 -15
  8. airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
  9. airflow/providers/amazon/aws/hooks/ec2.py +1 -1
  10. airflow/providers/amazon/aws/hooks/eks.py +3 -6
  11. airflow/providers/amazon/aws/hooks/glue.py +6 -2
  12. airflow/providers/amazon/aws/hooks/logs.py +2 -2
  13. airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
  14. airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -1
  15. airflow/providers/amazon/aws/hooks/redshift_data.py +2 -2
  16. airflow/providers/amazon/aws/hooks/s3.py +3 -1
  17. airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
  18. airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
  19. airflow/providers/amazon/aws/links/base_aws.py +8 -1
  20. airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
  21. airflow/providers/amazon/aws/log/s3_task_handler.py +22 -7
  22. airflow/providers/amazon/aws/notifications/chime.py +1 -2
  23. airflow/providers/amazon/aws/notifications/sns.py +1 -1
  24. airflow/providers/amazon/aws/notifications/sqs.py +1 -1
  25. airflow/providers/amazon/aws/operators/ec2.py +91 -83
  26. airflow/providers/amazon/aws/operators/mwaa.py +73 -2
  27. airflow/providers/amazon/aws/operators/s3.py +147 -157
  28. airflow/providers/amazon/aws/operators/sagemaker.py +1 -2
  29. airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
  30. airflow/providers/amazon/aws/sensors/ec2.py +5 -12
  31. airflow/providers/amazon/aws/sensors/emr.py +1 -1
  32. airflow/providers/amazon/aws/sensors/mwaa.py +160 -0
  33. airflow/providers/amazon/aws/sensors/rds.py +10 -5
  34. airflow/providers/amazon/aws/sensors/s3.py +31 -42
  35. airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
  36. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
  37. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
  38. airflow/providers/amazon/aws/triggers/README.md +4 -4
  39. airflow/providers/amazon/aws/triggers/base.py +11 -2
  40. airflow/providers/amazon/aws/triggers/ecs.py +6 -2
  41. airflow/providers/amazon/aws/triggers/eks.py +2 -2
  42. airflow/providers/amazon/aws/triggers/glue.py +1 -1
  43. airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
  44. airflow/providers/amazon/aws/triggers/s3.py +31 -6
  45. airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
  46. airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
  47. airflow/providers/amazon/aws/triggers/sqs.py +11 -3
  48. airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
  49. airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
  50. airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
  51. airflow/providers/amazon/get_provider_info.py +45 -4
  52. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/METADATA +38 -31
  53. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/RECORD +55 -48
  54. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/WHEEL +1 -1
  55. airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
  56. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/entry_points.txt +0 -0
@@ -18,6 +18,7 @@
18
18
 
19
19
  from __future__ import annotations
20
20
 
21
+ import requests
21
22
  from botocore.exceptions import ClientError
22
23
 
23
24
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -29,6 +30,12 @@ class MwaaHook(AwsBaseHook):
29
30
 
30
31
  Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`
31
32
 
33
+ If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the hook will use a fallback method
34
+ that uses the AWS credential to generate a local web login token for the Airflow Web UI and then directly
35
+ make requests to the Airflow API. This fallback method can be set as the default (and only) method used by
36
+ setting `generate_local_token` to True. Learn more here:
37
+ https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API
38
+
32
39
  Additional arguments (such as ``aws_conn_id``) may be specified and
33
40
  are passed down to the underlying AwsBaseHook.
34
41
 
@@ -47,6 +54,7 @@ class MwaaHook(AwsBaseHook):
47
54
  method: str,
48
55
  body: dict | None = None,
49
56
  query_params: dict | None = None,
57
+ generate_local_token: bool = False,
50
58
  ) -> dict:
51
59
  """
52
60
  Invoke the REST API on the Airflow webserver with the specified inputs.
@@ -56,30 +64,86 @@ class MwaaHook(AwsBaseHook):
56
64
 
57
65
  :param env_name: name of the MWAA environment
58
66
  :param path: Apache Airflow REST API endpoint path to be called
59
- :param method: HTTP method used for making Airflow REST API calls
67
+ :param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE'
60
68
  :param body: Request body for the Apache Airflow REST API call
61
69
  :param query_params: Query parameters to be included in the Apache Airflow REST API call
70
+ :param generate_local_token: If True, only the local web token method is used without trying boto's
71
+ `invoke_rest_api` first. If False, the local web token method is used as a fallback after trying
72
+ boto's `invoke_rest_api`
62
73
  """
63
- body = body or {}
74
+ # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
75
+ body = {k: v for k, v in body.items() if v is not None} if body else {}
76
+ query_params = query_params or {}
64
77
  api_kwargs = {
65
78
  "Name": env_name,
66
79
  "Path": path,
67
80
  "Method": method,
68
- # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
69
- "Body": {k: v for k, v in body.items() if v is not None},
70
- "QueryParameters": query_params if query_params else {},
81
+ "Body": body,
82
+ "QueryParameters": query_params,
71
83
  }
84
+
85
+ if generate_local_token:
86
+ return self._invoke_rest_api_using_local_session_token(**api_kwargs)
87
+
72
88
  try:
73
- result = self.conn.invoke_rest_api(**api_kwargs)
89
+ response = self.conn.invoke_rest_api(**api_kwargs)
74
90
  # ResponseMetadata is removed because it contains data that is either very unlikely to be useful
75
91
  # in XComs and logs, or redundant given the data already included in the response
76
- result.pop("ResponseMetadata", None)
77
- return result
92
+ response.pop("ResponseMetadata", None)
93
+ return response
94
+
78
95
  except ClientError as e:
79
- to_log = e.response
80
- # ResponseMetadata and Error are removed because they contain data that is either very unlikely to
81
- # be useful in XComs and logs, or redundant given the data already included in the response
82
- to_log.pop("ResponseMetadata", None)
83
- to_log.pop("Error", None)
84
- self.log.error(to_log)
85
- raise e
96
+ if (
97
+ e.response["Error"]["Code"] == "AccessDeniedException"
98
+ and "Airflow role" in e.response["Error"]["Message"]
99
+ ):
100
+ self.log.info(
101
+ "Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
102
+ )
103
+ return self._invoke_rest_api_using_local_session_token(**api_kwargs)
104
+ else:
105
+ to_log = e.response
106
+ # ResponseMetadata is removed because it contains data that is either very unlikely to be
107
+ # useful in XComs and logs, or redundant given the data already included in the response
108
+ to_log.pop("ResponseMetadata", None)
109
+ self.log.error(to_log)
110
+ raise
111
+
112
+ def _invoke_rest_api_using_local_session_token(
113
+ self,
114
+ **api_kwargs,
115
+ ) -> dict:
116
+ try:
117
+ session, hostname = self._get_session_conn(api_kwargs["Name"])
118
+
119
+ response = session.request(
120
+ method=api_kwargs["Method"],
121
+ url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
122
+ params=api_kwargs["QueryParameters"],
123
+ json=api_kwargs["Body"],
124
+ timeout=10,
125
+ )
126
+ response.raise_for_status()
127
+
128
+ except requests.HTTPError as e:
129
+ self.log.error(e.response.json())
130
+ raise
131
+
132
+ return {
133
+ "RestApiStatusCode": response.status_code,
134
+ "RestApiResponse": response.json(),
135
+ }
136
+
137
+ # Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
138
+ def _get_session_conn(self, env_name: str) -> tuple:
139
+ create_token_response = self.conn.create_web_login_token(Name=env_name)
140
+ web_server_hostname = create_token_response["WebServerHostname"]
141
+ web_token = create_token_response["WebToken"]
142
+
143
+ login_url = f"https://{web_server_hostname}/aws_mwaa/login"
144
+ login_payload = {"token": web_token}
145
+ session = requests.Session()
146
+ login_response = session.post(login_url, data=login_payload, timeout=10)
147
+ login_response.raise_for_status()
148
+
149
+ return session, web_server_hostname
@@ -93,7 +93,7 @@ class RedshiftHook(AwsBaseHook):
93
93
  return "cluster_not_found"
94
94
 
95
95
  async def cluster_status_async(self, cluster_identifier: str) -> str:
96
- async with self.async_conn as client:
96
+ async with await self.get_async_conn() as client:
97
97
  response = await client.describe_clusters(ClusterIdentifier=cluster_identifier)
98
98
  return response["Clusters"][0]["ClusterStatus"] if response else None
99
99
 
@@ -275,7 +275,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
275
275
 
276
276
  :param statement_id: the UUID of the statement
277
277
  """
278
- async with self.async_conn as client:
278
+ async with await self.get_async_conn() as client:
279
279
  desc = await client.describe_statement(Id=statement_id)
280
280
  return desc["Status"] in RUNNING_STATES
281
281
 
@@ -288,6 +288,6 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
288
288
 
289
289
  :param statement_id: the UUID of the statement
290
290
  """
291
- async with self.async_conn as client:
291
+ async with await self.get_async_conn() as client:
292
292
  resp = await client.describe_statement(Id=statement_id)
293
293
  return self.parse_statement_response(resp)
@@ -1494,7 +1494,9 @@ class S3Hook(AwsBaseHook):
1494
1494
  get_hook_lineage_collector().add_output_asset(
1495
1495
  context=self,
1496
1496
  scheme="file",
1497
- asset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()},
1497
+ asset_kwargs={
1498
+ "path": str(file_path) if file_path.is_absolute() else str(file_path.absolute())
1499
+ },
1498
1500
  )
1499
1501
  file = open(file_path, "wb")
1500
1502
  else:
@@ -1318,7 +1318,7 @@ class SageMakerHook(AwsBaseHook):
1318
1318
 
1319
1319
  :param job_name: the name of the training job
1320
1320
  """
1321
- async with self.async_conn as client:
1321
+ async with await self.get_async_conn() as client:
1322
1322
  response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
1323
1323
  return response
1324
1324
 
@@ -0,0 +1,188 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ """This module contains the Amazon SageMaker Unified Studio Notebook hook."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import time
23
+
24
+ from sagemaker_studio import ClientConfig
25
+ from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI
26
+
27
+ from airflow.exceptions import AirflowException
28
+ from airflow.hooks.base import BaseHook
29
+ from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner
30
+
31
+
32
+ class SageMakerNotebookHook(BaseHook):
33
+ """
34
+ Interact with Sagemaker Unified Studio Workflows.
35
+
36
+ This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API.
37
+
38
+ Examples:
39
+ .. code-block:: python
40
+
41
+ from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import SageMakerNotebookHook
42
+
43
+ notebook_hook = SageMakerNotebookHook(
44
+ input_config={"input_path": "path/to/notebook.ipynb", "input_params": {"param1": "value1"}},
45
+ output_config={"output_uri": "folder/output/location/prefix", "output_formats": "NOTEBOOK"},
46
+ execution_name="notebook_execution",
47
+ waiter_delay=10,
48
+ waiter_max_attempts=1440,
49
+ )
50
+
51
+ :param execution_name: The name of the notebook job to be executed, this is same as task_id.
52
+ :param input_config: Configuration for the input file.
53
+ Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}}
54
+ :param output_config: Configuration for the output format. It should include an output_formats parameter to specify the output format.
55
+ Example: {'output_formats': ['NOTEBOOK']}
56
+ :param compute: compute configuration to use for the notebook execution. This is a required attribute
57
+ if the execution is on a remote compute.
58
+ Example: { "instance_type": "ml.m5.large", "volume_size_in_gb": 30, "volume_kms_key_id": "", "image_uri": "string", "container_entrypoint": [ "string" ]}
59
+ :param termination_condition: conditions to match to terminate the remote execution.
60
+ Example: { "MaxRuntimeInSeconds": 3600 }
61
+ :param tags: tags to be associated with the remote execution runs.
62
+ Example: { "md_analytics": "logs" }
63
+ :param waiter_delay: Interval in seconds to check the task execution status.
64
+ :param waiter_max_attempts: Number of attempts to wait before returning FAILED.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ execution_name: str,
70
+ input_config: dict | None = None,
71
+ output_config: dict | None = None,
72
+ compute: dict | None = None,
73
+ termination_condition: dict | None = None,
74
+ tags: dict | None = None,
75
+ waiter_delay: int = 10,
76
+ waiter_max_attempts: int = 1440,
77
+ *args,
78
+ **kwargs,
79
+ ):
80
+ super().__init__(*args, **kwargs)
81
+ self._sagemaker_studio = SageMakerStudioAPI(self._get_sagemaker_studio_config())
82
+ self.execution_name = execution_name
83
+ self.input_config = input_config or {}
84
+ self.output_config = output_config or {"output_formats": ["NOTEBOOK"]}
85
+ self.compute = compute
86
+ self.termination_condition = termination_condition or {}
87
+ self.tags = tags or {}
88
+ self.waiter_delay = waiter_delay
89
+ self.waiter_max_attempts = waiter_max_attempts
90
+
91
+ def _get_sagemaker_studio_config(self):
92
+ config = ClientConfig()
93
+ config.overrides["execution"] = {"local": is_local_runner()}
94
+ return config
95
+
96
+ def _format_start_execution_input_config(self):
97
+ config = {
98
+ "notebook_config": {
99
+ "input_path": self.input_config.get("input_path"),
100
+ "input_parameters": self.input_config.get("input_params"),
101
+ },
102
+ }
103
+
104
+ return config
105
+
106
+ def _format_start_execution_output_config(self):
107
+ output_formats = self.output_config.get("output_formats")
108
+ config = {
109
+ "notebook_config": {
110
+ "output_formats": output_formats,
111
+ }
112
+ }
113
+ return config
114
+
115
+ def start_notebook_execution(self):
116
+ start_execution_params = {
117
+ "execution_name": self.execution_name,
118
+ "execution_type": "NOTEBOOK",
119
+ "input_config": self._format_start_execution_input_config(),
120
+ "output_config": self._format_start_execution_output_config(),
121
+ "termination_condition": self.termination_condition,
122
+ "tags": self.tags,
123
+ }
124
+ if self.compute:
125
+ start_execution_params["compute"] = self.compute
126
+ else:
127
+ start_execution_params["compute"] = {"instance_type": "ml.m4.xlarge"}
128
+
129
+ print(start_execution_params)
130
+ return self._sagemaker_studio.execution_client.start_execution(**start_execution_params)
131
+
132
+ def wait_for_execution_completion(self, execution_id, context):
133
+ wait_attempts = 0
134
+ while wait_attempts < self.waiter_max_attempts:
135
+ wait_attempts += 1
136
+ time.sleep(self.waiter_delay)
137
+ response = self._sagemaker_studio.execution_client.get_execution(execution_id=execution_id)
138
+ error_message = response.get("error_details", {}).get("error_message")
139
+ status = response["status"]
140
+ if "files" in response:
141
+ self._set_xcom_files(response["files"], context)
142
+ if "s3_path" in response:
143
+ self._set_xcom_s3_path(response["s3_path"], context)
144
+
145
+ ret = self._handle_state(execution_id, status, error_message)
146
+ if ret:
147
+ return ret
148
+
149
+ # If timeout, handle state FAILED with timeout message
150
+ return self._handle_state(execution_id, "FAILED", "Execution timed out")
151
+
152
+ def _set_xcom_files(self, files, context):
153
+ if not context:
154
+ error_message = "context is required"
155
+ raise AirflowException(error_message)
156
+ for file in files:
157
+ context["ti"].xcom_push(
158
+ key=f"{file['display_name']}.{file['file_format']}",
159
+ value=file["file_path"],
160
+ )
161
+
162
+ def _set_xcom_s3_path(self, s3_path, context):
163
+ if not context:
164
+ error_message = "context is required"
165
+ raise AirflowException(error_message)
166
+ context["ti"].xcom_push(
167
+ key="s3_path",
168
+ value=s3_path,
169
+ )
170
+
171
+ def _handle_state(self, execution_id, status, error_message):
172
+ finished_states = ["COMPLETED"]
173
+ in_progress_states = ["IN_PROGRESS", "STOPPING"]
174
+
175
+ if status in in_progress_states:
176
+ info_message = f"Execution {execution_id} is still in progress with state:{status}, will check for a terminal status again in {self.waiter_delay}"
177
+ self.log.info(info_message)
178
+ return None
179
+ execution_message = f"Exiting Execution {execution_id} State: {status}"
180
+ if status in finished_states:
181
+ self.log.info(execution_message)
182
+ return {"Status": status, "ExecutionId": execution_id}
183
+ else:
184
+ log_error_message = f"Execution {execution_id} failed with error: {error_message}"
185
+ self.log.error(log_error_message)
186
+ if error_message == "":
187
+ error_message = execution_message
188
+ raise AirflowException(error_message)
@@ -19,14 +19,21 @@ from __future__ import annotations
19
19
 
20
20
  from typing import TYPE_CHECKING, ClassVar
21
21
 
22
- from airflow.models import BaseOperatorLink, XCom
23
22
  from airflow.providers.amazon.aws.utils.suppress import return_on_error
23
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from airflow.models import BaseOperator
27
27
  from airflow.models.taskinstancekey import TaskInstanceKey
28
28
  from airflow.utils.context import Context
29
29
 
30
+ if AIRFLOW_V_3_0_PLUS:
31
+ from airflow.sdk import BaseOperatorLink
32
+ from airflow.sdk.execution_time.xcom import XCom
33
+ else:
34
+ from airflow.models import XCom # type: ignore[no-redef]
35
+ from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
36
+
30
37
 
31
38
  BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}"
32
39
 
@@ -0,0 +1,27 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
20
+
21
+
22
+ class SageMakerUnifiedStudioLink(BaseAwsLink):
23
+ """Helper class for constructing Amazon SageMaker Unified Studio Links."""
24
+
25
+ name = "Amazon SageMaker Unified Studio"
26
+ key = "sagemaker_unified_studio"
27
+ format_str = BASE_AWS_CONSOLE_LINK + "/datazone/home?region={region_name}"
@@ -26,6 +26,7 @@ from typing import TYPE_CHECKING
26
26
 
27
27
  from airflow.configuration import conf
28
28
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
29
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
29
30
  from airflow.utils.log.file_task_handler import FileTaskHandler
30
31
  from airflow.utils.log.logging_mixin import LoggingMixin
31
32
 
@@ -58,7 +59,8 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
58
59
  def hook(self):
59
60
  """Returns S3Hook."""
60
61
  return S3Hook(
61
- aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"), transfer_config_args={"use_threads": False}
62
+ aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"),
63
+ transfer_config_args={"use_threads": False},
62
64
  )
63
65
 
64
66
  def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None:
@@ -73,7 +75,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
73
75
  is_trigger_log_context = getattr(ti, "is_trigger_log_context", False)
74
76
  self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None)
75
77
  # Clear the file first so that duplicate data is not uploaded
76
- # when re-using the same path (e.g. with rescheduled sensors)
78
+ # when reusing the same path (e.g. with rescheduled sensors)
77
79
  if self.upload_on_close:
78
80
  with open(self.handler.baseFilename, "w"):
79
81
  pass
@@ -116,12 +118,16 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
116
118
  keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix)
117
119
  if keys:
118
120
  keys = sorted(f"s3://{bucket}/{key}" for key in keys)
119
- messages.append("Found logs in s3:")
120
- messages.extend(f" * {key}" for key in keys)
121
+ if AIRFLOW_V_3_0_PLUS:
122
+ messages = keys
123
+ else:
124
+ messages.append("Found logs in s3:")
125
+ messages.extend(f" * {key}" for key in keys)
121
126
  for key in keys:
122
127
  logs.append(self.s3_read(key, return_error=True))
123
128
  else:
124
- messages.append(f"No logs found on s3 for ti={ti}")
129
+ if not AIRFLOW_V_3_0_PLUS:
130
+ messages.append(f"No logs found on s3 for ti={ti}")
125
131
  return messages, logs
126
132
 
127
133
  def s3_log_exists(self, remote_log_location: str) -> bool:
@@ -152,7 +158,13 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
152
158
  return msg
153
159
  return ""
154
160
 
155
- def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_retry: int = 1) -> bool:
161
+ def s3_write(
162
+ self,
163
+ log: str,
164
+ remote_log_location: str,
165
+ append: bool = True,
166
+ max_retry: int = 1,
167
+ ) -> bool:
156
168
  """
157
169
  Write the log to the remote_log_location; return `True` or fails silently and return `False`.
158
170
 
@@ -185,7 +197,10 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
185
197
  break
186
198
  except Exception:
187
199
  if try_num < max_retry:
188
- self.log.warning("Failed attempt to write logs to %s, will retry", remote_log_location)
200
+ self.log.warning(
201
+ "Failed attempt to write logs to %s, will retry",
202
+ remote_log_location,
203
+ )
189
204
  else:
190
205
  self.log.exception("Could not write logs to %s", remote_log_location)
191
206
  return False
@@ -21,12 +21,11 @@ from functools import cached_property
21
21
  from typing import TYPE_CHECKING
22
22
 
23
23
  from airflow.providers.amazon.aws.hooks.chime import ChimeWebhookHook
24
+ from airflow.providers.common.compat.notifier import BaseNotifier
24
25
 
25
26
  if TYPE_CHECKING:
26
27
  from airflow.utils.context import Context
27
28
 
28
- from airflow.notifications.basenotifier import BaseNotifier
29
-
30
29
 
31
30
  class ChimeNotifier(BaseNotifier):
32
31
  """
@@ -20,8 +20,8 @@ from __future__ import annotations
20
20
  from collections.abc import Sequence
21
21
  from functools import cached_property
22
22
 
23
- from airflow.notifications.basenotifier import BaseNotifier
24
23
  from airflow.providers.amazon.aws.hooks.sns import SnsHook
24
+ from airflow.providers.common.compat.notifier import BaseNotifier
25
25
 
26
26
 
27
27
  class SnsNotifier(BaseNotifier):
@@ -20,8 +20,8 @@ from __future__ import annotations
20
20
  from collections.abc import Sequence
21
21
  from functools import cached_property
22
22
 
23
- from airflow.notifications.basenotifier import BaseNotifier
24
23
  from airflow.providers.amazon.aws.hooks.sqs import SqsHook
24
+ from airflow.providers.common.compat.notifier import BaseNotifier
25
25
 
26
26
 
27
27
  class SqsNotifier(BaseNotifier):