apache-airflow-providers-amazon 9.5.0rc2__py3-none-any.whl → 9.6.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +13 -15
- airflow/providers/amazon/aws/auth_manager/router/login.py +4 -2
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +53 -1
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
- airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
- airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
- airflow/providers/amazon/aws/hooks/dms.py +3 -1
- airflow/providers/amazon/aws/hooks/glue.py +17 -2
- airflow/providers/amazon/aws/hooks/mwaa.py +1 -1
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +9 -9
- airflow/providers/amazon/aws/hooks/redshift_data.py +1 -2
- airflow/providers/amazon/aws/hooks/s3.py +0 -4
- airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
- airflow/providers/amazon/aws/links/athena.py +1 -2
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
- airflow/providers/amazon/aws/log/s3_task_handler.py +123 -86
- airflow/providers/amazon/aws/operators/bedrock.py +119 -0
- airflow/providers/amazon/aws/operators/ec2.py +1 -1
- airflow/providers/amazon/aws/operators/eks.py +3 -3
- airflow/providers/amazon/aws/operators/rds.py +83 -18
- airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
- airflow/providers/amazon/aws/operators/sagemaker.py +3 -5
- airflow/providers/amazon/aws/sensors/bedrock.py +110 -0
- airflow/providers/amazon/aws/sensors/glacier.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +2 -1
- airflow/providers/amazon/aws/sensors/rds.py +23 -20
- airflow/providers/amazon/aws/sensors/s3.py +1 -1
- airflow/providers/amazon/aws/sensors/step_function.py +2 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
- airflow/providers/amazon/aws/triggers/bedrock.py +98 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +9 -1
- airflow/providers/amazon/aws/waiters/bedrock.json +134 -0
- airflow/providers/amazon/get_provider_info.py +0 -124
- {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/METADATA +18 -18
- {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/RECORD +39 -39
- {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/WHEEL +1 -1
- {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/entry_points.txt +0 -0
@@ -17,20 +17,31 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
import copy
|
21
|
+
import json
|
22
|
+
import logging
|
23
|
+
import os
|
20
24
|
from datetime import date, datetime, timedelta, timezone
|
21
25
|
from functools import cached_property
|
26
|
+
from pathlib import Path
|
22
27
|
from typing import TYPE_CHECKING, Any
|
23
28
|
|
29
|
+
import attrs
|
24
30
|
import watchtower
|
25
31
|
|
26
32
|
from airflow.configuration import conf
|
27
33
|
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
28
34
|
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
|
35
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
29
36
|
from airflow.utils.log.file_task_handler import FileTaskHandler
|
30
37
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
31
38
|
|
32
39
|
if TYPE_CHECKING:
|
33
|
-
|
40
|
+
import structlog.typing
|
41
|
+
|
42
|
+
from airflow.models.taskinstance import TaskInstance
|
43
|
+
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
|
44
|
+
from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
|
34
45
|
|
35
46
|
|
36
47
|
def json_serialize_legacy(value: Any) -> str | None:
|
@@ -62,6 +73,155 @@ def json_serialize(value: Any) -> str | None:
|
|
62
73
|
return watchtower._json_serialize_default(value)
|
63
74
|
|
64
75
|
|
76
|
+
@attrs.define(kw_only=True)
|
77
|
+
class CloudWatchRemoteLogIO(LoggingMixin): # noqa: D101
|
78
|
+
base_log_folder: Path = attrs.field(converter=Path)
|
79
|
+
remote_base: str = ""
|
80
|
+
delete_local_copy: bool = True
|
81
|
+
|
82
|
+
log_group_arn: str
|
83
|
+
log_stream_name: str = ""
|
84
|
+
log_group: str = attrs.field(init=False, repr=False)
|
85
|
+
region_name: str = attrs.field(init=False, repr=False)
|
86
|
+
|
87
|
+
@log_group.default
|
88
|
+
def _(self):
|
89
|
+
return self.log_group_arn.split(":")[6]
|
90
|
+
|
91
|
+
@region_name.default
|
92
|
+
def _(self):
|
93
|
+
return self.log_group_arn.split(":")[3]
|
94
|
+
|
95
|
+
@cached_property
|
96
|
+
def hook(self):
|
97
|
+
"""Returns AwsLogsHook."""
|
98
|
+
return AwsLogsHook(
|
99
|
+
aws_conn_id=conf.get("logging", "remote_log_conn_id"), region_name=self.region_name
|
100
|
+
)
|
101
|
+
|
102
|
+
@cached_property
|
103
|
+
def handler(self) -> watchtower.CloudWatchLogHandler:
|
104
|
+
_json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer", fallback=None)
|
105
|
+
return watchtower.CloudWatchLogHandler(
|
106
|
+
log_group_name=self.log_group,
|
107
|
+
log_stream_name=self.log_stream_name,
|
108
|
+
use_queues=True,
|
109
|
+
boto3_client=self.hook.get_conn(),
|
110
|
+
json_serialize_default=_json_serialize or json_serialize_legacy,
|
111
|
+
)
|
112
|
+
|
113
|
+
@cached_property
|
114
|
+
def processors(self) -> tuple[structlog.typing.Processor, ...]:
|
115
|
+
from logging import getLogRecordFactory
|
116
|
+
|
117
|
+
import structlog.stdlib
|
118
|
+
|
119
|
+
logRecordFactory = getLogRecordFactory()
|
120
|
+
# The handler MUST be initted here, before the processor is actually used to log anything.
|
121
|
+
# Otherwise, logging that occurs during the creation of the handler can create infinite loops.
|
122
|
+
_handler = self.handler
|
123
|
+
from airflow.sdk.log import relative_path_from_logger
|
124
|
+
|
125
|
+
def proc(logger: structlog.typing.WrappedLogger, method_name: str, event: structlog.typing.EventDict):
|
126
|
+
if not logger or not (stream_name := relative_path_from_logger(logger)):
|
127
|
+
return event
|
128
|
+
# Only init the handler stream_name once. We cannot do it above when we init the handler because
|
129
|
+
# we don't yet know the log path at that point.
|
130
|
+
if not _handler.log_stream_name:
|
131
|
+
_handler.log_stream_name = stream_name.as_posix().replace(":", "_")
|
132
|
+
name = event.get("logger_name") or event.get("logger", "")
|
133
|
+
level = structlog.stdlib.NAME_TO_LEVEL.get(method_name.lower(), logging.INFO)
|
134
|
+
msg = copy.copy(event)
|
135
|
+
created = None
|
136
|
+
if ts := msg.pop("timestamp", None):
|
137
|
+
try:
|
138
|
+
created = datetime.fromisoformat(ts)
|
139
|
+
except Exception:
|
140
|
+
pass
|
141
|
+
record = logRecordFactory(
|
142
|
+
name, level, pathname="", lineno=0, msg=msg, args=(), exc_info=None, func=None, sinfo=None
|
143
|
+
)
|
144
|
+
if created is not None:
|
145
|
+
ct = created.timestamp()
|
146
|
+
record.created = ct
|
147
|
+
record.msecs = int((ct - int(ct)) * 1000) + 0.0 # Copied from stdlib logging
|
148
|
+
_handler.handle(record)
|
149
|
+
return event
|
150
|
+
|
151
|
+
return (proc,)
|
152
|
+
|
153
|
+
def close(self):
|
154
|
+
self.handler.close()
|
155
|
+
|
156
|
+
def upload(self, path: os.PathLike | str, ti: RuntimeTI):
|
157
|
+
# No-op, as we upload via the processor as we go
|
158
|
+
# But we need to give the handler time to finish off its business
|
159
|
+
self.close()
|
160
|
+
return
|
161
|
+
|
162
|
+
def read(self, relative_path, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
|
163
|
+
logs: LogMessages | None = []
|
164
|
+
messages = [
|
165
|
+
f"Reading remote log from Cloudwatch log_group: {self.log_group} log_stream: {relative_path}"
|
166
|
+
]
|
167
|
+
try:
|
168
|
+
if AIRFLOW_V_3_0_PLUS:
|
169
|
+
from airflow.utils.log.file_task_handler import StructuredLogMessage
|
170
|
+
|
171
|
+
logs = [
|
172
|
+
StructuredLogMessage.model_validate(log)
|
173
|
+
for log in self.get_cloudwatch_logs(relative_path, ti)
|
174
|
+
]
|
175
|
+
else:
|
176
|
+
logs = [self.get_cloudwatch_logs(relative_path, ti)] # type: ignore[arg-value]
|
177
|
+
except Exception as e:
|
178
|
+
logs = None
|
179
|
+
messages.append(str(e))
|
180
|
+
|
181
|
+
return messages, logs
|
182
|
+
|
183
|
+
def get_cloudwatch_logs(self, stream_name: str, task_instance: RuntimeTI):
|
184
|
+
"""
|
185
|
+
Return all logs from the given log stream.
|
186
|
+
|
187
|
+
:param stream_name: name of the Cloudwatch log stream to get all logs from
|
188
|
+
:param task_instance: the task instance to get logs about
|
189
|
+
:return: string of all logs from the given log stream
|
190
|
+
"""
|
191
|
+
stream_name = stream_name.replace(":", "_")
|
192
|
+
# If there is an end_date to the task instance, fetch logs until that date + 30 seconds
|
193
|
+
# 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted
|
194
|
+
end_time = (
|
195
|
+
None
|
196
|
+
if (end_date := getattr(task_instance, "end_date", None)) is None
|
197
|
+
else datetime_to_epoch_utc_ms(end_date + timedelta(seconds=30))
|
198
|
+
)
|
199
|
+
events = self.hook.get_log_events(
|
200
|
+
log_group=self.log_group,
|
201
|
+
log_stream_name=stream_name,
|
202
|
+
end_time=end_time,
|
203
|
+
)
|
204
|
+
if AIRFLOW_V_3_0_PLUS:
|
205
|
+
return list(self._event_to_dict(e) for e in events)
|
206
|
+
return "\n".join(self._event_to_str(event) for event in events)
|
207
|
+
|
208
|
+
def _event_to_dict(self, event: dict) -> dict:
|
209
|
+
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc).isoformat()
|
210
|
+
message = event["message"]
|
211
|
+
try:
|
212
|
+
message = json.loads(message)
|
213
|
+
message["timestamp"] = event_dt
|
214
|
+
return message
|
215
|
+
except Exception:
|
216
|
+
return {"timestamp": event_dt, "event": message}
|
217
|
+
|
218
|
+
def _event_to_str(self, event: dict) -> str:
|
219
|
+
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
|
220
|
+
formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
|
221
|
+
message = event["message"]
|
222
|
+
return f"[{formatted_event_dt}] {message}"
|
223
|
+
|
224
|
+
|
65
225
|
class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
66
226
|
"""
|
67
227
|
CloudwatchTaskHandler is a python log handler that handles and reads task instance logs.
|
@@ -84,6 +244,11 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
|
84
244
|
self.region_name = split_arn[3]
|
85
245
|
self.closed = False
|
86
246
|
|
247
|
+
self.io = CloudWatchRemoteLogIO(
|
248
|
+
base_log_folder=base_log_folder,
|
249
|
+
log_group_arn=log_group_arn,
|
250
|
+
)
|
251
|
+
|
87
252
|
@cached_property
|
88
253
|
def hook(self):
|
89
254
|
"""Returns AwsLogsHook."""
|
@@ -97,14 +262,9 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
|
97
262
|
|
98
263
|
def set_context(self, ti: TaskInstance, *, identifier: str | None = None):
|
99
264
|
super().set_context(ti)
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
log_stream_name=self._render_filename(ti, ti.try_number),
|
104
|
-
use_queues=not getattr(ti, "is_trigger_log_context", False),
|
105
|
-
boto3_client=self.hook.get_conn(),
|
106
|
-
json_serialize_default=_json_serialize or json_serialize_legacy,
|
107
|
-
)
|
265
|
+
self.io.log_stream_name = self._render_filename(ti, ti.try_number)
|
266
|
+
|
267
|
+
self.handler = self.io.handler
|
108
268
|
|
109
269
|
def close(self):
|
110
270
|
"""Close the handler responsible for the upload of the local log file to Cloudwatch."""
|
@@ -120,49 +280,9 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
|
120
280
|
# Mark closed so we don't double write if close is called twice
|
121
281
|
self.closed = True
|
122
282
|
|
123
|
-
def
|
283
|
+
def _read_remote_logs(
|
284
|
+
self, task_instance, try_number, metadata=None
|
285
|
+
) -> tuple[LogSourceInfo, LogMessages]:
|
124
286
|
stream_name = self._render_filename(task_instance, try_number)
|
125
|
-
|
126
|
-
|
127
|
-
f"*** Reading remote log from Cloudwatch log_group: {self.log_group} "
|
128
|
-
f"log_stream: {stream_name}.\n"
|
129
|
-
f"{self.get_cloudwatch_logs(stream_name=stream_name, task_instance=task_instance)}\n",
|
130
|
-
{"end_of_log": True},
|
131
|
-
)
|
132
|
-
except Exception as e:
|
133
|
-
log = (
|
134
|
-
f"*** Unable to read remote logs from Cloudwatch (log_group: {self.log_group}, log_stream: "
|
135
|
-
f"{stream_name})\n*** {e}\n\n"
|
136
|
-
)
|
137
|
-
self.log.error(log)
|
138
|
-
local_log, metadata = super()._read(task_instance, try_number, metadata)
|
139
|
-
log += local_log
|
140
|
-
return log, metadata
|
141
|
-
|
142
|
-
def get_cloudwatch_logs(self, stream_name: str, task_instance: TaskInstance) -> str:
|
143
|
-
"""
|
144
|
-
Return all logs from the given log stream.
|
145
|
-
|
146
|
-
:param stream_name: name of the Cloudwatch log stream to get all logs from
|
147
|
-
:param task_instance: the task instance to get logs about
|
148
|
-
:return: string of all logs from the given log stream
|
149
|
-
"""
|
150
|
-
# If there is an end_date to the task instance, fetch logs until that date + 30 seconds
|
151
|
-
# 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted
|
152
|
-
end_time = (
|
153
|
-
None
|
154
|
-
if task_instance.end_date is None
|
155
|
-
else datetime_to_epoch_utc_ms(task_instance.end_date + timedelta(seconds=30))
|
156
|
-
)
|
157
|
-
events = self.hook.get_log_events(
|
158
|
-
log_group=self.log_group,
|
159
|
-
log_stream_name=stream_name,
|
160
|
-
end_time=end_time,
|
161
|
-
)
|
162
|
-
return "\n".join(self._event_to_str(event) for event in events)
|
163
|
-
|
164
|
-
def _event_to_str(self, event: dict) -> str:
|
165
|
-
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
|
166
|
-
formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
|
167
|
-
message = event["message"]
|
168
|
-
return f"[{formatted_event_dt}] {message}"
|
287
|
+
messages, logs = self.io.read(stream_name, task_instance)
|
288
|
+
return messages, logs or []
|
@@ -24,6 +24,8 @@ import shutil
|
|
24
24
|
from functools import cached_property
|
25
25
|
from typing import TYPE_CHECKING
|
26
26
|
|
27
|
+
import attrs
|
28
|
+
|
27
29
|
from airflow.configuration import conf
|
28
30
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
29
31
|
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
@@ -32,28 +34,34 @@ from airflow.utils.log.logging_mixin import LoggingMixin
|
|
32
34
|
|
33
35
|
if TYPE_CHECKING:
|
34
36
|
from airflow.models.taskinstance import TaskInstance
|
37
|
+
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
|
38
|
+
from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
|
35
39
|
|
36
40
|
|
37
|
-
|
38
|
-
|
39
|
-
|
41
|
+
@attrs.define
|
42
|
+
class S3RemoteLogIO(LoggingMixin): # noqa: D101
|
43
|
+
remote_base: str
|
44
|
+
base_log_folder: pathlib.Path = attrs.field(converter=pathlib.Path)
|
45
|
+
delete_local_copy: bool
|
40
46
|
|
41
|
-
|
42
|
-
"""
|
47
|
+
processors = ()
|
43
48
|
|
44
|
-
|
49
|
+
def upload(self, path: os.PathLike | str, ti: RuntimeTI):
|
50
|
+
"""Upload the given log path to the remote storage."""
|
51
|
+
path = pathlib.Path(path)
|
52
|
+
if path.is_absolute():
|
53
|
+
local_loc = path
|
54
|
+
remote_loc = os.path.join(self.remote_base, path.relative_to(self.base_log_folder))
|
55
|
+
else:
|
56
|
+
local_loc = self.base_log_folder.joinpath(path)
|
57
|
+
remote_loc = os.path.join(self.remote_base, path)
|
45
58
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
self.closed = False
|
53
|
-
self.upload_on_close = True
|
54
|
-
self.delete_local_copy = kwargs.get(
|
55
|
-
"delete_local_copy", conf.getboolean("logging", "delete_local_logs")
|
56
|
-
)
|
59
|
+
if local_loc.is_file():
|
60
|
+
# read log and remove old logs to get just the latest additions
|
61
|
+
log = local_loc.read_text()
|
62
|
+
has_uploaded = self.write(log, remote_loc)
|
63
|
+
if has_uploaded and self.delete_local_copy:
|
64
|
+
shutil.rmtree(os.path.dirname(local_loc))
|
57
65
|
|
58
66
|
@cached_property
|
59
67
|
def hook(self):
|
@@ -63,73 +71,6 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
|
63
71
|
transfer_config_args={"use_threads": False},
|
64
72
|
)
|
65
73
|
|
66
|
-
def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None:
|
67
|
-
super().set_context(ti, identifier=identifier)
|
68
|
-
# Local location and remote location is needed to open and
|
69
|
-
# upload local log file to S3 remote storage.
|
70
|
-
if TYPE_CHECKING:
|
71
|
-
assert self.handler is not None
|
72
|
-
|
73
|
-
full_path = self.handler.baseFilename
|
74
|
-
self.log_relative_path = pathlib.Path(full_path).relative_to(self.local_base).as_posix()
|
75
|
-
is_trigger_log_context = getattr(ti, "is_trigger_log_context", False)
|
76
|
-
self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None)
|
77
|
-
# Clear the file first so that duplicate data is not uploaded
|
78
|
-
# when reusing the same path (e.g. with rescheduled sensors)
|
79
|
-
if self.upload_on_close:
|
80
|
-
with open(self.handler.baseFilename, "w"):
|
81
|
-
pass
|
82
|
-
|
83
|
-
def close(self):
|
84
|
-
"""Close and upload local log file to remote storage S3."""
|
85
|
-
# When application exit, system shuts down all handlers by
|
86
|
-
# calling close method. Here we check if logger is already
|
87
|
-
# closed to prevent uploading the log to remote storage multiple
|
88
|
-
# times when `logging.shutdown` is called.
|
89
|
-
if self.closed:
|
90
|
-
return
|
91
|
-
|
92
|
-
super().close()
|
93
|
-
|
94
|
-
if not self.upload_on_close:
|
95
|
-
return
|
96
|
-
|
97
|
-
local_loc = os.path.join(self.local_base, self.log_relative_path)
|
98
|
-
remote_loc = os.path.join(self.remote_base, self.log_relative_path)
|
99
|
-
if os.path.exists(local_loc):
|
100
|
-
# read log and remove old logs to get just the latest additions
|
101
|
-
log = pathlib.Path(local_loc).read_text()
|
102
|
-
write_to_s3 = self.s3_write(log, remote_loc)
|
103
|
-
if write_to_s3 and self.delete_local_copy:
|
104
|
-
shutil.rmtree(os.path.dirname(local_loc))
|
105
|
-
|
106
|
-
# Mark closed so we don't double write if close is called twice
|
107
|
-
self.closed = True
|
108
|
-
|
109
|
-
def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], list[str]]:
|
110
|
-
# Explicitly getting log relative path is necessary as the given
|
111
|
-
# task instance might be different than task instance passed in
|
112
|
-
# in set_context method.
|
113
|
-
worker_log_rel_path = self._render_filename(ti, try_number)
|
114
|
-
|
115
|
-
logs = []
|
116
|
-
messages = []
|
117
|
-
bucket, prefix = self.hook.parse_s3_url(s3url=os.path.join(self.remote_base, worker_log_rel_path))
|
118
|
-
keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix)
|
119
|
-
if keys:
|
120
|
-
keys = sorted(f"s3://{bucket}/{key}" for key in keys)
|
121
|
-
if AIRFLOW_V_3_0_PLUS:
|
122
|
-
messages = keys
|
123
|
-
else:
|
124
|
-
messages.append("Found logs in s3:")
|
125
|
-
messages.extend(f" * {key}" for key in keys)
|
126
|
-
for key in keys:
|
127
|
-
logs.append(self.s3_read(key, return_error=True))
|
128
|
-
else:
|
129
|
-
if not AIRFLOW_V_3_0_PLUS:
|
130
|
-
messages.append(f"No logs found on s3 for ti={ti}")
|
131
|
-
return messages, logs
|
132
|
-
|
133
74
|
def s3_log_exists(self, remote_log_location: str) -> bool:
|
134
75
|
"""
|
135
76
|
Check if remote_log_location exists in remote storage.
|
@@ -158,7 +99,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
|
158
99
|
return msg
|
159
100
|
return ""
|
160
101
|
|
161
|
-
def
|
102
|
+
def write(
|
162
103
|
self,
|
163
104
|
log: str,
|
164
105
|
remote_log_location: str,
|
@@ -168,7 +109,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
|
168
109
|
"""
|
169
110
|
Write the log to the remote_log_location; return `True` or fails silently and return `False`.
|
170
111
|
|
171
|
-
:param log: the
|
112
|
+
:param log: the contents to write to the remote_log_location
|
172
113
|
:param remote_log_location: the log's location in remote storage
|
173
114
|
:param append: if False, any existing log file is overwritten. If True,
|
174
115
|
the new log is appended to any existing logs.
|
@@ -205,3 +146,99 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
|
205
146
|
self.log.exception("Could not write logs to %s", remote_log_location)
|
206
147
|
return False
|
207
148
|
return True
|
149
|
+
|
150
|
+
def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
|
151
|
+
logs: list[str] = []
|
152
|
+
messages = []
|
153
|
+
bucket, prefix = self.hook.parse_s3_url(s3url=os.path.join(self.remote_base, relative_path))
|
154
|
+
keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix)
|
155
|
+
if keys:
|
156
|
+
keys = sorted(f"s3://{bucket}/{key}" for key in keys)
|
157
|
+
if AIRFLOW_V_3_0_PLUS:
|
158
|
+
messages = keys
|
159
|
+
else:
|
160
|
+
messages.append("Found logs in s3:")
|
161
|
+
messages.extend(f" * {key}" for key in keys)
|
162
|
+
for key in keys:
|
163
|
+
logs.append(self.s3_read(key, return_error=True))
|
164
|
+
return messages, logs
|
165
|
+
else:
|
166
|
+
return messages, None
|
167
|
+
|
168
|
+
|
169
|
+
class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
170
|
+
"""
|
171
|
+
S3TaskHandler is a python log handler that handles and reads task instance logs.
|
172
|
+
|
173
|
+
It extends airflow FileTaskHandler and uploads to and reads from S3 remote storage.
|
174
|
+
"""
|
175
|
+
|
176
|
+
def __init__(self, base_log_folder: str, s3_log_folder: str, **kwargs):
|
177
|
+
super().__init__(base_log_folder)
|
178
|
+
self.handler: logging.FileHandler | None = None
|
179
|
+
self.remote_base = s3_log_folder
|
180
|
+
self.log_relative_path = ""
|
181
|
+
self._hook = None
|
182
|
+
self.closed = False
|
183
|
+
self.upload_on_close = True
|
184
|
+
self.io = S3RemoteLogIO(
|
185
|
+
remote_base=s3_log_folder,
|
186
|
+
base_log_folder=base_log_folder,
|
187
|
+
delete_local_copy=kwargs.get(
|
188
|
+
"delete_local_copy", conf.getboolean("logging", "delete_local_logs")
|
189
|
+
),
|
190
|
+
)
|
191
|
+
|
192
|
+
def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None:
|
193
|
+
super().set_context(ti, identifier=identifier)
|
194
|
+
# Local location and remote location is needed to open and
|
195
|
+
# upload local log file to S3 remote storage.
|
196
|
+
if TYPE_CHECKING:
|
197
|
+
assert self.handler is not None
|
198
|
+
|
199
|
+
self.ti = ti
|
200
|
+
|
201
|
+
full_path = self.handler.baseFilename
|
202
|
+
self.log_relative_path = pathlib.Path(full_path).relative_to(self.local_base).as_posix()
|
203
|
+
is_trigger_log_context = getattr(ti, "is_trigger_log_context", False)
|
204
|
+
self.upload_on_close = is_trigger_log_context or not getattr(ti, "raw", None)
|
205
|
+
# Clear the file first so that duplicate data is not uploaded
|
206
|
+
# when reusing the same path (e.g. with rescheduled sensors)
|
207
|
+
if self.upload_on_close:
|
208
|
+
with open(self.handler.baseFilename, "w"):
|
209
|
+
pass
|
210
|
+
|
211
|
+
def close(self):
|
212
|
+
"""Close and upload local log file to remote storage S3."""
|
213
|
+
# When application exit, system shuts down all handlers by
|
214
|
+
# calling close method. Here we check if logger is already
|
215
|
+
# closed to prevent uploading the log to remote storage multiple
|
216
|
+
# times when `logging.shutdown` is called.
|
217
|
+
if self.closed:
|
218
|
+
return
|
219
|
+
|
220
|
+
super().close()
|
221
|
+
|
222
|
+
if not self.upload_on_close:
|
223
|
+
return
|
224
|
+
|
225
|
+
if hasattr(self, "ti"):
|
226
|
+
self.io.upload(self.log_relative_path, self.ti)
|
227
|
+
|
228
|
+
# Mark closed so we don't double write if close is called twice
|
229
|
+
self.closed = True
|
230
|
+
|
231
|
+
def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]:
|
232
|
+
# Explicitly getting log relative path is necessary as the given
|
233
|
+
# task instance might be different than task instance passed in
|
234
|
+
# in set_context method.
|
235
|
+
worker_log_rel_path = self._render_filename(ti, try_number)
|
236
|
+
|
237
|
+
messages, logs = self.io.read(worker_log_rel_path, ti)
|
238
|
+
|
239
|
+
if logs is None:
|
240
|
+
logs = []
|
241
|
+
if not AIRFLOW_V_3_0_PLUS:
|
242
|
+
messages.append(f"No logs found on s3 for ti={ti}")
|
243
|
+
|
244
|
+
return messages, logs
|
@@ -33,6 +33,7 @@ from airflow.providers.amazon.aws.hooks.bedrock import (
|
|
33
33
|
)
|
34
34
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
35
35
|
from airflow.providers.amazon.aws.triggers.bedrock import (
|
36
|
+
BedrockBatchInferenceCompletedTrigger,
|
36
37
|
BedrockCustomizeModelCompletedTrigger,
|
37
38
|
BedrockIngestionJobTrigger,
|
38
39
|
BedrockKnowledgeBaseActiveTrigger,
|
@@ -869,3 +870,121 @@ class BedrockRetrieveOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
|
|
869
870
|
|
870
871
|
self.log.info("\nQuery: %s\nRetrieved: %s", self.retrieval_query, result["retrievalResults"])
|
871
872
|
return result
|
873
|
+
|
874
|
+
|
875
|
+
class BedrockBatchInferenceOperator(AwsBaseOperator[BedrockHook]):
|
876
|
+
"""
|
877
|
+
Create a batch inference job to invoke a model on multiple prompts.
|
878
|
+
|
879
|
+
.. seealso::
|
880
|
+
For more information on how to use this operator, take a look at the guide:
|
881
|
+
:ref:`howto/operator:BedrockBatchInferenceOperator`
|
882
|
+
|
883
|
+
:param job_name: A name to give the batch inference job. (templated)
|
884
|
+
:param role_arn: The ARN of the IAM role with permissions to create the knowledge base. (templated)
|
885
|
+
:param model_id: Name or ARN of the model to associate with this provisioned throughput. (templated)
|
886
|
+
:param input_uri: The S3 location of the input data. (templated)
|
887
|
+
:param output_uri: The S3 location of the output data. (templated)
|
888
|
+
:param invoke_kwargs: Additional keyword arguments to pass to the API call. (templated)
|
889
|
+
|
890
|
+
:param wait_for_completion: Whether to wait for cluster to stop. (default: True)
|
891
|
+
NOTE: The way batch inference jobs work, your jobs are added to a queue and done "eventually"
|
892
|
+
so using deferrable mode is much more practical than using wait_for_completion.
|
893
|
+
:param waiter_delay: Time in seconds to wait between status checks. (default: 60)
|
894
|
+
:param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 10)
|
895
|
+
:param deferrable: If True, the operator will wait asynchronously for the cluster to stop.
|
896
|
+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
897
|
+
(default: False)
|
898
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
899
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
900
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
901
|
+
empty, then default boto3 configuration would be used (and must be
|
902
|
+
maintained on each worker node).
|
903
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
904
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
905
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
906
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
907
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
908
|
+
"""
|
909
|
+
|
910
|
+
aws_hook_class = BedrockHook
|
911
|
+
template_fields: Sequence[str] = aws_template_fields(
|
912
|
+
"job_name",
|
913
|
+
"role_arn",
|
914
|
+
"model_id",
|
915
|
+
"input_uri",
|
916
|
+
"output_uri",
|
917
|
+
"invoke_kwargs",
|
918
|
+
)
|
919
|
+
|
920
|
+
def __init__(
|
921
|
+
self,
|
922
|
+
job_name: str,
|
923
|
+
role_arn: str,
|
924
|
+
model_id: str,
|
925
|
+
input_uri: str,
|
926
|
+
output_uri: str,
|
927
|
+
invoke_kwargs: dict[str, Any] | None = None,
|
928
|
+
wait_for_completion: bool = True,
|
929
|
+
waiter_delay: int = 60,
|
930
|
+
waiter_max_attempts: int = 10,
|
931
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
932
|
+
**kwargs,
|
933
|
+
):
|
934
|
+
super().__init__(**kwargs)
|
935
|
+
self.job_name = job_name
|
936
|
+
self.role_arn = role_arn
|
937
|
+
self.model_id = model_id
|
938
|
+
self.input_uri = input_uri
|
939
|
+
self.output_uri = output_uri
|
940
|
+
self.invoke_kwargs = invoke_kwargs or {}
|
941
|
+
|
942
|
+
self.wait_for_completion = wait_for_completion
|
943
|
+
self.waiter_delay = waiter_delay
|
944
|
+
self.waiter_max_attempts = waiter_max_attempts
|
945
|
+
self.deferrable = deferrable
|
946
|
+
|
947
|
+
self.activity = "Bedrock batch inference job"
|
948
|
+
|
949
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
950
|
+
validated_event = validate_execute_complete_event(event)
|
951
|
+
|
952
|
+
if validated_event["status"] != "success":
|
953
|
+
raise AirflowException(f"Error while running {self.activity}: {validated_event}")
|
954
|
+
|
955
|
+
self.log.info("%s '%s' complete.", self.activity, validated_event["job_arn"])
|
956
|
+
|
957
|
+
return validated_event["job_arn"]
|
958
|
+
|
959
|
+
def execute(self, context: Context) -> str:
|
960
|
+
response = self.hook.conn.create_model_invocation_job(
|
961
|
+
jobName=self.job_name,
|
962
|
+
roleArn=self.role_arn,
|
963
|
+
modelId=self.model_id,
|
964
|
+
inputDataConfig={"s3InputDataConfig": {"s3Uri": self.input_uri}},
|
965
|
+
outputDataConfig={"s3OutputDataConfig": {"s3Uri": self.output_uri}},
|
966
|
+
**self.invoke_kwargs,
|
967
|
+
)
|
968
|
+
job_arn = response["jobArn"]
|
969
|
+
self.log.info("%s '%s' started with ARN: %s", self.activity, self.job_name, job_arn)
|
970
|
+
|
971
|
+
task_description = f"for {self.activity} '{self.job_name}' to complete."
|
972
|
+
if self.deferrable:
|
973
|
+
self.log.info("Deferring %s", task_description)
|
974
|
+
self.defer(
|
975
|
+
trigger=BedrockBatchInferenceCompletedTrigger(
|
976
|
+
job_arn=job_arn,
|
977
|
+
waiter_delay=self.waiter_delay,
|
978
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
979
|
+
aws_conn_id=self.aws_conn_id,
|
980
|
+
),
|
981
|
+
method_name="execute_complete",
|
982
|
+
)
|
983
|
+
elif self.wait_for_completion:
|
984
|
+
self.log.info("Waiting %s", task_description)
|
985
|
+
self.hook.get_waiter(waiter_name="batch_inference_complete").wait(
|
986
|
+
jobIdentifier=job_arn,
|
987
|
+
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
|
988
|
+
)
|
989
|
+
|
990
|
+
return job_arn
|
@@ -254,7 +254,7 @@ class EC2CreateInstanceOperator(AwsBaseOperator[EC2Hook]):
|
|
254
254
|
region_name=self.region_name,
|
255
255
|
api_type="client_type",
|
256
256
|
) """
|
257
|
-
self.hook.terminate_instances(
|
257
|
+
self.hook.terminate_instances(instance_ids=instance_ids)
|
258
258
|
super().on_kill()
|
259
259
|
|
260
260
|
|