apache-airflow-providers-amazon 9.4.0rc1__py3-none-any.whl → 9.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +80 -110
  5. airflow/providers/amazon/aws/auth_manager/router/login.py +11 -4
  6. airflow/providers/amazon/aws/auth_manager/user.py +7 -4
  7. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
  8. airflow/providers/amazon/aws/hooks/appflow.py +5 -15
  9. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
  10. airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
  11. airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
  12. airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
  13. airflow/providers/amazon/aws/hooks/dms.py +3 -1
  14. airflow/providers/amazon/aws/hooks/ec2.py +1 -1
  15. airflow/providers/amazon/aws/hooks/eks.py +3 -6
  16. airflow/providers/amazon/aws/hooks/glue.py +6 -2
  17. airflow/providers/amazon/aws/hooks/logs.py +2 -2
  18. airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
  19. airflow/providers/amazon/aws/hooks/redshift_cluster.py +10 -10
  20. airflow/providers/amazon/aws/hooks/redshift_data.py +3 -4
  21. airflow/providers/amazon/aws/hooks/s3.py +3 -1
  22. airflow/providers/amazon/aws/hooks/sagemaker.py +2 -2
  23. airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
  24. airflow/providers/amazon/aws/links/athena.py +1 -2
  25. airflow/providers/amazon/aws/links/base_aws.py +8 -1
  26. airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
  27. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
  28. airflow/providers/amazon/aws/log/s3_task_handler.py +136 -84
  29. airflow/providers/amazon/aws/notifications/chime.py +1 -2
  30. airflow/providers/amazon/aws/notifications/sns.py +1 -1
  31. airflow/providers/amazon/aws/notifications/sqs.py +1 -1
  32. airflow/providers/amazon/aws/operators/ec2.py +91 -83
  33. airflow/providers/amazon/aws/operators/eks.py +3 -3
  34. airflow/providers/amazon/aws/operators/mwaa.py +73 -2
  35. airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
  36. airflow/providers/amazon/aws/operators/s3.py +147 -157
  37. airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
  38. airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
  39. airflow/providers/amazon/aws/sensors/ec2.py +5 -12
  40. airflow/providers/amazon/aws/sensors/emr.py +1 -1
  41. airflow/providers/amazon/aws/sensors/glacier.py +1 -1
  42. airflow/providers/amazon/aws/sensors/mwaa.py +161 -0
  43. airflow/providers/amazon/aws/sensors/rds.py +10 -5
  44. airflow/providers/amazon/aws/sensors/s3.py +32 -43
  45. airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
  46. airflow/providers/amazon/aws/sensors/step_function.py +2 -1
  47. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
  48. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
  49. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
  50. airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
  51. airflow/providers/amazon/aws/triggers/README.md +4 -4
  52. airflow/providers/amazon/aws/triggers/base.py +11 -2
  53. airflow/providers/amazon/aws/triggers/ecs.py +6 -2
  54. airflow/providers/amazon/aws/triggers/eks.py +2 -2
  55. airflow/providers/amazon/aws/triggers/glue.py +1 -1
  56. airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
  57. airflow/providers/amazon/aws/triggers/s3.py +31 -6
  58. airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
  59. airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
  60. airflow/providers/amazon/aws/triggers/sqs.py +11 -3
  61. airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
  62. airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
  63. airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
  64. airflow/providers/amazon/get_provider_info.py +46 -5
  65. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/METADATA +40 -33
  66. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/RECORD +68 -61
  67. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/WHEEL +1 -1
  68. airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
  69. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/entry_points.txt +0 -0
@@ -17,20 +17,31 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
+ import copy
21
+ import json
22
+ import logging
23
+ import os
20
24
  from datetime import date, datetime, timedelta, timezone
21
25
  from functools import cached_property
26
+ from pathlib import Path
22
27
  from typing import TYPE_CHECKING, Any
23
28
 
29
+ import attrs
24
30
  import watchtower
25
31
 
26
32
  from airflow.configuration import conf
27
33
  from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
28
34
  from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
35
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
29
36
  from airflow.utils.log.file_task_handler import FileTaskHandler
30
37
  from airflow.utils.log.logging_mixin import LoggingMixin
31
38
 
32
39
  if TYPE_CHECKING:
33
- from airflow.models import TaskInstance
40
+ import structlog.typing
41
+
42
+ from airflow.models.taskinstance import TaskInstance
43
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
44
+ from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
34
45
 
35
46
 
36
47
  def json_serialize_legacy(value: Any) -> str | None:
@@ -62,6 +73,155 @@ def json_serialize(value: Any) -> str | None:
62
73
  return watchtower._json_serialize_default(value)
63
74
 
64
75
 
76
+ @attrs.define(kw_only=True)
77
+ class CloudWatchRemoteLogIO(LoggingMixin): # noqa: D101
78
+ base_log_folder: Path = attrs.field(converter=Path)
79
+ remote_base: str = ""
80
+ delete_local_copy: bool = True
81
+
82
+ log_group_arn: str
83
+ log_stream_name: str = ""
84
+ log_group: str = attrs.field(init=False, repr=False)
85
+ region_name: str = attrs.field(init=False, repr=False)
86
+
87
+ @log_group.default
88
+ def _(self):
89
+ return self.log_group_arn.split(":")[6]
90
+
91
+ @region_name.default
92
+ def _(self):
93
+ return self.log_group_arn.split(":")[3]
94
+
95
+ @cached_property
96
+ def hook(self):
97
+ """Returns AwsLogsHook."""
98
+ return AwsLogsHook(
99
+ aws_conn_id=conf.get("logging", "remote_log_conn_id"), region_name=self.region_name
100
+ )
101
+
102
+ @cached_property
103
+ def handler(self) -> watchtower.CloudWatchLogHandler:
104
+ _json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer", fallback=None)
105
+ return watchtower.CloudWatchLogHandler(
106
+ log_group_name=self.log_group,
107
+ log_stream_name=self.log_stream_name,
108
+ use_queues=True,
109
+ boto3_client=self.hook.get_conn(),
110
+ json_serialize_default=_json_serialize or json_serialize_legacy,
111
+ )
112
+
113
+ @cached_property
114
+ def processors(self) -> tuple[structlog.typing.Processor, ...]:
115
+ from logging import getLogRecordFactory
116
+
117
+ import structlog.stdlib
118
+
119
+ logRecordFactory = getLogRecordFactory()
120
+ # The handler MUST be initted here, before the processor is actually used to log anything.
121
+ # Otherwise, logging that occurs during the creation of the handler can create infinite loops.
122
+ _handler = self.handler
123
+ from airflow.sdk.log import relative_path_from_logger
124
+
125
+ def proc(logger: structlog.typing.WrappedLogger, method_name: str, event: structlog.typing.EventDict):
126
+ if not logger or not (stream_name := relative_path_from_logger(logger)):
127
+ return event
128
+ # Only init the handler stream_name once. We cannot do it above when we init the handler because
129
+ # we don't yet know the log path at that point.
130
+ if not _handler.log_stream_name:
131
+ _handler.log_stream_name = stream_name.as_posix().replace(":", "_")
132
+ name = event.get("logger_name") or event.get("logger", "")
133
+ level = structlog.stdlib.NAME_TO_LEVEL.get(method_name.lower(), logging.INFO)
134
+ msg = copy.copy(event)
135
+ created = None
136
+ if ts := msg.pop("timestamp", None):
137
+ try:
138
+ created = datetime.fromisoformat(ts)
139
+ except Exception:
140
+ pass
141
+ record = logRecordFactory(
142
+ name, level, pathname="", lineno=0, msg=msg, args=(), exc_info=None, func=None, sinfo=None
143
+ )
144
+ if created is not None:
145
+ ct = created.timestamp()
146
+ record.created = ct
147
+ record.msecs = int((ct - int(ct)) * 1000) + 0.0 # Copied from stdlib logging
148
+ _handler.handle(record)
149
+ return event
150
+
151
+ return (proc,)
152
+
153
+ def close(self):
154
+ self.handler.close()
155
+
156
+ def upload(self, path: os.PathLike | str, ti: RuntimeTI):
157
+ # No-op, as we upload via the processor as we go
158
+ # But we need to give the handler time to finish off its business
159
+ self.close()
160
+ return
161
+
162
+ def read(self, relative_path, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
163
+ logs: LogMessages | None = []
164
+ messages = [
165
+ f"Reading remote log from Cloudwatch log_group: {self.log_group} log_stream: {relative_path}"
166
+ ]
167
+ try:
168
+ if AIRFLOW_V_3_0_PLUS:
169
+ from airflow.utils.log.file_task_handler import StructuredLogMessage
170
+
171
+ logs = [
172
+ StructuredLogMessage.model_validate(log)
173
+ for log in self.get_cloudwatch_logs(relative_path, ti)
174
+ ]
175
+ else:
176
+ logs = [self.get_cloudwatch_logs(relative_path, ti)] # type: ignore[arg-value]
177
+ except Exception as e:
178
+ logs = None
179
+ messages.append(str(e))
180
+
181
+ return messages, logs
182
+
183
+ def get_cloudwatch_logs(self, stream_name: str, task_instance: RuntimeTI):
184
+ """
185
+ Return all logs from the given log stream.
186
+
187
+ :param stream_name: name of the Cloudwatch log stream to get all logs from
188
+ :param task_instance: the task instance to get logs about
189
+ :return: string of all logs from the given log stream
190
+ """
191
+ stream_name = stream_name.replace(":", "_")
192
+ # If there is an end_date to the task instance, fetch logs until that date + 30 seconds
193
+ # 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted
194
+ end_time = (
195
+ None
196
+ if (end_date := getattr(task_instance, "end_date", None)) is None
197
+ else datetime_to_epoch_utc_ms(end_date + timedelta(seconds=30))
198
+ )
199
+ events = self.hook.get_log_events(
200
+ log_group=self.log_group,
201
+ log_stream_name=stream_name,
202
+ end_time=end_time,
203
+ )
204
+ if AIRFLOW_V_3_0_PLUS:
205
+ return list(self._event_to_dict(e) for e in events)
206
+ return "\n".join(self._event_to_str(event) for event in events)
207
+
208
+ def _event_to_dict(self, event: dict) -> dict:
209
+ event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc).isoformat()
210
+ message = event["message"]
211
+ try:
212
+ message = json.loads(message)
213
+ message["timestamp"] = event_dt
214
+ return message
215
+ except Exception:
216
+ return {"timestamp": event_dt, "event": message}
217
+
218
+ def _event_to_str(self, event: dict) -> str:
219
+ event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
220
+ formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
221
+ message = event["message"]
222
+ return f"[{formatted_event_dt}] {message}"
223
+
224
+
65
225
  class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
66
226
  """
67
227
  CloudwatchTaskHandler is a python log handler that handles and reads task instance logs.
@@ -84,6 +244,11 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
84
244
  self.region_name = split_arn[3]
85
245
  self.closed = False
86
246
 
247
+ self.io = CloudWatchRemoteLogIO(
248
+ base_log_folder=base_log_folder,
249
+ log_group_arn=log_group_arn,
250
+ )
251
+
87
252
  @cached_property
88
253
  def hook(self):
89
254
  """Returns AwsLogsHook."""
@@ -97,14 +262,9 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
97
262
 
98
263
  def set_context(self, ti: TaskInstance, *, identifier: str | None = None):
99
264
  super().set_context(ti)
100
- _json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer", fallback=None)
101
- self.handler = watchtower.CloudWatchLogHandler(
102
- log_group_name=self.log_group,
103
- log_stream_name=self._render_filename(ti, ti.try_number),
104
- use_queues=not getattr(ti, "is_trigger_log_context", False),
105
- boto3_client=self.hook.get_conn(),
106
- json_serialize_default=_json_serialize or json_serialize_legacy,
107
- )
265
+ self.io.log_stream_name = self._render_filename(ti, ti.try_number)
266
+
267
+ self.handler = self.io.handler
108
268
 
109
269
  def close(self):
110
270
  """Close the handler responsible for the upload of the local log file to Cloudwatch."""
@@ -120,49 +280,9 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
120
280
  # Mark closed so we don't double write if close is called twice
121
281
  self.closed = True
122
282
 
123
- def _read(self, task_instance, try_number, metadata=None):
283
+ def _read_remote_logs(
284
+ self, task_instance, try_number, metadata=None
285
+ ) -> tuple[LogSourceInfo, LogMessages]:
124
286
  stream_name = self._render_filename(task_instance, try_number)
125
- try:
126
- return (
127
- f"*** Reading remote log from Cloudwatch log_group: {self.log_group} "
128
- f"log_stream: {stream_name}.\n"
129
- f"{self.get_cloudwatch_logs(stream_name=stream_name, task_instance=task_instance)}\n",
130
- {"end_of_log": True},
131
- )
132
- except Exception as e:
133
- log = (
134
- f"*** Unable to read remote logs from Cloudwatch (log_group: {self.log_group}, log_stream: "
135
- f"{stream_name})\n*** {e}\n\n"
136
- )
137
- self.log.error(log)
138
- local_log, metadata = super()._read(task_instance, try_number, metadata)
139
- log += local_log
140
- return log, metadata
141
-
142
- def get_cloudwatch_logs(self, stream_name: str, task_instance: TaskInstance) -> str:
143
- """
144
- Return all logs from the given log stream.
145
-
146
- :param stream_name: name of the Cloudwatch log stream to get all logs from
147
- :param task_instance: the task instance to get logs about
148
- :return: string of all logs from the given log stream
149
- """
150
- # If there is an end_date to the task instance, fetch logs until that date + 30 seconds
151
- # 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted
152
- end_time = (
153
- None
154
- if task_instance.end_date is None
155
- else datetime_to_epoch_utc_ms(task_instance.end_date + timedelta(seconds=30))
156
- )
157
- events = self.hook.get_log_events(
158
- log_group=self.log_group,
159
- log_stream_name=stream_name,
160
- end_time=end_time,
161
- )
162
- return "\n".join(self._event_to_str(event) for event in events)
163
-
164
- def _event_to_str(self, event: dict) -> str:
165
- event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
166
- formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
167
- message = event["message"]
168
- return f"[{formatted_event_dt}] {message}"
287
+ messages, logs = self.io.read(stream_name, task_instance)
288
+ return messages, logs or []
@@ -24,106 +24,53 @@ import shutil
24
24
  from functools import cached_property
25
25
  from typing import TYPE_CHECKING
26
26
 
27
+ import attrs
28
+
27
29
  from airflow.configuration import conf
28
30
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
31
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
29
32
  from airflow.utils.log.file_task_handler import FileTaskHandler
30
33
  from airflow.utils.log.logging_mixin import LoggingMixin
31
34
 
32
35
  if TYPE_CHECKING:
33
36
  from airflow.models.taskinstance import TaskInstance
37
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
38
+ from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
34
39
 
35
40
 
36
- class S3TaskHandler(FileTaskHandler, LoggingMixin):
37
- """
38
- S3TaskHandler is a python log handler that handles and reads task instance logs.
41
+ @attrs.define
42
+ class S3RemoteLogIO(LoggingMixin): # noqa: D101
43
+ remote_base: str
44
+ base_log_folder: pathlib.Path = attrs.field(converter=pathlib.Path)
45
+ delete_local_copy: bool
39
46
 
40
- It extends airflow FileTaskHandler and uploads to and reads from S3 remote storage.
41
- """
47
+ processors = ()
42
48
 
43
- trigger_should_wrap = True
49
+ def upload(self, path: os.PathLike | str, ti: RuntimeTI):
50
+ """Upload the given log path to the remote storage."""
51
+ path = pathlib.Path(path)
52
+ if path.is_absolute():
53
+ local_loc = path
54
+ remote_loc = os.path.join(self.remote_base, path.relative_to(self.base_log_folder))
55
+ else:
56
+ local_loc = self.base_log_folder.joinpath(path)
57
+ remote_loc = os.path.join(self.remote_base, path)
44
58
 
45
- def __init__(self, base_log_folder: str, s3_log_folder: str, **kwargs):
46
- super().__init__(base_log_folder)
47
- self.handler: logging.FileHandler | None = None
48
- self.remote_base = s3_log_folder
49
- self.log_relative_path = ""
50
- self._hook = None
51
- self.closed = False
52
- self.upload_on_close = True
53
- self.delete_local_copy = kwargs.get(
54
- "delete_local_copy", conf.getboolean("logging", "delete_local_logs")
55
- )
59
+ if local_loc.is_file():
60
+ # read log and remove old logs to get just the latest additions
61
+ log = local_loc.read_text()
62
+ has_uploaded = self.write(log, remote_loc)
63
+ if has_uploaded and self.delete_local_copy:
64
+ shutil.rmtree(os.path.dirname(local_loc))
56
65
 
57
66
  @cached_property
58
67
  def hook(self):
59
68
  """Returns S3Hook."""
60
69
  return S3Hook(
61
- aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"), transfer_config_args={"use_threads": False}
70
+ aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"),
71
+ transfer_config_args={"use_threads": False},
62
72
  )
63
73
 
64
- def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None:
65
- super().set_context(ti, identifier=identifier)
66
- # Local location and remote location is needed to open and
67
- # upload local log file to S3 remote storage.
68
- if TYPE_CHECKING:
69
- assert self.handler is not None
70
-
71
- full_path = self.handler.baseFilename
72
- self.log_relative_path = pathlib.Path(full_path).relative_to(self.local_base).as_posix()
73
- is_trigger_log_context = getattr(ti, "is_trigger_log_context", False)
74
- self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None)
75
- # Clear the file first so that duplicate data is not uploaded
76
- # when re-using the same path (e.g. with rescheduled sensors)
77
- if self.upload_on_close:
78
- with open(self.handler.baseFilename, "w"):
79
- pass
80
-
81
- def close(self):
82
- """Close and upload local log file to remote storage S3."""
83
- # When application exit, system shuts down all handlers by
84
- # calling close method. Here we check if logger is already
85
- # closed to prevent uploading the log to remote storage multiple
86
- # times when `logging.shutdown` is called.
87
- if self.closed:
88
- return
89
-
90
- super().close()
91
-
92
- if not self.upload_on_close:
93
- return
94
-
95
- local_loc = os.path.join(self.local_base, self.log_relative_path)
96
- remote_loc = os.path.join(self.remote_base, self.log_relative_path)
97
- if os.path.exists(local_loc):
98
- # read log and remove old logs to get just the latest additions
99
- log = pathlib.Path(local_loc).read_text()
100
- write_to_s3 = self.s3_write(log, remote_loc)
101
- if write_to_s3 and self.delete_local_copy:
102
- shutil.rmtree(os.path.dirname(local_loc))
103
-
104
- # Mark closed so we don't double write if close is called twice
105
- self.closed = True
106
-
107
- def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], list[str]]:
108
- # Explicitly getting log relative path is necessary as the given
109
- # task instance might be different than task instance passed in
110
- # in set_context method.
111
- worker_log_rel_path = self._render_filename(ti, try_number)
112
-
113
- logs = []
114
- messages = []
115
- bucket, prefix = self.hook.parse_s3_url(s3url=os.path.join(self.remote_base, worker_log_rel_path))
116
- keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix)
117
- if keys:
118
- keys = sorted(f"s3://{bucket}/{key}" for key in keys)
119
- messages.append("Found logs in s3:")
120
- messages.extend(f" * {key}" for key in keys)
121
- for key in keys:
122
- logs.append(self.s3_read(key, return_error=True))
123
- else:
124
- messages.append(f"No logs found on s3 for ti={ti}")
125
- return messages, logs
126
-
127
74
  def s3_log_exists(self, remote_log_location: str) -> bool:
128
75
  """
129
76
  Check if remote_log_location exists in remote storage.
@@ -152,11 +99,17 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
152
99
  return msg
153
100
  return ""
154
101
 
155
- def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_retry: int = 1) -> bool:
102
+ def write(
103
+ self,
104
+ log: str,
105
+ remote_log_location: str,
106
+ append: bool = True,
107
+ max_retry: int = 1,
108
+ ) -> bool:
156
109
  """
157
110
  Write the log to the remote_log_location; return `True` or fails silently and return `False`.
158
111
 
159
- :param log: the log to write to the remote_log_location
112
+ :param log: the contents to write to the remote_log_location
160
113
  :param remote_log_location: the log's location in remote storage
161
114
  :param append: if False, any existing log file is overwritten. If True,
162
115
  the new log is appended to any existing logs.
@@ -185,8 +138,107 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
185
138
  break
186
139
  except Exception:
187
140
  if try_num < max_retry:
188
- self.log.warning("Failed attempt to write logs to %s, will retry", remote_log_location)
141
+ self.log.warning(
142
+ "Failed attempt to write logs to %s, will retry",
143
+ remote_log_location,
144
+ )
189
145
  else:
190
146
  self.log.exception("Could not write logs to %s", remote_log_location)
191
147
  return False
192
148
  return True
149
+
150
+ def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
151
+ logs: list[str] = []
152
+ messages = []
153
+ bucket, prefix = self.hook.parse_s3_url(s3url=os.path.join(self.remote_base, relative_path))
154
+ keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix)
155
+ if keys:
156
+ keys = sorted(f"s3://{bucket}/{key}" for key in keys)
157
+ if AIRFLOW_V_3_0_PLUS:
158
+ messages = keys
159
+ else:
160
+ messages.append("Found logs in s3:")
161
+ messages.extend(f" * {key}" for key in keys)
162
+ for key in keys:
163
+ logs.append(self.s3_read(key, return_error=True))
164
+ return messages, logs
165
+ else:
166
+ return messages, None
167
+
168
+
169
+ class S3TaskHandler(FileTaskHandler, LoggingMixin):
170
+ """
171
+ S3TaskHandler is a python log handler that handles and reads task instance logs.
172
+
173
+ It extends airflow FileTaskHandler and uploads to and reads from S3 remote storage.
174
+ """
175
+
176
+ def __init__(self, base_log_folder: str, s3_log_folder: str, **kwargs):
177
+ super().__init__(base_log_folder)
178
+ self.handler: logging.FileHandler | None = None
179
+ self.remote_base = s3_log_folder
180
+ self.log_relative_path = ""
181
+ self._hook = None
182
+ self.closed = False
183
+ self.upload_on_close = True
184
+ self.io = S3RemoteLogIO(
185
+ remote_base=s3_log_folder,
186
+ base_log_folder=base_log_folder,
187
+ delete_local_copy=kwargs.get(
188
+ "delete_local_copy", conf.getboolean("logging", "delete_local_logs")
189
+ ),
190
+ )
191
+
192
+ def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None:
193
+ super().set_context(ti, identifier=identifier)
194
+ # Local location and remote location is needed to open and
195
+ # upload local log file to S3 remote storage.
196
+ if TYPE_CHECKING:
197
+ assert self.handler is not None
198
+
199
+ self.ti = ti
200
+
201
+ full_path = self.handler.baseFilename
202
+ self.log_relative_path = pathlib.Path(full_path).relative_to(self.local_base).as_posix()
203
+ is_trigger_log_context = getattr(ti, "is_trigger_log_context", False)
204
+ self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None)
205
+ # Clear the file first so that duplicate data is not uploaded
206
+ # when reusing the same path (e.g. with rescheduled sensors)
207
+ if self.upload_on_close:
208
+ with open(self.handler.baseFilename, "w"):
209
+ pass
210
+
211
+ def close(self):
212
+ """Close and upload local log file to remote storage S3."""
213
+ # When application exit, system shuts down all handlers by
214
+ # calling close method. Here we check if logger is already
215
+ # closed to prevent uploading the log to remote storage multiple
216
+ # times when `logging.shutdown` is called.
217
+ if self.closed:
218
+ return
219
+
220
+ super().close()
221
+
222
+ if not self.upload_on_close:
223
+ return
224
+
225
+ if hasattr(self, "ti"):
226
+ self.io.upload(self.log_relative_path, self.ti)
227
+
228
+ # Mark closed so we don't double write if close is called twice
229
+ self.closed = True
230
+
231
+ def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]:
232
+ # Explicitly getting log relative path is necessary as the given
233
+ # task instance might be different than task instance passed in
234
+ # in set_context method.
235
+ worker_log_rel_path = self._render_filename(ti, try_number)
236
+
237
+ messages, logs = self.io.read(worker_log_rel_path, ti)
238
+
239
+ if logs is None:
240
+ logs = []
241
+ if not AIRFLOW_V_3_0_PLUS:
242
+ messages.append(f"No logs found on s3 for ti={ti}")
243
+
244
+ return messages, logs
@@ -21,12 +21,11 @@ from functools import cached_property
21
21
  from typing import TYPE_CHECKING
22
22
 
23
23
  from airflow.providers.amazon.aws.hooks.chime import ChimeWebhookHook
24
+ from airflow.providers.common.compat.notifier import BaseNotifier
24
25
 
25
26
  if TYPE_CHECKING:
26
27
  from airflow.utils.context import Context
27
28
 
28
- from airflow.notifications.basenotifier import BaseNotifier
29
-
30
29
 
31
30
  class ChimeNotifier(BaseNotifier):
32
31
  """
@@ -20,8 +20,8 @@ from __future__ import annotations
20
20
  from collections.abc import Sequence
21
21
  from functools import cached_property
22
22
 
23
- from airflow.notifications.basenotifier import BaseNotifier
24
23
  from airflow.providers.amazon.aws.hooks.sns import SnsHook
24
+ from airflow.providers.common.compat.notifier import BaseNotifier
25
25
 
26
26
 
27
27
  class SnsNotifier(BaseNotifier):
@@ -20,8 +20,8 @@ from __future__ import annotations
20
20
  from collections.abc import Sequence
21
21
  from functools import cached_property
22
22
 
23
- from airflow.notifications.basenotifier import BaseNotifier
24
23
  from airflow.providers.amazon.aws.hooks.sqs import SqsHook
24
+ from airflow.providers.common.compat.notifier import BaseNotifier
25
25
 
26
26
 
27
27
  class SqsNotifier(BaseNotifier):