apache-airflow-providers-amazon 8.25.0rc1__py3-none-any.whl → 8.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +10 -0
  3. airflow/providers/amazon/aws/executors/batch/batch_executor.py +19 -16
  4. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +22 -15
  5. airflow/providers/amazon/aws/hooks/athena.py +18 -9
  6. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -1
  7. airflow/providers/amazon/aws/hooks/base_aws.py +34 -10
  8. airflow/providers/amazon/aws/hooks/chime.py +2 -1
  9. airflow/providers/amazon/aws/hooks/datasync.py +6 -3
  10. airflow/providers/amazon/aws/hooks/ecr.py +2 -1
  11. airflow/providers/amazon/aws/hooks/ecs.py +12 -6
  12. airflow/providers/amazon/aws/hooks/glacier.py +8 -4
  13. airflow/providers/amazon/aws/hooks/kinesis.py +2 -1
  14. airflow/providers/amazon/aws/hooks/logs.py +4 -2
  15. airflow/providers/amazon/aws/hooks/redshift_cluster.py +24 -12
  16. airflow/providers/amazon/aws/hooks/redshift_data.py +4 -2
  17. airflow/providers/amazon/aws/hooks/redshift_sql.py +6 -3
  18. airflow/providers/amazon/aws/hooks/s3.py +70 -53
  19. airflow/providers/amazon/aws/hooks/sagemaker.py +82 -41
  20. airflow/providers/amazon/aws/hooks/secrets_manager.py +6 -3
  21. airflow/providers/amazon/aws/hooks/sts.py +2 -1
  22. airflow/providers/amazon/aws/operators/athena.py +21 -8
  23. airflow/providers/amazon/aws/operators/batch.py +12 -6
  24. airflow/providers/amazon/aws/operators/datasync.py +2 -1
  25. airflow/providers/amazon/aws/operators/ecs.py +1 -0
  26. airflow/providers/amazon/aws/operators/emr.py +6 -86
  27. airflow/providers/amazon/aws/operators/glue.py +4 -2
  28. airflow/providers/amazon/aws/operators/glue_crawler.py +22 -19
  29. airflow/providers/amazon/aws/operators/neptune.py +2 -1
  30. airflow/providers/amazon/aws/operators/redshift_cluster.py +2 -1
  31. airflow/providers/amazon/aws/operators/s3.py +11 -1
  32. airflow/providers/amazon/aws/operators/sagemaker.py +8 -10
  33. airflow/providers/amazon/aws/sensors/base_aws.py +2 -1
  34. airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +25 -17
  35. airflow/providers/amazon/aws/sensors/glue_crawler.py +16 -12
  36. airflow/providers/amazon/aws/sensors/s3.py +11 -5
  37. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +6 -3
  38. airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +2 -1
  39. airflow/providers/amazon/aws/transfers/s3_to_sql.py +2 -1
  40. airflow/providers/amazon/aws/triggers/ecs.py +3 -1
  41. airflow/providers/amazon/aws/triggers/glue.py +15 -3
  42. airflow/providers/amazon/aws/triggers/glue_crawler.py +8 -1
  43. airflow/providers/amazon/aws/utils/connection_wrapper.py +10 -5
  44. airflow/providers/amazon/aws/utils/mixins.py +2 -1
  45. airflow/providers/amazon/aws/utils/redshift.py +2 -1
  46. airflow/providers/amazon/get_provider_info.py +2 -1
  47. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/METADATA +9 -9
  48. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/RECORD +50 -50
  49. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/WHEEL +0 -0
  50. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/entry_points.txt +0 -0
@@ -27,7 +27,6 @@ from uuid import uuid4
27
27
  from airflow.configuration import conf
28
28
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
29
29
  from airflow.models import BaseOperator
30
- from airflow.models.mappedoperator import MappedOperator
31
30
  from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
32
31
  from airflow.providers.amazon.aws.links.emr import (
33
32
  EmrClusterLink,
@@ -1259,91 +1258,12 @@ class EmrServerlessStartJobOperator(BaseOperator):
1259
1258
  "configuration_overrides": "json",
1260
1259
  }
1261
1260
 
1262
- @property
1263
- def operator_extra_links(self):
1264
- """
1265
- Dynamically add extra links depending on the job type and if they're enabled.
1266
-
1267
- If S3 or CloudWatch monitoring configurations exist, add links directly to the relevant consoles.
1268
- Only add dashboard links if they're explicitly enabled. These are one-time links that any user
1269
- can access, but expire on first click or one hour, whichever comes first.
1270
- """
1271
- op_extra_links = []
1272
-
1273
- if isinstance(self, MappedOperator):
1274
- operator_class = self.operator_class
1275
- enable_application_ui_links = self.partial_kwargs.get(
1276
- "enable_application_ui_links"
1277
- ) or self.expand_input.value.get("enable_application_ui_links")
1278
- job_driver = self.partial_kwargs.get("job_driver", {}) or self.expand_input.value.get(
1279
- "job_driver", {}
1280
- )
1281
- configuration_overrides = self.partial_kwargs.get(
1282
- "configuration_overrides"
1283
- ) or self.expand_input.value.get("configuration_overrides")
1284
-
1285
- # Configuration overrides can either be a list or a dictionary, depending on whether it's passed in as partial or expand.
1286
- if isinstance(configuration_overrides, list):
1287
- if any(
1288
- [
1289
- operator_class.is_monitoring_in_job_override(
1290
- self=operator_class,
1291
- config_key="s3MonitoringConfiguration",
1292
- job_override=job_override,
1293
- )
1294
- for job_override in configuration_overrides
1295
- ]
1296
- ):
1297
- op_extra_links.extend([EmrServerlessS3LogsLink()])
1298
- if any(
1299
- [
1300
- operator_class.is_monitoring_in_job_override(
1301
- self=operator_class,
1302
- config_key="cloudWatchLoggingConfiguration",
1303
- job_override=job_override,
1304
- )
1305
- for job_override in configuration_overrides
1306
- ]
1307
- ):
1308
- op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
1309
- else:
1310
- if operator_class.is_monitoring_in_job_override(
1311
- self=operator_class,
1312
- config_key="s3MonitoringConfiguration",
1313
- job_override=configuration_overrides,
1314
- ):
1315
- op_extra_links.extend([EmrServerlessS3LogsLink()])
1316
- if operator_class.is_monitoring_in_job_override(
1317
- self=operator_class,
1318
- config_key="cloudWatchLoggingConfiguration",
1319
- job_override=configuration_overrides,
1320
- ):
1321
- op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
1322
-
1323
- else:
1324
- operator_class = self
1325
- enable_application_ui_links = self.enable_application_ui_links
1326
- configuration_overrides = self.configuration_overrides
1327
- job_driver = self.job_driver
1328
-
1329
- if operator_class.is_monitoring_in_job_override(
1330
- "s3MonitoringConfiguration", configuration_overrides
1331
- ):
1332
- op_extra_links.extend([EmrServerlessS3LogsLink()])
1333
- if operator_class.is_monitoring_in_job_override(
1334
- "cloudWatchLoggingConfiguration", configuration_overrides
1335
- ):
1336
- op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
1337
-
1338
- if enable_application_ui_links:
1339
- op_extra_links.extend([EmrServerlessDashboardLink()])
1340
- if isinstance(job_driver, list):
1341
- if any("sparkSubmit" in ind_job_driver for ind_job_driver in job_driver):
1342
- op_extra_links.extend([EmrServerlessLogsLink()])
1343
- elif "sparkSubmit" in job_driver:
1344
- op_extra_links.extend([EmrServerlessLogsLink()])
1345
-
1346
- return tuple(op_extra_links)
1261
+ operator_extra_links = (
1262
+ EmrServerlessS3LogsLink(),
1263
+ EmrServerlessCloudWatchLogsLink(),
1264
+ EmrServerlessDashboardLink(),
1265
+ EmrServerlessLogsLink(),
1266
+ )
1347
1267
 
1348
1268
  def __init__(
1349
1269
  self,
@@ -43,7 +43,8 @@ if TYPE_CHECKING:
43
43
 
44
44
 
45
45
  class GlueJobOperator(BaseOperator):
46
- """Create an AWS Glue Job.
46
+ """
47
+ Create an AWS Glue Job.
47
48
 
48
49
  AWS Glue is a serverless Spark ETL service for running Spark Jobs on the AWS
49
50
  cloud. Language support: Python and Scala.
@@ -179,7 +180,8 @@ class GlueJobOperator(BaseOperator):
179
180
  )
180
181
 
181
182
  def execute(self, context: Context):
182
- """Execute AWS Glue Job from Airflow.
183
+ """
184
+ Execute AWS Glue Job from Airflow.
183
185
 
184
186
  :return: the current Glue job ID.
185
187
  """
@@ -17,22 +17,22 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- from functools import cached_property
21
20
  from typing import TYPE_CHECKING, Any, Sequence
22
21
 
23
22
  from airflow.configuration import conf
24
23
  from airflow.exceptions import AirflowException
24
+ from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
25
25
  from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger
26
26
  from airflow.providers.amazon.aws.utils import validate_execute_complete_event
27
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
27
28
 
28
29
  if TYPE_CHECKING:
29
30
  from airflow.utils.context import Context
30
31
 
31
- from airflow.models import BaseOperator
32
32
  from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
33
33
 
34
34
 
35
- class GlueCrawlerOperator(BaseOperator):
35
+ class GlueCrawlerOperator(AwsBaseOperator[GlueCrawlerHook]):
36
36
  """
37
37
  Creates, updates and triggers an AWS Glue Crawler.
38
38
 
@@ -45,45 +45,45 @@ class GlueCrawlerOperator(BaseOperator):
45
45
  :ref:`howto/operator:GlueCrawlerOperator`
46
46
 
47
47
  :param config: Configurations for the AWS Glue crawler
48
- :param aws_conn_id: The Airflow connection used for AWS credentials.
49
- If this is None or empty then the default boto3 behaviour is used. If
50
- running Airflow in a distributed manner and aws_conn_id is None or
51
- empty, then default boto3 configuration would be used (and must be
52
- maintained on each worker node).
53
48
  :param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status
54
49
  :param wait_for_completion: Whether to wait for crawl execution completion. (default: True)
55
50
  :param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
56
51
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
57
52
  (default: False)
53
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
54
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
55
+ running Airflow in a distributed manner and aws_conn_id is None or
56
+ empty, then default boto3 configuration would be used (and must be
57
+ maintained on each worker node).
58
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
59
+ :param verify: Whether or not to verify SSL certificates. See:
60
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
61
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
62
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
58
63
  """
59
64
 
60
- template_fields: Sequence[str] = ("config",)
65
+ aws_hook_class = GlueCrawlerHook
66
+
67
+ template_fields: Sequence[str] = aws_template_fields(
68
+ "config",
69
+ )
61
70
  ui_color = "#ededed"
62
71
 
63
72
  def __init__(
64
73
  self,
65
74
  config,
66
- aws_conn_id="aws_default",
67
- region_name: str | None = None,
68
75
  poll_interval: int = 5,
69
76
  wait_for_completion: bool = True,
70
77
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
71
78
  **kwargs,
72
79
  ):
73
80
  super().__init__(**kwargs)
74
- self.aws_conn_id = aws_conn_id
75
81
  self.poll_interval = poll_interval
76
82
  self.wait_for_completion = wait_for_completion
77
83
  self.deferrable = deferrable
78
- self.region_name = region_name
79
84
  self.config = config
80
85
 
81
- @cached_property
82
- def hook(self) -> GlueCrawlerHook:
83
- """Create and return a GlueCrawlerHook."""
84
- return GlueCrawlerHook(self.aws_conn_id, region_name=self.region_name)
85
-
86
- def execute(self, context: Context):
86
+ def execute(self, context: Context) -> str:
87
87
  """
88
88
  Execute AWS Glue Crawler from Airflow.
89
89
 
@@ -103,6 +103,9 @@ class GlueCrawlerOperator(BaseOperator):
103
103
  crawler_name=crawler_name,
104
104
  waiter_delay=self.poll_interval,
105
105
  aws_conn_id=self.aws_conn_id,
106
+ region_name=self.region_name,
107
+ verify=self.verify,
108
+ botocore_config=self.botocore_config,
106
109
  ),
107
110
  method_name="execute_complete",
108
111
  )
@@ -81,7 +81,8 @@ def handle_waitable_exception(
81
81
 
82
82
 
83
83
  class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
84
- """Starts an Amazon Neptune DB cluster.
84
+ """
85
+ Starts an Amazon Neptune DB cluster.
85
86
 
86
87
  Amazon Neptune Database is a serverless graph database designed for superior scalability
87
88
  and availability. Neptune Database provides built-in security, continuous backups, and
@@ -38,7 +38,8 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RedshiftCreateClusterOperator(BaseOperator):
41
- """Creates a new cluster with the specified parameters.
41
+ """
42
+ Creates a new cluster with the specified parameters.
42
43
 
43
44
  .. seealso::
44
45
  For more information on how to use this operator, take a look at the guide:
@@ -610,6 +610,7 @@ class S3FileTransformOperator(BaseOperator):
610
610
  :param dest_s3_key: The key to be written from S3. (templated)
611
611
  :param transform_script: location of the executable transformation script
612
612
  :param select_expression: S3 Select expression
613
+ :param select_expr_serialization_config: A dictionary that contains input and output serialization configurations for S3 Select.
613
614
  :param script_args: arguments for transformation script (templated)
614
615
  :param source_aws_conn_id: source s3 connection
615
616
  :param source_verify: Whether or not to verify SSL certificates for S3 connection.
@@ -641,6 +642,7 @@ class S3FileTransformOperator(BaseOperator):
641
642
  dest_s3_key: str,
642
643
  transform_script: str | None = None,
643
644
  select_expression=None,
645
+ select_expr_serialization_config: dict[str, dict[str, dict]] | None = None,
644
646
  script_args: Sequence[str] | None = None,
645
647
  source_aws_conn_id: str | None = "aws_default",
646
648
  source_verify: bool | str | None = None,
@@ -659,6 +661,7 @@ class S3FileTransformOperator(BaseOperator):
659
661
  self.replace = replace
660
662
  self.transform_script = transform_script
661
663
  self.select_expression = select_expression
664
+ self.select_expr_serialization_config = select_expr_serialization_config or {}
662
665
  self.script_args = script_args or []
663
666
  self.output_encoding = sys.getdefaultencoding()
664
667
 
@@ -678,7 +681,14 @@ class S3FileTransformOperator(BaseOperator):
678
681
  self.log.info("Dumping S3 file %s contents to local file %s", self.source_s3_key, f_source.name)
679
682
 
680
683
  if self.select_expression is not None:
681
- content = source_s3.select_key(key=self.source_s3_key, expression=self.select_expression)
684
+ input_serialization = self.select_expr_serialization_config.get("input_serialization")
685
+ output_serialization = self.select_expr_serialization_config.get("output_serialization")
686
+ content = source_s3.select_key(
687
+ key=self.source_s3_key,
688
+ expression=self.select_expression,
689
+ input_serialization=input_serialization,
690
+ output_serialization=output_serialization,
691
+ )
682
692
  f_source.write(content.encode("utf-8"))
683
693
  else:
684
694
  source_s3_key_object.download_fileobj(Fileobj=f_source)
@@ -60,7 +60,8 @@ def serialize(result: dict) -> dict:
60
60
 
61
61
 
62
62
  class SageMakerBaseOperator(BaseOperator):
63
- """This is the base operator for all SageMaker operators.
63
+ """
64
+ This is the base operator for all SageMaker operators.
64
65
 
65
66
  :param config: The configuration necessary to start a training job (templated)
66
67
  """
@@ -360,7 +361,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
360
361
  raise AirflowException(f"Error while running job: {event}")
361
362
 
362
363
  self.log.info(event["message"])
363
- self.serialized_job = serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
364
+ self.serialized_job = serialize(self.hook.describe_processing_job(event["job_name"]))
364
365
  self.log.info("%s completed successfully.", self.task_id)
365
366
  return {"Processing": self.serialized_job}
366
367
 
@@ -611,12 +612,11 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
611
612
 
612
613
  if event["status"] != "success":
613
614
  raise AirflowException(f"Error while running job: {event}")
614
- endpoint_info = self.config.get("Endpoint", self.config)
615
+
616
+ response = self.hook.describe_endpoint(event["job_name"])
615
617
  return {
616
- "EndpointConfig": serialize(
617
- self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
618
- ),
619
- "Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
618
+ "EndpointConfig": serialize(self.hook.describe_endpoint_config(response["EndpointConfigName"])),
619
+ "Endpoint": serialize(self.hook.describe_endpoint(response["EndpointName"])),
620
620
  }
621
621
 
622
622
 
@@ -996,9 +996,7 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
996
996
 
997
997
  if event["status"] != "success":
998
998
  raise AirflowException(f"Error while running job: {event}")
999
- return {
1000
- "Tuning": serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]))
1001
- }
999
+ return {"Tuning": serialize(self.hook.describe_tuning_job(event["job_name"]))}
1002
1000
 
1003
1001
 
1004
1002
  class SageMakerModelOperator(SageMakerBaseOperator):
@@ -30,7 +30,8 @@ from airflow.utils.types import NOTSET, ArgNotSet
30
30
 
31
31
 
32
32
  class AwsBaseSensor(BaseSensorOperator, AwsBaseHookMixin[AwsHookType]):
33
- """Base AWS (Amazon) Sensor Class for build sensors in top of AWS Hooks.
33
+ """
34
+ Base AWS (Amazon) Sensor Class for build sensors in top of AWS Hooks.
34
35
 
35
36
  .. warning::
36
37
  Only for internal usage, this class might be changed, renamed or removed in the future
@@ -18,7 +18,6 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from datetime import timedelta
21
- from functools import cached_property
22
21
  from typing import TYPE_CHECKING, Any, Sequence
23
22
 
24
23
  from deprecated import deprecated
@@ -26,18 +25,23 @@ from deprecated import deprecated
26
25
  from airflow.configuration import conf
27
26
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
28
27
  from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
28
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
29
29
  from airflow.providers.amazon.aws.triggers.glue import GlueCatalogPartitionTrigger
30
30
  from airflow.providers.amazon.aws.utils import validate_execute_complete_event
31
- from airflow.sensors.base import BaseSensorOperator
31
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
32
32
 
33
33
  if TYPE_CHECKING:
34
34
  from airflow.utils.context import Context
35
35
 
36
36
 
37
- class GlueCatalogPartitionSensor(BaseSensorOperator):
37
+ class GlueCatalogPartitionSensor(AwsBaseSensor[GlueCatalogHook]):
38
38
  """
39
39
  Waits for a partition to show up in AWS Glue Catalog.
40
40
 
41
+ .. seealso::
42
+ For more information on how to use this sensor, take a look at the guide:
43
+ :ref:`howto/sensor:GlueCatalogPartitionSensor`
44
+
41
45
  :param table_name: The name of the table to wait for, supports the dot
42
46
  notation (my_database.my_table)
43
47
  :param expression: The partition clause to wait for. This is passed as
@@ -46,19 +50,27 @@ class GlueCatalogPartitionSensor(BaseSensorOperator):
46
50
  AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
47
51
  See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
48
52
  #aws-glue-api-catalog-partitions-GetPartitions
49
- :param aws_conn_id: ID of the Airflow connection where
50
- credentials and extra configuration are stored
51
- :param region_name: Optional aws region name (example: us-east-1). Uses region from connection
52
- if not specified.
53
53
  :param database_name: The name of the catalog database where the partitions reside.
54
54
  :param poke_interval: Time in seconds that the job should wait in
55
55
  between each tries
56
56
  :param deferrable: If true, then the sensor will wait asynchronously for the partition to
57
57
  show up in the AWS Glue Catalog.
58
58
  (default: False, but can be overridden in config file by setting default_deferrable to True)
59
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
60
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
61
+ running Airflow in a distributed manner and aws_conn_id is None or
62
+ empty, then default boto3 configuration would be used (and must be
63
+ maintained on each worker node).
64
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
65
+ :param verify: Whether or not to verify SSL certificates. See:
66
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
67
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
68
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
59
69
  """
60
70
 
61
- template_fields: Sequence[str] = (
71
+ aws_hook_class = GlueCatalogHook
72
+
73
+ template_fields: Sequence[str] = aws_template_fields(
62
74
  "database_name",
63
75
  "table_name",
64
76
  "expression",
@@ -70,19 +82,16 @@ class GlueCatalogPartitionSensor(BaseSensorOperator):
70
82
  *,
71
83
  table_name: str,
72
84
  expression: str = "ds='{{ ds }}'",
73
- aws_conn_id: str | None = "aws_default",
74
- region_name: str | None = None,
75
85
  database_name: str = "default",
76
86
  poke_interval: int = 60 * 3,
77
87
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
78
88
  **kwargs,
79
89
  ):
80
- super().__init__(poke_interval=poke_interval, **kwargs)
81
- self.aws_conn_id = aws_conn_id
82
- self.region_name = region_name
90
+ super().__init__(**kwargs)
83
91
  self.table_name = table_name
84
92
  self.expression = expression
85
93
  self.database_name = database_name
94
+ self.poke_interval = poke_interval
86
95
  self.deferrable = deferrable
87
96
 
88
97
  def execute(self, context: Context) -> Any:
@@ -93,7 +102,10 @@ class GlueCatalogPartitionSensor(BaseSensorOperator):
93
102
  table_name=self.table_name,
94
103
  expression=self.expression,
95
104
  aws_conn_id=self.aws_conn_id,
105
+ region_name=self.region_name,
96
106
  waiter_delay=int(self.poke_interval),
107
+ verify=self.verify,
108
+ botocore_config=self.botocore_config,
97
109
  ),
98
110
  method_name="execute_complete",
99
111
  timeout=timedelta(seconds=self.timeout),
@@ -126,7 +138,3 @@ class GlueCatalogPartitionSensor(BaseSensorOperator):
126
138
  def get_hook(self) -> GlueCatalogHook:
127
139
  """Get the GlueCatalogHook."""
128
140
  return self.hook
129
-
130
- @cached_property
131
- def hook(self) -> GlueCatalogHook:
132
- return GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
@@ -17,20 +17,20 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- from functools import cached_property
21
20
  from typing import TYPE_CHECKING, Sequence
22
21
 
23
22
  from deprecated import deprecated
24
23
 
25
24
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
26
25
  from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
27
- from airflow.sensors.base import BaseSensorOperator
26
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
27
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
28
28
 
29
29
  if TYPE_CHECKING:
30
30
  from airflow.utils.context import Context
31
31
 
32
32
 
33
- class GlueCrawlerSensor(BaseSensorOperator):
33
+ class GlueCrawlerSensor(AwsBaseSensor[GlueCrawlerHook]):
34
34
  """
35
35
  Waits for an AWS Glue crawler to reach any of the statuses below.
36
36
 
@@ -41,19 +41,27 @@ class GlueCrawlerSensor(BaseSensorOperator):
41
41
  :ref:`howto/sensor:GlueCrawlerSensor`
42
42
 
43
43
  :param crawler_name: The AWS Glue crawler unique name
44
- :param aws_conn_id: aws connection to use, defaults to 'aws_default'
45
- If this is None or empty then the default boto3 behaviour is used. If
44
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
45
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
46
46
  running Airflow in a distributed manner and aws_conn_id is None or
47
47
  empty, then default boto3 configuration would be used (and must be
48
48
  maintained on each worker node).
49
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
50
+ :param verify: Whether or not to verify SSL certificates. See:
51
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
52
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
53
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
49
54
  """
50
55
 
51
- template_fields: Sequence[str] = ("crawler_name",)
56
+ aws_hook_class = GlueCrawlerHook
52
57
 
53
- def __init__(self, *, crawler_name: str, aws_conn_id: str | None = "aws_default", **kwargs) -> None:
58
+ template_fields: Sequence[str] = aws_template_fields(
59
+ "crawler_name",
60
+ )
61
+
62
+ def __init__(self, *, crawler_name: str, **kwargs) -> None:
54
63
  super().__init__(**kwargs)
55
64
  self.crawler_name = crawler_name
56
- self.aws_conn_id = aws_conn_id
57
65
  self.success_statuses = "SUCCEEDED"
58
66
  self.errored_statuses = ("FAILED", "CANCELLED")
59
67
 
@@ -79,7 +87,3 @@ class GlueCrawlerSensor(BaseSensorOperator):
79
87
  def get_hook(self) -> GlueCrawlerHook:
80
88
  """Return a new or pre-existing GlueCrawlerHook."""
81
89
  return self.hook
82
-
83
- @cached_property
84
- def hook(self) -> GlueCrawlerHook:
85
- return GlueCrawlerHook(aws_conn_id=self.aws_conn_id)
@@ -18,6 +18,7 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import fnmatch
21
+ import inspect
21
22
  import os
22
23
  import re
23
24
  from datetime import datetime, timedelta
@@ -57,13 +58,13 @@ class S3KeySensor(BaseSensorOperator):
57
58
  refers to this bucket
58
59
  :param wildcard_match: whether the bucket_key should be interpreted as a
59
60
  Unix wildcard pattern
60
- :param check_fn: Function that receives the list of the S3 objects,
61
+ :param check_fn: Function that receives the list of the S3 objects with the context values,
61
62
  and returns a boolean:
62
63
  - ``True``: the criteria is met
63
64
  - ``False``: the criteria isn't met
64
65
  **Example**: Wait for any S3 object size more than 1 megabyte ::
65
66
 
66
- def check_fn(files: List) -> bool:
67
+ def check_fn(files: List, **kwargs) -> bool:
67
68
  return any(f.get('Size', 0) > 1048576 for f in files)
68
69
  :param aws_conn_id: a reference to the s3 connection
69
70
  :param verify: Whether to verify SSL certificates for S3 connection.
@@ -112,7 +113,7 @@ class S3KeySensor(BaseSensorOperator):
112
113
  self.use_regex = use_regex
113
114
  self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
114
115
 
115
- def _check_key(self, key):
116
+ def _check_key(self, key, context: Context):
116
117
  bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
117
118
  self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
118
119
 
@@ -167,15 +168,20 @@ class S3KeySensor(BaseSensorOperator):
167
168
  files = [metadata]
168
169
 
169
170
  if self.check_fn is not None:
171
+ # For backwards compatibility, check if the function takes a context argument
172
+ signature = inspect.signature(self.check_fn)
173
+ if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
174
+ return self.check_fn(files, **context)
175
+ # Otherwise, just pass the files
170
176
  return self.check_fn(files)
171
177
 
172
178
  return True
173
179
 
174
180
  def poke(self, context: Context):
175
181
  if isinstance(self.bucket_key, str):
176
- return self._check_key(self.bucket_key)
182
+ return self._check_key(self.bucket_key, context=context)
177
183
  else:
178
- return all(self._check_key(key) for key in self.bucket_key)
184
+ return all(self._check_key(key, context=context) for key in self.bucket_key)
179
185
 
180
186
  def execute(self, context: Context) -> None:
181
187
  """Airflow runs this method on the worker and defers using the trigger."""
@@ -34,7 +34,8 @@ if TYPE_CHECKING:
34
34
 
35
35
 
36
36
  class MongoToS3Operator(BaseOperator):
37
- """Move data from MongoDB to S3.
37
+ """
38
+ Move data from MongoDB to S3.
38
39
 
39
40
  .. seealso::
40
41
  For more information on how to use this operator, take a look at the guide:
@@ -128,7 +129,8 @@ class MongoToS3Operator(BaseOperator):
128
129
 
129
130
  @staticmethod
130
131
  def _stringify(iterable: Iterable, joinable: str = "\n") -> str:
131
- """Stringify an iterable of dicts.
132
+ """
133
+ Stringify an iterable of dicts.
132
134
 
133
135
  This dumps each dict with JSON, and joins them with ``joinable``.
134
136
  """
@@ -136,7 +138,8 @@ class MongoToS3Operator(BaseOperator):
136
138
 
137
139
  @staticmethod
138
140
  def transform(docs: Any) -> Any:
139
- """Transform the data for transfer.
141
+ """
142
+ Transform the data for transfer.
140
143
 
141
144
  This method is meant to be extended by child classes to perform
142
145
  transformations unique to those operators needs. Processes pyMongo
@@ -44,7 +44,8 @@ class KeySchema(TypedDict):
44
44
 
45
45
 
46
46
  class S3ToDynamoDBOperator(BaseOperator):
47
- """Load Data from S3 into a DynamoDB.
47
+ """
48
+ Load Data from S3 into a DynamoDB.
48
49
 
49
50
  Data stored in S3 can be uploaded to a new or existing DynamoDB. Supported file formats CSV, DynamoDB JSON and
50
51
  Amazon ION.
@@ -30,7 +30,8 @@ if TYPE_CHECKING:
30
30
 
31
31
 
32
32
  class S3ToSqlOperator(BaseOperator):
33
- """Load Data from S3 into a SQL Database.
33
+ """
34
+ Load Data from S3 into a SQL Database.
34
35
 
35
36
  You need to provide a parser function that takes a filename as an input
36
37
  and returns an iterable of rows
@@ -179,7 +179,9 @@ class TaskDoneTrigger(BaseTrigger):
179
179
  cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1}
180
180
  )
181
181
  # we reach this point only if the waiter met a success criteria
182
- yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
182
+ yield TriggerEvent(
183
+ {"status": "success", "task_arn": self.task_arn, "cluster": self.cluster}
184
+ )
183
185
  return
184
186
  except WaiterError as error:
185
187
  if "terminal failure" in str(error):