apache-airflow-providers-amazon 8.24.0rc1__py3-none-any.whl → 8.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/LICENSE +4 -4
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/hooks/base_aws.py +8 -3
- airflow/providers/amazon/aws/hooks/comprehend.py +33 -0
- airflow/providers/amazon/aws/hooks/glue.py +123 -0
- airflow/providers/amazon/aws/hooks/redshift_sql.py +8 -1
- airflow/providers/amazon/aws/operators/bedrock.py +6 -20
- airflow/providers/amazon/aws/operators/comprehend.py +148 -1
- airflow/providers/amazon/aws/operators/emr.py +38 -30
- airflow/providers/amazon/aws/operators/glue.py +408 -2
- airflow/providers/amazon/aws/operators/sagemaker.py +85 -12
- airflow/providers/amazon/aws/sensors/comprehend.py +112 -1
- airflow/providers/amazon/aws/sensors/glue.py +260 -2
- airflow/providers/amazon/aws/sensors/s3.py +35 -5
- airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +0 -1
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -1
- airflow/providers/amazon/aws/triggers/comprehend.py +36 -0
- airflow/providers/amazon/aws/triggers/glue.py +76 -2
- airflow/providers/amazon/aws/utils/__init__.py +2 -3
- airflow/providers/amazon/aws/waiters/comprehend.json +55 -0
- airflow/providers/amazon/aws/waiters/glue.json +98 -0
- airflow/providers/amazon/get_provider_info.py +20 -13
- {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/METADATA +22 -21
- {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/RECORD +26 -26
- {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/entry_points.txt +0 -0
@@ -22,13 +22,20 @@ import urllib.parse
|
|
22
22
|
from functools import cached_property
|
23
23
|
from typing import TYPE_CHECKING, Any, Sequence
|
24
24
|
|
25
|
+
from botocore.exceptions import ClientError
|
26
|
+
|
25
27
|
from airflow.configuration import conf
|
26
28
|
from airflow.exceptions import AirflowException
|
27
29
|
from airflow.models import BaseOperator
|
28
|
-
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
|
30
|
+
from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
|
29
31
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
30
32
|
from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink
|
31
|
-
from airflow.providers.amazon.aws.
|
33
|
+
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
34
|
+
from airflow.providers.amazon.aws.triggers.glue import (
|
35
|
+
GlueDataQualityRuleRecommendationRunCompleteTrigger,
|
36
|
+
GlueDataQualityRuleSetEvaluationRunCompleteTrigger,
|
37
|
+
GlueJobCompleteTrigger,
|
38
|
+
)
|
32
39
|
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
33
40
|
|
34
41
|
if TYPE_CHECKING:
|
@@ -239,3 +246,402 @@ class GlueJobOperator(BaseOperator):
|
|
239
246
|
)
|
240
247
|
if not response["SuccessfulSubmissions"]:
|
241
248
|
self.log.error("Failed to stop AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
|
249
|
+
|
250
|
+
|
251
|
+
class GlueDataQualityOperator(AwsBaseOperator[GlueDataQualityHook]):
|
252
|
+
"""
|
253
|
+
Creates a data quality ruleset with DQDL rules applied to a specified Glue table.
|
254
|
+
|
255
|
+
.. seealso::
|
256
|
+
For more information on how to use this operator, take a look at the guide:
|
257
|
+
:ref:`howto/operator:GlueDataQualityOperator`
|
258
|
+
|
259
|
+
:param name: A unique name for the data quality ruleset.
|
260
|
+
:param ruleset: A Data Quality Definition Language (DQDL) ruleset.
|
261
|
+
For more information, see the Glue developer guide.
|
262
|
+
:param description: A description of the data quality ruleset.
|
263
|
+
:param update_rule_set: To update existing ruleset, Set this flag to True. (default: False)
|
264
|
+
:param data_quality_ruleset_kwargs: Extra arguments for RuleSet.
|
265
|
+
|
266
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
267
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
268
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
269
|
+
empty, then default boto3 configuration would be used (and must be
|
270
|
+
maintained on each worker node).
|
271
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
272
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
273
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
274
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
275
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
276
|
+
"""
|
277
|
+
|
278
|
+
aws_hook_class = GlueDataQualityHook
|
279
|
+
template_fields: Sequence[str] = ("name", "ruleset", "description", "data_quality_ruleset_kwargs")
|
280
|
+
|
281
|
+
template_fields_renderers = {
|
282
|
+
"data_quality_ruleset_kwargs": "json",
|
283
|
+
}
|
284
|
+
ui_color = "#ededed"
|
285
|
+
|
286
|
+
def __init__(
|
287
|
+
self,
|
288
|
+
*,
|
289
|
+
name: str,
|
290
|
+
ruleset: str,
|
291
|
+
description: str = "AWS Glue Data Quality Rule Set With Airflow",
|
292
|
+
update_rule_set: bool = False,
|
293
|
+
data_quality_ruleset_kwargs: dict | None = None,
|
294
|
+
aws_conn_id: str | None = "aws_default",
|
295
|
+
**kwargs,
|
296
|
+
):
|
297
|
+
super().__init__(**kwargs)
|
298
|
+
self.name = name
|
299
|
+
self.ruleset = ruleset.strip()
|
300
|
+
self.description = description
|
301
|
+
self.update_rule_set = update_rule_set
|
302
|
+
self.data_quality_ruleset_kwargs = data_quality_ruleset_kwargs or {}
|
303
|
+
self.aws_conn_id = aws_conn_id
|
304
|
+
|
305
|
+
def validate_inputs(self) -> None:
|
306
|
+
if not self.ruleset.startswith("Rules") or not self.ruleset.endswith("]"):
|
307
|
+
raise AttributeError("RuleSet must starts with Rules = [ and ends with ]")
|
308
|
+
|
309
|
+
if self.data_quality_ruleset_kwargs.get("TargetTable"):
|
310
|
+
target_table = self.data_quality_ruleset_kwargs["TargetTable"]
|
311
|
+
|
312
|
+
if not target_table.get("TableName") or not target_table.get("DatabaseName"):
|
313
|
+
raise AttributeError("Target table must have DatabaseName and TableName")
|
314
|
+
|
315
|
+
def execute(self, context: Context):
|
316
|
+
self.validate_inputs()
|
317
|
+
|
318
|
+
config = {
|
319
|
+
"Name": self.name,
|
320
|
+
"Ruleset": self.ruleset,
|
321
|
+
"Description": self.description,
|
322
|
+
**self.data_quality_ruleset_kwargs,
|
323
|
+
}
|
324
|
+
try:
|
325
|
+
if self.update_rule_set:
|
326
|
+
self.hook.conn.update_data_quality_ruleset(**config)
|
327
|
+
self.log.info("AWS Glue data quality ruleset updated successfully")
|
328
|
+
else:
|
329
|
+
self.hook.conn.create_data_quality_ruleset(**config)
|
330
|
+
self.log.info("AWS Glue data quality ruleset created successfully")
|
331
|
+
except ClientError as error:
|
332
|
+
raise AirflowException(
|
333
|
+
f"AWS Glue data quality ruleset failed: {error.response['Error']['Message']}"
|
334
|
+
)
|
335
|
+
|
336
|
+
|
337
|
+
class GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualityHook]):
|
338
|
+
"""
|
339
|
+
Evaluate a ruleset against a data source (Glue table).
|
340
|
+
|
341
|
+
.. seealso::
|
342
|
+
For more information on how to use this operator, take a look at the guide:
|
343
|
+
:ref:`howto/operator:GlueDataQualityRuleSetEvaluationRunOperator`
|
344
|
+
|
345
|
+
:param datasource: The data source (Glue table) associated with this run. (templated)
|
346
|
+
:param role: IAM role supplied for job execution. (templated)
|
347
|
+
:param rule_set_names: A list of ruleset names for evaluation. (templated)
|
348
|
+
:param number_of_workers: The number of G.1X workers to be used in the run. (default: 5)
|
349
|
+
:param timeout: The timeout for a run in minutes. This is the maximum time that a run can consume resources
|
350
|
+
before it is terminated and enters TIMEOUT status. (default: 2,880)
|
351
|
+
:param verify_result_status: Validate all the ruleset rules evaluation run results,
|
352
|
+
If any of the rule status is Fail or Error then an exception is thrown. (default: True)
|
353
|
+
:param show_results: Displays all the ruleset rules evaluation run results. (default: True)
|
354
|
+
:param rule_set_evaluation_run_kwargs: Extra arguments for evaluation run. (templated)
|
355
|
+
:param wait_for_completion: Whether to wait for job to stop. (default: True)
|
356
|
+
:param waiter_delay: Time in seconds to wait between status checks. (default: 60)
|
357
|
+
:param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20)
|
358
|
+
:param deferrable: If True, the operator will wait asynchronously for the job to stop.
|
359
|
+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
360
|
+
(default: False)
|
361
|
+
|
362
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
363
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
364
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
365
|
+
empty, then default boto3 configuration would be used (and must be
|
366
|
+
maintained on each worker node).
|
367
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
368
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
369
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
370
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
371
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
372
|
+
"""
|
373
|
+
|
374
|
+
aws_hook_class = GlueDataQualityHook
|
375
|
+
|
376
|
+
template_fields: Sequence[str] = (
|
377
|
+
"datasource",
|
378
|
+
"role",
|
379
|
+
"rule_set_names",
|
380
|
+
"rule_set_evaluation_run_kwargs",
|
381
|
+
)
|
382
|
+
|
383
|
+
template_fields_renderers = {"datasource": "json", "rule_set_evaluation_run_kwargs": "json"}
|
384
|
+
|
385
|
+
ui_color = "#ededed"
|
386
|
+
|
387
|
+
def __init__(
|
388
|
+
self,
|
389
|
+
*,
|
390
|
+
datasource: dict,
|
391
|
+
role: str,
|
392
|
+
rule_set_names: list[str],
|
393
|
+
number_of_workers: int = 5,
|
394
|
+
timeout: int = 2880,
|
395
|
+
verify_result_status: bool = True,
|
396
|
+
show_results: bool = True,
|
397
|
+
rule_set_evaluation_run_kwargs: dict[str, Any] | None = None,
|
398
|
+
wait_for_completion: bool = True,
|
399
|
+
waiter_delay: int = 60,
|
400
|
+
waiter_max_attempts: int = 20,
|
401
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
402
|
+
aws_conn_id: str | None = "aws_default",
|
403
|
+
**kwargs,
|
404
|
+
):
|
405
|
+
super().__init__(**kwargs)
|
406
|
+
self.datasource = datasource
|
407
|
+
self.role = role
|
408
|
+
self.rule_set_names = rule_set_names
|
409
|
+
self.number_of_workers = number_of_workers
|
410
|
+
self.timeout = timeout
|
411
|
+
self.verify_result_status = verify_result_status
|
412
|
+
self.show_results = show_results
|
413
|
+
self.rule_set_evaluation_run_kwargs = rule_set_evaluation_run_kwargs or {}
|
414
|
+
self.wait_for_completion = wait_for_completion
|
415
|
+
self.waiter_delay = waiter_delay
|
416
|
+
self.waiter_max_attempts = waiter_max_attempts
|
417
|
+
self.deferrable = deferrable
|
418
|
+
self.aws_conn_id = aws_conn_id
|
419
|
+
|
420
|
+
def validate_inputs(self) -> None:
|
421
|
+
glue_table = self.datasource.get("GlueTable", {})
|
422
|
+
|
423
|
+
if not glue_table.get("DatabaseName") or not glue_table.get("TableName"):
|
424
|
+
raise AttributeError("DataSource glue table must have DatabaseName and TableName")
|
425
|
+
|
426
|
+
not_found_ruleset = [
|
427
|
+
ruleset_name
|
428
|
+
for ruleset_name in self.rule_set_names
|
429
|
+
if not self.hook.has_data_quality_ruleset(ruleset_name)
|
430
|
+
]
|
431
|
+
|
432
|
+
if not_found_ruleset:
|
433
|
+
raise AirflowException(f"Following RulesetNames are not found {not_found_ruleset}")
|
434
|
+
|
435
|
+
def execute(self, context: Context) -> str:
|
436
|
+
self.validate_inputs()
|
437
|
+
|
438
|
+
self.log.info(
|
439
|
+
"Submitting AWS Glue data quality ruleset evaluation run for RulesetNames %s", self.rule_set_names
|
440
|
+
)
|
441
|
+
|
442
|
+
response = self.hook.conn.start_data_quality_ruleset_evaluation_run(
|
443
|
+
DataSource=self.datasource,
|
444
|
+
Role=self.role,
|
445
|
+
NumberOfWorkers=self.number_of_workers,
|
446
|
+
Timeout=self.timeout,
|
447
|
+
RulesetNames=self.rule_set_names,
|
448
|
+
**self.rule_set_evaluation_run_kwargs,
|
449
|
+
)
|
450
|
+
|
451
|
+
evaluation_run_id = response["RunId"]
|
452
|
+
|
453
|
+
message_description = (
|
454
|
+
f"AWS Glue data quality ruleset evaluation run RunId: {evaluation_run_id} to complete."
|
455
|
+
)
|
456
|
+
if self.deferrable:
|
457
|
+
self.log.info("Deferring %s", message_description)
|
458
|
+
self.defer(
|
459
|
+
trigger=GlueDataQualityRuleSetEvaluationRunCompleteTrigger(
|
460
|
+
evaluation_run_id=response["RunId"],
|
461
|
+
waiter_delay=self.waiter_delay,
|
462
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
463
|
+
aws_conn_id=self.aws_conn_id,
|
464
|
+
),
|
465
|
+
method_name="execute_complete",
|
466
|
+
)
|
467
|
+
|
468
|
+
elif self.wait_for_completion:
|
469
|
+
self.log.info("Waiting for %s", message_description)
|
470
|
+
|
471
|
+
self.hook.get_waiter("data_quality_ruleset_evaluation_run_complete").wait(
|
472
|
+
RunId=evaluation_run_id,
|
473
|
+
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
|
474
|
+
)
|
475
|
+
|
476
|
+
self.log.info(
|
477
|
+
"AWS Glue data quality ruleset evaluation run completed RunId: %s", evaluation_run_id
|
478
|
+
)
|
479
|
+
|
480
|
+
self.hook.validate_evaluation_run_results(
|
481
|
+
evaluation_run_id=evaluation_run_id,
|
482
|
+
show_results=self.show_results,
|
483
|
+
verify_result_status=self.verify_result_status,
|
484
|
+
)
|
485
|
+
else:
|
486
|
+
self.log.info("AWS Glue data quality ruleset evaluation run runId: %s.", evaluation_run_id)
|
487
|
+
|
488
|
+
return evaluation_run_id
|
489
|
+
|
490
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
491
|
+
event = validate_execute_complete_event(event)
|
492
|
+
|
493
|
+
if event["status"] != "success":
|
494
|
+
raise AirflowException(f"Error: AWS Glue data quality ruleset evaluation run: {event}")
|
495
|
+
|
496
|
+
self.hook.validate_evaluation_run_results(
|
497
|
+
evaluation_run_id=event["evaluation_run_id"],
|
498
|
+
show_results=self.show_results,
|
499
|
+
verify_result_status=self.verify_result_status,
|
500
|
+
)
|
501
|
+
|
502
|
+
return event["evaluation_run_id"]
|
503
|
+
|
504
|
+
|
505
|
+
class GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQualityHook]):
|
506
|
+
"""
|
507
|
+
Starts a recommendation run that is used to generate rules, Glue Data Quality analyzes the data and comes up with recommendations for a potential ruleset.
|
508
|
+
|
509
|
+
Recommendation runs are automatically deleted after 90 days.
|
510
|
+
|
511
|
+
.. seealso::
|
512
|
+
For more information on how to use this operator, take a look at the guide:
|
513
|
+
:ref:`howto/operator:GlueDataQualityRuleRecommendationRunOperator`
|
514
|
+
|
515
|
+
:param datasource: The data source (Glue table) associated with this run. (templated)
|
516
|
+
:param role: IAM role supplied for job execution. (templated)
|
517
|
+
:param number_of_workers: The number of G.1X workers to be used in the run. (default: 5)
|
518
|
+
:param timeout: The timeout for a run in minutes. This is the maximum time that a run can consume resources
|
519
|
+
before it is terminated and enters TIMEOUT status. (default: 2,880)
|
520
|
+
:param show_results: Displays the recommended ruleset (a set of rules), when recommendation run completes. (default: True)
|
521
|
+
:param recommendation_run_kwargs: Extra arguments for recommendation run. (templated)
|
522
|
+
:param wait_for_completion: Whether to wait for job to stop. (default: True)
|
523
|
+
:param waiter_delay: Time in seconds to wait between status checks. (default: 60)
|
524
|
+
:param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20)
|
525
|
+
:param deferrable: If True, the operator will wait asynchronously for the job to stop.
|
526
|
+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
527
|
+
(default: False)
|
528
|
+
|
529
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
530
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
531
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
532
|
+
empty, then default boto3 configuration would be used (and must be
|
533
|
+
maintained on each worker node).
|
534
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
535
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
536
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
537
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
538
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
539
|
+
"""
|
540
|
+
|
541
|
+
aws_hook_class = GlueDataQualityHook
|
542
|
+
template_fields: Sequence[str] = (
|
543
|
+
"datasource",
|
544
|
+
"role",
|
545
|
+
"recommendation_run_kwargs",
|
546
|
+
)
|
547
|
+
|
548
|
+
template_fields_renderers = {"datasource": "json", "recommendation_run_kwargs": "json"}
|
549
|
+
|
550
|
+
ui_color = "#ededed"
|
551
|
+
|
552
|
+
def __init__(
|
553
|
+
self,
|
554
|
+
*,
|
555
|
+
datasource: dict,
|
556
|
+
role: str,
|
557
|
+
number_of_workers: int = 5,
|
558
|
+
timeout: int = 2880,
|
559
|
+
show_results: bool = True,
|
560
|
+
recommendation_run_kwargs: dict[str, Any] | None = None,
|
561
|
+
wait_for_completion: bool = True,
|
562
|
+
waiter_delay: int = 60,
|
563
|
+
waiter_max_attempts: int = 20,
|
564
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
565
|
+
aws_conn_id: str | None = "aws_default",
|
566
|
+
**kwargs,
|
567
|
+
):
|
568
|
+
super().__init__(**kwargs)
|
569
|
+
self.datasource = datasource
|
570
|
+
self.role = role
|
571
|
+
self.number_of_workers = number_of_workers
|
572
|
+
self.timeout = timeout
|
573
|
+
self.show_results = show_results
|
574
|
+
self.recommendation_run_kwargs = recommendation_run_kwargs or {}
|
575
|
+
self.wait_for_completion = wait_for_completion
|
576
|
+
self.waiter_delay = waiter_delay
|
577
|
+
self.waiter_max_attempts = waiter_max_attempts
|
578
|
+
self.deferrable = deferrable
|
579
|
+
self.aws_conn_id = aws_conn_id
|
580
|
+
|
581
|
+
def execute(self, context: Context) -> str:
|
582
|
+
glue_table = self.datasource.get("GlueTable", {})
|
583
|
+
|
584
|
+
if not glue_table.get("DatabaseName") or not glue_table.get("TableName"):
|
585
|
+
raise AttributeError("DataSource glue table must have DatabaseName and TableName")
|
586
|
+
|
587
|
+
self.log.info("Submitting AWS Glue data quality recommendation run with %s", self.datasource)
|
588
|
+
|
589
|
+
try:
|
590
|
+
response = self.hook.conn.start_data_quality_rule_recommendation_run(
|
591
|
+
DataSource=self.datasource,
|
592
|
+
Role=self.role,
|
593
|
+
NumberOfWorkers=self.number_of_workers,
|
594
|
+
Timeout=self.timeout,
|
595
|
+
**self.recommendation_run_kwargs,
|
596
|
+
)
|
597
|
+
except ClientError as error:
|
598
|
+
raise AirflowException(
|
599
|
+
f"AWS Glue data quality recommendation run failed: {error.response['Error']['Message']}"
|
600
|
+
)
|
601
|
+
|
602
|
+
recommendation_run_id = response["RunId"]
|
603
|
+
|
604
|
+
message_description = (
|
605
|
+
f"AWS Glue data quality recommendation run RunId: {recommendation_run_id} to complete."
|
606
|
+
)
|
607
|
+
if self.deferrable:
|
608
|
+
self.log.info("Deferring %s", message_description)
|
609
|
+
self.defer(
|
610
|
+
trigger=GlueDataQualityRuleRecommendationRunCompleteTrigger(
|
611
|
+
recommendation_run_id=recommendation_run_id,
|
612
|
+
waiter_delay=self.waiter_delay,
|
613
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
614
|
+
aws_conn_id=self.aws_conn_id,
|
615
|
+
),
|
616
|
+
method_name="execute_complete",
|
617
|
+
)
|
618
|
+
|
619
|
+
elif self.wait_for_completion:
|
620
|
+
self.log.info("Waiting for %s", message_description)
|
621
|
+
|
622
|
+
self.hook.get_waiter("data_quality_rule_recommendation_run_complete").wait(
|
623
|
+
RunId=recommendation_run_id,
|
624
|
+
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
|
625
|
+
)
|
626
|
+
self.log.info(
|
627
|
+
"AWS Glue data quality recommendation run completed RunId: %s", recommendation_run_id
|
628
|
+
)
|
629
|
+
|
630
|
+
if self.show_results:
|
631
|
+
self.hook.log_recommendation_results(run_id=recommendation_run_id)
|
632
|
+
|
633
|
+
else:
|
634
|
+
self.log.info(message_description)
|
635
|
+
|
636
|
+
return recommendation_run_id
|
637
|
+
|
638
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
639
|
+
event = validate_execute_complete_event(event)
|
640
|
+
|
641
|
+
if event["status"] != "success":
|
642
|
+
raise AirflowException(f"Error: AWS Glue data quality rule recommendation run: {event}")
|
643
|
+
|
644
|
+
if self.show_results:
|
645
|
+
self.hook.log_recommendation_results(run_id=event["recommendation_run_id"])
|
646
|
+
|
647
|
+
return event["recommendation_run_id"]
|
@@ -136,26 +136,64 @@ class SageMakerBaseOperator(BaseOperator):
|
|
136
136
|
:param describe_func: The `describe_` function for that kind of job.
|
137
137
|
We use it as an O(1) way to check if a job exists.
|
138
138
|
"""
|
139
|
-
|
140
|
-
|
139
|
+
return self._get_unique_name(
|
140
|
+
proposed_name, fail_if_exists, describe_func, self._check_if_job_exists, "job"
|
141
|
+
)
|
142
|
+
|
143
|
+
def _get_unique_name(
|
144
|
+
self,
|
145
|
+
proposed_name: str,
|
146
|
+
fail_if_exists: bool,
|
147
|
+
describe_func: Callable[[str], Any],
|
148
|
+
check_exists_func: Callable[[str, Callable[[str], Any]], bool],
|
149
|
+
resource_type: str,
|
150
|
+
) -> str:
|
151
|
+
"""
|
152
|
+
Return the proposed name if it doesn't already exist, otherwise returns it with a timestamp suffix.
|
153
|
+
|
154
|
+
:param proposed_name: Base name.
|
155
|
+
:param fail_if_exists: Will throw an error if a resource with that name already exists
|
156
|
+
instead of finding a new name.
|
157
|
+
:param check_exists_func: The function to check if the resource exists.
|
158
|
+
It should take the resource name and a describe function as arguments.
|
159
|
+
:param resource_type: Type of the resource (e.g., "model", "job").
|
160
|
+
"""
|
161
|
+
self._check_resource_type(resource_type)
|
162
|
+
name = proposed_name
|
163
|
+
while check_exists_func(name, describe_func):
|
141
164
|
# this while should loop only once in most cases, just setting it this way to regenerate a name
|
142
165
|
# in case there is collision.
|
143
166
|
if fail_if_exists:
|
144
|
-
raise AirflowException(f"A SageMaker
|
167
|
+
raise AirflowException(f"A SageMaker {resource_type} with name {name} already exists.")
|
145
168
|
else:
|
146
|
-
|
147
|
-
self.log.info("Changed
|
148
|
-
return
|
169
|
+
name = f"{proposed_name}-{time.time_ns()//1000000}"
|
170
|
+
self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
|
171
|
+
return name
|
172
|
+
|
173
|
+
def _check_resource_type(self, resource_type: str):
|
174
|
+
"""Raise exception if resource type is not 'model' or 'job'."""
|
175
|
+
if resource_type not in ("model", "job"):
|
176
|
+
raise AirflowException(
|
177
|
+
"Argument resource_type accepts only 'model' and 'job'. "
|
178
|
+
f"Provided value: '{resource_type}'."
|
179
|
+
)
|
149
180
|
|
150
|
-
def _check_if_job_exists(self, job_name, describe_func: Callable[[str], Any]) -> bool:
|
181
|
+
def _check_if_job_exists(self, job_name: str, describe_func: Callable[[str], Any]) -> bool:
|
151
182
|
"""Return True if job exists, False otherwise."""
|
183
|
+
return self._check_if_resource_exists(job_name, "job", describe_func)
|
184
|
+
|
185
|
+
def _check_if_resource_exists(
|
186
|
+
self, resource_name: str, resource_type: str, describe_func: Callable[[str], Any]
|
187
|
+
) -> bool:
|
188
|
+
"""Return True if resource exists, False otherwise."""
|
189
|
+
self._check_resource_type(resource_type)
|
152
190
|
try:
|
153
|
-
describe_func(
|
154
|
-
self.log.info("Found existing
|
191
|
+
describe_func(resource_name)
|
192
|
+
self.log.info("Found existing %s with name '%s'.", resource_type, resource_name)
|
155
193
|
return True
|
156
194
|
except ClientError as e:
|
157
195
|
if e.response["Error"]["Code"] == "ValidationException":
|
158
|
-
return False # ValidationException is thrown when the
|
196
|
+
return False # ValidationException is thrown when the resource could not be found
|
159
197
|
else:
|
160
198
|
raise e
|
161
199
|
|
@@ -637,6 +675,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
637
675
|
max_ingestion_time: int | None = None,
|
638
676
|
check_if_job_exists: bool = True,
|
639
677
|
action_if_job_exists: str = "timestamp",
|
678
|
+
check_if_model_exists: bool = True,
|
679
|
+
action_if_model_exists: str = "timestamp",
|
640
680
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
641
681
|
**kwargs,
|
642
682
|
):
|
@@ -660,6 +700,14 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
660
700
|
f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \
|
661
701
|
Provided value: '{action_if_job_exists}'."
|
662
702
|
)
|
703
|
+
self.check_if_model_exists = check_if_model_exists
|
704
|
+
if action_if_model_exists in ("fail", "timestamp"):
|
705
|
+
self.action_if_model_exists = action_if_model_exists
|
706
|
+
else:
|
707
|
+
raise AirflowException(
|
708
|
+
f"Argument action_if_model_exists accepts only 'timestamp' and 'fail'. \
|
709
|
+
Provided value: '{action_if_model_exists}'."
|
710
|
+
)
|
663
711
|
self.deferrable = deferrable
|
664
712
|
self.serialized_model: dict
|
665
713
|
self.serialized_transform: dict
|
@@ -697,6 +745,14 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
697
745
|
|
698
746
|
model_config = self.config.get("Model")
|
699
747
|
if model_config:
|
748
|
+
if self.check_if_model_exists:
|
749
|
+
model_config["ModelName"] = self._get_unique_model_name(
|
750
|
+
model_config["ModelName"],
|
751
|
+
self.action_if_model_exists == "fail",
|
752
|
+
self.hook.describe_model,
|
753
|
+
)
|
754
|
+
if "ModelName" in self.config["Transform"].keys():
|
755
|
+
self.config["Transform"]["ModelName"] = model_config["ModelName"]
|
700
756
|
self.log.info("Creating SageMaker Model %s for transform job", model_config["ModelName"])
|
701
757
|
self.hook.create_model(model_config)
|
702
758
|
|
@@ -752,6 +808,17 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
752
808
|
|
753
809
|
return self.serialize_result(transform_config["TransformJobName"])
|
754
810
|
|
811
|
+
def _get_unique_model_name(
|
812
|
+
self, proposed_name: str, fail_if_exists: bool, describe_func: Callable[[str], Any]
|
813
|
+
) -> str:
|
814
|
+
return self._get_unique_name(
|
815
|
+
proposed_name, fail_if_exists, describe_func, self._check_if_model_exists, "model"
|
816
|
+
)
|
817
|
+
|
818
|
+
def _check_if_model_exists(self, model_name: str, describe_func: Callable[[str], Any]) -> bool:
|
819
|
+
"""Return True if model exists, False otherwise."""
|
820
|
+
return self._check_if_resource_exists(model_name, "model", describe_func)
|
821
|
+
|
755
822
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
|
756
823
|
event = validate_execute_complete_event(event)
|
757
824
|
|
@@ -883,7 +950,8 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
|
|
883
950
|
def execute(self, context: Context) -> dict:
|
884
951
|
self.preprocess_config()
|
885
952
|
self.log.info(
|
886
|
-
"Creating SageMaker Hyper-Parameter Tuning Job %s",
|
953
|
+
"Creating SageMaker Hyper-Parameter Tuning Job %s",
|
954
|
+
self.config["HyperParameterTuningJobName"],
|
887
955
|
)
|
888
956
|
response = self.hook.create_tuning_job(
|
889
957
|
self.config,
|
@@ -1234,7 +1302,12 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
|
|
1234
1302
|
:return str: Returns The ARN of the pipeline execution created in Amazon SageMaker.
|
1235
1303
|
"""
|
1236
1304
|
|
1237
|
-
template_fields: Sequence[str] = (
|
1305
|
+
template_fields: Sequence[str] = (
|
1306
|
+
"aws_conn_id",
|
1307
|
+
"pipeline_name",
|
1308
|
+
"display_name",
|
1309
|
+
"pipeline_params",
|
1310
|
+
)
|
1238
1311
|
|
1239
1312
|
def __init__(
|
1240
1313
|
self,
|