apache-airflow-providers-amazon 8.24.0__py3-none-any.whl → 8.24.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/LICENSE +4 -4
- airflow/providers/amazon/aws/hooks/base_aws.py +3 -8
- airflow/providers/amazon/aws/hooks/glue.py +0 -123
- airflow/providers/amazon/aws/operators/bedrock.py +20 -6
- airflow/providers/amazon/aws/operators/emr.py +30 -38
- airflow/providers/amazon/aws/operators/glue.py +2 -408
- airflow/providers/amazon/aws/operators/sagemaker.py +12 -85
- airflow/providers/amazon/aws/sensors/glue.py +2 -260
- airflow/providers/amazon/aws/sensors/s3.py +5 -35
- airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +1 -0
- airflow/providers/amazon/aws/triggers/glue.py +2 -76
- airflow/providers/amazon/aws/waiters/glue.json +0 -98
- airflow/providers/amazon/get_provider_info.py +12 -18
- {apache_airflow_providers_amazon-8.24.0.dist-info → apache_airflow_providers_amazon-8.24.0rc1.dist-info}/METADATA +17 -18
- {apache_airflow_providers_amazon-8.24.0.dist-info → apache_airflow_providers_amazon-8.24.0rc1.dist-info}/RECORD +17 -17
- {apache_airflow_providers_amazon-8.24.0.dist-info → apache_airflow_providers_amazon-8.24.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.24.0.dist-info → apache_airflow_providers_amazon-8.24.0rc1.dist-info}/entry_points.txt +0 -0
@@ -22,20 +22,13 @@ import urllib.parse
|
|
22
22
|
from functools import cached_property
|
23
23
|
from typing import TYPE_CHECKING, Any, Sequence
|
24
24
|
|
25
|
-
from botocore.exceptions import ClientError
|
26
|
-
|
27
25
|
from airflow.configuration import conf
|
28
26
|
from airflow.exceptions import AirflowException
|
29
27
|
from airflow.models import BaseOperator
|
30
|
-
from airflow.providers.amazon.aws.hooks.glue import
|
28
|
+
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
|
31
29
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
32
30
|
from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink
|
33
|
-
from airflow.providers.amazon.aws.
|
34
|
-
from airflow.providers.amazon.aws.triggers.glue import (
|
35
|
-
GlueDataQualityRuleRecommendationRunCompleteTrigger,
|
36
|
-
GlueDataQualityRuleSetEvaluationRunCompleteTrigger,
|
37
|
-
GlueJobCompleteTrigger,
|
38
|
-
)
|
31
|
+
from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger
|
39
32
|
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
40
33
|
|
41
34
|
if TYPE_CHECKING:
|
@@ -246,402 +239,3 @@ class GlueJobOperator(BaseOperator):
|
|
246
239
|
)
|
247
240
|
if not response["SuccessfulSubmissions"]:
|
248
241
|
self.log.error("Failed to stop AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
|
249
|
-
|
250
|
-
|
251
|
-
class GlueDataQualityOperator(AwsBaseOperator[GlueDataQualityHook]):
|
252
|
-
"""
|
253
|
-
Creates a data quality ruleset with DQDL rules applied to a specified Glue table.
|
254
|
-
|
255
|
-
.. seealso::
|
256
|
-
For more information on how to use this operator, take a look at the guide:
|
257
|
-
:ref:`howto/operator:GlueDataQualityOperator`
|
258
|
-
|
259
|
-
:param name: A unique name for the data quality ruleset.
|
260
|
-
:param ruleset: A Data Quality Definition Language (DQDL) ruleset.
|
261
|
-
For more information, see the Glue developer guide.
|
262
|
-
:param description: A description of the data quality ruleset.
|
263
|
-
:param update_rule_set: To update existing ruleset, Set this flag to True. (default: False)
|
264
|
-
:param data_quality_ruleset_kwargs: Extra arguments for RuleSet.
|
265
|
-
|
266
|
-
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
267
|
-
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
268
|
-
running Airflow in a distributed manner and aws_conn_id is None or
|
269
|
-
empty, then default boto3 configuration would be used (and must be
|
270
|
-
maintained on each worker node).
|
271
|
-
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
272
|
-
:param verify: Whether or not to verify SSL certificates. See:
|
273
|
-
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
274
|
-
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
275
|
-
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
276
|
-
"""
|
277
|
-
|
278
|
-
aws_hook_class = GlueDataQualityHook
|
279
|
-
template_fields: Sequence[str] = ("name", "ruleset", "description", "data_quality_ruleset_kwargs")
|
280
|
-
|
281
|
-
template_fields_renderers = {
|
282
|
-
"data_quality_ruleset_kwargs": "json",
|
283
|
-
}
|
284
|
-
ui_color = "#ededed"
|
285
|
-
|
286
|
-
def __init__(
|
287
|
-
self,
|
288
|
-
*,
|
289
|
-
name: str,
|
290
|
-
ruleset: str,
|
291
|
-
description: str = "AWS Glue Data Quality Rule Set With Airflow",
|
292
|
-
update_rule_set: bool = False,
|
293
|
-
data_quality_ruleset_kwargs: dict | None = None,
|
294
|
-
aws_conn_id: str | None = "aws_default",
|
295
|
-
**kwargs,
|
296
|
-
):
|
297
|
-
super().__init__(**kwargs)
|
298
|
-
self.name = name
|
299
|
-
self.ruleset = ruleset.strip()
|
300
|
-
self.description = description
|
301
|
-
self.update_rule_set = update_rule_set
|
302
|
-
self.data_quality_ruleset_kwargs = data_quality_ruleset_kwargs or {}
|
303
|
-
self.aws_conn_id = aws_conn_id
|
304
|
-
|
305
|
-
def validate_inputs(self) -> None:
|
306
|
-
if not self.ruleset.startswith("Rules") or not self.ruleset.endswith("]"):
|
307
|
-
raise AttributeError("RuleSet must starts with Rules = [ and ends with ]")
|
308
|
-
|
309
|
-
if self.data_quality_ruleset_kwargs.get("TargetTable"):
|
310
|
-
target_table = self.data_quality_ruleset_kwargs["TargetTable"]
|
311
|
-
|
312
|
-
if not target_table.get("TableName") or not target_table.get("DatabaseName"):
|
313
|
-
raise AttributeError("Target table must have DatabaseName and TableName")
|
314
|
-
|
315
|
-
def execute(self, context: Context):
|
316
|
-
self.validate_inputs()
|
317
|
-
|
318
|
-
config = {
|
319
|
-
"Name": self.name,
|
320
|
-
"Ruleset": self.ruleset,
|
321
|
-
"Description": self.description,
|
322
|
-
**self.data_quality_ruleset_kwargs,
|
323
|
-
}
|
324
|
-
try:
|
325
|
-
if self.update_rule_set:
|
326
|
-
self.hook.conn.update_data_quality_ruleset(**config)
|
327
|
-
self.log.info("AWS Glue data quality ruleset updated successfully")
|
328
|
-
else:
|
329
|
-
self.hook.conn.create_data_quality_ruleset(**config)
|
330
|
-
self.log.info("AWS Glue data quality ruleset created successfully")
|
331
|
-
except ClientError as error:
|
332
|
-
raise AirflowException(
|
333
|
-
f"AWS Glue data quality ruleset failed: {error.response['Error']['Message']}"
|
334
|
-
)
|
335
|
-
|
336
|
-
|
337
|
-
class GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualityHook]):
|
338
|
-
"""
|
339
|
-
Evaluate a ruleset against a data source (Glue table).
|
340
|
-
|
341
|
-
.. seealso::
|
342
|
-
For more information on how to use this operator, take a look at the guide:
|
343
|
-
:ref:`howto/operator:GlueDataQualityRuleSetEvaluationRunOperator`
|
344
|
-
|
345
|
-
:param datasource: The data source (Glue table) associated with this run. (templated)
|
346
|
-
:param role: IAM role supplied for job execution. (templated)
|
347
|
-
:param rule_set_names: A list of ruleset names for evaluation. (templated)
|
348
|
-
:param number_of_workers: The number of G.1X workers to be used in the run. (default: 5)
|
349
|
-
:param timeout: The timeout for a run in minutes. This is the maximum time that a run can consume resources
|
350
|
-
before it is terminated and enters TIMEOUT status. (default: 2,880)
|
351
|
-
:param verify_result_status: Validate all the ruleset rules evaluation run results,
|
352
|
-
If any of the rule status is Fail or Error then an exception is thrown. (default: True)
|
353
|
-
:param show_results: Displays all the ruleset rules evaluation run results. (default: True)
|
354
|
-
:param rule_set_evaluation_run_kwargs: Extra arguments for evaluation run. (templated)
|
355
|
-
:param wait_for_completion: Whether to wait for job to stop. (default: True)
|
356
|
-
:param waiter_delay: Time in seconds to wait between status checks. (default: 60)
|
357
|
-
:param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20)
|
358
|
-
:param deferrable: If True, the operator will wait asynchronously for the job to stop.
|
359
|
-
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
360
|
-
(default: False)
|
361
|
-
|
362
|
-
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
363
|
-
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
364
|
-
running Airflow in a distributed manner and aws_conn_id is None or
|
365
|
-
empty, then default boto3 configuration would be used (and must be
|
366
|
-
maintained on each worker node).
|
367
|
-
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
368
|
-
:param verify: Whether or not to verify SSL certificates. See:
|
369
|
-
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
370
|
-
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
371
|
-
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
372
|
-
"""
|
373
|
-
|
374
|
-
aws_hook_class = GlueDataQualityHook
|
375
|
-
|
376
|
-
template_fields: Sequence[str] = (
|
377
|
-
"datasource",
|
378
|
-
"role",
|
379
|
-
"rule_set_names",
|
380
|
-
"rule_set_evaluation_run_kwargs",
|
381
|
-
)
|
382
|
-
|
383
|
-
template_fields_renderers = {"datasource": "json", "rule_set_evaluation_run_kwargs": "json"}
|
384
|
-
|
385
|
-
ui_color = "#ededed"
|
386
|
-
|
387
|
-
def __init__(
|
388
|
-
self,
|
389
|
-
*,
|
390
|
-
datasource: dict,
|
391
|
-
role: str,
|
392
|
-
rule_set_names: list[str],
|
393
|
-
number_of_workers: int = 5,
|
394
|
-
timeout: int = 2880,
|
395
|
-
verify_result_status: bool = True,
|
396
|
-
show_results: bool = True,
|
397
|
-
rule_set_evaluation_run_kwargs: dict[str, Any] | None = None,
|
398
|
-
wait_for_completion: bool = True,
|
399
|
-
waiter_delay: int = 60,
|
400
|
-
waiter_max_attempts: int = 20,
|
401
|
-
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
402
|
-
aws_conn_id: str | None = "aws_default",
|
403
|
-
**kwargs,
|
404
|
-
):
|
405
|
-
super().__init__(**kwargs)
|
406
|
-
self.datasource = datasource
|
407
|
-
self.role = role
|
408
|
-
self.rule_set_names = rule_set_names
|
409
|
-
self.number_of_workers = number_of_workers
|
410
|
-
self.timeout = timeout
|
411
|
-
self.verify_result_status = verify_result_status
|
412
|
-
self.show_results = show_results
|
413
|
-
self.rule_set_evaluation_run_kwargs = rule_set_evaluation_run_kwargs or {}
|
414
|
-
self.wait_for_completion = wait_for_completion
|
415
|
-
self.waiter_delay = waiter_delay
|
416
|
-
self.waiter_max_attempts = waiter_max_attempts
|
417
|
-
self.deferrable = deferrable
|
418
|
-
self.aws_conn_id = aws_conn_id
|
419
|
-
|
420
|
-
def validate_inputs(self) -> None:
|
421
|
-
glue_table = self.datasource.get("GlueTable", {})
|
422
|
-
|
423
|
-
if not glue_table.get("DatabaseName") or not glue_table.get("TableName"):
|
424
|
-
raise AttributeError("DataSource glue table must have DatabaseName and TableName")
|
425
|
-
|
426
|
-
not_found_ruleset = [
|
427
|
-
ruleset_name
|
428
|
-
for ruleset_name in self.rule_set_names
|
429
|
-
if not self.hook.has_data_quality_ruleset(ruleset_name)
|
430
|
-
]
|
431
|
-
|
432
|
-
if not_found_ruleset:
|
433
|
-
raise AirflowException(f"Following RulesetNames are not found {not_found_ruleset}")
|
434
|
-
|
435
|
-
def execute(self, context: Context) -> str:
|
436
|
-
self.validate_inputs()
|
437
|
-
|
438
|
-
self.log.info(
|
439
|
-
"Submitting AWS Glue data quality ruleset evaluation run for RulesetNames %s", self.rule_set_names
|
440
|
-
)
|
441
|
-
|
442
|
-
response = self.hook.conn.start_data_quality_ruleset_evaluation_run(
|
443
|
-
DataSource=self.datasource,
|
444
|
-
Role=self.role,
|
445
|
-
NumberOfWorkers=self.number_of_workers,
|
446
|
-
Timeout=self.timeout,
|
447
|
-
RulesetNames=self.rule_set_names,
|
448
|
-
**self.rule_set_evaluation_run_kwargs,
|
449
|
-
)
|
450
|
-
|
451
|
-
evaluation_run_id = response["RunId"]
|
452
|
-
|
453
|
-
message_description = (
|
454
|
-
f"AWS Glue data quality ruleset evaluation run RunId: {evaluation_run_id} to complete."
|
455
|
-
)
|
456
|
-
if self.deferrable:
|
457
|
-
self.log.info("Deferring %s", message_description)
|
458
|
-
self.defer(
|
459
|
-
trigger=GlueDataQualityRuleSetEvaluationRunCompleteTrigger(
|
460
|
-
evaluation_run_id=response["RunId"],
|
461
|
-
waiter_delay=self.waiter_delay,
|
462
|
-
waiter_max_attempts=self.waiter_max_attempts,
|
463
|
-
aws_conn_id=self.aws_conn_id,
|
464
|
-
),
|
465
|
-
method_name="execute_complete",
|
466
|
-
)
|
467
|
-
|
468
|
-
elif self.wait_for_completion:
|
469
|
-
self.log.info("Waiting for %s", message_description)
|
470
|
-
|
471
|
-
self.hook.get_waiter("data_quality_ruleset_evaluation_run_complete").wait(
|
472
|
-
RunId=evaluation_run_id,
|
473
|
-
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
|
474
|
-
)
|
475
|
-
|
476
|
-
self.log.info(
|
477
|
-
"AWS Glue data quality ruleset evaluation run completed RunId: %s", evaluation_run_id
|
478
|
-
)
|
479
|
-
|
480
|
-
self.hook.validate_evaluation_run_results(
|
481
|
-
evaluation_run_id=evaluation_run_id,
|
482
|
-
show_results=self.show_results,
|
483
|
-
verify_result_status=self.verify_result_status,
|
484
|
-
)
|
485
|
-
else:
|
486
|
-
self.log.info("AWS Glue data quality ruleset evaluation run runId: %s.", evaluation_run_id)
|
487
|
-
|
488
|
-
return evaluation_run_id
|
489
|
-
|
490
|
-
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
491
|
-
event = validate_execute_complete_event(event)
|
492
|
-
|
493
|
-
if event["status"] != "success":
|
494
|
-
raise AirflowException(f"Error: AWS Glue data quality ruleset evaluation run: {event}")
|
495
|
-
|
496
|
-
self.hook.validate_evaluation_run_results(
|
497
|
-
evaluation_run_id=event["evaluation_run_id"],
|
498
|
-
show_results=self.show_results,
|
499
|
-
verify_result_status=self.verify_result_status,
|
500
|
-
)
|
501
|
-
|
502
|
-
return event["evaluation_run_id"]
|
503
|
-
|
504
|
-
|
505
|
-
class GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQualityHook]):
|
506
|
-
"""
|
507
|
-
Starts a recommendation run that is used to generate rules, Glue Data Quality analyzes the data and comes up with recommendations for a potential ruleset.
|
508
|
-
|
509
|
-
Recommendation runs are automatically deleted after 90 days.
|
510
|
-
|
511
|
-
.. seealso::
|
512
|
-
For more information on how to use this operator, take a look at the guide:
|
513
|
-
:ref:`howto/operator:GlueDataQualityRuleRecommendationRunOperator`
|
514
|
-
|
515
|
-
:param datasource: The data source (Glue table) associated with this run. (templated)
|
516
|
-
:param role: IAM role supplied for job execution. (templated)
|
517
|
-
:param number_of_workers: The number of G.1X workers to be used in the run. (default: 5)
|
518
|
-
:param timeout: The timeout for a run in minutes. This is the maximum time that a run can consume resources
|
519
|
-
before it is terminated and enters TIMEOUT status. (default: 2,880)
|
520
|
-
:param show_results: Displays the recommended ruleset (a set of rules), when recommendation run completes. (default: True)
|
521
|
-
:param recommendation_run_kwargs: Extra arguments for recommendation run. (templated)
|
522
|
-
:param wait_for_completion: Whether to wait for job to stop. (default: True)
|
523
|
-
:param waiter_delay: Time in seconds to wait between status checks. (default: 60)
|
524
|
-
:param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20)
|
525
|
-
:param deferrable: If True, the operator will wait asynchronously for the job to stop.
|
526
|
-
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
527
|
-
(default: False)
|
528
|
-
|
529
|
-
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
530
|
-
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
531
|
-
running Airflow in a distributed manner and aws_conn_id is None or
|
532
|
-
empty, then default boto3 configuration would be used (and must be
|
533
|
-
maintained on each worker node).
|
534
|
-
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
535
|
-
:param verify: Whether or not to verify SSL certificates. See:
|
536
|
-
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
537
|
-
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
538
|
-
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
539
|
-
"""
|
540
|
-
|
541
|
-
aws_hook_class = GlueDataQualityHook
|
542
|
-
template_fields: Sequence[str] = (
|
543
|
-
"datasource",
|
544
|
-
"role",
|
545
|
-
"recommendation_run_kwargs",
|
546
|
-
)
|
547
|
-
|
548
|
-
template_fields_renderers = {"datasource": "json", "recommendation_run_kwargs": "json"}
|
549
|
-
|
550
|
-
ui_color = "#ededed"
|
551
|
-
|
552
|
-
def __init__(
|
553
|
-
self,
|
554
|
-
*,
|
555
|
-
datasource: dict,
|
556
|
-
role: str,
|
557
|
-
number_of_workers: int = 5,
|
558
|
-
timeout: int = 2880,
|
559
|
-
show_results: bool = True,
|
560
|
-
recommendation_run_kwargs: dict[str, Any] | None = None,
|
561
|
-
wait_for_completion: bool = True,
|
562
|
-
waiter_delay: int = 60,
|
563
|
-
waiter_max_attempts: int = 20,
|
564
|
-
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
565
|
-
aws_conn_id: str | None = "aws_default",
|
566
|
-
**kwargs,
|
567
|
-
):
|
568
|
-
super().__init__(**kwargs)
|
569
|
-
self.datasource = datasource
|
570
|
-
self.role = role
|
571
|
-
self.number_of_workers = number_of_workers
|
572
|
-
self.timeout = timeout
|
573
|
-
self.show_results = show_results
|
574
|
-
self.recommendation_run_kwargs = recommendation_run_kwargs or {}
|
575
|
-
self.wait_for_completion = wait_for_completion
|
576
|
-
self.waiter_delay = waiter_delay
|
577
|
-
self.waiter_max_attempts = waiter_max_attempts
|
578
|
-
self.deferrable = deferrable
|
579
|
-
self.aws_conn_id = aws_conn_id
|
580
|
-
|
581
|
-
def execute(self, context: Context) -> str:
|
582
|
-
glue_table = self.datasource.get("GlueTable", {})
|
583
|
-
|
584
|
-
if not glue_table.get("DatabaseName") or not glue_table.get("TableName"):
|
585
|
-
raise AttributeError("DataSource glue table must have DatabaseName and TableName")
|
586
|
-
|
587
|
-
self.log.info("Submitting AWS Glue data quality recommendation run with %s", self.datasource)
|
588
|
-
|
589
|
-
try:
|
590
|
-
response = self.hook.conn.start_data_quality_rule_recommendation_run(
|
591
|
-
DataSource=self.datasource,
|
592
|
-
Role=self.role,
|
593
|
-
NumberOfWorkers=self.number_of_workers,
|
594
|
-
Timeout=self.timeout,
|
595
|
-
**self.recommendation_run_kwargs,
|
596
|
-
)
|
597
|
-
except ClientError as error:
|
598
|
-
raise AirflowException(
|
599
|
-
f"AWS Glue data quality recommendation run failed: {error.response['Error']['Message']}"
|
600
|
-
)
|
601
|
-
|
602
|
-
recommendation_run_id = response["RunId"]
|
603
|
-
|
604
|
-
message_description = (
|
605
|
-
f"AWS Glue data quality recommendation run RunId: {recommendation_run_id} to complete."
|
606
|
-
)
|
607
|
-
if self.deferrable:
|
608
|
-
self.log.info("Deferring %s", message_description)
|
609
|
-
self.defer(
|
610
|
-
trigger=GlueDataQualityRuleRecommendationRunCompleteTrigger(
|
611
|
-
recommendation_run_id=recommendation_run_id,
|
612
|
-
waiter_delay=self.waiter_delay,
|
613
|
-
waiter_max_attempts=self.waiter_max_attempts,
|
614
|
-
aws_conn_id=self.aws_conn_id,
|
615
|
-
),
|
616
|
-
method_name="execute_complete",
|
617
|
-
)
|
618
|
-
|
619
|
-
elif self.wait_for_completion:
|
620
|
-
self.log.info("Waiting for %s", message_description)
|
621
|
-
|
622
|
-
self.hook.get_waiter("data_quality_rule_recommendation_run_complete").wait(
|
623
|
-
RunId=recommendation_run_id,
|
624
|
-
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
|
625
|
-
)
|
626
|
-
self.log.info(
|
627
|
-
"AWS Glue data quality recommendation run completed RunId: %s", recommendation_run_id
|
628
|
-
)
|
629
|
-
|
630
|
-
if self.show_results:
|
631
|
-
self.hook.log_recommendation_results(run_id=recommendation_run_id)
|
632
|
-
|
633
|
-
else:
|
634
|
-
self.log.info(message_description)
|
635
|
-
|
636
|
-
return recommendation_run_id
|
637
|
-
|
638
|
-
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
639
|
-
event = validate_execute_complete_event(event)
|
640
|
-
|
641
|
-
if event["status"] != "success":
|
642
|
-
raise AirflowException(f"Error: AWS Glue data quality rule recommendation run: {event}")
|
643
|
-
|
644
|
-
if self.show_results:
|
645
|
-
self.hook.log_recommendation_results(run_id=event["recommendation_run_id"])
|
646
|
-
|
647
|
-
return event["recommendation_run_id"]
|
@@ -136,64 +136,26 @@ class SageMakerBaseOperator(BaseOperator):
|
|
136
136
|
:param describe_func: The `describe_` function for that kind of job.
|
137
137
|
We use it as an O(1) way to check if a job exists.
|
138
138
|
"""
|
139
|
-
|
140
|
-
|
141
|
-
)
|
142
|
-
|
143
|
-
def _get_unique_name(
|
144
|
-
self,
|
145
|
-
proposed_name: str,
|
146
|
-
fail_if_exists: bool,
|
147
|
-
describe_func: Callable[[str], Any],
|
148
|
-
check_exists_func: Callable[[str, Callable[[str], Any]], bool],
|
149
|
-
resource_type: str,
|
150
|
-
) -> str:
|
151
|
-
"""
|
152
|
-
Return the proposed name if it doesn't already exist, otherwise returns it with a timestamp suffix.
|
153
|
-
|
154
|
-
:param proposed_name: Base name.
|
155
|
-
:param fail_if_exists: Will throw an error if a resource with that name already exists
|
156
|
-
instead of finding a new name.
|
157
|
-
:param check_exists_func: The function to check if the resource exists.
|
158
|
-
It should take the resource name and a describe function as arguments.
|
159
|
-
:param resource_type: Type of the resource (e.g., "model", "job").
|
160
|
-
"""
|
161
|
-
self._check_resource_type(resource_type)
|
162
|
-
name = proposed_name
|
163
|
-
while check_exists_func(name, describe_func):
|
139
|
+
job_name = proposed_name
|
140
|
+
while self._check_if_job_exists(job_name, describe_func):
|
164
141
|
# this while should loop only once in most cases, just setting it this way to regenerate a name
|
165
142
|
# in case there is collision.
|
166
143
|
if fail_if_exists:
|
167
|
-
raise AirflowException(f"A SageMaker
|
144
|
+
raise AirflowException(f"A SageMaker job with name {job_name} already exists.")
|
168
145
|
else:
|
169
|
-
|
170
|
-
self.log.info("Changed
|
171
|
-
return
|
172
|
-
|
173
|
-
def _check_resource_type(self, resource_type: str):
|
174
|
-
"""Raise exception if resource type is not 'model' or 'job'."""
|
175
|
-
if resource_type not in ("model", "job"):
|
176
|
-
raise AirflowException(
|
177
|
-
"Argument resource_type accepts only 'model' and 'job'. "
|
178
|
-
f"Provided value: '{resource_type}'."
|
179
|
-
)
|
146
|
+
job_name = f"{proposed_name}-{time.time_ns()//1000000}"
|
147
|
+
self.log.info("Changed job name to '%s' to avoid collision.", job_name)
|
148
|
+
return job_name
|
180
149
|
|
181
|
-
def _check_if_job_exists(self, job_name
|
150
|
+
def _check_if_job_exists(self, job_name, describe_func: Callable[[str], Any]) -> bool:
|
182
151
|
"""Return True if job exists, False otherwise."""
|
183
|
-
return self._check_if_resource_exists(job_name, "job", describe_func)
|
184
|
-
|
185
|
-
def _check_if_resource_exists(
|
186
|
-
self, resource_name: str, resource_type: str, describe_func: Callable[[str], Any]
|
187
|
-
) -> bool:
|
188
|
-
"""Return True if resource exists, False otherwise."""
|
189
|
-
self._check_resource_type(resource_type)
|
190
152
|
try:
|
191
|
-
describe_func(
|
192
|
-
self.log.info("Found existing
|
153
|
+
describe_func(job_name)
|
154
|
+
self.log.info("Found existing job with name '%s'.", job_name)
|
193
155
|
return True
|
194
156
|
except ClientError as e:
|
195
157
|
if e.response["Error"]["Code"] == "ValidationException":
|
196
|
-
return False # ValidationException is thrown when the
|
158
|
+
return False # ValidationException is thrown when the job could not be found
|
197
159
|
else:
|
198
160
|
raise e
|
199
161
|
|
@@ -675,8 +637,6 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
675
637
|
max_ingestion_time: int | None = None,
|
676
638
|
check_if_job_exists: bool = True,
|
677
639
|
action_if_job_exists: str = "timestamp",
|
678
|
-
check_if_model_exists: bool = True,
|
679
|
-
action_if_model_exists: str = "timestamp",
|
680
640
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
681
641
|
**kwargs,
|
682
642
|
):
|
@@ -700,14 +660,6 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
700
660
|
f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \
|
701
661
|
Provided value: '{action_if_job_exists}'."
|
702
662
|
)
|
703
|
-
self.check_if_model_exists = check_if_model_exists
|
704
|
-
if action_if_model_exists in ("fail", "timestamp"):
|
705
|
-
self.action_if_model_exists = action_if_model_exists
|
706
|
-
else:
|
707
|
-
raise AirflowException(
|
708
|
-
f"Argument action_if_model_exists accepts only 'timestamp' and 'fail'. \
|
709
|
-
Provided value: '{action_if_model_exists}'."
|
710
|
-
)
|
711
663
|
self.deferrable = deferrable
|
712
664
|
self.serialized_model: dict
|
713
665
|
self.serialized_transform: dict
|
@@ -745,14 +697,6 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
745
697
|
|
746
698
|
model_config = self.config.get("Model")
|
747
699
|
if model_config:
|
748
|
-
if self.check_if_model_exists:
|
749
|
-
model_config["ModelName"] = self._get_unique_model_name(
|
750
|
-
model_config["ModelName"],
|
751
|
-
self.action_if_model_exists == "fail",
|
752
|
-
self.hook.describe_model,
|
753
|
-
)
|
754
|
-
if "ModelName" in self.config["Transform"].keys():
|
755
|
-
self.config["Transform"]["ModelName"] = model_config["ModelName"]
|
756
700
|
self.log.info("Creating SageMaker Model %s for transform job", model_config["ModelName"])
|
757
701
|
self.hook.create_model(model_config)
|
758
702
|
|
@@ -808,17 +752,6 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
808
752
|
|
809
753
|
return self.serialize_result(transform_config["TransformJobName"])
|
810
754
|
|
811
|
-
def _get_unique_model_name(
|
812
|
-
self, proposed_name: str, fail_if_exists: bool, describe_func: Callable[[str], Any]
|
813
|
-
) -> str:
|
814
|
-
return self._get_unique_name(
|
815
|
-
proposed_name, fail_if_exists, describe_func, self._check_if_model_exists, "model"
|
816
|
-
)
|
817
|
-
|
818
|
-
def _check_if_model_exists(self, model_name: str, describe_func: Callable[[str], Any]) -> bool:
|
819
|
-
"""Return True if model exists, False otherwise."""
|
820
|
-
return self._check_if_resource_exists(model_name, "model", describe_func)
|
821
|
-
|
822
755
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
|
823
756
|
event = validate_execute_complete_event(event)
|
824
757
|
|
@@ -950,8 +883,7 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
|
|
950
883
|
def execute(self, context: Context) -> dict:
|
951
884
|
self.preprocess_config()
|
952
885
|
self.log.info(
|
953
|
-
"Creating SageMaker Hyper-Parameter Tuning Job %s",
|
954
|
-
self.config["HyperParameterTuningJobName"],
|
886
|
+
"Creating SageMaker Hyper-Parameter Tuning Job %s", self.config["HyperParameterTuningJobName"]
|
955
887
|
)
|
956
888
|
response = self.hook.create_tuning_job(
|
957
889
|
self.config,
|
@@ -1302,12 +1234,7 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
|
|
1302
1234
|
:return str: Returns The ARN of the pipeline execution created in Amazon SageMaker.
|
1303
1235
|
"""
|
1304
1236
|
|
1305
|
-
template_fields: Sequence[str] = (
|
1306
|
-
"aws_conn_id",
|
1307
|
-
"pipeline_name",
|
1308
|
-
"display_name",
|
1309
|
-
"pipeline_params",
|
1310
|
-
)
|
1237
|
+
template_fields: Sequence[str] = ("aws_conn_id", "pipeline_name", "display_name", "pipeline_params")
|
1311
1238
|
|
1312
1239
|
def __init__(
|
1313
1240
|
self,
|