apache-airflow-providers-amazon 8.24.0rc1__py3-none-any.whl → 8.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. airflow/providers/amazon/LICENSE +4 -4
  2. airflow/providers/amazon/__init__.py +1 -1
  3. airflow/providers/amazon/aws/hooks/base_aws.py +8 -3
  4. airflow/providers/amazon/aws/hooks/comprehend.py +33 -0
  5. airflow/providers/amazon/aws/hooks/glue.py +123 -0
  6. airflow/providers/amazon/aws/hooks/redshift_sql.py +8 -1
  7. airflow/providers/amazon/aws/operators/bedrock.py +6 -20
  8. airflow/providers/amazon/aws/operators/comprehend.py +148 -1
  9. airflow/providers/amazon/aws/operators/emr.py +38 -30
  10. airflow/providers/amazon/aws/operators/glue.py +408 -2
  11. airflow/providers/amazon/aws/operators/sagemaker.py +85 -12
  12. airflow/providers/amazon/aws/sensors/comprehend.py +112 -1
  13. airflow/providers/amazon/aws/sensors/glue.py +260 -2
  14. airflow/providers/amazon/aws/sensors/s3.py +35 -5
  15. airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +0 -1
  16. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -1
  17. airflow/providers/amazon/aws/triggers/comprehend.py +36 -0
  18. airflow/providers/amazon/aws/triggers/glue.py +76 -2
  19. airflow/providers/amazon/aws/utils/__init__.py +2 -3
  20. airflow/providers/amazon/aws/waiters/comprehend.json +55 -0
  21. airflow/providers/amazon/aws/waiters/glue.json +98 -0
  22. airflow/providers/amazon/get_provider_info.py +20 -13
  23. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/METADATA +22 -21
  24. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/RECORD +26 -26
  25. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/WHEEL +0 -0
  26. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/entry_points.txt +0 -0
@@ -215,7 +215,7 @@ Third party Apache 2.0 licenses
215
215
 
216
216
  The following components are provided under the Apache 2.0 License.
217
217
  See project link for details. The text of each license is also included
218
- at licenses/LICENSE-[project].txt.
218
+ at 3rd-party-licenses/LICENSE-[project].txt.
219
219
 
220
220
  (ALv2 License) hue v4.3.0 (https://github.com/cloudera/hue/)
221
221
  (ALv2 License) jqclock v2.3.0 (https://github.com/JohnRDOrazio/jQuery-Clock-Plugin)
@@ -227,7 +227,7 @@ MIT licenses
227
227
  ========================================================================
228
228
 
229
229
  The following components are provided under the MIT License. See project link for details.
230
- The text of each license is also included at licenses/LICENSE-[project].txt.
230
+ The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
231
231
 
232
232
  (MIT License) jquery v3.5.1 (https://jquery.org/license/)
233
233
  (MIT License) dagre-d3 v0.6.4 (https://github.com/cpettitt/dagre-d3)
@@ -243,11 +243,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
243
243
  BSD 3-Clause licenses
244
244
  ========================================================================
245
245
  The following components are provided under the BSD 3-Clause license. See project links for details.
246
- The text of each license is also included at licenses/LICENSE-[project].txt.
246
+ The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
247
247
 
248
248
  (BSD 3 License) d3 v5.16.0 (https://d3js.org)
249
249
  (BSD 3 License) d3-shape v2.1.0 (https://github.com/d3/d3-shape)
250
250
  (BSD 3 License) cgroupspy 0.2.1 (https://github.com/cloudsigma/cgroupspy)
251
251
 
252
252
  ========================================================================
253
- See licenses/LICENSES-ui.txt for packages used in `/airflow/www`
253
+ See 3rd-party-licenses/LICENSES-ui.txt for packages used in `/airflow/www`
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "8.24.0"
32
+ __version__ = "8.25.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.7.0"
@@ -156,7 +156,9 @@ class BaseSessionFactory(LoggingMixin):
156
156
 
157
157
  return async_get_session()
158
158
 
159
- def create_session(self, deferrable: bool = False) -> boto3.session.Session:
159
+ def create_session(
160
+ self, deferrable: bool = False
161
+ ) -> boto3.session.Session | aiobotocore.session.AioSession:
160
162
  """Create boto3 or aiobotocore Session from connection config."""
161
163
  if not self.conn:
162
164
  self.log.info(
@@ -198,7 +200,7 @@ class BaseSessionFactory(LoggingMixin):
198
200
 
199
201
  def _create_session_with_assume_role(
200
202
  self, session_kwargs: dict[str, Any], deferrable: bool = False
201
- ) -> boto3.session.Session:
203
+ ) -> boto3.session.Session | aiobotocore.session.AioSession:
202
204
  if self.conn.assume_role_method == "assume_role_with_web_identity":
203
205
  # Deferred credentials have no initial credentials
204
206
  credential_fetcher = self._get_web_identity_credential_fetcher()
@@ -239,7 +241,10 @@ class BaseSessionFactory(LoggingMixin):
239
241
  session._credentials = credentials
240
242
  session.set_config_variable("region", self.basic_session.region_name)
241
243
 
242
- return boto3.session.Session(botocore_session=session, **session_kwargs)
244
+ if not deferrable:
245
+ return boto3.session.Session(botocore_session=session, **session_kwargs)
246
+
247
+ return session
243
248
 
244
249
  def _refresh_credentials(self) -> dict[str, Any]:
245
250
  self.log.debug("Refreshing credentials")
@@ -16,6 +16,7 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
+ from airflow.exceptions import AirflowException
19
20
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
20
21
 
21
22
 
@@ -35,3 +36,35 @@ class ComprehendHook(AwsBaseHook):
35
36
  def __init__(self, *args, **kwargs) -> None:
36
37
  kwargs["client_type"] = "comprehend"
37
38
  super().__init__(*args, **kwargs)
39
+
40
+ def validate_document_classifier_training_status(
41
+ self, document_classifier_arn: str, fail_on_warnings: bool = False
42
+ ) -> None:
43
+ """
44
+ Log the Information about the document classifier.
45
+
46
+ NumberOfLabels
47
+ NumberOfTrainedDocuments
48
+ NumberOfTestDocuments
49
+ EvaluationMetrics
50
+
51
+ """
52
+ response = self.conn.describe_document_classifier(DocumentClassifierArn=document_classifier_arn)
53
+
54
+ status = response["DocumentClassifierProperties"]["Status"]
55
+
56
+ if status == "TRAINED_WITH_WARNING":
57
+ self.log.info(
58
+ "AWS Comprehend document classifier training completed with %s, Message: %s please review the skipped files folder in the output location %s",
59
+ status,
60
+ response["DocumentClassifierProperties"]["Message"],
61
+ response["DocumentClassifierProperties"]["OutputDataConfig"]["S3Uri"],
62
+ )
63
+
64
+ if fail_on_warnings:
65
+ raise AirflowException("Warnings in AWS Comprehend document classifier training.")
66
+
67
+ self.log.info(
68
+ "AWS Comprehend document classifier metadata: %s",
69
+ response["DocumentClassifierProperties"]["ClassifierMetadata"],
70
+ )
@@ -20,6 +20,7 @@ from __future__ import annotations
20
20
  import asyncio
21
21
  import time
22
22
  from functools import cached_property
23
+ from typing import Any
23
24
 
24
25
  from botocore.exceptions import ClientError
25
26
 
@@ -430,3 +431,125 @@ class GlueJobHook(AwsBaseHook):
430
431
  self.conn.create_job(**config)
431
432
 
432
433
  return self.job_name
434
+
435
+
436
+ class GlueDataQualityHook(AwsBaseHook):
437
+ """
438
+ Interact with AWS Glue Data Quality.
439
+
440
+ Provide thick wrapper around :external+boto3:py:class:`boto3.client("glue") <Glue.Client>`.
441
+
442
+ Additional arguments (such as ``aws_conn_id``) may be specified and
443
+ are passed down to the underlying AwsBaseHook.
444
+
445
+ .. seealso::
446
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
447
+ """
448
+
449
+ def __init__(
450
+ self,
451
+ *args,
452
+ **kwargs,
453
+ ):
454
+ kwargs["client_type"] = "glue"
455
+ super().__init__(*args, **kwargs)
456
+
457
+ def has_data_quality_ruleset(self, name: str) -> bool:
458
+ try:
459
+ self.conn.get_data_quality_ruleset(Name=name)
460
+ return True
461
+ except self.conn.exceptions.EntityNotFoundException:
462
+ return False
463
+
464
+ def _log_results(self, result: dict[str, Any]) -> None:
465
+ """
466
+ Print the outcome of evaluation run, An evaluation run can involve multiple rulesets evaluated against a data source (Glue table).
467
+
468
+ Name Description Result EvaluatedMetrics EvaluationMessage
469
+ Rule_1 RowCount between 150000 and 600000 PASS {'Dataset.*.RowCount': 300000.0} NaN
470
+ Rule_2 IsComplete "marketplace" PASS {'Column.marketplace.Completeness': 1.0} NaN
471
+ Rule_3 ColumnLength "marketplace" between 1 and 2 FAIL {'Column.marketplace.MaximumLength': 9.0, 'Column.marketplace.MinimumLength': 3.0} Value: 9.0 does not meet the constraint requirement!
472
+
473
+ """
474
+ import pandas as pd
475
+
476
+ pd.set_option("display.max_rows", None)
477
+ pd.set_option("display.max_columns", None)
478
+ pd.set_option("display.width", None)
479
+ pd.set_option("display.max_colwidth", None)
480
+
481
+ self.log.info(
482
+ "AWS Glue data quality ruleset evaluation result for RulesetName: %s RulesetEvaluationRunId: %s Score: %s",
483
+ result.get("RulesetName"),
484
+ result.get("RulesetEvaluationRunId"),
485
+ result.get("Score"),
486
+ )
487
+
488
+ rule_results = result["RuleResults"]
489
+ rule_results_df = pd.DataFrame(rule_results)
490
+ self.log.info(rule_results_df)
491
+
492
+ def get_evaluation_run_results(self, run_id: str) -> dict[str, Any]:
493
+ response = self.conn.get_data_quality_ruleset_evaluation_run(RunId=run_id)
494
+
495
+ return self.conn.batch_get_data_quality_result(ResultIds=response["ResultIds"])
496
+
497
+ def validate_evaluation_run_results(
498
+ self, evaluation_run_id: str, show_results: bool = True, verify_result_status: bool = True
499
+ ) -> None:
500
+ results = self.get_evaluation_run_results(evaluation_run_id)
501
+ total_failed_rules = 0
502
+
503
+ if results.get("ResultsNotFound"):
504
+ self.log.info(
505
+ "AWS Glue data quality ruleset evaluation run, results not found for %s",
506
+ results["ResultsNotFound"],
507
+ )
508
+
509
+ for result in results["Results"]:
510
+ rule_results = result["RuleResults"]
511
+
512
+ total_failed_rules += len(
513
+ list(
514
+ filter(
515
+ lambda result: result.get("Result") == "FAIL" or result.get("Result") == "ERROR",
516
+ rule_results,
517
+ )
518
+ )
519
+ )
520
+
521
+ if show_results:
522
+ self._log_results(result)
523
+
524
+ self.log.info(
525
+ "AWS Glue data quality ruleset evaluation run, total number of rules failed: %s",
526
+ total_failed_rules,
527
+ )
528
+
529
+ if verify_result_status and total_failed_rules > 0:
530
+ raise AirflowException(
531
+ "AWS Glue data quality ruleset evaluation run failed for one or more rules"
532
+ )
533
+
534
+ def log_recommendation_results(self, run_id: str) -> None:
535
+ """
536
+ Print the outcome of recommendation run, recommendation run generates multiple rules against a data source (Glue table) in Data Quality Definition Language (DQDL) format.
537
+
538
+ Rules = [
539
+ IsComplete "NAME",
540
+ ColumnLength "EMP_ID" between 1 and 12,
541
+ IsUnique "EMP_ID",
542
+ ColumnValues "INCOME" > 50000
543
+ ]
544
+ """
545
+ result = self.conn.get_data_quality_rule_recommendation_run(RunId=run_id)
546
+
547
+ if result.get("RecommendedRuleset"):
548
+ self.log.info(
549
+ "AWS Glue data quality recommended rules for DatabaseName: %s TableName: %s",
550
+ result["DataSource"]["GlueTable"]["DatabaseName"],
551
+ result["DataSource"]["GlueTable"]["TableName"],
552
+ )
553
+ self.log.info(result["RecommendedRuleset"])
554
+ else:
555
+ self.log.info("AWS Glue data quality, no recommended rules available for RunId: %s", run_id)
@@ -20,14 +20,19 @@ from functools import cached_property
20
20
  from typing import TYPE_CHECKING
21
21
 
22
22
  import redshift_connector
23
+ from packaging.version import Version
23
24
  from redshift_connector import Connection as RedshiftConnection
24
25
  from sqlalchemy import create_engine
25
26
  from sqlalchemy.engine.url import URL
26
27
 
28
+ from airflow import __version__ as AIRFLOW_VERSION
27
29
  from airflow.exceptions import AirflowException
28
30
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
29
31
  from airflow.providers.common.sql.hooks.sql import DbApiHook
30
32
 
33
+ _IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")
34
+
35
+
31
36
  if TYPE_CHECKING:
32
37
  from airflow.models.connection import Connection
33
38
  from airflow.providers.openlineage.sqlparser import DatabaseInfo
@@ -257,4 +262,6 @@ class RedshiftSQLHook(DbApiHook):
257
262
 
258
263
  def get_openlineage_default_schema(self) -> str | None:
259
264
  """Return current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
260
- return self.get_first("SELECT CURRENT_SCHEMA();")[0]
265
+ if _IS_AIRFLOW_2_10_OR_HIGHER:
266
+ return self.get_first("SELECT CURRENT_SCHEMA();")[0]
267
+ return super().get_openlineage_default_schema()
@@ -20,7 +20,6 @@ import json
20
20
  from time import sleep
21
21
  from typing import TYPE_CHECKING, Any, Sequence
22
22
 
23
- import botocore
24
23
  from botocore.exceptions import ClientError
25
24
 
26
25
  from airflow.configuration import conf
@@ -38,7 +37,7 @@ from airflow.providers.amazon.aws.triggers.bedrock import (
38
37
  BedrockKnowledgeBaseActiveTrigger,
39
38
  BedrockProvisionModelThroughputCompletedTrigger,
40
39
  )
41
- from airflow.providers.amazon.aws.utils import get_botocore_version, validate_execute_complete_event
40
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
42
41
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
43
42
  from airflow.utils.helpers import prune_dict
44
43
  from airflow.utils.timezone import utcnow
@@ -799,24 +798,11 @@ class BedrockRaGOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
799
798
  def execute(self, context: Context) -> Any:
800
799
  self.validate_inputs()
801
800
 
802
- try:
803
- result = self.hook.conn.retrieve_and_generate(
804
- input={"text": self.input},
805
- retrieveAndGenerateConfiguration=self.build_rag_config(),
806
- **self.rag_kwargs,
807
- )
808
- except botocore.exceptions.ParamValidationError as error:
809
- if (
810
- 'Unknown parameter in retrieveAndGenerateConfiguration: "externalSourcesConfiguration"'
811
- in str(error)
812
- ) and (self.source_type == "EXTERNAL_SOURCES"):
813
- self.log.error(
814
- "You are attempting to use External Sources and the BOTO API returned an "
815
- "error message which may indicate the need to update botocore to do this. \n"
816
- "Support for External Sources was added in botocore 1.34.90 and you are using botocore %s",
817
- ".".join(map(str, get_botocore_version())),
818
- )
819
- raise
801
+ result = self.hook.conn.retrieve_and_generate(
802
+ input={"text": self.input},
803
+ retrieveAndGenerateConfiguration=self.build_rag_config(),
804
+ **self.rag_kwargs,
805
+ )
820
806
 
821
807
  self.log.info(
822
808
  "\nPrompt: %s\nResponse: %s\nCitations: %s",
@@ -23,7 +23,10 @@ from airflow.configuration import conf
23
23
  from airflow.exceptions import AirflowException
24
24
  from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
25
25
  from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
26
- from airflow.providers.amazon.aws.triggers.comprehend import ComprehendPiiEntitiesDetectionJobCompletedTrigger
26
+ from airflow.providers.amazon.aws.triggers.comprehend import (
27
+ ComprehendCreateDocumentClassifierCompletedTrigger,
28
+ ComprehendPiiEntitiesDetectionJobCompletedTrigger,
29
+ )
27
30
  from airflow.providers.amazon.aws.utils import validate_execute_complete_event
28
31
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
29
32
  from airflow.utils.timezone import utcnow
@@ -190,3 +193,147 @@ class ComprehendStartPiiEntitiesDetectionJobOperator(ComprehendBaseOperator):
190
193
 
191
194
  self.log.info("Comprehend pii entities detection job `%s` complete.", event["job_id"])
192
195
  return event["job_id"]
196
+
197
+
198
+ class ComprehendCreateDocumentClassifierOperator(AwsBaseOperator[ComprehendHook]):
199
+ """
200
+ Create a comprehend document classifier that can categorize documents.
201
+
202
+ Provide a set of training documents that are labeled with the categories.
203
+
204
+ .. seealso::
205
+ For more information on how to use this operator, take a look at the guide:
206
+ :ref:`howto/operator:ComprehendCreateDocumentClassifierOperator`
207
+
208
+ :param document_classifier_name: The name of the document classifier. (templated)
209
+ :param input_data_config: Specifies the format and location of the input data for the job. (templated)
210
+ :param mode: Indicates the mode in which the classifier will be trained. (templated)
211
+ :param data_access_role_arn: The Amazon Resource Name (ARN) of the IAM role that grants Amazon Comprehend
212
+ read access to your input data. (templated)
213
+ :param language_code: The language of the input documents. You can specify any of the languages supported by
214
+ Amazon Comprehend. All documents must be in the same language. (templated)
215
+ :param fail_on_warnings: If set to True, the document classifier training job will throw an error when the
216
+ status is TRAINED_WITH_WARNING. (default False)
217
+ :param output_data_config: Specifies the location for the output files from a custom classifier job.
218
+ This parameter is required for a request that creates a native document model. (templated)
219
+ :param document_classifier_kwargs: Any optional parameters to pass to the document classifier. (templated)
220
+
221
+ :param wait_for_completion: Whether to wait for job to stop. (default: True)
222
+ :param waiter_delay: Time in seconds to wait between status checks. (default: 60)
223
+ :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20)
224
+ :param deferrable: If True, the operator will wait asynchronously for the job to stop.
225
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
226
+ (default: False)
227
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
228
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
229
+ running Airflow in a distributed manner and aws_conn_id is None or
230
+ empty, then default boto3 configuration would be used (and must be
231
+ maintained on each worker node).
232
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
233
+ :param verify: Whether to verify SSL certificates. See:
234
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
235
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
236
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
237
+ """
238
+
239
+ aws_hook_class = ComprehendHook
240
+
241
+ template_fields: Sequence[str] = aws_template_fields(
242
+ "document_classifier_name",
243
+ "input_data_config",
244
+ "mode",
245
+ "data_access_role_arn",
246
+ "language_code",
247
+ "output_data_config",
248
+ "document_classifier_kwargs",
249
+ )
250
+
251
+ template_fields_renderers: dict = {
252
+ "input_data_config": "json",
253
+ "output_data_config": "json",
254
+ "document_classifier_kwargs": "json",
255
+ }
256
+
257
+ def __init__(
258
+ self,
259
+ document_classifier_name: str,
260
+ input_data_config: dict[str, Any],
261
+ mode: str,
262
+ data_access_role_arn: str,
263
+ language_code: str,
264
+ fail_on_warnings: bool = False,
265
+ output_data_config: dict[str, Any] | None = None,
266
+ document_classifier_kwargs: dict[str, Any] | None = None,
267
+ wait_for_completion: bool = True,
268
+ waiter_delay: int = 60,
269
+ waiter_max_attempts: int = 20,
270
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
271
+ aws_conn_id: str | None = "aws_default",
272
+ **kwargs,
273
+ ):
274
+ super().__init__(**kwargs)
275
+ self.document_classifier_name = document_classifier_name
276
+ self.input_data_config = input_data_config
277
+ self.mode = mode
278
+ self.data_access_role_arn = data_access_role_arn
279
+ self.language_code = language_code
280
+ self.fail_on_warnings = fail_on_warnings
281
+ self.output_data_config = output_data_config
282
+ self.document_classifier_kwargs = document_classifier_kwargs or {}
283
+ self.wait_for_completion = wait_for_completion
284
+ self.waiter_delay = waiter_delay
285
+ self.waiter_max_attempts = waiter_max_attempts
286
+ self.deferrable = deferrable
287
+ self.aws_conn_id = aws_conn_id
288
+
289
+ def execute(self, context: Context) -> str:
290
+ if self.output_data_config:
291
+ self.document_classifier_kwargs["OutputDataConfig"] = self.output_data_config
292
+
293
+ document_classifier_arn = self.hook.conn.create_document_classifier(
294
+ DocumentClassifierName=self.document_classifier_name,
295
+ InputDataConfig=self.input_data_config,
296
+ Mode=self.mode,
297
+ DataAccessRoleArn=self.data_access_role_arn,
298
+ LanguageCode=self.language_code,
299
+ **self.document_classifier_kwargs,
300
+ )["DocumentClassifierArn"]
301
+
302
+ message_description = f"document classifier {document_classifier_arn} to complete."
303
+ if self.deferrable:
304
+ self.log.info("Deferring %s", message_description)
305
+ self.defer(
306
+ trigger=ComprehendCreateDocumentClassifierCompletedTrigger(
307
+ document_classifier_arn=document_classifier_arn,
308
+ waiter_delay=self.waiter_delay,
309
+ waiter_max_attempts=self.waiter_max_attempts,
310
+ aws_conn_id=self.aws_conn_id,
311
+ ),
312
+ method_name="execute_complete",
313
+ )
314
+ elif self.wait_for_completion:
315
+ self.log.info("Waiting for %s", message_description)
316
+
317
+ self.hook.get_waiter("create_document_classifier_complete").wait(
318
+ DocumentClassifierArn=document_classifier_arn,
319
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
320
+ )
321
+
322
+ self.hook.validate_document_classifier_training_status(
323
+ document_classifier_arn=document_classifier_arn, fail_on_warnings=self.fail_on_warnings
324
+ )
325
+
326
+ return document_classifier_arn
327
+
328
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
329
+ event = validate_execute_complete_event(event)
330
+ if event["status"] != "success":
331
+ raise AirflowException("Error while running comprehend create document classifier: %s", event)
332
+
333
+ self.hook.validate_document_classifier_training_status(
334
+ document_classifier_arn=event["document_classifier_arn"], fail_on_warnings=self.fail_on_warnings
335
+ )
336
+
337
+ self.log.info("Comprehend document classifier `%s` complete.", event["document_classifier_arn"])
338
+
339
+ return event["document_classifier_arn"]
@@ -263,30 +263,34 @@ class EmrStartNotebookExecutionOperator(BaseOperator):
263
263
  wait_for_completion: bool = False,
264
264
  aws_conn_id: str | None = "aws_default",
265
265
  # TODO: waiter_max_attempts and waiter_delay should default to None when the other two are deprecated.
266
- waiter_max_attempts: int | None | ArgNotSet = NOTSET,
267
- waiter_delay: int | None | ArgNotSet = NOTSET,
268
- waiter_countdown: int = 25 * 60,
269
- waiter_check_interval_seconds: int = 60,
266
+ waiter_max_attempts: int | None = None,
267
+ waiter_delay: int | None = None,
268
+ waiter_countdown: int | None = None,
269
+ waiter_check_interval_seconds: int | None = None,
270
270
  **kwargs: Any,
271
271
  ):
272
- if waiter_max_attempts is NOTSET:
272
+ if waiter_check_interval_seconds:
273
273
  warnings.warn(
274
- "The parameter waiter_countdown has been deprecated to standardize "
275
- "naming conventions. Please use waiter_max_attempts instead. In the "
274
+ "The parameter `waiter_check_interval_seconds` has been deprecated to "
275
+ "standardize naming conventions. Please `use waiter_delay instead`. In the "
276
276
  "future this will default to None and defer to the waiter's default value.",
277
277
  AirflowProviderDeprecationWarning,
278
278
  stacklevel=2,
279
279
  )
280
- waiter_max_attempts = waiter_countdown // waiter_check_interval_seconds
281
- if waiter_delay is NOTSET:
280
+ else:
281
+ waiter_check_interval_seconds = 60
282
+ if waiter_countdown:
282
283
  warnings.warn(
283
- "The parameter waiter_check_interval_seconds has been deprecated to "
284
- "standardize naming conventions. Please use waiter_delay instead. In the "
284
+ "The parameter waiter_countdown has been deprecated to standardize "
285
+ "naming conventions. Please use waiter_max_attempts instead. In the "
285
286
  "future this will default to None and defer to the waiter's default value.",
286
287
  AirflowProviderDeprecationWarning,
287
288
  stacklevel=2,
288
289
  )
289
- waiter_delay = waiter_check_interval_seconds
290
+ # waiter_countdown defaults to never timing out, which is not supported
291
+ # by boto waiters, so we will set it here to "a very long time" for now.
292
+ waiter_max_attempts = (waiter_countdown or 999) // waiter_check_interval_seconds
293
+
290
294
  super().__init__(**kwargs)
291
295
  self.editor_id = editor_id
292
296
  self.relative_path = relative_path
@@ -298,8 +302,8 @@ class EmrStartNotebookExecutionOperator(BaseOperator):
298
302
  self.wait_for_completion = wait_for_completion
299
303
  self.cluster_id = cluster_id
300
304
  self.aws_conn_id = aws_conn_id
301
- self.waiter_max_attempts = waiter_max_attempts
302
- self.waiter_delay = waiter_delay
305
+ self.waiter_max_attempts = waiter_max_attempts or 25
306
+ self.waiter_delay = waiter_delay or waiter_check_interval_seconds or 60
303
307
  self.master_instance_security_group_id = master_instance_security_group_id
304
308
 
305
309
  def execute(self, context: Context):
@@ -387,36 +391,40 @@ class EmrStopNotebookExecutionOperator(BaseOperator):
387
391
  wait_for_completion: bool = False,
388
392
  aws_conn_id: str | None = "aws_default",
389
393
  # TODO: waiter_max_attempts and waiter_delay should default to None when the other two are deprecated.
390
- waiter_max_attempts: int | None | ArgNotSet = NOTSET,
391
- waiter_delay: int | None | ArgNotSet = NOTSET,
392
- waiter_countdown: int = 25 * 60,
393
- waiter_check_interval_seconds: int = 60,
394
+ waiter_max_attempts: int | None = None,
395
+ waiter_delay: int | None = None,
396
+ waiter_countdown: int | None = None,
397
+ waiter_check_interval_seconds: int | None = None,
394
398
  **kwargs: Any,
395
399
  ):
396
- if waiter_max_attempts is NOTSET:
400
+ if waiter_check_interval_seconds:
397
401
  warnings.warn(
398
- "The parameter waiter_countdown has been deprecated to standardize "
399
- "naming conventions. Please use waiter_max_attempts instead. In the "
402
+ "The parameter `waiter_check_interval_seconds` has been deprecated to "
403
+ "standardize naming conventions. Please `use waiter_delay instead`. In the "
400
404
  "future this will default to None and defer to the waiter's default value.",
401
405
  AirflowProviderDeprecationWarning,
402
406
  stacklevel=2,
403
407
  )
404
- waiter_max_attempts = waiter_countdown // waiter_check_interval_seconds
405
- if waiter_delay is NOTSET:
408
+ else:
409
+ waiter_check_interval_seconds = 60
410
+ if waiter_countdown:
406
411
  warnings.warn(
407
- "The parameter waiter_check_interval_seconds has been deprecated to "
408
- "standardize naming conventions. Please use waiter_delay instead. In the "
412
+ "The parameter waiter_countdown has been deprecated to standardize "
413
+ "naming conventions. Please use waiter_max_attempts instead. In the "
409
414
  "future this will default to None and defer to the waiter's default value.",
410
415
  AirflowProviderDeprecationWarning,
411
416
  stacklevel=2,
412
417
  )
413
- waiter_delay = waiter_check_interval_seconds
418
+ # waiter_countdown defaults to never timing out, which is not supported
419
+ # by boto waiters, so we will set it here to "a very long time" for now.
420
+ waiter_max_attempts = (waiter_countdown or 999) // waiter_check_interval_seconds
421
+
414
422
  super().__init__(**kwargs)
415
423
  self.notebook_execution_id = notebook_execution_id
416
424
  self.wait_for_completion = wait_for_completion
417
425
  self.aws_conn_id = aws_conn_id
418
- self.waiter_max_attempts = waiter_max_attempts
419
- self.waiter_delay = waiter_delay
426
+ self.waiter_max_attempts = waiter_max_attempts or 25
427
+ self.waiter_delay = waiter_delay or waiter_check_interval_seconds or 60
420
428
 
421
429
  def execute(self, context: Context) -> None:
422
430
  emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
@@ -822,8 +830,8 @@ class EmrCreateJobFlowOperator(BaseOperator):
822
830
  trigger=EmrCreateJobFlowTrigger(
823
831
  job_flow_id=self._job_flow_id,
824
832
  aws_conn_id=self.aws_conn_id,
825
- poll_interval=self.waiter_delay,
826
- max_attempts=self.waiter_max_attempts,
833
+ waiter_delay=self.waiter_delay,
834
+ waiter_max_attempts=self.waiter_max_attempts,
827
835
  ),
828
836
  method_name="execute_complete",
829
837
  # timeout is set to ensure that if a trigger dies, the timeout does not restart