apache-airflow-providers-amazon 8.24.0rc1__py3-none-any.whl → 8.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. airflow/providers/amazon/LICENSE +4 -4
  2. airflow/providers/amazon/__init__.py +1 -1
  3. airflow/providers/amazon/aws/hooks/base_aws.py +8 -3
  4. airflow/providers/amazon/aws/hooks/comprehend.py +33 -0
  5. airflow/providers/amazon/aws/hooks/glue.py +123 -0
  6. airflow/providers/amazon/aws/hooks/redshift_sql.py +8 -1
  7. airflow/providers/amazon/aws/operators/bedrock.py +6 -20
  8. airflow/providers/amazon/aws/operators/comprehend.py +148 -1
  9. airflow/providers/amazon/aws/operators/emr.py +38 -30
  10. airflow/providers/amazon/aws/operators/glue.py +408 -2
  11. airflow/providers/amazon/aws/operators/sagemaker.py +85 -12
  12. airflow/providers/amazon/aws/sensors/comprehend.py +112 -1
  13. airflow/providers/amazon/aws/sensors/glue.py +260 -2
  14. airflow/providers/amazon/aws/sensors/s3.py +35 -5
  15. airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +0 -1
  16. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -1
  17. airflow/providers/amazon/aws/triggers/comprehend.py +36 -0
  18. airflow/providers/amazon/aws/triggers/glue.py +76 -2
  19. airflow/providers/amazon/aws/utils/__init__.py +2 -3
  20. airflow/providers/amazon/aws/waiters/comprehend.json +55 -0
  21. airflow/providers/amazon/aws/waiters/glue.json +98 -0
  22. airflow/providers/amazon/get_provider_info.py +20 -13
  23. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/METADATA +22 -21
  24. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/RECORD +26 -26
  25. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/WHEEL +0 -0
  26. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/entry_points.txt +0 -0
@@ -22,13 +22,20 @@ import urllib.parse
22
22
  from functools import cached_property
23
23
  from typing import TYPE_CHECKING, Any, Sequence
24
24
 
25
+ from botocore.exceptions import ClientError
26
+
25
27
  from airflow.configuration import conf
26
28
  from airflow.exceptions import AirflowException
27
29
  from airflow.models import BaseOperator
28
- from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
30
+ from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
29
31
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
30
32
  from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink
31
- from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger
33
+ from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
34
+ from airflow.providers.amazon.aws.triggers.glue import (
35
+ GlueDataQualityRuleRecommendationRunCompleteTrigger,
36
+ GlueDataQualityRuleSetEvaluationRunCompleteTrigger,
37
+ GlueJobCompleteTrigger,
38
+ )
32
39
  from airflow.providers.amazon.aws.utils import validate_execute_complete_event
33
40
 
34
41
  if TYPE_CHECKING:
@@ -239,3 +246,402 @@ class GlueJobOperator(BaseOperator):
239
246
  )
240
247
  if not response["SuccessfulSubmissions"]:
241
248
  self.log.error("Failed to stop AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
249
+
250
+
251
+ class GlueDataQualityOperator(AwsBaseOperator[GlueDataQualityHook]):
252
+ """
253
+ Creates a data quality ruleset with DQDL rules applied to a specified Glue table.
254
+
255
+ .. seealso::
256
+ For more information on how to use this operator, take a look at the guide:
257
+ :ref:`howto/operator:GlueDataQualityOperator`
258
+
259
+ :param name: A unique name for the data quality ruleset.
260
+ :param ruleset: A Data Quality Definition Language (DQDL) ruleset.
261
+ For more information, see the Glue developer guide.
262
+ :param description: A description of the data quality ruleset.
263
+ :param update_rule_set: To update existing ruleset, Set this flag to True. (default: False)
264
+ :param data_quality_ruleset_kwargs: Extra arguments for RuleSet.
265
+
266
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
267
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
268
+ running Airflow in a distributed manner and aws_conn_id is None or
269
+ empty, then default boto3 configuration would be used (and must be
270
+ maintained on each worker node).
271
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
272
+ :param verify: Whether or not to verify SSL certificates. See:
273
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
274
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
275
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
276
+ """
277
+
278
+ aws_hook_class = GlueDataQualityHook
279
+ template_fields: Sequence[str] = ("name", "ruleset", "description", "data_quality_ruleset_kwargs")
280
+
281
+ template_fields_renderers = {
282
+ "data_quality_ruleset_kwargs": "json",
283
+ }
284
+ ui_color = "#ededed"
285
+
286
+ def __init__(
287
+ self,
288
+ *,
289
+ name: str,
290
+ ruleset: str,
291
+ description: str = "AWS Glue Data Quality Rule Set With Airflow",
292
+ update_rule_set: bool = False,
293
+ data_quality_ruleset_kwargs: dict | None = None,
294
+ aws_conn_id: str | None = "aws_default",
295
+ **kwargs,
296
+ ):
297
+ super().__init__(**kwargs)
298
+ self.name = name
299
+ self.ruleset = ruleset.strip()
300
+ self.description = description
301
+ self.update_rule_set = update_rule_set
302
+ self.data_quality_ruleset_kwargs = data_quality_ruleset_kwargs or {}
303
+ self.aws_conn_id = aws_conn_id
304
+
305
+ def validate_inputs(self) -> None:
306
+ if not self.ruleset.startswith("Rules") or not self.ruleset.endswith("]"):
307
+ raise AttributeError("RuleSet must starts with Rules = [ and ends with ]")
308
+
309
+ if self.data_quality_ruleset_kwargs.get("TargetTable"):
310
+ target_table = self.data_quality_ruleset_kwargs["TargetTable"]
311
+
312
+ if not target_table.get("TableName") or not target_table.get("DatabaseName"):
313
+ raise AttributeError("Target table must have DatabaseName and TableName")
314
+
315
+ def execute(self, context: Context):
316
+ self.validate_inputs()
317
+
318
+ config = {
319
+ "Name": self.name,
320
+ "Ruleset": self.ruleset,
321
+ "Description": self.description,
322
+ **self.data_quality_ruleset_kwargs,
323
+ }
324
+ try:
325
+ if self.update_rule_set:
326
+ self.hook.conn.update_data_quality_ruleset(**config)
327
+ self.log.info("AWS Glue data quality ruleset updated successfully")
328
+ else:
329
+ self.hook.conn.create_data_quality_ruleset(**config)
330
+ self.log.info("AWS Glue data quality ruleset created successfully")
331
+ except ClientError as error:
332
+ raise AirflowException(
333
+ f"AWS Glue data quality ruleset failed: {error.response['Error']['Message']}"
334
+ )
335
+
336
+
337
+ class GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualityHook]):
338
+ """
339
+ Evaluate a ruleset against a data source (Glue table).
340
+
341
+ .. seealso::
342
+ For more information on how to use this operator, take a look at the guide:
343
+ :ref:`howto/operator:GlueDataQualityRuleSetEvaluationRunOperator`
344
+
345
+ :param datasource: The data source (Glue table) associated with this run. (templated)
346
+ :param role: IAM role supplied for job execution. (templated)
347
+ :param rule_set_names: A list of ruleset names for evaluation. (templated)
348
+ :param number_of_workers: The number of G.1X workers to be used in the run. (default: 5)
349
+ :param timeout: The timeout for a run in minutes. This is the maximum time that a run can consume resources
350
+ before it is terminated and enters TIMEOUT status. (default: 2,880)
351
+ :param verify_result_status: Validate all the ruleset rules evaluation run results,
352
+ If any of the rule status is Fail or Error then an exception is thrown. (default: True)
353
+ :param show_results: Displays all the ruleset rules evaluation run results. (default: True)
354
+ :param rule_set_evaluation_run_kwargs: Extra arguments for evaluation run. (templated)
355
+ :param wait_for_completion: Whether to wait for job to stop. (default: True)
356
+ :param waiter_delay: Time in seconds to wait between status checks. (default: 60)
357
+ :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20)
358
+ :param deferrable: If True, the operator will wait asynchronously for the job to stop.
359
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
360
+ (default: False)
361
+
362
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
363
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
364
+ running Airflow in a distributed manner and aws_conn_id is None or
365
+ empty, then default boto3 configuration would be used (and must be
366
+ maintained on each worker node).
367
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
368
+ :param verify: Whether or not to verify SSL certificates. See:
369
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
370
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
371
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
372
+ """
373
+
374
+ aws_hook_class = GlueDataQualityHook
375
+
376
+ template_fields: Sequence[str] = (
377
+ "datasource",
378
+ "role",
379
+ "rule_set_names",
380
+ "rule_set_evaluation_run_kwargs",
381
+ )
382
+
383
+ template_fields_renderers = {"datasource": "json", "rule_set_evaluation_run_kwargs": "json"}
384
+
385
+ ui_color = "#ededed"
386
+
387
+ def __init__(
388
+ self,
389
+ *,
390
+ datasource: dict,
391
+ role: str,
392
+ rule_set_names: list[str],
393
+ number_of_workers: int = 5,
394
+ timeout: int = 2880,
395
+ verify_result_status: bool = True,
396
+ show_results: bool = True,
397
+ rule_set_evaluation_run_kwargs: dict[str, Any] | None = None,
398
+ wait_for_completion: bool = True,
399
+ waiter_delay: int = 60,
400
+ waiter_max_attempts: int = 20,
401
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
402
+ aws_conn_id: str | None = "aws_default",
403
+ **kwargs,
404
+ ):
405
+ super().__init__(**kwargs)
406
+ self.datasource = datasource
407
+ self.role = role
408
+ self.rule_set_names = rule_set_names
409
+ self.number_of_workers = number_of_workers
410
+ self.timeout = timeout
411
+ self.verify_result_status = verify_result_status
412
+ self.show_results = show_results
413
+ self.rule_set_evaluation_run_kwargs = rule_set_evaluation_run_kwargs or {}
414
+ self.wait_for_completion = wait_for_completion
415
+ self.waiter_delay = waiter_delay
416
+ self.waiter_max_attempts = waiter_max_attempts
417
+ self.deferrable = deferrable
418
+ self.aws_conn_id = aws_conn_id
419
+
420
+ def validate_inputs(self) -> None:
421
+ glue_table = self.datasource.get("GlueTable", {})
422
+
423
+ if not glue_table.get("DatabaseName") or not glue_table.get("TableName"):
424
+ raise AttributeError("DataSource glue table must have DatabaseName and TableName")
425
+
426
+ not_found_ruleset = [
427
+ ruleset_name
428
+ for ruleset_name in self.rule_set_names
429
+ if not self.hook.has_data_quality_ruleset(ruleset_name)
430
+ ]
431
+
432
+ if not_found_ruleset:
433
+ raise AirflowException(f"Following RulesetNames are not found {not_found_ruleset}")
434
+
435
+ def execute(self, context: Context) -> str:
436
+ self.validate_inputs()
437
+
438
+ self.log.info(
439
+ "Submitting AWS Glue data quality ruleset evaluation run for RulesetNames %s", self.rule_set_names
440
+ )
441
+
442
+ response = self.hook.conn.start_data_quality_ruleset_evaluation_run(
443
+ DataSource=self.datasource,
444
+ Role=self.role,
445
+ NumberOfWorkers=self.number_of_workers,
446
+ Timeout=self.timeout,
447
+ RulesetNames=self.rule_set_names,
448
+ **self.rule_set_evaluation_run_kwargs,
449
+ )
450
+
451
+ evaluation_run_id = response["RunId"]
452
+
453
+ message_description = (
454
+ f"AWS Glue data quality ruleset evaluation run RunId: {evaluation_run_id} to complete."
455
+ )
456
+ if self.deferrable:
457
+ self.log.info("Deferring %s", message_description)
458
+ self.defer(
459
+ trigger=GlueDataQualityRuleSetEvaluationRunCompleteTrigger(
460
+ evaluation_run_id=response["RunId"],
461
+ waiter_delay=self.waiter_delay,
462
+ waiter_max_attempts=self.waiter_max_attempts,
463
+ aws_conn_id=self.aws_conn_id,
464
+ ),
465
+ method_name="execute_complete",
466
+ )
467
+
468
+ elif self.wait_for_completion:
469
+ self.log.info("Waiting for %s", message_description)
470
+
471
+ self.hook.get_waiter("data_quality_ruleset_evaluation_run_complete").wait(
472
+ RunId=evaluation_run_id,
473
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
474
+ )
475
+
476
+ self.log.info(
477
+ "AWS Glue data quality ruleset evaluation run completed RunId: %s", evaluation_run_id
478
+ )
479
+
480
+ self.hook.validate_evaluation_run_results(
481
+ evaluation_run_id=evaluation_run_id,
482
+ show_results=self.show_results,
483
+ verify_result_status=self.verify_result_status,
484
+ )
485
+ else:
486
+ self.log.info("AWS Glue data quality ruleset evaluation run runId: %s.", evaluation_run_id)
487
+
488
+ return evaluation_run_id
489
+
490
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
491
+ event = validate_execute_complete_event(event)
492
+
493
+ if event["status"] != "success":
494
+ raise AirflowException(f"Error: AWS Glue data quality ruleset evaluation run: {event}")
495
+
496
+ self.hook.validate_evaluation_run_results(
497
+ evaluation_run_id=event["evaluation_run_id"],
498
+ show_results=self.show_results,
499
+ verify_result_status=self.verify_result_status,
500
+ )
501
+
502
+ return event["evaluation_run_id"]
503
+
504
+
505
+ class GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQualityHook]):
506
+ """
507
+ Starts a recommendation run that is used to generate rules, Glue Data Quality analyzes the data and comes up with recommendations for a potential ruleset.
508
+
509
+ Recommendation runs are automatically deleted after 90 days.
510
+
511
+ .. seealso::
512
+ For more information on how to use this operator, take a look at the guide:
513
+ :ref:`howto/operator:GlueDataQualityRuleRecommendationRunOperator`
514
+
515
+ :param datasource: The data source (Glue table) associated with this run. (templated)
516
+ :param role: IAM role supplied for job execution. (templated)
517
+ :param number_of_workers: The number of G.1X workers to be used in the run. (default: 5)
518
+ :param timeout: The timeout for a run in minutes. This is the maximum time that a run can consume resources
519
+ before it is terminated and enters TIMEOUT status. (default: 2,880)
520
+ :param show_results: Displays the recommended ruleset (a set of rules), when recommendation run completes. (default: True)
521
+ :param recommendation_run_kwargs: Extra arguments for recommendation run. (templated)
522
+ :param wait_for_completion: Whether to wait for job to stop. (default: True)
523
+ :param waiter_delay: Time in seconds to wait between status checks. (default: 60)
524
+ :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20)
525
+ :param deferrable: If True, the operator will wait asynchronously for the job to stop.
526
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
527
+ (default: False)
528
+
529
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
530
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
531
+ running Airflow in a distributed manner and aws_conn_id is None or
532
+ empty, then default boto3 configuration would be used (and must be
533
+ maintained on each worker node).
534
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
535
+ :param verify: Whether or not to verify SSL certificates. See:
536
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
537
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
538
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
539
+ """
540
+
541
+ aws_hook_class = GlueDataQualityHook
542
+ template_fields: Sequence[str] = (
543
+ "datasource",
544
+ "role",
545
+ "recommendation_run_kwargs",
546
+ )
547
+
548
+ template_fields_renderers = {"datasource": "json", "recommendation_run_kwargs": "json"}
549
+
550
+ ui_color = "#ededed"
551
+
552
+ def __init__(
553
+ self,
554
+ *,
555
+ datasource: dict,
556
+ role: str,
557
+ number_of_workers: int = 5,
558
+ timeout: int = 2880,
559
+ show_results: bool = True,
560
+ recommendation_run_kwargs: dict[str, Any] | None = None,
561
+ wait_for_completion: bool = True,
562
+ waiter_delay: int = 60,
563
+ waiter_max_attempts: int = 20,
564
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
565
+ aws_conn_id: str | None = "aws_default",
566
+ **kwargs,
567
+ ):
568
+ super().__init__(**kwargs)
569
+ self.datasource = datasource
570
+ self.role = role
571
+ self.number_of_workers = number_of_workers
572
+ self.timeout = timeout
573
+ self.show_results = show_results
574
+ self.recommendation_run_kwargs = recommendation_run_kwargs or {}
575
+ self.wait_for_completion = wait_for_completion
576
+ self.waiter_delay = waiter_delay
577
+ self.waiter_max_attempts = waiter_max_attempts
578
+ self.deferrable = deferrable
579
+ self.aws_conn_id = aws_conn_id
580
+
581
+ def execute(self, context: Context) -> str:
582
+ glue_table = self.datasource.get("GlueTable", {})
583
+
584
+ if not glue_table.get("DatabaseName") or not glue_table.get("TableName"):
585
+ raise AttributeError("DataSource glue table must have DatabaseName and TableName")
586
+
587
+ self.log.info("Submitting AWS Glue data quality recommendation run with %s", self.datasource)
588
+
589
+ try:
590
+ response = self.hook.conn.start_data_quality_rule_recommendation_run(
591
+ DataSource=self.datasource,
592
+ Role=self.role,
593
+ NumberOfWorkers=self.number_of_workers,
594
+ Timeout=self.timeout,
595
+ **self.recommendation_run_kwargs,
596
+ )
597
+ except ClientError as error:
598
+ raise AirflowException(
599
+ f"AWS Glue data quality recommendation run failed: {error.response['Error']['Message']}"
600
+ )
601
+
602
+ recommendation_run_id = response["RunId"]
603
+
604
+ message_description = (
605
+ f"AWS Glue data quality recommendation run RunId: {recommendation_run_id} to complete."
606
+ )
607
+ if self.deferrable:
608
+ self.log.info("Deferring %s", message_description)
609
+ self.defer(
610
+ trigger=GlueDataQualityRuleRecommendationRunCompleteTrigger(
611
+ recommendation_run_id=recommendation_run_id,
612
+ waiter_delay=self.waiter_delay,
613
+ waiter_max_attempts=self.waiter_max_attempts,
614
+ aws_conn_id=self.aws_conn_id,
615
+ ),
616
+ method_name="execute_complete",
617
+ )
618
+
619
+ elif self.wait_for_completion:
620
+ self.log.info("Waiting for %s", message_description)
621
+
622
+ self.hook.get_waiter("data_quality_rule_recommendation_run_complete").wait(
623
+ RunId=recommendation_run_id,
624
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
625
+ )
626
+ self.log.info(
627
+ "AWS Glue data quality recommendation run completed RunId: %s", recommendation_run_id
628
+ )
629
+
630
+ if self.show_results:
631
+ self.hook.log_recommendation_results(run_id=recommendation_run_id)
632
+
633
+ else:
634
+ self.log.info(message_description)
635
+
636
+ return recommendation_run_id
637
+
638
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
639
+ event = validate_execute_complete_event(event)
640
+
641
+ if event["status"] != "success":
642
+ raise AirflowException(f"Error: AWS Glue data quality rule recommendation run: {event}")
643
+
644
+ if self.show_results:
645
+ self.hook.log_recommendation_results(run_id=event["recommendation_run_id"])
646
+
647
+ return event["recommendation_run_id"]
@@ -136,26 +136,64 @@ class SageMakerBaseOperator(BaseOperator):
136
136
  :param describe_func: The `describe_` function for that kind of job.
137
137
  We use it as an O(1) way to check if a job exists.
138
138
  """
139
- job_name = proposed_name
140
- while self._check_if_job_exists(job_name, describe_func):
139
+ return self._get_unique_name(
140
+ proposed_name, fail_if_exists, describe_func, self._check_if_job_exists, "job"
141
+ )
142
+
143
+ def _get_unique_name(
144
+ self,
145
+ proposed_name: str,
146
+ fail_if_exists: bool,
147
+ describe_func: Callable[[str], Any],
148
+ check_exists_func: Callable[[str, Callable[[str], Any]], bool],
149
+ resource_type: str,
150
+ ) -> str:
151
+ """
152
+ Return the proposed name if it doesn't already exist, otherwise returns it with a timestamp suffix.
153
+
154
+ :param proposed_name: Base name.
155
+ :param fail_if_exists: Will throw an error if a resource with that name already exists
156
+ instead of finding a new name.
157
+ :param check_exists_func: The function to check if the resource exists.
158
+ It should take the resource name and a describe function as arguments.
159
+ :param resource_type: Type of the resource (e.g., "model", "job").
160
+ """
161
+ self._check_resource_type(resource_type)
162
+ name = proposed_name
163
+ while check_exists_func(name, describe_func):
141
164
  # this while should loop only once in most cases, just setting it this way to regenerate a name
142
165
  # in case there is collision.
143
166
  if fail_if_exists:
144
- raise AirflowException(f"A SageMaker job with name {job_name} already exists.")
167
+ raise AirflowException(f"A SageMaker {resource_type} with name {name} already exists.")
145
168
  else:
146
- job_name = f"{proposed_name}-{time.time_ns()//1000000}"
147
- self.log.info("Changed job name to '%s' to avoid collision.", job_name)
148
- return job_name
169
+ name = f"{proposed_name}-{time.time_ns()//1000000}"
170
+ self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
171
+ return name
172
+
173
+ def _check_resource_type(self, resource_type: str):
174
+ """Raise exception if resource type is not 'model' or 'job'."""
175
+ if resource_type not in ("model", "job"):
176
+ raise AirflowException(
177
+ "Argument resource_type accepts only 'model' and 'job'. "
178
+ f"Provided value: '{resource_type}'."
179
+ )
149
180
 
150
- def _check_if_job_exists(self, job_name, describe_func: Callable[[str], Any]) -> bool:
181
+ def _check_if_job_exists(self, job_name: str, describe_func: Callable[[str], Any]) -> bool:
151
182
  """Return True if job exists, False otherwise."""
183
+ return self._check_if_resource_exists(job_name, "job", describe_func)
184
+
185
+ def _check_if_resource_exists(
186
+ self, resource_name: str, resource_type: str, describe_func: Callable[[str], Any]
187
+ ) -> bool:
188
+ """Return True if resource exists, False otherwise."""
189
+ self._check_resource_type(resource_type)
152
190
  try:
153
- describe_func(job_name)
154
- self.log.info("Found existing job with name '%s'.", job_name)
191
+ describe_func(resource_name)
192
+ self.log.info("Found existing %s with name '%s'.", resource_type, resource_name)
155
193
  return True
156
194
  except ClientError as e:
157
195
  if e.response["Error"]["Code"] == "ValidationException":
158
- return False # ValidationException is thrown when the job could not be found
196
+ return False # ValidationException is thrown when the resource could not be found
159
197
  else:
160
198
  raise e
161
199
 
@@ -637,6 +675,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
637
675
  max_ingestion_time: int | None = None,
638
676
  check_if_job_exists: bool = True,
639
677
  action_if_job_exists: str = "timestamp",
678
+ check_if_model_exists: bool = True,
679
+ action_if_model_exists: str = "timestamp",
640
680
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
641
681
  **kwargs,
642
682
  ):
@@ -660,6 +700,14 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
660
700
  f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \
661
701
  Provided value: '{action_if_job_exists}'."
662
702
  )
703
+ self.check_if_model_exists = check_if_model_exists
704
+ if action_if_model_exists in ("fail", "timestamp"):
705
+ self.action_if_model_exists = action_if_model_exists
706
+ else:
707
+ raise AirflowException(
708
+ f"Argument action_if_model_exists accepts only 'timestamp' and 'fail'. \
709
+ Provided value: '{action_if_model_exists}'."
710
+ )
663
711
  self.deferrable = deferrable
664
712
  self.serialized_model: dict
665
713
  self.serialized_transform: dict
@@ -697,6 +745,14 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
697
745
 
698
746
  model_config = self.config.get("Model")
699
747
  if model_config:
748
+ if self.check_if_model_exists:
749
+ model_config["ModelName"] = self._get_unique_model_name(
750
+ model_config["ModelName"],
751
+ self.action_if_model_exists == "fail",
752
+ self.hook.describe_model,
753
+ )
754
+ if "ModelName" in self.config["Transform"].keys():
755
+ self.config["Transform"]["ModelName"] = model_config["ModelName"]
700
756
  self.log.info("Creating SageMaker Model %s for transform job", model_config["ModelName"])
701
757
  self.hook.create_model(model_config)
702
758
 
@@ -752,6 +808,17 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
752
808
 
753
809
  return self.serialize_result(transform_config["TransformJobName"])
754
810
 
811
+ def _get_unique_model_name(
812
+ self, proposed_name: str, fail_if_exists: bool, describe_func: Callable[[str], Any]
813
+ ) -> str:
814
+ return self._get_unique_name(
815
+ proposed_name, fail_if_exists, describe_func, self._check_if_model_exists, "model"
816
+ )
817
+
818
+ def _check_if_model_exists(self, model_name: str, describe_func: Callable[[str], Any]) -> bool:
819
+ """Return True if model exists, False otherwise."""
820
+ return self._check_if_resource_exists(model_name, "model", describe_func)
821
+
755
822
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
756
823
  event = validate_execute_complete_event(event)
757
824
 
@@ -883,7 +950,8 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
883
950
  def execute(self, context: Context) -> dict:
884
951
  self.preprocess_config()
885
952
  self.log.info(
886
- "Creating SageMaker Hyper-Parameter Tuning Job %s", self.config["HyperParameterTuningJobName"]
953
+ "Creating SageMaker Hyper-Parameter Tuning Job %s",
954
+ self.config["HyperParameterTuningJobName"],
887
955
  )
888
956
  response = self.hook.create_tuning_job(
889
957
  self.config,
@@ -1234,7 +1302,12 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
1234
1302
  :return str: Returns The ARN of the pipeline execution created in Amazon SageMaker.
1235
1303
  """
1236
1304
 
1237
- template_fields: Sequence[str] = ("aws_conn_id", "pipeline_name", "display_name", "pipeline_params")
1305
+ template_fields: Sequence[str] = (
1306
+ "aws_conn_id",
1307
+ "pipeline_name",
1308
+ "display_name",
1309
+ "pipeline_params",
1310
+ )
1238
1311
 
1239
1312
  def __init__(
1240
1313
  self,