apache-airflow-providers-amazon 9.5.0rc2__py3-none-any.whl → 9.6.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +13 -15
- airflow/providers/amazon/aws/auth_manager/router/login.py +4 -2
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +53 -1
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
- airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
- airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
- airflow/providers/amazon/aws/hooks/dms.py +3 -1
- airflow/providers/amazon/aws/hooks/glue.py +17 -2
- airflow/providers/amazon/aws/hooks/mwaa.py +1 -1
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +9 -9
- airflow/providers/amazon/aws/hooks/redshift_data.py +1 -2
- airflow/providers/amazon/aws/hooks/s3.py +0 -4
- airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
- airflow/providers/amazon/aws/links/athena.py +1 -2
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
- airflow/providers/amazon/aws/log/s3_task_handler.py +123 -86
- airflow/providers/amazon/aws/operators/bedrock.py +119 -0
- airflow/providers/amazon/aws/operators/ec2.py +1 -1
- airflow/providers/amazon/aws/operators/eks.py +3 -3
- airflow/providers/amazon/aws/operators/rds.py +83 -18
- airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
- airflow/providers/amazon/aws/operators/sagemaker.py +3 -5
- airflow/providers/amazon/aws/sensors/bedrock.py +110 -0
- airflow/providers/amazon/aws/sensors/glacier.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +2 -1
- airflow/providers/amazon/aws/sensors/rds.py +23 -20
- airflow/providers/amazon/aws/sensors/s3.py +1 -1
- airflow/providers/amazon/aws/sensors/step_function.py +2 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
- airflow/providers/amazon/aws/triggers/bedrock.py +98 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +9 -1
- airflow/providers/amazon/aws/waiters/bedrock.json +134 -0
- airflow/providers/amazon/get_provider_info.py +0 -124
- {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/METADATA +18 -18
- {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/RECORD +39 -39
- {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/WHEEL +1 -1
- {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/entry_points.txt +0 -0
@@ -338,7 +338,7 @@ class EksCreateClusterOperator(BaseOperator):
|
|
338
338
|
fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
|
339
339
|
fargate_selectors=self.fargate_selectors,
|
340
340
|
create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
|
341
|
-
subnets=cast(list[str], self.resources_vpc_config.get("subnetIds")),
|
341
|
+
subnets=cast("list[str]", self.resources_vpc_config.get("subnetIds")),
|
342
342
|
)
|
343
343
|
|
344
344
|
def deferrable_create_cluster_next(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
@@ -377,7 +377,7 @@ class EksCreateClusterOperator(BaseOperator):
|
|
377
377
|
fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
|
378
378
|
fargate_selectors=self.fargate_selectors,
|
379
379
|
create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
|
380
|
-
subnets=cast(list[str], self.resources_vpc_config.get("subnetIds")),
|
380
|
+
subnets=cast("list[str]", self.resources_vpc_config.get("subnetIds")),
|
381
381
|
)
|
382
382
|
if self.compute == "fargate":
|
383
383
|
self.defer(
|
@@ -503,7 +503,7 @@ class EksCreateNodegroupOperator(BaseOperator):
|
|
503
503
|
nodegroup_subnets_list: list[str] = []
|
504
504
|
if self.nodegroup_subnets != "":
|
505
505
|
try:
|
506
|
-
nodegroup_subnets_list = cast(list, literal_eval(self.nodegroup_subnets))
|
506
|
+
nodegroup_subnets_list = cast("list", literal_eval(self.nodegroup_subnets))
|
507
507
|
except ValueError:
|
508
508
|
self.log.warning(
|
509
509
|
"The nodegroup_subnets should be List or string representing "
|
@@ -20,19 +20,19 @@ from __future__ import annotations
|
|
20
20
|
import json
|
21
21
|
from collections.abc import Sequence
|
22
22
|
from datetime import timedelta
|
23
|
-
from functools import cached_property
|
24
23
|
from typing import TYPE_CHECKING, Any
|
25
24
|
|
26
25
|
from airflow.configuration import conf
|
27
26
|
from airflow.exceptions import AirflowException
|
28
|
-
from airflow.models import BaseOperator
|
29
27
|
from airflow.providers.amazon.aws.hooks.rds import RdsHook
|
28
|
+
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
30
29
|
from airflow.providers.amazon.aws.triggers.rds import (
|
31
30
|
RdsDbAvailableTrigger,
|
32
31
|
RdsDbDeletedTrigger,
|
33
32
|
RdsDbStoppedTrigger,
|
34
33
|
)
|
35
34
|
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
35
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
36
36
|
from airflow.providers.amazon.aws.utils.rds import RdsDbType
|
37
37
|
from airflow.providers.amazon.aws.utils.tags import format_tags
|
38
38
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
|
@@ -44,9 +44,10 @@ if TYPE_CHECKING:
|
|
44
44
|
from airflow.utils.context import Context
|
45
45
|
|
46
46
|
|
47
|
-
class RdsBaseOperator(
|
47
|
+
class RdsBaseOperator(AwsBaseOperator[RdsHook]):
|
48
48
|
"""Base operator that implements common functions for all operators."""
|
49
49
|
|
50
|
+
aws_hook_class = RdsHook
|
50
51
|
ui_color = "#eeaa88"
|
51
52
|
ui_fgcolor = "#ffffff"
|
52
53
|
|
@@ -63,10 +64,6 @@ class RdsBaseOperator(BaseOperator):
|
|
63
64
|
|
64
65
|
self._await_interval = 60 # seconds
|
65
66
|
|
66
|
-
@cached_property
|
67
|
-
def hook(self) -> RdsHook:
|
68
|
-
return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
69
|
-
|
70
67
|
def execute(self, context: Context) -> str:
|
71
68
|
"""Different implementations for snapshots, tasks and events."""
|
72
69
|
raise NotImplementedError
|
@@ -92,9 +89,19 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
|
|
92
89
|
:param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
|
93
90
|
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
|
94
91
|
:param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True)
|
92
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
93
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
94
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
95
|
+
empty, then default boto3 configuration would be used (and must be
|
96
|
+
maintained on each worker node).
|
97
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
98
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
99
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
100
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
101
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
95
102
|
"""
|
96
103
|
|
97
|
-
template_fields = ("db_snapshot_identifier", "db_identifier", "tags")
|
104
|
+
template_fields = aws_template_fields("db_snapshot_identifier", "db_identifier", "tags")
|
98
105
|
|
99
106
|
def __init__(
|
100
107
|
self,
|
@@ -167,9 +174,14 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
|
|
167
174
|
Only when db_type='instance'
|
168
175
|
:param source_region: The ID of the region that contains the snapshot to be copied
|
169
176
|
:param wait_for_completion: If True, waits for snapshot copy to complete. (default: True)
|
177
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
178
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
179
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
180
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
181
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
170
182
|
"""
|
171
183
|
|
172
|
-
template_fields = (
|
184
|
+
template_fields = aws_template_fields(
|
173
185
|
"source_db_snapshot_identifier",
|
174
186
|
"target_db_snapshot_identifier",
|
175
187
|
"tags",
|
@@ -260,9 +272,16 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
|
|
260
272
|
|
261
273
|
:param db_type: Type of the DB - either "instance" or "cluster"
|
262
274
|
:param db_snapshot_identifier: The identifier for the DB instance or DB cluster snapshot
|
275
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
276
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
277
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
278
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
279
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
263
280
|
"""
|
264
281
|
|
265
|
-
template_fields = (
|
282
|
+
template_fields = aws_template_fields(
|
283
|
+
"db_snapshot_identifier",
|
284
|
+
)
|
266
285
|
|
267
286
|
def __init__(
|
268
287
|
self,
|
@@ -319,9 +338,14 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
|
|
319
338
|
:param wait_for_completion: If True, waits for the DB snapshot export to complete. (default: True)
|
320
339
|
:param waiter_interval: The number of seconds to wait before checking the export status. (default: 30)
|
321
340
|
:param waiter_max_attempts: The number of attempts to make before failing. (default: 40)
|
341
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
342
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
343
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
344
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
345
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
322
346
|
"""
|
323
347
|
|
324
|
-
template_fields = (
|
348
|
+
template_fields = aws_template_fields(
|
325
349
|
"export_task_identifier",
|
326
350
|
"source_arn",
|
327
351
|
"s3_bucket_name",
|
@@ -394,9 +418,16 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
|
|
394
418
|
:param wait_for_completion: If True, waits for DB snapshot export to cancel. (default: True)
|
395
419
|
:param check_interval: The amount of time in seconds to wait between attempts
|
396
420
|
:param max_attempts: The maximum number of attempts to be made
|
421
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
422
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
423
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
424
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
425
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
397
426
|
"""
|
398
427
|
|
399
|
-
template_fields = (
|
428
|
+
template_fields = aws_template_fields(
|
429
|
+
"export_task_identifier",
|
430
|
+
)
|
400
431
|
|
401
432
|
def __init__(
|
402
433
|
self,
|
@@ -450,9 +481,14 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
|
|
450
481
|
:param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
|
451
482
|
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
|
452
483
|
:param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True)
|
484
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
485
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
486
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
487
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
488
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
453
489
|
"""
|
454
490
|
|
455
|
-
template_fields = (
|
491
|
+
template_fields = aws_template_fields(
|
456
492
|
"subscription_name",
|
457
493
|
"sns_topic_arn",
|
458
494
|
"source_type",
|
@@ -513,9 +549,16 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
|
|
513
549
|
:ref:`howto/operator:RdsDeleteEventSubscriptionOperator`
|
514
550
|
|
515
551
|
:param subscription_name: The name of the RDS event notification subscription you want to delete
|
552
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
553
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
554
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
555
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
556
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
516
557
|
"""
|
517
558
|
|
518
|
-
template_fields = (
|
559
|
+
template_fields = aws_template_fields(
|
560
|
+
"subscription_name",
|
561
|
+
)
|
519
562
|
|
520
563
|
def __init__(
|
521
564
|
self,
|
@@ -560,9 +603,16 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
|
|
560
603
|
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
|
561
604
|
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
562
605
|
(default: False)
|
606
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
607
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
608
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
609
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
610
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
563
611
|
"""
|
564
612
|
|
565
|
-
template_fields = (
|
613
|
+
template_fields = aws_template_fields(
|
614
|
+
"db_instance_identifier", "db_instance_class", "engine", "rds_kwargs"
|
615
|
+
)
|
566
616
|
|
567
617
|
def __init__(
|
568
618
|
self,
|
@@ -652,9 +702,14 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
|
|
652
702
|
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
|
653
703
|
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
654
704
|
(default: False)
|
705
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
706
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
707
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
708
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
709
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
655
710
|
"""
|
656
711
|
|
657
|
-
template_fields = ("db_instance_identifier", "rds_kwargs")
|
712
|
+
template_fields = aws_template_fields("db_instance_identifier", "rds_kwargs")
|
658
713
|
|
659
714
|
def __init__(
|
660
715
|
self,
|
@@ -735,9 +790,14 @@ class RdsStartDbOperator(RdsBaseOperator):
|
|
735
790
|
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
|
736
791
|
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
|
737
792
|
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
793
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
794
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
795
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
796
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
797
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
738
798
|
"""
|
739
799
|
|
740
|
-
template_fields = ("db_identifier", "db_type")
|
800
|
+
template_fields = aws_template_fields("db_identifier", "db_type")
|
741
801
|
|
742
802
|
def __init__(
|
743
803
|
self,
|
@@ -832,9 +892,14 @@ class RdsStopDbOperator(RdsBaseOperator):
|
|
832
892
|
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
|
833
893
|
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
|
834
894
|
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
895
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
896
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
897
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
898
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
899
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
835
900
|
"""
|
836
901
|
|
837
|
-
template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
|
902
|
+
template_fields = aws_template_fields("db_identifier", "db_snapshot_identifier", "db_type")
|
838
903
|
|
839
904
|
def __init__(
|
840
905
|
self,
|
@@ -755,11 +755,18 @@ class RedshiftDeleteClusterOperator(BaseOperator):
|
|
755
755
|
final_cluster_snapshot_identifier=self.final_cluster_snapshot_identifier,
|
756
756
|
)
|
757
757
|
break
|
758
|
-
except self.redshift_hook.
|
758
|
+
except self.redshift_hook.conn.exceptions.InvalidClusterStateFault:
|
759
759
|
self._attempts -= 1
|
760
760
|
|
761
761
|
if self._attempts:
|
762
|
-
|
762
|
+
current_state = self.redshift_hook.conn.describe_clusters(
|
763
|
+
ClusterIdentifier=self.cluster_identifier
|
764
|
+
)["Clusters"][0]["ClusterStatus"]
|
765
|
+
self.log.error(
|
766
|
+
"Cluster in %s state, unable to delete. %d attempts remaining.",
|
767
|
+
current_state,
|
768
|
+
self._attempts,
|
769
|
+
)
|
763
770
|
time.sleep(self._attempt_interval)
|
764
771
|
else:
|
765
772
|
raise
|
@@ -785,7 +792,7 @@ class RedshiftDeleteClusterOperator(BaseOperator):
|
|
785
792
|
)
|
786
793
|
|
787
794
|
elif self.wait_for_completion:
|
788
|
-
waiter = self.redshift_hook.
|
795
|
+
waiter = self.redshift_hook.conn.get_waiter("cluster_deleted")
|
789
796
|
waiter.wait(
|
790
797
|
ClusterIdentifier=self.cluster_identifier,
|
791
798
|
WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": self.max_attempts},
|
@@ -170,7 +170,7 @@ class SageMakerBaseOperator(BaseOperator):
|
|
170
170
|
timestamp = str(
|
171
171
|
time.time_ns() // 1000000000
|
172
172
|
) # only keep the relevant datetime (first 10 digits)
|
173
|
-
name = f"{proposed_name[:max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
|
173
|
+
name = f"{proposed_name[: max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
|
174
174
|
self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
|
175
175
|
return name
|
176
176
|
|
@@ -178,8 +178,7 @@ class SageMakerBaseOperator(BaseOperator):
|
|
178
178
|
"""Raise exception if resource type is not 'model' or 'job'."""
|
179
179
|
if resource_type not in ("model", "job"):
|
180
180
|
raise AirflowException(
|
181
|
-
"Argument resource_type accepts only 'model' and 'job'. "
|
182
|
-
f"Provided value: '{resource_type}'."
|
181
|
+
f"Argument resource_type accepts only 'model' and 'job'. Provided value: '{resource_type}'."
|
183
182
|
)
|
184
183
|
|
185
184
|
def _check_if_job_exists(self, job_name: str, describe_func: Callable[[str], Any]) -> bool:
|
@@ -559,8 +558,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
|
|
559
558
|
self.operation = "update"
|
560
559
|
sagemaker_operation = self.hook.update_endpoint
|
561
560
|
self.log.warning(
|
562
|
-
"cannot create already existing endpoint %s, "
|
563
|
-
"updating it with the given config instead",
|
561
|
+
"cannot create already existing endpoint %s, updating it with the given config instead",
|
564
562
|
endpoint_info["EndpointName"],
|
565
563
|
)
|
566
564
|
if "Tags" in endpoint_info:
|
@@ -26,6 +26,8 @@ from airflow.exceptions import AirflowException
|
|
26
26
|
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook
|
27
27
|
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
28
28
|
from airflow.providers.amazon.aws.triggers.bedrock import (
|
29
|
+
BedrockBatchInferenceCompletedTrigger,
|
30
|
+
BedrockBatchInferenceScheduledTrigger,
|
29
31
|
BedrockCustomizeModelCompletedTrigger,
|
30
32
|
BedrockIngestionJobTrigger,
|
31
33
|
BedrockKnowledgeBaseActiveTrigger,
|
@@ -34,6 +36,7 @@ from airflow.providers.amazon.aws.triggers.bedrock import (
|
|
34
36
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
35
37
|
|
36
38
|
if TYPE_CHECKING:
|
39
|
+
from airflow.providers.amazon.aws.triggers.bedrock import BedrockBaseBatchInferenceTrigger
|
37
40
|
from airflow.utils.context import Context
|
38
41
|
|
39
42
|
|
@@ -368,3 +371,110 @@ class BedrockIngestionJobSensor(BedrockBaseSensor[BedrockAgentHook]):
|
|
368
371
|
)
|
369
372
|
else:
|
370
373
|
super().execute(context=context)
|
374
|
+
|
375
|
+
|
376
|
+
class BedrockBatchInferenceSensor(BedrockBaseSensor[BedrockHook]):
|
377
|
+
"""
|
378
|
+
Poll the batch inference job status until it reaches a terminal state; fails if creation fails.
|
379
|
+
|
380
|
+
.. seealso::
|
381
|
+
For more information on how to use this sensor, take a look at the guide:
|
382
|
+
:ref:`howto/sensor:BedrockBatchInferenceSensor`
|
383
|
+
|
384
|
+
:param job_arn: The Amazon Resource Name (ARN) of the batch inference job. (templated)
|
385
|
+
:param success_state: A BedrockBatchInferenceSensor.TargetState; defaults to 'SCHEDULED' (templated)
|
386
|
+
|
387
|
+
:param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
|
388
|
+
module to be installed.
|
389
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
390
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 5)
|
391
|
+
:param max_retries: Number of times before returning the current state (default: 24)
|
392
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
393
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
394
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
395
|
+
empty, then default boto3 configuration would be used (and must be
|
396
|
+
maintained on each worker node).
|
397
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
398
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
399
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
400
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
401
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
402
|
+
"""
|
403
|
+
|
404
|
+
class SuccessState:
|
405
|
+
"""
|
406
|
+
Target state for the BedrockBatchInferenceSensor.
|
407
|
+
|
408
|
+
Bedrock adds batch inference jobs to a queue, and they may take some time to complete.
|
409
|
+
If you want to wait for the job to complete, use TargetState.COMPLETED, but if you only want
|
410
|
+
to wait until the service confirms that the job is in the queue, use TargetState.SCHEDULED.
|
411
|
+
|
412
|
+
The normal successful progression of states is:
|
413
|
+
Submitted > Validating > Scheduled > InProgress > PartiallyCompleted > Completed
|
414
|
+
"""
|
415
|
+
|
416
|
+
SCHEDULED = "scheduled"
|
417
|
+
COMPLETED = "completed"
|
418
|
+
|
419
|
+
INTERMEDIATE_STATES: tuple[str, ...] # Defined in __init__ based on target state
|
420
|
+
FAILURE_STATES: tuple[str, ...] = ("Failed", "Stopped", "PartiallyCompleted", "Expired")
|
421
|
+
SUCCESS_STATES: tuple[str, ...] # Defined in __init__ based on target state
|
422
|
+
FAILURE_MESSAGE = "Bedrock batch inference job sensor failed."
|
423
|
+
INVALID_SUCCESS_STATE_MESSAGE = "success_state must be an instance of TargetState."
|
424
|
+
|
425
|
+
aws_hook_class = BedrockHook
|
426
|
+
|
427
|
+
template_fields: Sequence[str] = aws_template_fields("job_arn", "success_state")
|
428
|
+
|
429
|
+
def __init__(
|
430
|
+
self,
|
431
|
+
*,
|
432
|
+
job_arn: str,
|
433
|
+
success_state: SuccessState | str = SuccessState.SCHEDULED,
|
434
|
+
poke_interval: int = 120,
|
435
|
+
max_retries: int = 75,
|
436
|
+
**kwargs,
|
437
|
+
) -> None:
|
438
|
+
super().__init__(**kwargs)
|
439
|
+
self.poke_interval = poke_interval
|
440
|
+
self.max_retries = max_retries
|
441
|
+
self.job_arn = job_arn
|
442
|
+
self.success_state = success_state
|
443
|
+
|
444
|
+
base_success_states: tuple[str, ...] = ("Completed",)
|
445
|
+
base_intermediate_states: tuple[str, ...] = ("Submitted", "InProgress", "Stopping", "Validating")
|
446
|
+
scheduled_state = ("Scheduled",)
|
447
|
+
self.trigger_class: type[BedrockBaseBatchInferenceTrigger]
|
448
|
+
|
449
|
+
if self.success_state == BedrockBatchInferenceSensor.SuccessState.COMPLETED:
|
450
|
+
intermediate_states = base_intermediate_states + scheduled_state
|
451
|
+
success_states = base_success_states
|
452
|
+
self.trigger_class = BedrockBatchInferenceCompletedTrigger
|
453
|
+
elif self.success_state == BedrockBatchInferenceSensor.SuccessState.SCHEDULED:
|
454
|
+
intermediate_states = base_intermediate_states
|
455
|
+
success_states = base_success_states + scheduled_state
|
456
|
+
self.trigger_class = BedrockBatchInferenceScheduledTrigger
|
457
|
+
else:
|
458
|
+
raise ValueError(
|
459
|
+
"Success states for BedrockBatchInferenceSensor must be set using a BedrockBatchInferenceSensor.SuccessState"
|
460
|
+
)
|
461
|
+
|
462
|
+
BedrockBatchInferenceSensor.INTERMEDIATE_STATES = intermediate_states or base_intermediate_states
|
463
|
+
BedrockBatchInferenceSensor.SUCCESS_STATES = success_states or base_success_states
|
464
|
+
|
465
|
+
def get_state(self) -> str:
|
466
|
+
return self.hook.conn.get_model_invocation_job(jobIdentifier=self.job_arn)["status"]
|
467
|
+
|
468
|
+
def execute(self, context: Context) -> Any:
|
469
|
+
if self.deferrable:
|
470
|
+
self.defer(
|
471
|
+
trigger=self.trigger_class(
|
472
|
+
job_arn=self.job_arn,
|
473
|
+
waiter_delay=int(self.poke_interval),
|
474
|
+
waiter_max_attempts=self.max_retries,
|
475
|
+
aws_conn_id=self.aws_conn_id,
|
476
|
+
),
|
477
|
+
method_name="poke",
|
478
|
+
)
|
479
|
+
else:
|
480
|
+
super().execute(context=context)
|
@@ -95,5 +95,5 @@ class GlacierJobOperationSensor(AwsBaseSensor[GlacierHook]):
|
|
95
95
|
return False
|
96
96
|
else:
|
97
97
|
raise AirflowException(
|
98
|
-
f
|
98
|
+
f"Sensor failed. Job status: {response['Action']}, code status: {response['StatusCode']}"
|
99
99
|
)
|
@@ -150,7 +150,8 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
|
|
150
150
|
external_dag_run_id=self.external_dag_run_id,
|
151
151
|
success_states=self.success_states,
|
152
152
|
failure_states=self.failure_states,
|
153
|
-
|
153
|
+
# somehow the type of poke_interval is derived as float ??
|
154
|
+
waiter_delay=self.poke_interval, # type: ignore[arg-type]
|
154
155
|
waiter_max_attempts=self.max_retries,
|
155
156
|
aws_conn_id=self.aws_conn_id,
|
156
157
|
),
|
@@ -17,36 +17,30 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
from collections.abc import Sequence
|
20
|
-
from functools import cached_property
|
21
20
|
from typing import TYPE_CHECKING
|
22
21
|
|
23
22
|
from airflow.exceptions import AirflowException, AirflowNotFoundException
|
24
23
|
from airflow.providers.amazon.aws.hooks.rds import RdsHook
|
24
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
25
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
25
26
|
from airflow.providers.amazon.aws.utils.rds import RdsDbType
|
26
|
-
from airflow.sensors.base import BaseSensorOperator
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
29
|
from airflow.utils.context import Context
|
30
30
|
|
31
31
|
|
32
|
-
class RdsBaseSensor(
|
32
|
+
class RdsBaseSensor(AwsBaseSensor[RdsHook]):
|
33
33
|
"""Base operator that implements common functions for all sensors."""
|
34
34
|
|
35
|
+
aws_hook_class = RdsHook
|
35
36
|
ui_color = "#ddbb77"
|
36
37
|
ui_fgcolor = "#ffffff"
|
37
38
|
|
38
|
-
def __init__(
|
39
|
-
self, *args, aws_conn_id: str | None = "aws_conn_id", hook_params: dict | None = None, **kwargs
|
40
|
-
):
|
39
|
+
def __init__(self, *args, hook_params: dict | None = None, **kwargs):
|
41
40
|
self.hook_params = hook_params or {}
|
42
|
-
self.aws_conn_id = aws_conn_id
|
43
41
|
self.target_statuses: list[str] = []
|
44
42
|
super().__init__(*args, **kwargs)
|
45
43
|
|
46
|
-
@cached_property
|
47
|
-
def hook(self):
|
48
|
-
return RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params)
|
49
|
-
|
50
44
|
|
51
45
|
class RdsSnapshotExistenceSensor(RdsBaseSensor):
|
52
46
|
"""
|
@@ -59,9 +53,19 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
|
|
59
53
|
:param db_type: Type of the DB - either "instance" or "cluster"
|
60
54
|
:param db_snapshot_identifier: The identifier for the DB snapshot
|
61
55
|
:param target_statuses: Target status of snapshot
|
56
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
57
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
58
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
59
|
+
empty, then default boto3 configuration would be used (and must be
|
60
|
+
maintained on each worker node).
|
61
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
62
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
63
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
64
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
65
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
62
66
|
"""
|
63
67
|
|
64
|
-
template_fields: Sequence[str] = (
|
68
|
+
template_fields: Sequence[str] = aws_template_fields(
|
65
69
|
"db_snapshot_identifier",
|
66
70
|
"target_statuses",
|
67
71
|
)
|
@@ -72,10 +76,9 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
|
|
72
76
|
db_type: str,
|
73
77
|
db_snapshot_identifier: str,
|
74
78
|
target_statuses: list[str] | None = None,
|
75
|
-
aws_conn_id: str | None = "aws_conn_id",
|
76
79
|
**kwargs,
|
77
80
|
):
|
78
|
-
super().__init__(
|
81
|
+
super().__init__(**kwargs)
|
79
82
|
self.db_type = RdsDbType(db_type)
|
80
83
|
self.db_snapshot_identifier = db_snapshot_identifier
|
81
84
|
self.target_statuses = target_statuses or ["available"]
|
@@ -107,7 +110,9 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
|
|
107
110
|
:param error_statuses: Target error status of export task to fail the sensor
|
108
111
|
"""
|
109
112
|
|
110
|
-
template_fields: Sequence[str] = (
|
113
|
+
template_fields: Sequence[str] = aws_template_fields(
|
114
|
+
"export_task_identifier", "target_statuses", "error_statuses"
|
115
|
+
)
|
111
116
|
|
112
117
|
def __init__(
|
113
118
|
self,
|
@@ -115,10 +120,9 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
|
|
115
120
|
export_task_identifier: str,
|
116
121
|
target_statuses: list[str] | None = None,
|
117
122
|
error_statuses: list[str] | None = None,
|
118
|
-
aws_conn_id: str | None = "aws_default",
|
119
123
|
**kwargs,
|
120
124
|
):
|
121
|
-
super().__init__(
|
125
|
+
super().__init__(**kwargs)
|
122
126
|
|
123
127
|
self.export_task_identifier = export_task_identifier
|
124
128
|
self.target_statuses = target_statuses or [
|
@@ -159,7 +163,7 @@ class RdsDbSensor(RdsBaseSensor):
|
|
159
163
|
:param target_statuses: Target status of DB
|
160
164
|
"""
|
161
165
|
|
162
|
-
template_fields: Sequence[str] = (
|
166
|
+
template_fields: Sequence[str] = aws_template_fields(
|
163
167
|
"db_identifier",
|
164
168
|
"db_type",
|
165
169
|
"target_statuses",
|
@@ -171,10 +175,9 @@ class RdsDbSensor(RdsBaseSensor):
|
|
171
175
|
db_identifier: str,
|
172
176
|
db_type: RdsDbType | str = RdsDbType.INSTANCE,
|
173
177
|
target_statuses: list[str] | None = None,
|
174
|
-
aws_conn_id: str | None = "aws_default",
|
175
178
|
**kwargs,
|
176
179
|
):
|
177
|
-
super().__init__(
|
180
|
+
super().__init__(**kwargs)
|
178
181
|
self.db_identifier = db_identifier
|
179
182
|
self.target_statuses = target_statuses or ["available"]
|
180
183
|
self.db_type = db_type
|
@@ -192,7 +192,7 @@ class S3KeySensor(AwsBaseSensor[S3Hook]):
|
|
192
192
|
self.defer(
|
193
193
|
timeout=timedelta(seconds=self.timeout),
|
194
194
|
trigger=S3KeyTrigger(
|
195
|
-
bucket_name=cast(str, self.bucket_name),
|
195
|
+
bucket_name=cast("str", self.bucket_name),
|
196
196
|
bucket_key=self.bucket_key,
|
197
197
|
wildcard_match=self.wildcard_match,
|
198
198
|
aws_conn_id=self.aws_conn_id,
|
@@ -103,7 +103,7 @@ class MongoToS3Operator(BaseOperator):
|
|
103
103
|
if self.is_pipeline:
|
104
104
|
results: CommandCursor[Any] | Cursor = MongoHook(self.mongo_conn_id).aggregate(
|
105
105
|
mongo_collection=self.mongo_collection,
|
106
|
-
aggregate_query=cast(list, self.mongo_query),
|
106
|
+
aggregate_query=cast("list", self.mongo_query),
|
107
107
|
mongo_db=self.mongo_db,
|
108
108
|
allowDiskUse=self.allow_disk_use,
|
109
109
|
)
|
@@ -111,7 +111,7 @@ class MongoToS3Operator(BaseOperator):
|
|
111
111
|
else:
|
112
112
|
results = MongoHook(self.mongo_conn_id).find(
|
113
113
|
mongo_collection=self.mongo_collection,
|
114
|
-
query=cast(dict, self.mongo_query),
|
114
|
+
query=cast("dict", self.mongo_query),
|
115
115
|
projection=self.mongo_projection,
|
116
116
|
mongo_db=self.mongo_db,
|
117
117
|
find_one=False,
|
@@ -223,7 +223,7 @@ class SqlToS3Operator(BaseOperator):
|
|
223
223
|
return
|
224
224
|
for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups:
|
225
225
|
yield (
|
226
|
-
cast(str, group_label),
|
226
|
+
cast("str", group_label),
|
227
227
|
grouped_df.get_group(group_label)
|
228
228
|
.drop(random_column_name, axis=1, errors="ignore")
|
229
229
|
.reset_index(drop=True),
|