apache-airflow-providers-amazon 8.17.0rc2__py3-none-any.whl → 8.18.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +1 -1
  3. airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
  4. airflow/providers/amazon/aws/executors/ecs/Dockerfile +3 -3
  5. airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
  6. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +40 -17
  7. airflow/providers/amazon/aws/executors/ecs/utils.py +9 -7
  8. airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +23 -4
  9. airflow/providers/amazon/aws/hooks/athena.py +15 -2
  10. airflow/providers/amazon/aws/hooks/base_aws.py +16 -14
  11. airflow/providers/amazon/aws/hooks/logs.py +85 -1
  12. airflow/providers/amazon/aws/hooks/neptune.py +85 -0
  13. airflow/providers/amazon/aws/hooks/quicksight.py +9 -8
  14. airflow/providers/amazon/aws/hooks/redshift_cluster.py +8 -7
  15. airflow/providers/amazon/aws/hooks/redshift_sql.py +3 -3
  16. airflow/providers/amazon/aws/hooks/s3.py +4 -6
  17. airflow/providers/amazon/aws/hooks/sagemaker.py +136 -9
  18. airflow/providers/amazon/aws/operators/eks.py +8 -6
  19. airflow/providers/amazon/aws/operators/neptune.py +218 -0
  20. airflow/providers/amazon/aws/operators/sagemaker.py +74 -15
  21. airflow/providers/amazon/aws/sensors/batch.py +2 -2
  22. airflow/providers/amazon/aws/sensors/quicksight.py +17 -14
  23. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
  24. airflow/providers/amazon/aws/transfers/sql_to_s3.py +31 -3
  25. airflow/providers/amazon/aws/triggers/neptune.py +115 -0
  26. airflow/providers/amazon/aws/triggers/rds.py +9 -7
  27. airflow/providers/amazon/aws/triggers/redshift_cluster.py +2 -2
  28. airflow/providers/amazon/aws/triggers/redshift_data.py +1 -1
  29. airflow/providers/amazon/aws/triggers/sagemaker.py +82 -1
  30. airflow/providers/amazon/aws/utils/connection_wrapper.py +12 -8
  31. airflow/providers/amazon/aws/utils/mixins.py +5 -1
  32. airflow/providers/amazon/aws/waiters/neptune.json +85 -0
  33. airflow/providers/amazon/get_provider_info.py +22 -2
  34. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/METADATA +6 -6
  35. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/RECORD +37 -33
  36. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/WHEEL +0 -0
  37. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -17,10 +17,10 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import asyncio
20
- import warnings
21
20
  from typing import Any, Sequence
22
21
 
23
22
  import botocore.exceptions
23
+ from deprecated import deprecated
24
24
 
25
25
  from airflow.exceptions import AirflowProviderDeprecationWarning
26
26
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook
@@ -195,16 +195,17 @@ class RedshiftHook(AwsBaseHook):
195
195
  return None
196
196
 
197
197
 
198
+ @deprecated(
199
+ reason=(
200
+ "airflow.providers.amazon.aws.hook.base_aws.RedshiftAsyncHook "
201
+ "has been deprecated and will be removed in future"
202
+ ),
203
+ category=AirflowProviderDeprecationWarning,
204
+ )
198
205
  class RedshiftAsyncHook(AwsBaseAsyncHook):
199
206
  """Interact with AWS Redshift using aiobotocore library."""
200
207
 
201
208
  def __init__(self, *args, **kwargs):
202
- warnings.warn(
203
- "airflow.providers.amazon.aws.hook.base_aws.RedshiftAsyncHook has been deprecated and "
204
- "will be removed in future",
205
- AirflowProviderDeprecationWarning,
206
- stacklevel=2,
207
- )
208
209
  kwargs["client_type"] = "redshift"
209
210
  super().__init__(*args, **kwargs)
210
211
 
@@ -200,7 +200,7 @@ class RedshiftSQLHook(DbApiHook):
200
200
  return redshift_connector.connect(**conn_kwargs)
201
201
 
202
202
  def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
203
- """Returns Redshift specific information for OpenLineage."""
203
+ """Return Redshift specific information for OpenLineage."""
204
204
  from airflow.providers.openlineage.sqlparser import DatabaseInfo
205
205
 
206
206
  authority = self._get_openlineage_redshift_authority_part(connection)
@@ -252,9 +252,9 @@ class RedshiftSQLHook(DbApiHook):
252
252
  return hostname
253
253
 
254
254
  def get_openlineage_database_dialect(self, connection: Connection) -> str:
255
- """Returns redshift dialect."""
255
+ """Return redshift dialect."""
256
256
  return "redshift"
257
257
 
258
258
  def get_openlineage_default_schema(self) -> str | None:
259
- """Returns current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
259
+ """Return current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
260
260
  return self.get_first("SELECT CURRENT_SCHEMA();")[0]
@@ -1369,6 +1369,10 @@ class S3Hook(AwsBaseHook):
1369
1369
  """
1370
1370
  Download a file from the S3 location to the local file system.
1371
1371
 
1372
+ Note:
1373
+ This function shadows the 'download_file' method of S3 API, but it is not the same.
1374
+ If you want to use the original method from S3 API, please use 'S3Hook.get_conn().download_file()'
1375
+
1372
1376
  .. seealso::
1373
1377
  - :external+boto3:py:meth:`S3.Object.download_fileobj`
1374
1378
 
@@ -1386,12 +1390,6 @@ class S3Hook(AwsBaseHook):
1386
1390
  Default: True.
1387
1391
  :return: the file name.
1388
1392
  """
1389
- self.log.info(
1390
- "This function shadows the 'download_file' method of S3 API, but it is not the same. If you "
1391
- "want to use the original method from S3 API, please call "
1392
- "'S3Hook.get_conn().download_file()'"
1393
- )
1394
-
1395
1393
  self.log.info("Downloading source S3 file from Bucket %s with path %s", bucket_name, key)
1396
1394
 
1397
1395
  try:
@@ -26,8 +26,9 @@ import warnings
26
26
  from collections import Counter, namedtuple
27
27
  from datetime import datetime
28
28
  from functools import partial
29
- from typing import Any, Callable, Generator, cast
29
+ from typing import Any, AsyncGenerator, Callable, Generator, cast
30
30
 
31
+ from asgiref.sync import sync_to_async
31
32
  from botocore.exceptions import ClientError
32
33
 
33
34
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
@@ -310,10 +311,12 @@ class SageMakerHook(AwsBaseHook):
310
311
  max_ingestion_time,
311
312
  )
312
313
 
313
- billable_time = (
314
- describe_response["TrainingEndTime"] - describe_response["TrainingStartTime"]
315
- ) * describe_response["ResourceConfig"]["InstanceCount"]
316
- self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1)
314
+ billable_seconds = SageMakerHook.count_billable_seconds(
315
+ training_start_time=describe_response["TrainingStartTime"],
316
+ training_end_time=describe_response["TrainingEndTime"],
317
+ instance_count=describe_response["ResourceConfig"]["InstanceCount"],
318
+ )
319
+ self.log.info("Billable seconds: %d", billable_seconds)
317
320
 
318
321
  return response
319
322
 
@@ -811,10 +814,12 @@ class SageMakerHook(AwsBaseHook):
811
814
  if status in failed_states:
812
815
  reason = last_description.get("FailureReason", "(No reason provided)")
813
816
  raise AirflowException(f"Error training {job_name}: {status} Reason: {reason}")
814
- billable_time = (
815
- last_description["TrainingEndTime"] - last_description["TrainingStartTime"]
816
- ) * instance_count
817
- self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1)
817
+ billable_seconds = SageMakerHook.count_billable_seconds(
818
+ training_start_time=last_description["TrainingStartTime"],
819
+ training_end_time=last_description["TrainingEndTime"],
820
+ instance_count=instance_count,
821
+ )
822
+ self.log.info("Billable seconds: %d", billable_seconds)
818
823
 
819
824
  def list_training_jobs(
820
825
  self, name_contains: str | None = None, max_results: int | None = None, **kwargs
@@ -1300,3 +1305,125 @@ class SageMakerHook(AwsBaseHook):
1300
1305
  if "BestCandidate" in res:
1301
1306
  return res["BestCandidate"]
1302
1307
  return None
1308
+
1309
+ @staticmethod
1310
+ def count_billable_seconds(
1311
+ training_start_time: datetime, training_end_time: datetime, instance_count: int
1312
+ ) -> int:
1313
+ billable_time = (training_end_time - training_start_time) * instance_count
1314
+ return int(billable_time.total_seconds()) + 1
1315
+
1316
+ async def describe_training_job_async(self, job_name: str) -> dict[str, Any]:
1317
+ """
1318
+ Return the training job info associated with the name.
1319
+
1320
+ :param job_name: the name of the training job
1321
+ """
1322
+ async with self.async_conn as client:
1323
+ response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
1324
+ return response
1325
+
1326
+ async def describe_training_job_with_log_async(
1327
+ self,
1328
+ job_name: str,
1329
+ positions: dict[str, Any],
1330
+ stream_names: list[str],
1331
+ instance_count: int,
1332
+ state: int,
1333
+ last_description: dict[str, Any],
1334
+ last_describe_job_call: float,
1335
+ ) -> tuple[int, dict[str, Any], float]:
1336
+ """
1337
+ Return the training job info associated with job_name and print CloudWatch logs.
1338
+
1339
+ :param job_name: name of the job to check status
1340
+ :param positions: A list of pairs of (timestamp, skip) which represents the last record
1341
+ read from each stream.
1342
+ :param stream_names: A list of the log stream names. The position of the stream in this list is
1343
+ the stream number.
1344
+ :param instance_count: Count of the instance created for the job initially
1345
+ :param state: log state
1346
+ :param last_description: Latest description of the training job
1347
+ :param last_describe_job_call: previous job called time
1348
+ """
1349
+ log_group = "/aws/sagemaker/TrainingJobs"
1350
+
1351
+ if len(stream_names) < instance_count:
1352
+ logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
1353
+ streams = await logs_hook.describe_log_streams_async(
1354
+ log_group=log_group,
1355
+ stream_prefix=job_name + "/",
1356
+ order_by="LogStreamName",
1357
+ count=instance_count,
1358
+ )
1359
+
1360
+ stream_names = [s["logStreamName"] for s in streams["logStreams"]] if streams else []
1361
+ positions.update([(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions])
1362
+
1363
+ if len(stream_names) > 0:
1364
+ async for idx, event in self.get_multi_stream(log_group, stream_names, positions):
1365
+ self.log.info(event["message"])
1366
+ ts, count = positions[stream_names[idx]]
1367
+ if event["timestamp"] == ts:
1368
+ positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
1369
+ else:
1370
+ positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1)
1371
+
1372
+ if state == LogState.COMPLETE:
1373
+ return state, last_description, last_describe_job_call
1374
+
1375
+ if state == LogState.JOB_COMPLETE:
1376
+ state = LogState.COMPLETE
1377
+ elif time.time() - last_describe_job_call >= 30:
1378
+ description = await self.describe_training_job_async(job_name)
1379
+ last_describe_job_call = time.time()
1380
+
1381
+ if await sync_to_async(secondary_training_status_changed)(description, last_description):
1382
+ self.log.info(
1383
+ await sync_to_async(secondary_training_status_message)(description, last_description)
1384
+ )
1385
+ last_description = description
1386
+
1387
+ status = description["TrainingJobStatus"]
1388
+
1389
+ if status not in self.non_terminal_states:
1390
+ state = LogState.JOB_COMPLETE
1391
+ return state, last_description, last_describe_job_call
1392
+
1393
+ async def get_multi_stream(
1394
+ self, log_group: str, streams: list[str], positions: dict[str, Any]
1395
+ ) -> AsyncGenerator[Any, tuple[int, Any | None]]:
1396
+ """Iterate over the available events coming and interleaving the events from each stream so they're yielded in timestamp order.
1397
+
1398
+ :param log_group: The name of the log group.
1399
+ :param streams: A list of the log stream names. The position of the stream in this list is
1400
+ the stream number.
1401
+ :param positions: A list of pairs of (timestamp, skip) which represents the last record
1402
+ read from each stream.
1403
+ """
1404
+ positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
1405
+ events: list[Any | None] = []
1406
+
1407
+ logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
1408
+ event_iters = [
1409
+ logs_hook.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip)
1410
+ for s in streams
1411
+ ]
1412
+ for event_stream in event_iters:
1413
+ if not event_stream:
1414
+ events.append(None)
1415
+ continue
1416
+
1417
+ try:
1418
+ events.append(await event_stream.__anext__())
1419
+ except StopAsyncIteration:
1420
+ events.append(None)
1421
+
1422
+ while any(events):
1423
+ i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) or 0
1424
+ yield i, events[i]
1425
+
1426
+ try:
1427
+ events[i] = await event_iters[i].__anext__()
1428
+ except StopAsyncIteration:
1429
+ events[i] = None
@@ -25,6 +25,7 @@ from functools import cached_property
25
25
  from typing import TYPE_CHECKING, Any, List, Sequence, cast
26
26
 
27
27
  from botocore.exceptions import ClientError, WaiterError
28
+ from deprecated import deprecated
28
29
 
29
30
  from airflow.configuration import conf
30
31
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
@@ -263,13 +264,14 @@ class EksCreateClusterOperator(BaseOperator):
263
264
  return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
264
265
 
265
266
  @property
266
- def eks_hook(self):
267
- warnings.warn(
267
+ @deprecated(
268
+ reason=(
268
269
  "`eks_hook` property is deprecated and will be removed in the future. "
269
- "Please use `hook` property instead.",
270
- AirflowProviderDeprecationWarning,
271
- stacklevel=2,
272
- )
270
+ "Please use `hook` property instead."
271
+ ),
272
+ category=AirflowProviderDeprecationWarning,
273
+ )
274
+ def eks_hook(self):
273
275
  return self.hook
274
276
 
275
277
  def execute(self, context: Context):
@@ -0,0 +1,218 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ from __future__ import annotations
19
+
20
+ from typing import TYPE_CHECKING, Any, Sequence
21
+
22
+ from airflow.configuration import conf
23
+ from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
24
+ from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
25
+ from airflow.providers.amazon.aws.triggers.neptune import (
26
+ NeptuneClusterAvailableTrigger,
27
+ NeptuneClusterStoppedTrigger,
28
+ )
29
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
30
+
31
+ if TYPE_CHECKING:
32
+ from airflow.utils.context import Context
33
+
34
+
35
+ class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
36
+ """Starts an Amazon Neptune DB cluster.
37
+
38
+ Amazon Neptune Database is a serverless graph database designed for superior scalability
39
+ and availability. Neptune Database provides built-in security, continuous backups, and
40
+ integrations with other AWS services
41
+
42
+ .. seealso::
43
+ For more information on how to use this operator, take a look at the guide:
44
+ :ref:`howto/operator:NeptuneStartDbClusterOperator`
45
+
46
+ :param db_cluster_id: The DB cluster identifier of the Neptune DB cluster to be started.
47
+ :param wait_for_completion: Whether to wait for the cluster to start. (default: True)
48
+ :param deferrable: If True, the operator will wait asynchronously for the cluster to start.
49
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
50
+ (default: False)
51
+ :param waiter_delay: Time in seconds to wait between status checks.
52
+ :param waiter_max_attempts: Maximum number of attempts to check for job completion.
53
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
54
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
55
+ running Airflow in a distributed manner and aws_conn_id is None or
56
+ empty, then default boto3 configuration would be used (and must be
57
+ maintained on each worker node).
58
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
59
+
60
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
61
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
62
+ :return: dictionary with Neptune cluster id
63
+ """
64
+
65
+ aws_hook_class = NeptuneHook
66
+ template_fields: Sequence[str] = aws_template_fields("cluster_id")
67
+
68
+ def __init__(
69
+ self,
70
+ db_cluster_id: str,
71
+ wait_for_completion: bool = True,
72
+ waiter_delay: int = 30,
73
+ waiter_max_attempts: int = 60,
74
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
75
+ **kwargs,
76
+ ):
77
+ super().__init__(**kwargs)
78
+ self.cluster_id = db_cluster_id
79
+ self.wait_for_completion = wait_for_completion
80
+ self.deferrable = deferrable
81
+ self.delay = waiter_delay
82
+ self.max_attempts = waiter_max_attempts
83
+
84
+ def execute(self, context: Context) -> dict[str, str]:
85
+ self.log.info("Starting Neptune cluster: %s", self.cluster_id)
86
+
87
+ # Check to make sure the cluster is not already available.
88
+ status = self.hook.get_cluster_status(self.cluster_id)
89
+ if status.lower() in NeptuneHook.AVAILABLE_STATES:
90
+ self.log.info("Neptune cluster %s is already available.", self.cluster_id)
91
+ return {"db_cluster_id": self.cluster_id}
92
+
93
+ resp = self.hook.conn.start_db_cluster(DBClusterIdentifier=self.cluster_id)
94
+ status = resp.get("DBClusters", {}).get("Status", "Unknown")
95
+
96
+ if self.deferrable:
97
+ self.log.info("Deferring for cluster start: %s", self.cluster_id)
98
+
99
+ self.defer(
100
+ trigger=NeptuneClusterAvailableTrigger(
101
+ aws_conn_id=self.aws_conn_id,
102
+ db_cluster_id=self.cluster_id,
103
+ waiter_delay=self.delay,
104
+ waiter_max_attempts=self.max_attempts,
105
+ ),
106
+ method_name="execute_complete",
107
+ )
108
+
109
+ elif self.wait_for_completion:
110
+ self.log.info("Waiting for Neptune cluster %s to start.", self.cluster_id)
111
+ self.hook.wait_for_cluster_availability(self.cluster_id, self.delay, self.max_attempts)
112
+
113
+ return {"db_cluster_id": self.cluster_id}
114
+
115
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]:
116
+ status = ""
117
+ cluster_id = ""
118
+
119
+ if event:
120
+ status = event.get("status", "")
121
+ cluster_id = event.get("cluster_id", "")
122
+
123
+ self.log.info("Neptune cluster %s available with status: %s", cluster_id, status)
124
+
125
+ return {"db_cluster_id": cluster_id}
126
+
127
+
128
+ class NeptuneStopDbClusterOperator(AwsBaseOperator[NeptuneHook]):
129
+ """
130
+ Stops an Amazon Neptune DB cluster.
131
+
132
+ Amazon Neptune Database is a serverless graph database designed for superior scalability
133
+ and availability. Neptune Database provides built-in security, continuous backups, and
134
+ integrations with other AWS services
135
+
136
+ .. seealso::
137
+ For more information on how to use this operator, take a look at the guide:
138
+ :ref:`howto/operator:NeptuneStartDbClusterOperator`
139
+
140
+ :param db_cluster_id: The DB cluster identifier of the Neptune DB cluster to be stopped.
141
+ :param wait_for_completion: Whether to wait for cluster to stop. (default: True)
142
+ :param deferrable: If True, the operator will wait asynchronously for the cluster to stop.
143
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
144
+ (default: False)
145
+ :param waiter_delay: Time in seconds to wait between status checks.
146
+ :param waiter_max_attempts: Maximum number of attempts to check for job completion.
147
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
148
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
149
+ running Airflow in a distributed manner and aws_conn_id is None or
150
+ empty, then default boto3 configuration would be used (and must be
151
+ maintained on each worker node).
152
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
153
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
154
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
155
+ :return: dictionary with Neptune cluster id
156
+ """
157
+
158
+ aws_hook_class = NeptuneHook
159
+ template_fields: Sequence[str] = aws_template_fields("cluster_id")
160
+
161
+ def __init__(
162
+ self,
163
+ db_cluster_id: str,
164
+ wait_for_completion: bool = True,
165
+ waiter_delay: int = 30,
166
+ waiter_max_attempts: int = 60,
167
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
168
+ **kwargs,
169
+ ):
170
+ super().__init__(**kwargs)
171
+ self.cluster_id = db_cluster_id
172
+ self.wait_for_completion = wait_for_completion
173
+ self.deferrable = deferrable
174
+ self.delay = waiter_delay
175
+ self.max_attempts = waiter_max_attempts
176
+
177
+ def execute(self, context: Context) -> dict[str, str]:
178
+ self.log.info("Stopping Neptune cluster: %s", self.cluster_id)
179
+
180
+ # Check to make sure the cluster is not already stopped.
181
+ status = self.hook.get_cluster_status(self.cluster_id)
182
+ if status.lower() in NeptuneHook.STOPPED_STATES:
183
+ self.log.info("Neptune cluster %s is already stopped.", self.cluster_id)
184
+ return {"db_cluster_id": self.cluster_id}
185
+
186
+ resp = self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
187
+ status = resp.get("DBClusters", {}).get("Status", "Unknown")
188
+
189
+ if self.deferrable:
190
+ self.log.info("Deferring for cluster stop: %s", self.cluster_id)
191
+
192
+ self.defer(
193
+ trigger=NeptuneClusterStoppedTrigger(
194
+ aws_conn_id=self.aws_conn_id,
195
+ db_cluster_id=self.cluster_id,
196
+ waiter_delay=self.delay,
197
+ waiter_max_attempts=self.max_attempts,
198
+ ),
199
+ method_name="execute_complete",
200
+ )
201
+
202
+ elif self.wait_for_completion:
203
+ self.log.info("Waiting for Neptune cluster %s to start.", self.cluster_id)
204
+ self.hook.wait_for_cluster_stopped(self.cluster_id, self.delay, self.max_attempts)
205
+
206
+ return {"db_cluster_id": self.cluster_id}
207
+
208
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]:
209
+ status = ""
210
+ cluster_id = ""
211
+
212
+ if event:
213
+ status = event.get("status", "")
214
+ cluster_id = event.get("cluster_id", "")
215
+
216
+ self.log.info("Neptune cluster %s stopped with status: %s", cluster_id, status)
217
+
218
+ return {"db_cluster_id": cluster_id}
@@ -29,9 +29,14 @@ from airflow.configuration import conf
29
29
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
30
30
  from airflow.models import BaseOperator
31
31
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
32
- from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
32
+ from airflow.providers.amazon.aws.hooks.sagemaker import (
33
+ LogState,
34
+ SageMakerHook,
35
+ secondary_training_status_message,
36
+ )
33
37
  from airflow.providers.amazon.aws.triggers.sagemaker import (
34
38
  SageMakerPipelineTrigger,
39
+ SageMakerTrainingPrintLogTrigger,
35
40
  SageMakerTrigger,
36
41
  )
37
42
  from airflow.providers.amazon.aws.utils import trim_none_values
@@ -899,9 +904,11 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
899
904
  aws_conn_id=self.aws_conn_id,
900
905
  ),
901
906
  method_name="execute_complete",
902
- timeout=datetime.timedelta(seconds=self.max_ingestion_time)
903
- if self.max_ingestion_time is not None
904
- else None,
907
+ timeout=(
908
+ datetime.timedelta(seconds=self.max_ingestion_time)
909
+ if self.max_ingestion_time is not None
910
+ else None
911
+ ),
905
912
  )
906
913
  description = {} # never executed but makes static checkers happy
907
914
  elif self.wait_for_completion:
@@ -1085,28 +1092,80 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
1085
1092
  raise AirflowException(f"Sagemaker Training Job creation failed: {response}")
1086
1093
 
1087
1094
  if self.deferrable and self.wait_for_completion:
1088
- self.defer(
1089
- timeout=self.execution_timeout,
1090
- trigger=SageMakerTrigger(
1095
+ description = self.hook.describe_training_job(self.config["TrainingJobName"])
1096
+ status = description["TrainingJobStatus"]
1097
+
1098
+ if self.print_log:
1099
+ instance_count = description["ResourceConfig"]["InstanceCount"]
1100
+ last_describe_job_call = time.monotonic()
1101
+ job_already_completed = status not in self.hook.non_terminal_states
1102
+ _, description, last_describe_job_call = self.hook.describe_training_job_with_log(
1103
+ self.config["TrainingJobName"],
1104
+ {},
1105
+ [],
1106
+ instance_count,
1107
+ LogState.COMPLETE if job_already_completed else LogState.TAILING,
1108
+ description,
1109
+ last_describe_job_call,
1110
+ )
1111
+ self.log.info(secondary_training_status_message(description, None))
1112
+
1113
+ if status in self.hook.failed_states:
1114
+ reason = description.get("FailureReason", "(No reason provided)")
1115
+ raise AirflowException(f"SageMaker job failed because {reason}")
1116
+ elif status == "Completed":
1117
+ log_message = f"{self.task_id} completed successfully."
1118
+ if self.print_log:
1119
+ billable_seconds = SageMakerHook.count_billable_seconds(
1120
+ training_start_time=description["TrainingStartTime"],
1121
+ training_end_time=description["TrainingEndTime"],
1122
+ instance_count=instance_count,
1123
+ )
1124
+ log_message = f"Billable seconds: {billable_seconds}\n{log_message}"
1125
+ self.log.info(log_message)
1126
+ return {"Training": serialize(description)}
1127
+
1128
+ timeout = self.execution_timeout
1129
+ if self.max_ingestion_time:
1130
+ timeout = datetime.timedelta(seconds=self.max_ingestion_time)
1131
+
1132
+ trigger: SageMakerTrainingPrintLogTrigger | SageMakerTrigger
1133
+ if self.print_log:
1134
+ trigger = SageMakerTrainingPrintLogTrigger(
1135
+ job_name=self.config["TrainingJobName"],
1136
+ poke_interval=self.check_interval,
1137
+ aws_conn_id=self.aws_conn_id,
1138
+ )
1139
+ else:
1140
+ trigger = SageMakerTrigger(
1091
1141
  job_name=self.config["TrainingJobName"],
1092
1142
  job_type="Training",
1093
1143
  poke_interval=self.check_interval,
1094
1144
  max_attempts=self.max_attempts,
1095
1145
  aws_conn_id=self.aws_conn_id,
1096
- ),
1146
+ )
1147
+
1148
+ self.defer(
1149
+ timeout=timeout,
1150
+ trigger=trigger,
1097
1151
  method_name="execute_complete",
1098
1152
  )
1099
1153
 
1100
- self.serialized_training_data = serialize(
1101
- self.hook.describe_training_job(self.config["TrainingJobName"])
1102
- )
1103
- return {"Training": self.serialized_training_data}
1154
+ return self.serialize_result()
1155
+
1156
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
1157
+ if event is None:
1158
+ err_msg = "Trigger error: event is None"
1159
+ self.log.error(err_msg)
1160
+ raise AirflowException(err_msg)
1104
1161
 
1105
- def execute_complete(self, context, event=None):
1106
1162
  if event["status"] != "success":
1107
1163
  raise AirflowException(f"Error while running job: {event}")
1108
- else:
1109
- self.log.info(event["message"])
1164
+
1165
+ self.log.info(event["message"])
1166
+ return self.serialize_result()
1167
+
1168
+ def serialize_result(self) -> dict[str, dict]:
1110
1169
  self.serialized_training_data = serialize(
1111
1170
  self.hook.describe_training_job(self.config["TrainingJobName"])
1112
1171
  )
@@ -60,8 +60,8 @@ class BatchSensor(BaseSensorOperator):
60
60
  aws_conn_id: str = "aws_default",
61
61
  region_name: str | None = None,
62
62
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
63
- poke_interval: float = 5,
64
- max_retries: int = 5,
63
+ poke_interval: float = 30,
64
+ max_retries: int = 4200,
65
65
  **kwargs,
66
66
  ):
67
67
  super().__init__(**kwargs)