apache-airflow-providers-amazon 9.5.0rc2__py3-none-any.whl → 9.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +13 -15
  3. airflow/providers/amazon/aws/auth_manager/router/login.py +4 -2
  4. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +53 -1
  5. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
  6. airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
  7. airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
  8. airflow/providers/amazon/aws/hooks/dms.py +3 -1
  9. airflow/providers/amazon/aws/hooks/glue.py +17 -2
  10. airflow/providers/amazon/aws/hooks/mwaa.py +1 -1
  11. airflow/providers/amazon/aws/hooks/redshift_cluster.py +9 -9
  12. airflow/providers/amazon/aws/hooks/redshift_data.py +1 -2
  13. airflow/providers/amazon/aws/hooks/s3.py +0 -4
  14. airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
  15. airflow/providers/amazon/aws/links/athena.py +1 -2
  16. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
  17. airflow/providers/amazon/aws/log/s3_task_handler.py +123 -86
  18. airflow/providers/amazon/aws/operators/bedrock.py +119 -0
  19. airflow/providers/amazon/aws/operators/ec2.py +1 -1
  20. airflow/providers/amazon/aws/operators/eks.py +3 -3
  21. airflow/providers/amazon/aws/operators/rds.py +83 -18
  22. airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
  23. airflow/providers/amazon/aws/operators/sagemaker.py +3 -5
  24. airflow/providers/amazon/aws/sensors/bedrock.py +110 -0
  25. airflow/providers/amazon/aws/sensors/glacier.py +1 -1
  26. airflow/providers/amazon/aws/sensors/mwaa.py +2 -1
  27. airflow/providers/amazon/aws/sensors/rds.py +23 -20
  28. airflow/providers/amazon/aws/sensors/s3.py +1 -1
  29. airflow/providers/amazon/aws/sensors/step_function.py +2 -1
  30. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
  31. airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
  32. airflow/providers/amazon/aws/triggers/bedrock.py +98 -0
  33. airflow/providers/amazon/aws/utils/waiter_with_logging.py +9 -1
  34. airflow/providers/amazon/aws/waiters/bedrock.json +134 -0
  35. airflow/providers/amazon/get_provider_info.py +0 -124
  36. {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0.dist-info}/METADATA +21 -21
  37. {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0.dist-info}/RECORD +39 -39
  38. {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0.dist-info}/WHEEL +1 -1
  39. {apache_airflow_providers_amazon-9.5.0rc2.dist-info → apache_airflow_providers_amazon-9.6.0.dist-info}/entry_points.txt +0 -0
@@ -338,7 +338,7 @@ class EksCreateClusterOperator(BaseOperator):
338
338
  fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
339
339
  fargate_selectors=self.fargate_selectors,
340
340
  create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
341
- subnets=cast(list[str], self.resources_vpc_config.get("subnetIds")),
341
+ subnets=cast("list[str]", self.resources_vpc_config.get("subnetIds")),
342
342
  )
343
343
 
344
344
  def deferrable_create_cluster_next(self, context: Context, event: dict[str, Any] | None = None) -> None:
@@ -377,7 +377,7 @@ class EksCreateClusterOperator(BaseOperator):
377
377
  fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
378
378
  fargate_selectors=self.fargate_selectors,
379
379
  create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
380
- subnets=cast(list[str], self.resources_vpc_config.get("subnetIds")),
380
+ subnets=cast("list[str]", self.resources_vpc_config.get("subnetIds")),
381
381
  )
382
382
  if self.compute == "fargate":
383
383
  self.defer(
@@ -503,7 +503,7 @@ class EksCreateNodegroupOperator(BaseOperator):
503
503
  nodegroup_subnets_list: list[str] = []
504
504
  if self.nodegroup_subnets != "":
505
505
  try:
506
- nodegroup_subnets_list = cast(list, literal_eval(self.nodegroup_subnets))
506
+ nodegroup_subnets_list = cast("list", literal_eval(self.nodegroup_subnets))
507
507
  except ValueError:
508
508
  self.log.warning(
509
509
  "The nodegroup_subnets should be List or string representing "
@@ -20,19 +20,19 @@ from __future__ import annotations
20
20
  import json
21
21
  from collections.abc import Sequence
22
22
  from datetime import timedelta
23
- from functools import cached_property
24
23
  from typing import TYPE_CHECKING, Any
25
24
 
26
25
  from airflow.configuration import conf
27
26
  from airflow.exceptions import AirflowException
28
- from airflow.models import BaseOperator
29
27
  from airflow.providers.amazon.aws.hooks.rds import RdsHook
28
+ from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
30
29
  from airflow.providers.amazon.aws.triggers.rds import (
31
30
  RdsDbAvailableTrigger,
32
31
  RdsDbDeletedTrigger,
33
32
  RdsDbStoppedTrigger,
34
33
  )
35
34
  from airflow.providers.amazon.aws.utils import validate_execute_complete_event
35
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
36
36
  from airflow.providers.amazon.aws.utils.rds import RdsDbType
37
37
  from airflow.providers.amazon.aws.utils.tags import format_tags
38
38
  from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -44,9 +44,10 @@ if TYPE_CHECKING:
44
44
  from airflow.utils.context import Context
45
45
 
46
46
 
47
- class RdsBaseOperator(BaseOperator):
47
+ class RdsBaseOperator(AwsBaseOperator[RdsHook]):
48
48
  """Base operator that implements common functions for all operators."""
49
49
 
50
+ aws_hook_class = RdsHook
50
51
  ui_color = "#eeaa88"
51
52
  ui_fgcolor = "#ffffff"
52
53
 
@@ -63,10 +64,6 @@ class RdsBaseOperator(BaseOperator):
63
64
 
64
65
  self._await_interval = 60 # seconds
65
66
 
66
- @cached_property
67
- def hook(self) -> RdsHook:
68
- return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
69
-
70
67
  def execute(self, context: Context) -> str:
71
68
  """Different implementations for snapshots, tasks and events."""
72
69
  raise NotImplementedError
@@ -92,9 +89,19 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
92
89
  :param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
93
90
  `USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
94
91
  :param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True)
92
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
93
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
94
+ running Airflow in a distributed manner and aws_conn_id is None or
95
+ empty, then default boto3 configuration would be used (and must be
96
+ maintained on each worker node).
97
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
98
+ :param verify: Whether or not to verify SSL certificates. See:
99
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
100
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
101
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
95
102
  """
96
103
 
97
- template_fields = ("db_snapshot_identifier", "db_identifier", "tags")
104
+ template_fields = aws_template_fields("db_snapshot_identifier", "db_identifier", "tags")
98
105
 
99
106
  def __init__(
100
107
  self,
@@ -167,9 +174,14 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
167
174
  Only when db_type='instance'
168
175
  :param source_region: The ID of the region that contains the snapshot to be copied
169
176
  :param wait_for_completion: If True, waits for snapshot copy to complete. (default: True)
177
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
178
+ :param verify: Whether or not to verify SSL certificates. See:
179
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
180
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
181
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
170
182
  """
171
183
 
172
- template_fields = (
184
+ template_fields = aws_template_fields(
173
185
  "source_db_snapshot_identifier",
174
186
  "target_db_snapshot_identifier",
175
187
  "tags",
@@ -260,9 +272,16 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
260
272
 
261
273
  :param db_type: Type of the DB - either "instance" or "cluster"
262
274
  :param db_snapshot_identifier: The identifier for the DB instance or DB cluster snapshot
275
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
276
+ :param verify: Whether or not to verify SSL certificates. See:
277
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
278
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
279
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
263
280
  """
264
281
 
265
- template_fields = ("db_snapshot_identifier",)
282
+ template_fields = aws_template_fields(
283
+ "db_snapshot_identifier",
284
+ )
266
285
 
267
286
  def __init__(
268
287
  self,
@@ -319,9 +338,14 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
319
338
  :param wait_for_completion: If True, waits for the DB snapshot export to complete. (default: True)
320
339
  :param waiter_interval: The number of seconds to wait before checking the export status. (default: 30)
321
340
  :param waiter_max_attempts: The number of attempts to make before failing. (default: 40)
341
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
342
+ :param verify: Whether or not to verify SSL certificates. See:
343
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
344
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
345
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
322
346
  """
323
347
 
324
- template_fields = (
348
+ template_fields = aws_template_fields(
325
349
  "export_task_identifier",
326
350
  "source_arn",
327
351
  "s3_bucket_name",
@@ -394,9 +418,16 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
394
418
  :param wait_for_completion: If True, waits for DB snapshot export to cancel. (default: True)
395
419
  :param check_interval: The amount of time in seconds to wait between attempts
396
420
  :param max_attempts: The maximum number of attempts to be made
421
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
422
+ :param verify: Whether or not to verify SSL certificates. See:
423
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
424
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
425
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
397
426
  """
398
427
 
399
- template_fields = ("export_task_identifier",)
428
+ template_fields = aws_template_fields(
429
+ "export_task_identifier",
430
+ )
400
431
 
401
432
  def __init__(
402
433
  self,
@@ -450,9 +481,14 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
450
481
  :param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
451
482
  `USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
452
483
  :param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True)
484
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
485
+ :param verify: Whether or not to verify SSL certificates. See:
486
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
487
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
488
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
453
489
  """
454
490
 
455
- template_fields = (
491
+ template_fields = aws_template_fields(
456
492
  "subscription_name",
457
493
  "sns_topic_arn",
458
494
  "source_type",
@@ -513,9 +549,16 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
513
549
  :ref:`howto/operator:RdsDeleteEventSubscriptionOperator`
514
550
 
515
551
  :param subscription_name: The name of the RDS event notification subscription you want to delete
552
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
553
+ :param verify: Whether or not to verify SSL certificates. See:
554
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
555
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
556
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
516
557
  """
517
558
 
518
- template_fields = ("subscription_name",)
559
+ template_fields = aws_template_fields(
560
+ "subscription_name",
561
+ )
519
562
 
520
563
  def __init__(
521
564
  self,
@@ -560,9 +603,16 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
560
603
  :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
561
604
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
562
605
  (default: False)
606
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
607
+ :param verify: Whether or not to verify SSL certificates. See:
608
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
609
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
610
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
563
611
  """
564
612
 
565
- template_fields = ("db_instance_identifier", "db_instance_class", "engine", "rds_kwargs")
613
+ template_fields = aws_template_fields(
614
+ "db_instance_identifier", "db_instance_class", "engine", "rds_kwargs"
615
+ )
566
616
 
567
617
  def __init__(
568
618
  self,
@@ -652,9 +702,14 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
652
702
  :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
653
703
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
654
704
  (default: False)
705
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
706
+ :param verify: Whether or not to verify SSL certificates. See:
707
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
708
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
709
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
655
710
  """
656
711
 
657
- template_fields = ("db_instance_identifier", "rds_kwargs")
712
+ template_fields = aws_template_fields("db_instance_identifier", "rds_kwargs")
658
713
 
659
714
  def __init__(
660
715
  self,
@@ -735,9 +790,14 @@ class RdsStartDbOperator(RdsBaseOperator):
735
790
  :param waiter_max_attempts: The maximum number of attempts to check DB instance state
736
791
  :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
737
792
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
793
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
794
+ :param verify: Whether or not to verify SSL certificates. See:
795
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
796
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
797
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
738
798
  """
739
799
 
740
- template_fields = ("db_identifier", "db_type")
800
+ template_fields = aws_template_fields("db_identifier", "db_type")
741
801
 
742
802
  def __init__(
743
803
  self,
@@ -832,9 +892,14 @@ class RdsStopDbOperator(RdsBaseOperator):
832
892
  :param waiter_max_attempts: The maximum number of attempts to check DB instance state
833
893
  :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
834
894
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
895
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
896
+ :param verify: Whether or not to verify SSL certificates. See:
897
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
898
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
899
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
835
900
  """
836
901
 
837
- template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
902
+ template_fields = aws_template_fields("db_identifier", "db_snapshot_identifier", "db_type")
838
903
 
839
904
  def __init__(
840
905
  self,
@@ -755,11 +755,18 @@ class RedshiftDeleteClusterOperator(BaseOperator):
755
755
  final_cluster_snapshot_identifier=self.final_cluster_snapshot_identifier,
756
756
  )
757
757
  break
758
- except self.redshift_hook.get_conn().exceptions.InvalidClusterStateFault:
758
+ except self.redshift_hook.conn.exceptions.InvalidClusterStateFault:
759
759
  self._attempts -= 1
760
760
 
761
761
  if self._attempts:
762
- self.log.error("Unable to delete cluster. %d attempts remaining.", self._attempts)
762
+ current_state = self.redshift_hook.conn.describe_clusters(
763
+ ClusterIdentifier=self.cluster_identifier
764
+ )["Clusters"][0]["ClusterStatus"]
765
+ self.log.error(
766
+ "Cluster in %s state, unable to delete. %d attempts remaining.",
767
+ current_state,
768
+ self._attempts,
769
+ )
763
770
  time.sleep(self._attempt_interval)
764
771
  else:
765
772
  raise
@@ -785,7 +792,7 @@ class RedshiftDeleteClusterOperator(BaseOperator):
785
792
  )
786
793
 
787
794
  elif self.wait_for_completion:
788
- waiter = self.redshift_hook.get_conn().get_waiter("cluster_deleted")
795
+ waiter = self.redshift_hook.conn.get_waiter("cluster_deleted")
789
796
  waiter.wait(
790
797
  ClusterIdentifier=self.cluster_identifier,
791
798
  WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": self.max_attempts},
@@ -170,7 +170,7 @@ class SageMakerBaseOperator(BaseOperator):
170
170
  timestamp = str(
171
171
  time.time_ns() // 1000000000
172
172
  ) # only keep the relevant datetime (first 10 digits)
173
- name = f"{proposed_name[:max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
173
+ name = f"{proposed_name[: max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
174
174
  self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
175
175
  return name
176
176
 
@@ -178,8 +178,7 @@ class SageMakerBaseOperator(BaseOperator):
178
178
  """Raise exception if resource type is not 'model' or 'job'."""
179
179
  if resource_type not in ("model", "job"):
180
180
  raise AirflowException(
181
- "Argument resource_type accepts only 'model' and 'job'. "
182
- f"Provided value: '{resource_type}'."
181
+ f"Argument resource_type accepts only 'model' and 'job'. Provided value: '{resource_type}'."
183
182
  )
184
183
 
185
184
  def _check_if_job_exists(self, job_name: str, describe_func: Callable[[str], Any]) -> bool:
@@ -559,8 +558,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
559
558
  self.operation = "update"
560
559
  sagemaker_operation = self.hook.update_endpoint
561
560
  self.log.warning(
562
- "cannot create already existing endpoint %s, "
563
- "updating it with the given config instead",
561
+ "cannot create already existing endpoint %s, updating it with the given config instead",
564
562
  endpoint_info["EndpointName"],
565
563
  )
566
564
  if "Tags" in endpoint_info:
@@ -26,6 +26,8 @@ from airflow.exceptions import AirflowException
26
26
  from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook
27
27
  from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
28
28
  from airflow.providers.amazon.aws.triggers.bedrock import (
29
+ BedrockBatchInferenceCompletedTrigger,
30
+ BedrockBatchInferenceScheduledTrigger,
29
31
  BedrockCustomizeModelCompletedTrigger,
30
32
  BedrockIngestionJobTrigger,
31
33
  BedrockKnowledgeBaseActiveTrigger,
@@ -34,6 +36,7 @@ from airflow.providers.amazon.aws.triggers.bedrock import (
34
36
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
35
37
 
36
38
  if TYPE_CHECKING:
39
+ from airflow.providers.amazon.aws.triggers.bedrock import BedrockBaseBatchInferenceTrigger
37
40
  from airflow.utils.context import Context
38
41
 
39
42
 
@@ -368,3 +371,110 @@ class BedrockIngestionJobSensor(BedrockBaseSensor[BedrockAgentHook]):
368
371
  )
369
372
  else:
370
373
  super().execute(context=context)
374
+
375
+
376
+ class BedrockBatchInferenceSensor(BedrockBaseSensor[BedrockHook]):
377
+ """
378
+ Poll the batch inference job status until it reaches a terminal state; fails if creation fails.
379
+
380
+ .. seealso::
381
+ For more information on how to use this sensor, take a look at the guide:
382
+ :ref:`howto/sensor:BedrockBatchInferenceSensor`
383
+
384
+ :param job_arn: The Amazon Resource Name (ARN) of the batch inference job. (templated)
385
+ :param success_state: A BedrockBatchInferenceSensor.TargetState; defaults to 'SCHEDULED' (templated)
386
+
387
+ :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
388
+ module to be installed.
389
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
390
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 5)
391
+ :param max_retries: Number of times before returning the current state (default: 24)
392
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
393
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
394
+ running Airflow in a distributed manner and aws_conn_id is None or
395
+ empty, then default boto3 configuration would be used (and must be
396
+ maintained on each worker node).
397
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
398
+ :param verify: Whether or not to verify SSL certificates. See:
399
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
400
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
401
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
402
+ """
403
+
404
+ class SuccessState:
405
+ """
406
+ Target state for the BedrockBatchInferenceSensor.
407
+
408
+ Bedrock adds batch inference jobs to a queue, and they may take some time to complete.
409
+ If you want to wait for the job to complete, use TargetState.COMPLETED, but if you only want
410
+ to wait until the service confirms that the job is in the queue, use TargetState.SCHEDULED.
411
+
412
+ The normal successful progression of states is:
413
+ Submitted > Validating > Scheduled > InProgress > PartiallyCompleted > Completed
414
+ """
415
+
416
+ SCHEDULED = "scheduled"
417
+ COMPLETED = "completed"
418
+
419
+ INTERMEDIATE_STATES: tuple[str, ...] # Defined in __init__ based on target state
420
+ FAILURE_STATES: tuple[str, ...] = ("Failed", "Stopped", "PartiallyCompleted", "Expired")
421
+ SUCCESS_STATES: tuple[str, ...] # Defined in __init__ based on target state
422
+ FAILURE_MESSAGE = "Bedrock batch inference job sensor failed."
423
+ INVALID_SUCCESS_STATE_MESSAGE = "success_state must be an instance of TargetState."
424
+
425
+ aws_hook_class = BedrockHook
426
+
427
+ template_fields: Sequence[str] = aws_template_fields("job_arn", "success_state")
428
+
429
+ def __init__(
430
+ self,
431
+ *,
432
+ job_arn: str,
433
+ success_state: SuccessState | str = SuccessState.SCHEDULED,
434
+ poke_interval: int = 120,
435
+ max_retries: int = 75,
436
+ **kwargs,
437
+ ) -> None:
438
+ super().__init__(**kwargs)
439
+ self.poke_interval = poke_interval
440
+ self.max_retries = max_retries
441
+ self.job_arn = job_arn
442
+ self.success_state = success_state
443
+
444
+ base_success_states: tuple[str, ...] = ("Completed",)
445
+ base_intermediate_states: tuple[str, ...] = ("Submitted", "InProgress", "Stopping", "Validating")
446
+ scheduled_state = ("Scheduled",)
447
+ self.trigger_class: type[BedrockBaseBatchInferenceTrigger]
448
+
449
+ if self.success_state == BedrockBatchInferenceSensor.SuccessState.COMPLETED:
450
+ intermediate_states = base_intermediate_states + scheduled_state
451
+ success_states = base_success_states
452
+ self.trigger_class = BedrockBatchInferenceCompletedTrigger
453
+ elif self.success_state == BedrockBatchInferenceSensor.SuccessState.SCHEDULED:
454
+ intermediate_states = base_intermediate_states
455
+ success_states = base_success_states + scheduled_state
456
+ self.trigger_class = BedrockBatchInferenceScheduledTrigger
457
+ else:
458
+ raise ValueError(
459
+ "Success states for BedrockBatchInferenceSensor must be set using a BedrockBatchInferenceSensor.SuccessState"
460
+ )
461
+
462
+ BedrockBatchInferenceSensor.INTERMEDIATE_STATES = intermediate_states or base_intermediate_states
463
+ BedrockBatchInferenceSensor.SUCCESS_STATES = success_states or base_success_states
464
+
465
+ def get_state(self) -> str:
466
+ return self.hook.conn.get_model_invocation_job(jobIdentifier=self.job_arn)["status"]
467
+
468
+ def execute(self, context: Context) -> Any:
469
+ if self.deferrable:
470
+ self.defer(
471
+ trigger=self.trigger_class(
472
+ job_arn=self.job_arn,
473
+ waiter_delay=int(self.poke_interval),
474
+ waiter_max_attempts=self.max_retries,
475
+ aws_conn_id=self.aws_conn_id,
476
+ ),
477
+ method_name="poke",
478
+ )
479
+ else:
480
+ super().execute(context=context)
@@ -95,5 +95,5 @@ class GlacierJobOperationSensor(AwsBaseSensor[GlacierHook]):
95
95
  return False
96
96
  else:
97
97
  raise AirflowException(
98
- f'Sensor failed. Job status: {response["Action"]}, code status: {response["StatusCode"]}'
98
+ f"Sensor failed. Job status: {response['Action']}, code status: {response['StatusCode']}"
99
99
  )
@@ -150,7 +150,8 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
150
150
  external_dag_run_id=self.external_dag_run_id,
151
151
  success_states=self.success_states,
152
152
  failure_states=self.failure_states,
153
- waiter_delay=self.poke_interval,
153
+ # somehow the type of poke_interval is derived as float ??
154
+ waiter_delay=self.poke_interval, # type: ignore[arg-type]
154
155
  waiter_max_attempts=self.max_retries,
155
156
  aws_conn_id=self.aws_conn_id,
156
157
  ),
@@ -17,36 +17,30 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  from collections.abc import Sequence
20
- from functools import cached_property
21
20
  from typing import TYPE_CHECKING
22
21
 
23
22
  from airflow.exceptions import AirflowException, AirflowNotFoundException
24
23
  from airflow.providers.amazon.aws.hooks.rds import RdsHook
24
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
25
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
25
26
  from airflow.providers.amazon.aws.utils.rds import RdsDbType
26
- from airflow.sensors.base import BaseSensorOperator
27
27
 
28
28
  if TYPE_CHECKING:
29
29
  from airflow.utils.context import Context
30
30
 
31
31
 
32
- class RdsBaseSensor(BaseSensorOperator):
32
+ class RdsBaseSensor(AwsBaseSensor[RdsHook]):
33
33
  """Base operator that implements common functions for all sensors."""
34
34
 
35
+ aws_hook_class = RdsHook
35
36
  ui_color = "#ddbb77"
36
37
  ui_fgcolor = "#ffffff"
37
38
 
38
- def __init__(
39
- self, *args, aws_conn_id: str | None = "aws_conn_id", hook_params: dict | None = None, **kwargs
40
- ):
39
+ def __init__(self, *args, hook_params: dict | None = None, **kwargs):
41
40
  self.hook_params = hook_params or {}
42
- self.aws_conn_id = aws_conn_id
43
41
  self.target_statuses: list[str] = []
44
42
  super().__init__(*args, **kwargs)
45
43
 
46
- @cached_property
47
- def hook(self):
48
- return RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params)
49
-
50
44
 
51
45
  class RdsSnapshotExistenceSensor(RdsBaseSensor):
52
46
  """
@@ -59,9 +53,19 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
59
53
  :param db_type: Type of the DB - either "instance" or "cluster"
60
54
  :param db_snapshot_identifier: The identifier for the DB snapshot
61
55
  :param target_statuses: Target status of snapshot
56
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
57
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
58
+ running Airflow in a distributed manner and aws_conn_id is None or
59
+ empty, then default boto3 configuration would be used (and must be
60
+ maintained on each worker node).
61
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
62
+ :param verify: Whether or not to verify SSL certificates. See:
63
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
64
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
65
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
62
66
  """
63
67
 
64
- template_fields: Sequence[str] = (
68
+ template_fields: Sequence[str] = aws_template_fields(
65
69
  "db_snapshot_identifier",
66
70
  "target_statuses",
67
71
  )
@@ -72,10 +76,9 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
72
76
  db_type: str,
73
77
  db_snapshot_identifier: str,
74
78
  target_statuses: list[str] | None = None,
75
- aws_conn_id: str | None = "aws_conn_id",
76
79
  **kwargs,
77
80
  ):
78
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
81
+ super().__init__(**kwargs)
79
82
  self.db_type = RdsDbType(db_type)
80
83
  self.db_snapshot_identifier = db_snapshot_identifier
81
84
  self.target_statuses = target_statuses or ["available"]
@@ -107,7 +110,9 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
107
110
  :param error_statuses: Target error status of export task to fail the sensor
108
111
  """
109
112
 
110
- template_fields: Sequence[str] = ("export_task_identifier", "target_statuses", "error_statuses")
113
+ template_fields: Sequence[str] = aws_template_fields(
114
+ "export_task_identifier", "target_statuses", "error_statuses"
115
+ )
111
116
 
112
117
  def __init__(
113
118
  self,
@@ -115,10 +120,9 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
115
120
  export_task_identifier: str,
116
121
  target_statuses: list[str] | None = None,
117
122
  error_statuses: list[str] | None = None,
118
- aws_conn_id: str | None = "aws_default",
119
123
  **kwargs,
120
124
  ):
121
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
125
+ super().__init__(**kwargs)
122
126
 
123
127
  self.export_task_identifier = export_task_identifier
124
128
  self.target_statuses = target_statuses or [
@@ -159,7 +163,7 @@ class RdsDbSensor(RdsBaseSensor):
159
163
  :param target_statuses: Target status of DB
160
164
  """
161
165
 
162
- template_fields: Sequence[str] = (
166
+ template_fields: Sequence[str] = aws_template_fields(
163
167
  "db_identifier",
164
168
  "db_type",
165
169
  "target_statuses",
@@ -171,10 +175,9 @@ class RdsDbSensor(RdsBaseSensor):
171
175
  db_identifier: str,
172
176
  db_type: RdsDbType | str = RdsDbType.INSTANCE,
173
177
  target_statuses: list[str] | None = None,
174
- aws_conn_id: str | None = "aws_default",
175
178
  **kwargs,
176
179
  ):
177
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
180
+ super().__init__(**kwargs)
178
181
  self.db_identifier = db_identifier
179
182
  self.target_statuses = target_statuses or ["available"]
180
183
  self.db_type = db_type
@@ -192,7 +192,7 @@ class S3KeySensor(AwsBaseSensor[S3Hook]):
192
192
  self.defer(
193
193
  timeout=timedelta(seconds=self.timeout),
194
194
  trigger=S3KeyTrigger(
195
- bucket_name=cast(str, self.bucket_name),
195
+ bucket_name=cast("str", self.bucket_name),
196
196
  bucket_key=self.bucket_key,
197
197
  wildcard_match=self.wildcard_match,
198
198
  aws_conn_id=self.aws_conn_id,
@@ -81,5 +81,6 @@ class StepFunctionExecutionSensor(AwsBaseSensor[StepFunctionHook]):
81
81
  return False
82
82
 
83
83
  self.log.info("Doing xcom_push of output")
84
- self.xcom_push(context, "output", output)
84
+
85
+ context["ti"].xcom_push(key="output", value=output)
85
86
  return True
@@ -103,7 +103,7 @@ class MongoToS3Operator(BaseOperator):
103
103
  if self.is_pipeline:
104
104
  results: CommandCursor[Any] | Cursor = MongoHook(self.mongo_conn_id).aggregate(
105
105
  mongo_collection=self.mongo_collection,
106
- aggregate_query=cast(list, self.mongo_query),
106
+ aggregate_query=cast("list", self.mongo_query),
107
107
  mongo_db=self.mongo_db,
108
108
  allowDiskUse=self.allow_disk_use,
109
109
  )
@@ -111,7 +111,7 @@ class MongoToS3Operator(BaseOperator):
111
111
  else:
112
112
  results = MongoHook(self.mongo_conn_id).find(
113
113
  mongo_collection=self.mongo_collection,
114
- query=cast(dict, self.mongo_query),
114
+ query=cast("dict", self.mongo_query),
115
115
  projection=self.mongo_projection,
116
116
  mongo_db=self.mongo_db,
117
117
  find_one=False,
@@ -223,7 +223,7 @@ class SqlToS3Operator(BaseOperator):
223
223
  return
224
224
  for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups:
225
225
  yield (
226
- cast(str, group_label),
226
+ cast("str", group_label),
227
227
  grouped_df.get_group(group_label)
228
228
  .drop(random_column_name, axis=1, errors="ignore")
229
229
  .reset_index(drop=True),