apache-airflow-providers-amazon 8.23.0__py3-none-any.whl → 8.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. airflow/providers/amazon/LICENSE +4 -4
  2. airflow/providers/amazon/__init__.py +1 -1
  3. airflow/providers/amazon/aws/hooks/base_aws.py +8 -3
  4. airflow/providers/amazon/aws/hooks/batch_client.py +3 -0
  5. airflow/providers/amazon/aws/hooks/dynamodb.py +34 -1
  6. airflow/providers/amazon/aws/hooks/glue.py +123 -0
  7. airflow/providers/amazon/aws/operators/batch.py +8 -0
  8. airflow/providers/amazon/aws/operators/bedrock.py +6 -20
  9. airflow/providers/amazon/aws/operators/ecs.py +5 -5
  10. airflow/providers/amazon/aws/operators/emr.py +38 -30
  11. airflow/providers/amazon/aws/operators/glue.py +408 -2
  12. airflow/providers/amazon/aws/operators/sagemaker.py +85 -12
  13. airflow/providers/amazon/aws/sensors/glue.py +260 -2
  14. airflow/providers/amazon/aws/sensors/s3.py +35 -5
  15. airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +0 -1
  16. airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +257 -0
  17. airflow/providers/amazon/aws/triggers/glue.py +76 -2
  18. airflow/providers/amazon/aws/waiters/dynamodb.json +37 -0
  19. airflow/providers/amazon/aws/waiters/glue.json +98 -0
  20. airflow/providers/amazon/get_provider_info.py +26 -13
  21. {apache_airflow_providers_amazon-8.23.0.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/METADATA +19 -18
  22. {apache_airflow_providers_amazon-8.23.0.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/RECORD +24 -23
  23. {apache_airflow_providers_amazon-8.23.0.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/WHEEL +0 -0
  24. {apache_airflow_providers_amazon-8.23.0.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/entry_points.txt +0 -0
@@ -215,7 +215,7 @@ Third party Apache 2.0 licenses
215
215
 
216
216
  The following components are provided under the Apache 2.0 License.
217
217
  See project link for details. The text of each license is also included
218
- at licenses/LICENSE-[project].txt.
218
+ at 3rd-party-licenses/LICENSE-[project].txt.
219
219
 
220
220
  (ALv2 License) hue v4.3.0 (https://github.com/cloudera/hue/)
221
221
  (ALv2 License) jqclock v2.3.0 (https://github.com/JohnRDOrazio/jQuery-Clock-Plugin)
@@ -227,7 +227,7 @@ MIT licenses
227
227
  ========================================================================
228
228
 
229
229
  The following components are provided under the MIT License. See project link for details.
230
- The text of each license is also included at licenses/LICENSE-[project].txt.
230
+ The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
231
231
 
232
232
  (MIT License) jquery v3.5.1 (https://jquery.org/license/)
233
233
  (MIT License) dagre-d3 v0.6.4 (https://github.com/cpettitt/dagre-d3)
@@ -243,11 +243,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
243
243
  BSD 3-Clause licenses
244
244
  ========================================================================
245
245
  The following components are provided under the BSD 3-Clause license. See project links for details.
246
- The text of each license is also included at licenses/LICENSE-[project].txt.
246
+ The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
247
247
 
248
248
  (BSD 3 License) d3 v5.16.0 (https://d3js.org)
249
249
  (BSD 3 License) d3-shape v2.1.0 (https://github.com/d3/d3-shape)
250
250
  (BSD 3 License) cgroupspy 0.2.1 (https://github.com/cloudsigma/cgroupspy)
251
251
 
252
252
  ========================================================================
253
- See licenses/LICENSES-ui.txt for packages used in `/airflow/www`
253
+ See 3rd-party-licenses/LICENSES-ui.txt for packages used in `/airflow/www`
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "8.23.0"
32
+ __version__ = "8.24.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.7.0"
@@ -156,7 +156,9 @@ class BaseSessionFactory(LoggingMixin):
156
156
 
157
157
  return async_get_session()
158
158
 
159
- def create_session(self, deferrable: bool = False) -> boto3.session.Session:
159
+ def create_session(
160
+ self, deferrable: bool = False
161
+ ) -> boto3.session.Session | aiobotocore.session.AioSession:
160
162
  """Create boto3 or aiobotocore Session from connection config."""
161
163
  if not self.conn:
162
164
  self.log.info(
@@ -198,7 +200,7 @@ class BaseSessionFactory(LoggingMixin):
198
200
 
199
201
  def _create_session_with_assume_role(
200
202
  self, session_kwargs: dict[str, Any], deferrable: bool = False
201
- ) -> boto3.session.Session:
203
+ ) -> boto3.session.Session | aiobotocore.session.AioSession:
202
204
  if self.conn.assume_role_method == "assume_role_with_web_identity":
203
205
  # Deferred credentials have no initial credentials
204
206
  credential_fetcher = self._get_web_identity_credential_fetcher()
@@ -239,7 +241,10 @@ class BaseSessionFactory(LoggingMixin):
239
241
  session._credentials = credentials
240
242
  session.set_config_variable("region", self.basic_session.region_name)
241
243
 
242
- return boto3.session.Session(botocore_session=session, **session_kwargs)
244
+ if not deferrable:
245
+ return boto3.session.Session(botocore_session=session, **session_kwargs)
246
+
247
+ return session
243
248
 
244
249
  def _refresh_credentials(self) -> dict[str, Any]:
245
250
  self.log.debug("Refreshing credentials")
@@ -102,6 +102,7 @@ class BatchProtocol(Protocol):
102
102
  arrayProperties: dict,
103
103
  parameters: dict,
104
104
  containerOverrides: dict,
105
+ ecsPropertiesOverride: dict,
105
106
  tags: dict,
106
107
  ) -> dict:
107
108
  """
@@ -119,6 +120,8 @@ class BatchProtocol(Protocol):
119
120
 
120
121
  :param containerOverrides: the same parameter that boto3 will receive
121
122
 
123
+ :param ecsPropertiesOverride: the same parameter that boto3 will receive
124
+
122
125
  :param tags: the same parameter that boto3 will receive
123
126
 
124
127
  :return: an API response
@@ -19,11 +19,17 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from typing import Iterable
22
+ from functools import cached_property
23
+ from typing import TYPE_CHECKING, Iterable
24
+
25
+ from botocore.exceptions import ClientError
23
26
 
24
27
  from airflow.exceptions import AirflowException
25
28
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
26
29
 
30
+ if TYPE_CHECKING:
31
+ from botocore.client import BaseClient
32
+
27
33
 
28
34
  class DynamoDBHook(AwsBaseHook):
29
35
  """
@@ -50,6 +56,11 @@ class DynamoDBHook(AwsBaseHook):
50
56
  kwargs["resource_type"] = "dynamodb"
51
57
  super().__init__(*args, **kwargs)
52
58
 
59
+ @cached_property
60
+ def client(self) -> BaseClient:
61
+ """Return boto3 client."""
62
+ return self.get_conn().meta.client
63
+
53
64
  def write_batch_data(self, items: Iterable) -> bool:
54
65
  """
55
66
  Write batch items to DynamoDB table with provisioned throughout capacity.
@@ -70,3 +81,25 @@ class DynamoDBHook(AwsBaseHook):
70
81
  return True
71
82
  except Exception as general_error:
72
83
  raise AirflowException(f"Failed to insert items in dynamodb, error: {general_error}")
84
+
85
+ def get_import_status(self, import_arn: str) -> tuple[str, str | None, str | None]:
86
+ """
87
+ Get import status from Dynamodb.
88
+
89
+ :param import_arn: The Amazon Resource Name (ARN) for the import.
90
+ :return: Import status, Error code and Error message
91
+ """
92
+ self.log.info("Poking for Dynamodb import %s", import_arn)
93
+
94
+ try:
95
+ describe_import = self.client.describe_import(ImportArn=import_arn)
96
+ status = describe_import["ImportTableDescription"]["ImportStatus"]
97
+ error_code = describe_import["ImportTableDescription"].get("FailureCode")
98
+ error_msg = describe_import["ImportTableDescription"].get("FailureMessage")
99
+ return status, error_code, error_msg
100
+ except ClientError as e:
101
+ error_code = e.response.get("Error", {}).get("Code")
102
+ if error_code == "ImportNotFoundException":
103
+ raise AirflowException("S3 import into Dynamodb job not found.")
104
+ else:
105
+ raise e
@@ -20,6 +20,7 @@ from __future__ import annotations
20
20
  import asyncio
21
21
  import time
22
22
  from functools import cached_property
23
+ from typing import Any
23
24
 
24
25
  from botocore.exceptions import ClientError
25
26
 
@@ -430,3 +431,125 @@ class GlueJobHook(AwsBaseHook):
430
431
  self.conn.create_job(**config)
431
432
 
432
433
  return self.job_name
434
+
435
+
436
+ class GlueDataQualityHook(AwsBaseHook):
437
+ """
438
+ Interact with AWS Glue Data Quality.
439
+
440
+ Provide thick wrapper around :external+boto3:py:class:`boto3.client("glue") <Glue.Client>`.
441
+
442
+ Additional arguments (such as ``aws_conn_id``) may be specified and
443
+ are passed down to the underlying AwsBaseHook.
444
+
445
+ .. seealso::
446
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
447
+ """
448
+
449
+ def __init__(
450
+ self,
451
+ *args,
452
+ **kwargs,
453
+ ):
454
+ kwargs["client_type"] = "glue"
455
+ super().__init__(*args, **kwargs)
456
+
457
+ def has_data_quality_ruleset(self, name: str) -> bool:
458
+ try:
459
+ self.conn.get_data_quality_ruleset(Name=name)
460
+ return True
461
+ except self.conn.exceptions.EntityNotFoundException:
462
+ return False
463
+
464
+ def _log_results(self, result: dict[str, Any]) -> None:
465
+ """
466
+ Print the outcome of evaluation run, An evaluation run can involve multiple rulesets evaluated against a data source (Glue table).
467
+
468
+ Name Description Result EvaluatedMetrics EvaluationMessage
469
+ Rule_1 RowCount between 150000 and 600000 PASS {'Dataset.*.RowCount': 300000.0} NaN
470
+ Rule_2 IsComplete "marketplace" PASS {'Column.marketplace.Completeness': 1.0} NaN
471
+ Rule_3 ColumnLength "marketplace" between 1 and 2 FAIL {'Column.marketplace.MaximumLength': 9.0, 'Column.marketplace.MinimumLength': 3.0} Value: 9.0 does not meet the constraint requirement!
472
+
473
+ """
474
+ import pandas as pd
475
+
476
+ pd.set_option("display.max_rows", None)
477
+ pd.set_option("display.max_columns", None)
478
+ pd.set_option("display.width", None)
479
+ pd.set_option("display.max_colwidth", None)
480
+
481
+ self.log.info(
482
+ "AWS Glue data quality ruleset evaluation result for RulesetName: %s RulesetEvaluationRunId: %s Score: %s",
483
+ result.get("RulesetName"),
484
+ result.get("RulesetEvaluationRunId"),
485
+ result.get("Score"),
486
+ )
487
+
488
+ rule_results = result["RuleResults"]
489
+ rule_results_df = pd.DataFrame(rule_results)
490
+ self.log.info(rule_results_df)
491
+
492
+ def get_evaluation_run_results(self, run_id: str) -> dict[str, Any]:
493
+ response = self.conn.get_data_quality_ruleset_evaluation_run(RunId=run_id)
494
+
495
+ return self.conn.batch_get_data_quality_result(ResultIds=response["ResultIds"])
496
+
497
+ def validate_evaluation_run_results(
498
+ self, evaluation_run_id: str, show_results: bool = True, verify_result_status: bool = True
499
+ ) -> None:
500
+ results = self.get_evaluation_run_results(evaluation_run_id)
501
+ total_failed_rules = 0
502
+
503
+ if results.get("ResultsNotFound"):
504
+ self.log.info(
505
+ "AWS Glue data quality ruleset evaluation run, results not found for %s",
506
+ results["ResultsNotFound"],
507
+ )
508
+
509
+ for result in results["Results"]:
510
+ rule_results = result["RuleResults"]
511
+
512
+ total_failed_rules += len(
513
+ list(
514
+ filter(
515
+ lambda result: result.get("Result") == "FAIL" or result.get("Result") == "ERROR",
516
+ rule_results,
517
+ )
518
+ )
519
+ )
520
+
521
+ if show_results:
522
+ self._log_results(result)
523
+
524
+ self.log.info(
525
+ "AWS Glue data quality ruleset evaluation run, total number of rules failed: %s",
526
+ total_failed_rules,
527
+ )
528
+
529
+ if verify_result_status and total_failed_rules > 0:
530
+ raise AirflowException(
531
+ "AWS Glue data quality ruleset evaluation run failed for one or more rules"
532
+ )
533
+
534
+ def log_recommendation_results(self, run_id: str) -> None:
535
+ """
536
+ Print the outcome of recommendation run, recommendation run generates multiple rules against a data source (Glue table) in Data Quality Definition Language (DQDL) format.
537
+
538
+ Rules = [
539
+ IsComplete "NAME",
540
+ ColumnLength "EMP_ID" between 1 and 12,
541
+ IsUnique "EMP_ID",
542
+ ColumnValues "INCOME" > 50000
543
+ ]
544
+ """
545
+ result = self.conn.get_data_quality_rule_recommendation_run(RunId=run_id)
546
+
547
+ if result.get("RecommendedRuleset"):
548
+ self.log.info(
549
+ "AWS Glue data quality recommended rules for DatabaseName: %s TableName: %s",
550
+ result["DataSource"]["GlueTable"]["DatabaseName"],
551
+ result["DataSource"]["GlueTable"]["TableName"],
552
+ )
553
+ self.log.info(result["RecommendedRuleset"])
554
+ else:
555
+ self.log.info("AWS Glue data quality, no recommended rules available for RunId: %s", run_id)
@@ -65,6 +65,7 @@ class BatchOperator(BaseOperator):
65
65
  :param job_queue: the queue name on AWS Batch
66
66
  :param overrides: DEPRECATED, use container_overrides instead with the same value.
67
67
  :param container_overrides: the `containerOverrides` parameter for boto3 (templated)
68
+ :param ecs_properties_override: the `ecsPropertiesOverride` parameter for boto3 (templated)
68
69
  :param node_overrides: the `nodeOverrides` parameter for boto3 (templated)
69
70
  :param share_identifier: The share identifier for the job. Don't specify this parameter if the job queue
70
71
  doesn't have a scheduling policy.
@@ -112,6 +113,7 @@ class BatchOperator(BaseOperator):
112
113
  "job_queue",
113
114
  "container_overrides",
114
115
  "array_properties",
116
+ "ecs_properties_override",
115
117
  "node_overrides",
116
118
  "parameters",
117
119
  "retry_strategy",
@@ -124,6 +126,7 @@ class BatchOperator(BaseOperator):
124
126
  template_fields_renderers = {
125
127
  "container_overrides": "json",
126
128
  "parameters": "json",
129
+ "ecs_properties_override": "json",
127
130
  "node_overrides": "json",
128
131
  "retry_strategy": "json",
129
132
  }
@@ -160,6 +163,7 @@ class BatchOperator(BaseOperator):
160
163
  overrides: dict | None = None, # deprecated
161
164
  container_overrides: dict | None = None,
162
165
  array_properties: dict | None = None,
166
+ ecs_properties_override: dict | None = None,
163
167
  node_overrides: dict | None = None,
164
168
  share_identifier: str | None = None,
165
169
  scheduling_priority_override: int | None = None,
@@ -201,6 +205,7 @@ class BatchOperator(BaseOperator):
201
205
  stacklevel=2,
202
206
  )
203
207
 
208
+ self.ecs_properties_override = ecs_properties_override
204
209
  self.node_overrides = node_overrides
205
210
  self.share_identifier = share_identifier
206
211
  self.scheduling_priority_override = scheduling_priority_override
@@ -296,6 +301,8 @@ class BatchOperator(BaseOperator):
296
301
  self.log.info("AWS Batch job - container overrides: %s", self.container_overrides)
297
302
  if self.array_properties:
298
303
  self.log.info("AWS Batch job - array properties: %s", self.array_properties)
304
+ if self.ecs_properties_override:
305
+ self.log.info("AWS Batch job - ECS properties: %s", self.ecs_properties_override)
299
306
  if self.node_overrides:
300
307
  self.log.info("AWS Batch job - node properties: %s", self.node_overrides)
301
308
 
@@ -307,6 +314,7 @@ class BatchOperator(BaseOperator):
307
314
  "parameters": self.parameters,
308
315
  "tags": self.tags,
309
316
  "containerOverrides": self.container_overrides,
317
+ "ecsPropertiesOverride": self.ecs_properties_override,
310
318
  "nodeOverrides": self.node_overrides,
311
319
  "retryStrategy": self.retry_strategy,
312
320
  "shareIdentifier": self.share_identifier,
@@ -20,7 +20,6 @@ import json
20
20
  from time import sleep
21
21
  from typing import TYPE_CHECKING, Any, Sequence
22
22
 
23
- import botocore
24
23
  from botocore.exceptions import ClientError
25
24
 
26
25
  from airflow.configuration import conf
@@ -38,7 +37,7 @@ from airflow.providers.amazon.aws.triggers.bedrock import (
38
37
  BedrockKnowledgeBaseActiveTrigger,
39
38
  BedrockProvisionModelThroughputCompletedTrigger,
40
39
  )
41
- from airflow.providers.amazon.aws.utils import get_botocore_version, validate_execute_complete_event
40
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
42
41
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
43
42
  from airflow.utils.helpers import prune_dict
44
43
  from airflow.utils.timezone import utcnow
@@ -799,24 +798,11 @@ class BedrockRaGOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
799
798
  def execute(self, context: Context) -> Any:
800
799
  self.validate_inputs()
801
800
 
802
- try:
803
- result = self.hook.conn.retrieve_and_generate(
804
- input={"text": self.input},
805
- retrieveAndGenerateConfiguration=self.build_rag_config(),
806
- **self.rag_kwargs,
807
- )
808
- except botocore.exceptions.ParamValidationError as error:
809
- if (
810
- 'Unknown parameter in retrieveAndGenerateConfiguration: "externalSourcesConfiguration"'
811
- in str(error)
812
- ) and (self.source_type == "EXTERNAL_SOURCES"):
813
- self.log.error(
814
- "You are attempting to use External Sources and the BOTO API returned an "
815
- "error message which may indicate the need to update botocore to do this. \n"
816
- "Support for External Sources was added in botocore 1.34.90 and you are using botocore %s",
817
- ".".join(map(str, get_botocore_version())),
818
- )
819
- raise
801
+ result = self.hook.conn.retrieve_and_generate(
802
+ input={"text": self.input},
803
+ retrieveAndGenerateConfiguration=self.build_rag_config(),
804
+ **self.rag_kwargs,
805
+ )
820
806
 
821
807
  self.log.info(
822
808
  "\nPrompt: %s\nResponse: %s\nCitations: %s",
@@ -141,7 +141,7 @@ class EcsCreateClusterOperator(EcsBaseOperator):
141
141
  waiter_delay=self.waiter_delay,
142
142
  waiter_max_attempts=self.waiter_max_attempts,
143
143
  aws_conn_id=self.aws_conn_id,
144
- region_name=self.region,
144
+ region_name=self.region_name,
145
145
  ),
146
146
  method_name="_complete_exec_with_cluster_desc",
147
147
  # timeout is set to ensure that if a trigger dies, the timeout does not restart
@@ -218,7 +218,7 @@ class EcsDeleteClusterOperator(EcsBaseOperator):
218
218
  waiter_delay=self.waiter_delay,
219
219
  waiter_max_attempts=self.waiter_max_attempts,
220
220
  aws_conn_id=self.aws_conn_id,
221
- region_name=self.region,
221
+ region_name=self.region_name,
222
222
  ),
223
223
  method_name="_complete_exec_with_cluster_desc",
224
224
  # timeout is set to ensure that if a trigger dies, the timeout does not restart
@@ -495,7 +495,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
495
495
  self.number_logs_exception = number_logs_exception
496
496
 
497
497
  if self.awslogs_region is None:
498
- self.awslogs_region = self.region
498
+ self.awslogs_region = self.region_name
499
499
 
500
500
  self.arn: str | None = None
501
501
  self._started_by: str | None = None
@@ -546,7 +546,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
546
546
  waiter_delay=self.waiter_delay,
547
547
  waiter_max_attempts=self.waiter_max_attempts,
548
548
  aws_conn_id=self.aws_conn_id,
549
- region=self.region,
549
+ region=self.region_name,
550
550
  log_group=self.awslogs_group,
551
551
  log_stream=self._get_logs_stream_name(),
552
552
  ),
@@ -589,7 +589,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
589
589
  self._after_execution()
590
590
  if self._aws_logs_enabled():
591
591
  # same behavior as non-deferrable mode, return last line of logs of the task.
592
- logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).conn
592
+ logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name).conn
593
593
  one_log = logs_client.get_log_events(
594
594
  logGroupName=self.awslogs_group,
595
595
  logStreamName=self._get_logs_stream_name(),
@@ -263,30 +263,34 @@ class EmrStartNotebookExecutionOperator(BaseOperator):
263
263
  wait_for_completion: bool = False,
264
264
  aws_conn_id: str | None = "aws_default",
265
265
  # TODO: waiter_max_attempts and waiter_delay should default to None when the other two are deprecated.
266
- waiter_max_attempts: int | None | ArgNotSet = NOTSET,
267
- waiter_delay: int | None | ArgNotSet = NOTSET,
268
- waiter_countdown: int = 25 * 60,
269
- waiter_check_interval_seconds: int = 60,
266
+ waiter_max_attempts: int | None = None,
267
+ waiter_delay: int | None = None,
268
+ waiter_countdown: int | None = None,
269
+ waiter_check_interval_seconds: int | None = None,
270
270
  **kwargs: Any,
271
271
  ):
272
- if waiter_max_attempts is NOTSET:
272
+ if waiter_check_interval_seconds:
273
273
  warnings.warn(
274
- "The parameter waiter_countdown has been deprecated to standardize "
275
- "naming conventions. Please use waiter_max_attempts instead. In the "
274
+ "The parameter `waiter_check_interval_seconds` has been deprecated to "
275
+ "standardize naming conventions. Please `use waiter_delay instead`. In the "
276
276
  "future this will default to None and defer to the waiter's default value.",
277
277
  AirflowProviderDeprecationWarning,
278
278
  stacklevel=2,
279
279
  )
280
- waiter_max_attempts = waiter_countdown // waiter_check_interval_seconds
281
- if waiter_delay is NOTSET:
280
+ else:
281
+ waiter_check_interval_seconds = 60
282
+ if waiter_countdown:
282
283
  warnings.warn(
283
- "The parameter waiter_check_interval_seconds has been deprecated to "
284
- "standardize naming conventions. Please use waiter_delay instead. In the "
284
+ "The parameter waiter_countdown has been deprecated to standardize "
285
+ "naming conventions. Please use waiter_max_attempts instead. In the "
285
286
  "future this will default to None and defer to the waiter's default value.",
286
287
  AirflowProviderDeprecationWarning,
287
288
  stacklevel=2,
288
289
  )
289
- waiter_delay = waiter_check_interval_seconds
290
+ # waiter_countdown defaults to never timing out, which is not supported
291
+ # by boto waiters, so we will set it here to "a very long time" for now.
292
+ waiter_max_attempts = (waiter_countdown or 999) // waiter_check_interval_seconds
293
+
290
294
  super().__init__(**kwargs)
291
295
  self.editor_id = editor_id
292
296
  self.relative_path = relative_path
@@ -298,8 +302,8 @@ class EmrStartNotebookExecutionOperator(BaseOperator):
298
302
  self.wait_for_completion = wait_for_completion
299
303
  self.cluster_id = cluster_id
300
304
  self.aws_conn_id = aws_conn_id
301
- self.waiter_max_attempts = waiter_max_attempts
302
- self.waiter_delay = waiter_delay
305
+ self.waiter_max_attempts = waiter_max_attempts or 25
306
+ self.waiter_delay = waiter_delay or waiter_check_interval_seconds or 60
303
307
  self.master_instance_security_group_id = master_instance_security_group_id
304
308
 
305
309
  def execute(self, context: Context):
@@ -387,36 +391,40 @@ class EmrStopNotebookExecutionOperator(BaseOperator):
387
391
  wait_for_completion: bool = False,
388
392
  aws_conn_id: str | None = "aws_default",
389
393
  # TODO: waiter_max_attempts and waiter_delay should default to None when the other two are deprecated.
390
- waiter_max_attempts: int | None | ArgNotSet = NOTSET,
391
- waiter_delay: int | None | ArgNotSet = NOTSET,
392
- waiter_countdown: int = 25 * 60,
393
- waiter_check_interval_seconds: int = 60,
394
+ waiter_max_attempts: int | None = None,
395
+ waiter_delay: int | None = None,
396
+ waiter_countdown: int | None = None,
397
+ waiter_check_interval_seconds: int | None = None,
394
398
  **kwargs: Any,
395
399
  ):
396
- if waiter_max_attempts is NOTSET:
400
+ if waiter_check_interval_seconds:
397
401
  warnings.warn(
398
- "The parameter waiter_countdown has been deprecated to standardize "
399
- "naming conventions. Please use waiter_max_attempts instead. In the "
402
+ "The parameter `waiter_check_interval_seconds` has been deprecated to "
403
+ "standardize naming conventions. Please `use waiter_delay instead`. In the "
400
404
  "future this will default to None and defer to the waiter's default value.",
401
405
  AirflowProviderDeprecationWarning,
402
406
  stacklevel=2,
403
407
  )
404
- waiter_max_attempts = waiter_countdown // waiter_check_interval_seconds
405
- if waiter_delay is NOTSET:
408
+ else:
409
+ waiter_check_interval_seconds = 60
410
+ if waiter_countdown:
406
411
  warnings.warn(
407
- "The parameter waiter_check_interval_seconds has been deprecated to "
408
- "standardize naming conventions. Please use waiter_delay instead. In the "
412
+ "The parameter waiter_countdown has been deprecated to standardize "
413
+ "naming conventions. Please use waiter_max_attempts instead. In the "
409
414
  "future this will default to None and defer to the waiter's default value.",
410
415
  AirflowProviderDeprecationWarning,
411
416
  stacklevel=2,
412
417
  )
413
- waiter_delay = waiter_check_interval_seconds
418
+ # waiter_countdown defaults to never timing out, which is not supported
419
+ # by boto waiters, so we will set it here to "a very long time" for now.
420
+ waiter_max_attempts = (waiter_countdown or 999) // waiter_check_interval_seconds
421
+
414
422
  super().__init__(**kwargs)
415
423
  self.notebook_execution_id = notebook_execution_id
416
424
  self.wait_for_completion = wait_for_completion
417
425
  self.aws_conn_id = aws_conn_id
418
- self.waiter_max_attempts = waiter_max_attempts
419
- self.waiter_delay = waiter_delay
426
+ self.waiter_max_attempts = waiter_max_attempts or 25
427
+ self.waiter_delay = waiter_delay or waiter_check_interval_seconds or 60
420
428
 
421
429
  def execute(self, context: Context) -> None:
422
430
  emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
@@ -822,8 +830,8 @@ class EmrCreateJobFlowOperator(BaseOperator):
822
830
  trigger=EmrCreateJobFlowTrigger(
823
831
  job_flow_id=self._job_flow_id,
824
832
  aws_conn_id=self.aws_conn_id,
825
- poll_interval=self.waiter_delay,
826
- max_attempts=self.waiter_max_attempts,
833
+ waiter_delay=self.waiter_delay,
834
+ waiter_max_attempts=self.waiter_max_attempts,
827
835
  ),
828
836
  method_name="execute_complete",
829
837
  # timeout is set to ensure that if a trigger dies, the timeout does not restart