apache-airflow-providers-amazon 9.4.0__py3-none-any.whl → 9.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +80 -110
- airflow/providers/amazon/aws/auth_manager/router/login.py +11 -4
- airflow/providers/amazon/aws/auth_manager/user.py +7 -4
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
- airflow/providers/amazon/aws/hooks/appflow.py +5 -15
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
- airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
- airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
- airflow/providers/amazon/aws/hooks/dms.py +3 -1
- airflow/providers/amazon/aws/hooks/ec2.py +1 -1
- airflow/providers/amazon/aws/hooks/eks.py +3 -6
- airflow/providers/amazon/aws/hooks/glue.py +6 -2
- airflow/providers/amazon/aws/hooks/logs.py +2 -2
- airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +10 -10
- airflow/providers/amazon/aws/hooks/redshift_data.py +3 -4
- airflow/providers/amazon/aws/hooks/s3.py +3 -1
- airflow/providers/amazon/aws/hooks/sagemaker.py +2 -2
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
- airflow/providers/amazon/aws/links/athena.py +1 -2
- airflow/providers/amazon/aws/links/base_aws.py +8 -1
- airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
- airflow/providers/amazon/aws/log/s3_task_handler.py +136 -84
- airflow/providers/amazon/aws/notifications/chime.py +1 -2
- airflow/providers/amazon/aws/notifications/sns.py +1 -1
- airflow/providers/amazon/aws/notifications/sqs.py +1 -1
- airflow/providers/amazon/aws/operators/ec2.py +91 -83
- airflow/providers/amazon/aws/operators/eks.py +3 -3
- airflow/providers/amazon/aws/operators/mwaa.py +73 -2
- airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
- airflow/providers/amazon/aws/operators/s3.py +147 -157
- airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
- airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
- airflow/providers/amazon/aws/sensors/ec2.py +5 -12
- airflow/providers/amazon/aws/sensors/emr.py +1 -1
- airflow/providers/amazon/aws/sensors/glacier.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +161 -0
- airflow/providers/amazon/aws/sensors/rds.py +10 -5
- airflow/providers/amazon/aws/sensors/s3.py +32 -43
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
- airflow/providers/amazon/aws/sensors/step_function.py +2 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
- airflow/providers/amazon/aws/triggers/README.md +4 -4
- airflow/providers/amazon/aws/triggers/base.py +11 -2
- airflow/providers/amazon/aws/triggers/ecs.py +6 -2
- airflow/providers/amazon/aws/triggers/eks.py +2 -2
- airflow/providers/amazon/aws/triggers/glue.py +1 -1
- airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
- airflow/providers/amazon/aws/triggers/s3.py +31 -6
- airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
- airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
- airflow/providers/amazon/aws/triggers/sqs.py +11 -3
- airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
- airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
- airflow/providers/amazon/get_provider_info.py +46 -5
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/METADATA +38 -31
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/RECORD +68 -61
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/WHEEL +1 -1
- airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/entry_points.txt +0 -0
@@ -17,20 +17,31 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
import copy
|
21
|
+
import json
|
22
|
+
import logging
|
23
|
+
import os
|
20
24
|
from datetime import date, datetime, timedelta, timezone
|
21
25
|
from functools import cached_property
|
26
|
+
from pathlib import Path
|
22
27
|
from typing import TYPE_CHECKING, Any
|
23
28
|
|
29
|
+
import attrs
|
24
30
|
import watchtower
|
25
31
|
|
26
32
|
from airflow.configuration import conf
|
27
33
|
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
28
34
|
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
|
35
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
29
36
|
from airflow.utils.log.file_task_handler import FileTaskHandler
|
30
37
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
31
38
|
|
32
39
|
if TYPE_CHECKING:
|
33
|
-
|
40
|
+
import structlog.typing
|
41
|
+
|
42
|
+
from airflow.models.taskinstance import TaskInstance
|
43
|
+
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
|
44
|
+
from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
|
34
45
|
|
35
46
|
|
36
47
|
def json_serialize_legacy(value: Any) -> str | None:
|
@@ -62,6 +73,155 @@ def json_serialize(value: Any) -> str | None:
|
|
62
73
|
return watchtower._json_serialize_default(value)
|
63
74
|
|
64
75
|
|
76
|
+
@attrs.define(kw_only=True)
|
77
|
+
class CloudWatchRemoteLogIO(LoggingMixin): # noqa: D101
|
78
|
+
base_log_folder: Path = attrs.field(converter=Path)
|
79
|
+
remote_base: str = ""
|
80
|
+
delete_local_copy: bool = True
|
81
|
+
|
82
|
+
log_group_arn: str
|
83
|
+
log_stream_name: str = ""
|
84
|
+
log_group: str = attrs.field(init=False, repr=False)
|
85
|
+
region_name: str = attrs.field(init=False, repr=False)
|
86
|
+
|
87
|
+
@log_group.default
|
88
|
+
def _(self):
|
89
|
+
return self.log_group_arn.split(":")[6]
|
90
|
+
|
91
|
+
@region_name.default
|
92
|
+
def _(self):
|
93
|
+
return self.log_group_arn.split(":")[3]
|
94
|
+
|
95
|
+
@cached_property
|
96
|
+
def hook(self):
|
97
|
+
"""Returns AwsLogsHook."""
|
98
|
+
return AwsLogsHook(
|
99
|
+
aws_conn_id=conf.get("logging", "remote_log_conn_id"), region_name=self.region_name
|
100
|
+
)
|
101
|
+
|
102
|
+
@cached_property
|
103
|
+
def handler(self) -> watchtower.CloudWatchLogHandler:
|
104
|
+
_json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer", fallback=None)
|
105
|
+
return watchtower.CloudWatchLogHandler(
|
106
|
+
log_group_name=self.log_group,
|
107
|
+
log_stream_name=self.log_stream_name,
|
108
|
+
use_queues=True,
|
109
|
+
boto3_client=self.hook.get_conn(),
|
110
|
+
json_serialize_default=_json_serialize or json_serialize_legacy,
|
111
|
+
)
|
112
|
+
|
113
|
+
@cached_property
|
114
|
+
def processors(self) -> tuple[structlog.typing.Processor, ...]:
|
115
|
+
from logging import getLogRecordFactory
|
116
|
+
|
117
|
+
import structlog.stdlib
|
118
|
+
|
119
|
+
logRecordFactory = getLogRecordFactory()
|
120
|
+
# The handler MUST be initted here, before the processor is actually used to log anything.
|
121
|
+
# Otherwise, logging that occurs during the creation of the handler can create infinite loops.
|
122
|
+
_handler = self.handler
|
123
|
+
from airflow.sdk.log import relative_path_from_logger
|
124
|
+
|
125
|
+
def proc(logger: structlog.typing.WrappedLogger, method_name: str, event: structlog.typing.EventDict):
|
126
|
+
if not logger or not (stream_name := relative_path_from_logger(logger)):
|
127
|
+
return event
|
128
|
+
# Only init the handler stream_name once. We cannot do it above when we init the handler because
|
129
|
+
# we don't yet know the log path at that point.
|
130
|
+
if not _handler.log_stream_name:
|
131
|
+
_handler.log_stream_name = stream_name.as_posix().replace(":", "_")
|
132
|
+
name = event.get("logger_name") or event.get("logger", "")
|
133
|
+
level = structlog.stdlib.NAME_TO_LEVEL.get(method_name.lower(), logging.INFO)
|
134
|
+
msg = copy.copy(event)
|
135
|
+
created = None
|
136
|
+
if ts := msg.pop("timestamp", None):
|
137
|
+
try:
|
138
|
+
created = datetime.fromisoformat(ts)
|
139
|
+
except Exception:
|
140
|
+
pass
|
141
|
+
record = logRecordFactory(
|
142
|
+
name, level, pathname="", lineno=0, msg=msg, args=(), exc_info=None, func=None, sinfo=None
|
143
|
+
)
|
144
|
+
if created is not None:
|
145
|
+
ct = created.timestamp()
|
146
|
+
record.created = ct
|
147
|
+
record.msecs = int((ct - int(ct)) * 1000) + 0.0 # Copied from stdlib logging
|
148
|
+
_handler.handle(record)
|
149
|
+
return event
|
150
|
+
|
151
|
+
return (proc,)
|
152
|
+
|
153
|
+
def close(self):
|
154
|
+
self.handler.close()
|
155
|
+
|
156
|
+
def upload(self, path: os.PathLike | str, ti: RuntimeTI):
|
157
|
+
# No-op, as we upload via the processor as we go
|
158
|
+
# But we need to give the handler time to finish off its business
|
159
|
+
self.close()
|
160
|
+
return
|
161
|
+
|
162
|
+
def read(self, relative_path, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
|
163
|
+
logs: LogMessages | None = []
|
164
|
+
messages = [
|
165
|
+
f"Reading remote log from Cloudwatch log_group: {self.log_group} log_stream: {relative_path}"
|
166
|
+
]
|
167
|
+
try:
|
168
|
+
if AIRFLOW_V_3_0_PLUS:
|
169
|
+
from airflow.utils.log.file_task_handler import StructuredLogMessage
|
170
|
+
|
171
|
+
logs = [
|
172
|
+
StructuredLogMessage.model_validate(log)
|
173
|
+
for log in self.get_cloudwatch_logs(relative_path, ti)
|
174
|
+
]
|
175
|
+
else:
|
176
|
+
logs = [self.get_cloudwatch_logs(relative_path, ti)] # type: ignore[arg-value]
|
177
|
+
except Exception as e:
|
178
|
+
logs = None
|
179
|
+
messages.append(str(e))
|
180
|
+
|
181
|
+
return messages, logs
|
182
|
+
|
183
|
+
def get_cloudwatch_logs(self, stream_name: str, task_instance: RuntimeTI):
|
184
|
+
"""
|
185
|
+
Return all logs from the given log stream.
|
186
|
+
|
187
|
+
:param stream_name: name of the Cloudwatch log stream to get all logs from
|
188
|
+
:param task_instance: the task instance to get logs about
|
189
|
+
:return: string of all logs from the given log stream
|
190
|
+
"""
|
191
|
+
stream_name = stream_name.replace(":", "_")
|
192
|
+
# If there is an end_date to the task instance, fetch logs until that date + 30 seconds
|
193
|
+
# 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted
|
194
|
+
end_time = (
|
195
|
+
None
|
196
|
+
if (end_date := getattr(task_instance, "end_date", None)) is None
|
197
|
+
else datetime_to_epoch_utc_ms(end_date + timedelta(seconds=30))
|
198
|
+
)
|
199
|
+
events = self.hook.get_log_events(
|
200
|
+
log_group=self.log_group,
|
201
|
+
log_stream_name=stream_name,
|
202
|
+
end_time=end_time,
|
203
|
+
)
|
204
|
+
if AIRFLOW_V_3_0_PLUS:
|
205
|
+
return list(self._event_to_dict(e) for e in events)
|
206
|
+
return "\n".join(self._event_to_str(event) for event in events)
|
207
|
+
|
208
|
+
def _event_to_dict(self, event: dict) -> dict:
|
209
|
+
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc).isoformat()
|
210
|
+
message = event["message"]
|
211
|
+
try:
|
212
|
+
message = json.loads(message)
|
213
|
+
message["timestamp"] = event_dt
|
214
|
+
return message
|
215
|
+
except Exception:
|
216
|
+
return {"timestamp": event_dt, "event": message}
|
217
|
+
|
218
|
+
def _event_to_str(self, event: dict) -> str:
|
219
|
+
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
|
220
|
+
formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
|
221
|
+
message = event["message"]
|
222
|
+
return f"[{formatted_event_dt}] {message}"
|
223
|
+
|
224
|
+
|
65
225
|
class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
66
226
|
"""
|
67
227
|
CloudwatchTaskHandler is a python log handler that handles and reads task instance logs.
|
@@ -84,6 +244,11 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
|
84
244
|
self.region_name = split_arn[3]
|
85
245
|
self.closed = False
|
86
246
|
|
247
|
+
self.io = CloudWatchRemoteLogIO(
|
248
|
+
base_log_folder=base_log_folder,
|
249
|
+
log_group_arn=log_group_arn,
|
250
|
+
)
|
251
|
+
|
87
252
|
@cached_property
|
88
253
|
def hook(self):
|
89
254
|
"""Returns AwsLogsHook."""
|
@@ -97,14 +262,9 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
|
97
262
|
|
98
263
|
def set_context(self, ti: TaskInstance, *, identifier: str | None = None):
|
99
264
|
super().set_context(ti)
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
log_stream_name=self._render_filename(ti, ti.try_number),
|
104
|
-
use_queues=not getattr(ti, "is_trigger_log_context", False),
|
105
|
-
boto3_client=self.hook.get_conn(),
|
106
|
-
json_serialize_default=_json_serialize or json_serialize_legacy,
|
107
|
-
)
|
265
|
+
self.io.log_stream_name = self._render_filename(ti, ti.try_number)
|
266
|
+
|
267
|
+
self.handler = self.io.handler
|
108
268
|
|
109
269
|
def close(self):
|
110
270
|
"""Close the handler responsible for the upload of the local log file to Cloudwatch."""
|
@@ -120,49 +280,9 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
|
120
280
|
# Mark closed so we don't double write if close is called twice
|
121
281
|
self.closed = True
|
122
282
|
|
123
|
-
def
|
283
|
+
def _read_remote_logs(
|
284
|
+
self, task_instance, try_number, metadata=None
|
285
|
+
) -> tuple[LogSourceInfo, LogMessages]:
|
124
286
|
stream_name = self._render_filename(task_instance, try_number)
|
125
|
-
|
126
|
-
|
127
|
-
f"*** Reading remote log from Cloudwatch log_group: {self.log_group} "
|
128
|
-
f"log_stream: {stream_name}.\n"
|
129
|
-
f"{self.get_cloudwatch_logs(stream_name=stream_name, task_instance=task_instance)}\n",
|
130
|
-
{"end_of_log": True},
|
131
|
-
)
|
132
|
-
except Exception as e:
|
133
|
-
log = (
|
134
|
-
f"*** Unable to read remote logs from Cloudwatch (log_group: {self.log_group}, log_stream: "
|
135
|
-
f"{stream_name})\n*** {e}\n\n"
|
136
|
-
)
|
137
|
-
self.log.error(log)
|
138
|
-
local_log, metadata = super()._read(task_instance, try_number, metadata)
|
139
|
-
log += local_log
|
140
|
-
return log, metadata
|
141
|
-
|
142
|
-
def get_cloudwatch_logs(self, stream_name: str, task_instance: TaskInstance) -> str:
|
143
|
-
"""
|
144
|
-
Return all logs from the given log stream.
|
145
|
-
|
146
|
-
:param stream_name: name of the Cloudwatch log stream to get all logs from
|
147
|
-
:param task_instance: the task instance to get logs about
|
148
|
-
:return: string of all logs from the given log stream
|
149
|
-
"""
|
150
|
-
# If there is an end_date to the task instance, fetch logs until that date + 30 seconds
|
151
|
-
# 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted
|
152
|
-
end_time = (
|
153
|
-
None
|
154
|
-
if task_instance.end_date is None
|
155
|
-
else datetime_to_epoch_utc_ms(task_instance.end_date + timedelta(seconds=30))
|
156
|
-
)
|
157
|
-
events = self.hook.get_log_events(
|
158
|
-
log_group=self.log_group,
|
159
|
-
log_stream_name=stream_name,
|
160
|
-
end_time=end_time,
|
161
|
-
)
|
162
|
-
return "\n".join(self._event_to_str(event) for event in events)
|
163
|
-
|
164
|
-
def _event_to_str(self, event: dict) -> str:
|
165
|
-
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
|
166
|
-
formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
|
167
|
-
message = event["message"]
|
168
|
-
return f"[{formatted_event_dt}] {message}"
|
287
|
+
messages, logs = self.io.read(stream_name, task_instance)
|
288
|
+
return messages, logs or []
|
@@ -24,106 +24,53 @@ import shutil
|
|
24
24
|
from functools import cached_property
|
25
25
|
from typing import TYPE_CHECKING
|
26
26
|
|
27
|
+
import attrs
|
28
|
+
|
27
29
|
from airflow.configuration import conf
|
28
30
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
31
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
29
32
|
from airflow.utils.log.file_task_handler import FileTaskHandler
|
30
33
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
31
34
|
|
32
35
|
if TYPE_CHECKING:
|
33
36
|
from airflow.models.taskinstance import TaskInstance
|
37
|
+
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
|
38
|
+
from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
|
34
39
|
|
35
40
|
|
36
|
-
|
37
|
-
|
38
|
-
|
41
|
+
@attrs.define
|
42
|
+
class S3RemoteLogIO(LoggingMixin): # noqa: D101
|
43
|
+
remote_base: str
|
44
|
+
base_log_folder: pathlib.Path = attrs.field(converter=pathlib.Path)
|
45
|
+
delete_local_copy: bool
|
39
46
|
|
40
|
-
|
41
|
-
"""
|
47
|
+
processors = ()
|
42
48
|
|
43
|
-
|
49
|
+
def upload(self, path: os.PathLike | str, ti: RuntimeTI):
|
50
|
+
"""Upload the given log path to the remote storage."""
|
51
|
+
path = pathlib.Path(path)
|
52
|
+
if path.is_absolute():
|
53
|
+
local_loc = path
|
54
|
+
remote_loc = os.path.join(self.remote_base, path.relative_to(self.base_log_folder))
|
55
|
+
else:
|
56
|
+
local_loc = self.base_log_folder.joinpath(path)
|
57
|
+
remote_loc = os.path.join(self.remote_base, path)
|
44
58
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
self.closed = False
|
52
|
-
self.upload_on_close = True
|
53
|
-
self.delete_local_copy = kwargs.get(
|
54
|
-
"delete_local_copy", conf.getboolean("logging", "delete_local_logs")
|
55
|
-
)
|
59
|
+
if local_loc.is_file():
|
60
|
+
# read log and remove old logs to get just the latest additions
|
61
|
+
log = local_loc.read_text()
|
62
|
+
has_uploaded = self.write(log, remote_loc)
|
63
|
+
if has_uploaded and self.delete_local_copy:
|
64
|
+
shutil.rmtree(os.path.dirname(local_loc))
|
56
65
|
|
57
66
|
@cached_property
|
58
67
|
def hook(self):
|
59
68
|
"""Returns S3Hook."""
|
60
69
|
return S3Hook(
|
61
|
-
aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"),
|
70
|
+
aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"),
|
71
|
+
transfer_config_args={"use_threads": False},
|
62
72
|
)
|
63
73
|
|
64
|
-
def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None:
|
65
|
-
super().set_context(ti, identifier=identifier)
|
66
|
-
# Local location and remote location is needed to open and
|
67
|
-
# upload local log file to S3 remote storage.
|
68
|
-
if TYPE_CHECKING:
|
69
|
-
assert self.handler is not None
|
70
|
-
|
71
|
-
full_path = self.handler.baseFilename
|
72
|
-
self.log_relative_path = pathlib.Path(full_path).relative_to(self.local_base).as_posix()
|
73
|
-
is_trigger_log_context = getattr(ti, "is_trigger_log_context", False)
|
74
|
-
self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None)
|
75
|
-
# Clear the file first so that duplicate data is not uploaded
|
76
|
-
# when re-using the same path (e.g. with rescheduled sensors)
|
77
|
-
if self.upload_on_close:
|
78
|
-
with open(self.handler.baseFilename, "w"):
|
79
|
-
pass
|
80
|
-
|
81
|
-
def close(self):
|
82
|
-
"""Close and upload local log file to remote storage S3."""
|
83
|
-
# When application exit, system shuts down all handlers by
|
84
|
-
# calling close method. Here we check if logger is already
|
85
|
-
# closed to prevent uploading the log to remote storage multiple
|
86
|
-
# times when `logging.shutdown` is called.
|
87
|
-
if self.closed:
|
88
|
-
return
|
89
|
-
|
90
|
-
super().close()
|
91
|
-
|
92
|
-
if not self.upload_on_close:
|
93
|
-
return
|
94
|
-
|
95
|
-
local_loc = os.path.join(self.local_base, self.log_relative_path)
|
96
|
-
remote_loc = os.path.join(self.remote_base, self.log_relative_path)
|
97
|
-
if os.path.exists(local_loc):
|
98
|
-
# read log and remove old logs to get just the latest additions
|
99
|
-
log = pathlib.Path(local_loc).read_text()
|
100
|
-
write_to_s3 = self.s3_write(log, remote_loc)
|
101
|
-
if write_to_s3 and self.delete_local_copy:
|
102
|
-
shutil.rmtree(os.path.dirname(local_loc))
|
103
|
-
|
104
|
-
# Mark closed so we don't double write if close is called twice
|
105
|
-
self.closed = True
|
106
|
-
|
107
|
-
def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], list[str]]:
|
108
|
-
# Explicitly getting log relative path is necessary as the given
|
109
|
-
# task instance might be different than task instance passed in
|
110
|
-
# in set_context method.
|
111
|
-
worker_log_rel_path = self._render_filename(ti, try_number)
|
112
|
-
|
113
|
-
logs = []
|
114
|
-
messages = []
|
115
|
-
bucket, prefix = self.hook.parse_s3_url(s3url=os.path.join(self.remote_base, worker_log_rel_path))
|
116
|
-
keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix)
|
117
|
-
if keys:
|
118
|
-
keys = sorted(f"s3://{bucket}/{key}" for key in keys)
|
119
|
-
messages.append("Found logs in s3:")
|
120
|
-
messages.extend(f" * {key}" for key in keys)
|
121
|
-
for key in keys:
|
122
|
-
logs.append(self.s3_read(key, return_error=True))
|
123
|
-
else:
|
124
|
-
messages.append(f"No logs found on s3 for ti={ti}")
|
125
|
-
return messages, logs
|
126
|
-
|
127
74
|
def s3_log_exists(self, remote_log_location: str) -> bool:
|
128
75
|
"""
|
129
76
|
Check if remote_log_location exists in remote storage.
|
@@ -152,11 +99,17 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
|
152
99
|
return msg
|
153
100
|
return ""
|
154
101
|
|
155
|
-
def
|
102
|
+
def write(
|
103
|
+
self,
|
104
|
+
log: str,
|
105
|
+
remote_log_location: str,
|
106
|
+
append: bool = True,
|
107
|
+
max_retry: int = 1,
|
108
|
+
) -> bool:
|
156
109
|
"""
|
157
110
|
Write the log to the remote_log_location; return `True` or fails silently and return `False`.
|
158
111
|
|
159
|
-
:param log: the
|
112
|
+
:param log: the contents to write to the remote_log_location
|
160
113
|
:param remote_log_location: the log's location in remote storage
|
161
114
|
:param append: if False, any existing log file is overwritten. If True,
|
162
115
|
the new log is appended to any existing logs.
|
@@ -185,8 +138,107 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
|
185
138
|
break
|
186
139
|
except Exception:
|
187
140
|
if try_num < max_retry:
|
188
|
-
self.log.warning(
|
141
|
+
self.log.warning(
|
142
|
+
"Failed attempt to write logs to %s, will retry",
|
143
|
+
remote_log_location,
|
144
|
+
)
|
189
145
|
else:
|
190
146
|
self.log.exception("Could not write logs to %s", remote_log_location)
|
191
147
|
return False
|
192
148
|
return True
|
149
|
+
|
150
|
+
def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
|
151
|
+
logs: list[str] = []
|
152
|
+
messages = []
|
153
|
+
bucket, prefix = self.hook.parse_s3_url(s3url=os.path.join(self.remote_base, relative_path))
|
154
|
+
keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix)
|
155
|
+
if keys:
|
156
|
+
keys = sorted(f"s3://{bucket}/{key}" for key in keys)
|
157
|
+
if AIRFLOW_V_3_0_PLUS:
|
158
|
+
messages = keys
|
159
|
+
else:
|
160
|
+
messages.append("Found logs in s3:")
|
161
|
+
messages.extend(f" * {key}" for key in keys)
|
162
|
+
for key in keys:
|
163
|
+
logs.append(self.s3_read(key, return_error=True))
|
164
|
+
return messages, logs
|
165
|
+
else:
|
166
|
+
return messages, None
|
167
|
+
|
168
|
+
|
169
|
+
class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
170
|
+
"""
|
171
|
+
S3TaskHandler is a python log handler that handles and reads task instance logs.
|
172
|
+
|
173
|
+
It extends airflow FileTaskHandler and uploads to and reads from S3 remote storage.
|
174
|
+
"""
|
175
|
+
|
176
|
+
def __init__(self, base_log_folder: str, s3_log_folder: str, **kwargs):
|
177
|
+
super().__init__(base_log_folder)
|
178
|
+
self.handler: logging.FileHandler | None = None
|
179
|
+
self.remote_base = s3_log_folder
|
180
|
+
self.log_relative_path = ""
|
181
|
+
self._hook = None
|
182
|
+
self.closed = False
|
183
|
+
self.upload_on_close = True
|
184
|
+
self.io = S3RemoteLogIO(
|
185
|
+
remote_base=s3_log_folder,
|
186
|
+
base_log_folder=base_log_folder,
|
187
|
+
delete_local_copy=kwargs.get(
|
188
|
+
"delete_local_copy", conf.getboolean("logging", "delete_local_logs")
|
189
|
+
),
|
190
|
+
)
|
191
|
+
|
192
|
+
def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None:
|
193
|
+
super().set_context(ti, identifier=identifier)
|
194
|
+
# Local location and remote location is needed to open and
|
195
|
+
# upload local log file to S3 remote storage.
|
196
|
+
if TYPE_CHECKING:
|
197
|
+
assert self.handler is not None
|
198
|
+
|
199
|
+
self.ti = ti
|
200
|
+
|
201
|
+
full_path = self.handler.baseFilename
|
202
|
+
self.log_relative_path = pathlib.Path(full_path).relative_to(self.local_base).as_posix()
|
203
|
+
is_trigger_log_context = getattr(ti, "is_trigger_log_context", False)
|
204
|
+
self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None)
|
205
|
+
# Clear the file first so that duplicate data is not uploaded
|
206
|
+
# when reusing the same path (e.g. with rescheduled sensors)
|
207
|
+
if self.upload_on_close:
|
208
|
+
with open(self.handler.baseFilename, "w"):
|
209
|
+
pass
|
210
|
+
|
211
|
+
def close(self):
|
212
|
+
"""Close and upload local log file to remote storage S3."""
|
213
|
+
# When application exit, system shuts down all handlers by
|
214
|
+
# calling close method. Here we check if logger is already
|
215
|
+
# closed to prevent uploading the log to remote storage multiple
|
216
|
+
# times when `logging.shutdown` is called.
|
217
|
+
if self.closed:
|
218
|
+
return
|
219
|
+
|
220
|
+
super().close()
|
221
|
+
|
222
|
+
if not self.upload_on_close:
|
223
|
+
return
|
224
|
+
|
225
|
+
if hasattr(self, "ti"):
|
226
|
+
self.io.upload(self.log_relative_path, self.ti)
|
227
|
+
|
228
|
+
# Mark closed so we don't double write if close is called twice
|
229
|
+
self.closed = True
|
230
|
+
|
231
|
+
def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]:
|
232
|
+
# Explicitly getting log relative path is necessary as the given
|
233
|
+
# task instance might be different than task instance passed in
|
234
|
+
# in set_context method.
|
235
|
+
worker_log_rel_path = self._render_filename(ti, try_number)
|
236
|
+
|
237
|
+
messages, logs = self.io.read(worker_log_rel_path, ti)
|
238
|
+
|
239
|
+
if logs is None:
|
240
|
+
logs = []
|
241
|
+
if not AIRFLOW_V_3_0_PLUS:
|
242
|
+
messages.append(f"No logs found on s3 for ti={ti}")
|
243
|
+
|
244
|
+
return messages, logs
|
@@ -21,12 +21,11 @@ from functools import cached_property
|
|
21
21
|
from typing import TYPE_CHECKING
|
22
22
|
|
23
23
|
from airflow.providers.amazon.aws.hooks.chime import ChimeWebhookHook
|
24
|
+
from airflow.providers.common.compat.notifier import BaseNotifier
|
24
25
|
|
25
26
|
if TYPE_CHECKING:
|
26
27
|
from airflow.utils.context import Context
|
27
28
|
|
28
|
-
from airflow.notifications.basenotifier import BaseNotifier
|
29
|
-
|
30
29
|
|
31
30
|
class ChimeNotifier(BaseNotifier):
|
32
31
|
"""
|
@@ -20,8 +20,8 @@ from __future__ import annotations
|
|
20
20
|
from collections.abc import Sequence
|
21
21
|
from functools import cached_property
|
22
22
|
|
23
|
-
from airflow.notifications.basenotifier import BaseNotifier
|
24
23
|
from airflow.providers.amazon.aws.hooks.sns import SnsHook
|
24
|
+
from airflow.providers.common.compat.notifier import BaseNotifier
|
25
25
|
|
26
26
|
|
27
27
|
class SnsNotifier(BaseNotifier):
|
@@ -20,8 +20,8 @@ from __future__ import annotations
|
|
20
20
|
from collections.abc import Sequence
|
21
21
|
from functools import cached_property
|
22
22
|
|
23
|
-
from airflow.notifications.basenotifier import BaseNotifier
|
24
23
|
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
|
24
|
+
from airflow.providers.common.compat.notifier import BaseNotifier
|
25
25
|
|
26
26
|
|
27
27
|
class SqsNotifier(BaseNotifier):
|