apache-airflow-providers-amazon 9.1.0rc3__py3-none-any.whl → 9.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. airflow/providers/amazon/__init__.py +3 -3
  2. airflow/providers/amazon/aws/auth_manager/avp/facade.py +2 -1
  3. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +4 -12
  4. airflow/providers/amazon/aws/executors/batch/batch_executor.py +4 -3
  5. airflow/providers/amazon/aws/executors/batch/utils.py +3 -3
  6. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +2 -1
  7. airflow/providers/amazon/aws/executors/ecs/utils.py +3 -3
  8. airflow/providers/amazon/aws/fs/s3.py +2 -2
  9. airflow/providers/amazon/aws/hooks/appflow.py +15 -5
  10. airflow/providers/amazon/aws/hooks/athena.py +2 -1
  11. airflow/providers/amazon/aws/hooks/dms.py +161 -0
  12. airflow/providers/amazon/aws/hooks/dynamodb.py +2 -1
  13. airflow/providers/amazon/aws/hooks/eks.py +4 -3
  14. airflow/providers/amazon/aws/hooks/glue.py +5 -1
  15. airflow/providers/amazon/aws/hooks/kinesis.py +1 -1
  16. airflow/providers/amazon/aws/hooks/logs.py +2 -1
  17. airflow/providers/amazon/aws/hooks/redshift_cluster.py +4 -3
  18. airflow/providers/amazon/aws/hooks/redshift_data.py +2 -1
  19. airflow/providers/amazon/aws/hooks/redshift_sql.py +2 -6
  20. airflow/providers/amazon/aws/hooks/s3.py +9 -3
  21. airflow/providers/amazon/aws/hooks/sagemaker.py +2 -1
  22. airflow/providers/amazon/aws/hooks/ses.py +2 -1
  23. airflow/providers/amazon/aws/notifications/sns.py +1 -1
  24. airflow/providers/amazon/aws/notifications/sqs.py +1 -1
  25. airflow/providers/amazon/aws/operators/athena.py +5 -2
  26. airflow/providers/amazon/aws/operators/base_aws.py +1 -1
  27. airflow/providers/amazon/aws/operators/batch.py +2 -1
  28. airflow/providers/amazon/aws/operators/bedrock.py +2 -1
  29. airflow/providers/amazon/aws/operators/cloud_formation.py +2 -1
  30. airflow/providers/amazon/aws/operators/comprehend.py +2 -1
  31. airflow/providers/amazon/aws/operators/datasync.py +2 -1
  32. airflow/providers/amazon/aws/operators/dms.py +531 -1
  33. airflow/providers/amazon/aws/operators/ec2.py +2 -1
  34. airflow/providers/amazon/aws/operators/ecs.py +15 -4
  35. airflow/providers/amazon/aws/operators/eks.py +8 -5
  36. airflow/providers/amazon/aws/operators/emr.py +31 -8
  37. airflow/providers/amazon/aws/operators/eventbridge.py +2 -1
  38. airflow/providers/amazon/aws/operators/glacier.py +2 -1
  39. airflow/providers/amazon/aws/operators/glue.py +12 -2
  40. airflow/providers/amazon/aws/operators/glue_crawler.py +2 -1
  41. airflow/providers/amazon/aws/operators/glue_databrew.py +2 -1
  42. airflow/providers/amazon/aws/operators/kinesis_analytics.py +2 -1
  43. airflow/providers/amazon/aws/operators/lambda_function.py +2 -1
  44. airflow/providers/amazon/aws/operators/neptune.py +2 -1
  45. airflow/providers/amazon/aws/operators/quicksight.py +2 -1
  46. airflow/providers/amazon/aws/operators/rds.py +2 -1
  47. airflow/providers/amazon/aws/operators/redshift_cluster.py +2 -1
  48. airflow/providers/amazon/aws/operators/s3.py +7 -1
  49. airflow/providers/amazon/aws/operators/sagemaker.py +2 -1
  50. airflow/providers/amazon/aws/operators/sns.py +2 -1
  51. airflow/providers/amazon/aws/operators/sqs.py +2 -1
  52. airflow/providers/amazon/aws/operators/step_function.py +2 -1
  53. airflow/providers/amazon/aws/sensors/athena.py +2 -1
  54. airflow/providers/amazon/aws/sensors/base_aws.py +1 -1
  55. airflow/providers/amazon/aws/sensors/batch.py +2 -1
  56. airflow/providers/amazon/aws/sensors/bedrock.py +2 -1
  57. airflow/providers/amazon/aws/sensors/cloud_formation.py +2 -1
  58. airflow/providers/amazon/aws/sensors/comprehend.py +2 -1
  59. airflow/providers/amazon/aws/sensors/dms.py +2 -1
  60. airflow/providers/amazon/aws/sensors/dynamodb.py +2 -1
  61. airflow/providers/amazon/aws/sensors/ec2.py +2 -1
  62. airflow/providers/amazon/aws/sensors/ecs.py +2 -1
  63. airflow/providers/amazon/aws/sensors/eks.py +2 -1
  64. airflow/providers/amazon/aws/sensors/emr.py +2 -1
  65. airflow/providers/amazon/aws/sensors/glacier.py +2 -1
  66. airflow/providers/amazon/aws/sensors/glue.py +2 -1
  67. airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +2 -1
  68. airflow/providers/amazon/aws/sensors/glue_crawler.py +2 -1
  69. airflow/providers/amazon/aws/sensors/kinesis_analytics.py +2 -1
  70. airflow/providers/amazon/aws/sensors/lambda_function.py +2 -1
  71. airflow/providers/amazon/aws/sensors/opensearch_serverless.py +2 -1
  72. airflow/providers/amazon/aws/sensors/quicksight.py +2 -1
  73. airflow/providers/amazon/aws/sensors/rds.py +2 -1
  74. airflow/providers/amazon/aws/sensors/redshift_cluster.py +2 -1
  75. airflow/providers/amazon/aws/sensors/s3.py +2 -1
  76. airflow/providers/amazon/aws/sensors/sagemaker.py +2 -1
  77. airflow/providers/amazon/aws/sensors/sqs.py +2 -1
  78. airflow/providers/amazon/aws/sensors/step_function.py +2 -1
  79. airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +2 -1
  80. airflow/providers/amazon/aws/transfers/base.py +1 -1
  81. airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +2 -1
  82. airflow/providers/amazon/aws/transfers/exasol_to_s3.py +2 -1
  83. airflow/providers/amazon/aws/transfers/ftp_to_s3.py +2 -1
  84. airflow/providers/amazon/aws/transfers/gcs_to_s3.py +4 -3
  85. airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +2 -1
  86. airflow/providers/amazon/aws/transfers/google_api_to_s3.py +4 -8
  87. airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +2 -1
  88. airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +2 -1
  89. airflow/providers/amazon/aws/transfers/local_to_s3.py +2 -1
  90. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -1
  91. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +2 -1
  92. airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +3 -2
  93. airflow/providers/amazon/aws/transfers/s3_to_ftp.py +2 -1
  94. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +2 -1
  95. airflow/providers/amazon/aws/transfers/s3_to_sftp.py +2 -1
  96. airflow/providers/amazon/aws/transfers/s3_to_sql.py +2 -1
  97. airflow/providers/amazon/aws/transfers/salesforce_to_s3.py +2 -1
  98. airflow/providers/amazon/aws/transfers/sftp_to_s3.py +14 -1
  99. airflow/providers/amazon/aws/transfers/sql_to_s3.py +2 -1
  100. airflow/providers/amazon/aws/triggers/athena.py +1 -2
  101. airflow/providers/amazon/aws/triggers/base.py +2 -1
  102. airflow/providers/amazon/aws/triggers/dms.py +221 -0
  103. airflow/providers/amazon/aws/triggers/glue.py +3 -2
  104. airflow/providers/amazon/aws/triggers/redshift_cluster.py +2 -1
  105. airflow/providers/amazon/aws/triggers/redshift_data.py +2 -1
  106. airflow/providers/amazon/aws/triggers/s3.py +2 -1
  107. airflow/providers/amazon/aws/triggers/sagemaker.py +2 -1
  108. airflow/providers/amazon/aws/triggers/sqs.py +2 -1
  109. airflow/providers/amazon/aws/utils/__init__.py +1 -15
  110. airflow/providers/amazon/aws/utils/task_log_fetcher.py +2 -1
  111. airflow/providers/amazon/aws/utils/waiter.py +20 -0
  112. airflow/providers/amazon/aws/waiters/dms.json +88 -0
  113. airflow/providers/amazon/get_provider_info.py +10 -5
  114. airflow/providers/amazon/version_compat.py +36 -0
  115. {apache_airflow_providers_amazon-9.1.0rc3.dist-info → apache_airflow_providers_amazon-9.2.0.dist-info}/METADATA +20 -26
  116. {apache_airflow_providers_amazon-9.1.0rc3.dist-info → apache_airflow_providers_amazon-9.2.0.dist-info}/RECORD +118 -115
  117. {apache_airflow_providers_amazon-9.1.0rc3.dist-info → apache_airflow_providers_amazon-9.2.0.dist-info}/WHEEL +1 -1
  118. {apache_airflow_providers_amazon-9.1.0rc3.dist-info → apache_airflow_providers_amazon-9.2.0.dist-info}/entry_points.txt +0 -0
@@ -26,9 +26,10 @@ AWS Batch services.
26
26
 
27
27
  from __future__ import annotations
28
28
 
29
+ from collections.abc import Sequence
29
30
  from datetime import timedelta
30
31
  from functools import cached_property
31
- from typing import TYPE_CHECKING, Any, Sequence
32
+ from typing import TYPE_CHECKING, Any
32
33
 
33
34
  from airflow.configuration import conf
34
35
  from airflow.exceptions import AirflowException
@@ -17,8 +17,9 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import json
20
+ from collections.abc import Sequence
20
21
  from time import sleep
21
- from typing import TYPE_CHECKING, Any, Sequence
22
+ from typing import TYPE_CHECKING, Any
22
23
 
23
24
  from botocore.exceptions import ClientError
24
25
 
@@ -19,7 +19,8 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from typing import TYPE_CHECKING, Sequence
22
+ from collections.abc import Sequence
23
+ from typing import TYPE_CHECKING
23
24
 
24
25
  from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook
25
26
  from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
@@ -16,8 +16,9 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
+ from collections.abc import Sequence
19
20
  from functools import cached_property
20
- from typing import TYPE_CHECKING, Any, ClassVar, Sequence
21
+ from typing import TYPE_CHECKING, Any, ClassVar
21
22
 
22
23
  from airflow.configuration import conf
23
24
  from airflow.exceptions import AirflowException
@@ -20,7 +20,8 @@ from __future__ import annotations
20
20
 
21
21
  import logging
22
22
  import random
23
- from typing import TYPE_CHECKING, Any, Sequence
23
+ from collections.abc import Sequence
24
+ from typing import TYPE_CHECKING, Any
24
25
 
25
26
  from airflow.exceptions import AirflowException, AirflowTaskTimeout
26
27
  from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
@@ -17,11 +17,23 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- from typing import TYPE_CHECKING, ClassVar, Sequence
20
+ from collections.abc import Sequence
21
+ from datetime import datetime
22
+ from typing import TYPE_CHECKING, Any, ClassVar
21
23
 
24
+ from airflow.configuration import conf
25
+ from airflow.exceptions import AirflowException
22
26
  from airflow.providers.amazon.aws.hooks.dms import DmsHook
23
27
  from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
28
+ from airflow.providers.amazon.aws.triggers.dms import (
29
+ DmsReplicationCompleteTrigger,
30
+ DmsReplicationConfigDeletedTrigger,
31
+ DmsReplicationDeprovisionedTrigger,
32
+ DmsReplicationStoppedTrigger,
33
+ DmsReplicationTerminalStatusTrigger,
34
+ )
24
35
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
36
+ from airflow.utils.context import Context
25
37
 
26
38
  if TYPE_CHECKING:
27
39
  from airflow.utils.context import Context
@@ -277,3 +289,521 @@ class DmsStopTaskOperator(AwsBaseOperator[DmsHook]):
277
289
  """Stop AWS DMS replication task from Airflow."""
278
290
  self.hook.stop_replication_task(replication_task_arn=self.replication_task_arn)
279
291
  self.log.info("DMS replication task(%s) is stopping.", self.replication_task_arn)
292
+
293
+
294
+ class DmsDescribeReplicationConfigsOperator(AwsBaseOperator[DmsHook]):
295
+ """
296
+ Describes AWS DMS Serverless replication configurations.
297
+
298
+ .. seealso::
299
+ For more information on how to use this operator, take a look at the guide:
300
+ :ref:`howto/operator:DmsDescribeReplicationConfigsOperator`
301
+
302
+ :param describe_config_filter: Filters block for filtering results.
303
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
304
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
305
+ running Airflow in a distributed manner and aws_conn_id is None or
306
+ empty, then default boto3 configuration would be used (and must be
307
+ """
308
+
309
+ aws_hook_class = DmsHook
310
+ template_fields: Sequence[str] = aws_template_fields("filter")
311
+ template_fields_renderers = {"filter": "json"}
312
+
313
+ def __init__(
314
+ self,
315
+ *,
316
+ filter: list[dict] | None = None,
317
+ aws_conn_id: str | None = "aws_default",
318
+ **kwargs,
319
+ ):
320
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
321
+ self.filter = filter
322
+
323
+ def execute(self, context: Context) -> list:
324
+ """
325
+ Describe AWS DMS replication configurations.
326
+
327
+ :return: List of replication configurations
328
+ """
329
+ return self.hook.describe_replication_configs(filters=self.filter)
330
+
331
+
332
+ class DmsCreateReplicationConfigOperator(AwsBaseOperator[DmsHook]):
333
+ """
334
+ Creates an AWS DMS Serverless replication configuration.
335
+
336
+ .. seealso::
337
+ For more information on how to use this operator, take a look at the guide:
338
+ :ref:`howto/operator:DmsCreateReplicationConfigOperator`
339
+
340
+ :param replication_config_id: Unique identifier used to create a ReplicationConfigArn.
341
+ :param source_endpoint_arn: ARN of the source endpoint
342
+ :param target_endpoint_arn: ARN of the target endpoint
343
+ :param compute_config: Parameters for provisioning an DMS Serverless replication.
344
+ :param replication_type: type of DMS Serverless replication
345
+ :param table_mappings: JSON table mappings
346
+ :param tags: Key-value tag pairs
347
+ :param additional_config_kwargs: Additional configuration parameters for DMS Serverless replication. Passed directly to the API
348
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
349
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
350
+ running Airflow in a distributed manner and aws_conn_id is None or
351
+ empty, then default boto3 configuration would be used (and must be
352
+ """
353
+
354
+ aws_hook_class = DmsHook
355
+ template_fields: Sequence[str] = aws_template_fields(
356
+ "replication_config_id",
357
+ "source_endpoint_arn",
358
+ "target_endpoint_arn",
359
+ "compute_config",
360
+ "replication_type",
361
+ "table_mappings",
362
+ )
363
+
364
+ template_fields_renderers = {"compute_config": "json", "tableMappings": "json"}
365
+
366
+ def __init__(
367
+ self,
368
+ *,
369
+ replication_config_id: str,
370
+ source_endpoint_arn: str,
371
+ target_endpoint_arn: str,
372
+ compute_config: dict[str, Any],
373
+ replication_type: str,
374
+ table_mappings: str,
375
+ additional_config_kwargs: dict | None = None,
376
+ aws_conn_id: str | None = "aws_default",
377
+ **kwargs,
378
+ ):
379
+ super().__init__(
380
+ aws_conn_id=aws_conn_id,
381
+ **kwargs,
382
+ )
383
+
384
+ self.replication_config_id = replication_config_id
385
+ self.source_endpoint_arn = source_endpoint_arn
386
+ self.target_endpoint_arn = target_endpoint_arn
387
+ self.compute_config = compute_config
388
+ self.replication_type = replication_type
389
+ self.table_mappings = table_mappings
390
+ self.additional_config_kwargs = additional_config_kwargs or {}
391
+
392
+ def execute(self, context: Context) -> str:
393
+ resp = self.hook.create_replication_config(
394
+ replication_config_id=self.replication_config_id,
395
+ source_endpoint_arn=self.source_endpoint_arn,
396
+ target_endpoint_arn=self.target_endpoint_arn,
397
+ compute_config=self.compute_config,
398
+ replication_type=self.replication_type,
399
+ table_mappings=self.table_mappings,
400
+ additional_config_kwargs=self.additional_config_kwargs,
401
+ )
402
+
403
+ self.log.info("DMS replication config(%s) has been created.", self.replication_config_id)
404
+ return resp
405
+
406
+
407
+ class DmsDeleteReplicationConfigOperator(AwsBaseOperator[DmsHook]):
408
+ """
409
+ Deletes an AWS DMS Serverless replication configuration.
410
+
411
+ .. seealso::
412
+ For more information on how to use this operator, take a look at the guide:
413
+ :ref:`howto/operator:DmsDeleteReplicationConfigOperator`
414
+
415
+ :param replication_config_arn: ARN of the replication config
416
+ :param wait_for_completion: If True, waits for the replication config to be deleted before returning.
417
+ If False, the operator will return immediately after the request is made.
418
+ :param deferrable: Run the operator in deferrable mode.
419
+ :param waiter_delay: The number of seconds to wait between retries (default: 60).
420
+ :param waiter_max_attempts: The maximum number of attempts to be made (default: 60).
421
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
422
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
423
+ running Airflow in a distributed manner and aws_conn_id is None or
424
+ empty, then default boto3 configuration would be used (and must be
425
+ """
426
+
427
+ aws_hook_class = DmsHook
428
+ template_fields: Sequence[str] = aws_template_fields("replication_config_arn")
429
+
430
+ VALID_STATES = ["failed", "stopped", "created"]
431
+ DELETING_STATES = ["deleting"]
432
+ TERMINAL_PROVISION_STATES = ["deprovisioned", ""]
433
+
434
+ def __init__(
435
+ self,
436
+ *,
437
+ replication_config_arn: str,
438
+ wait_for_completion: bool = True,
439
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
440
+ waiter_delay: int = 60,
441
+ waiter_max_attempts: int = 60,
442
+ aws_conn_id: str | None = "aws_default",
443
+ **kwargs,
444
+ ):
445
+ super().__init__(
446
+ aws_conn_id=aws_conn_id,
447
+ **kwargs,
448
+ )
449
+
450
+ self.replication_config_arn = replication_config_arn
451
+ self.wait_for_completion = wait_for_completion
452
+ self.deferrable = deferrable
453
+ self.waiter_delay = waiter_delay
454
+ self.waiter_max_attempts = waiter_max_attempts
455
+
456
+ def execute(self, context: Context) -> None:
457
+ results = self.hook.describe_replications(
458
+ filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}]
459
+ )
460
+
461
+ current_state = results[0].get("Status", "")
462
+ self.log.info(
463
+ "Current state of replication config(%s) is %s.", self.replication_config_arn, current_state
464
+ )
465
+ # replication must be deprovisioned before deleting
466
+ provision_status = self.hook.get_provision_status(replication_config_arn=self.replication_config_arn)
467
+
468
+ if self.deferrable:
469
+ if current_state.lower() not in self.VALID_STATES:
470
+ self.log.info("Deferring until terminal status reached.")
471
+ self.defer(
472
+ trigger=DmsReplicationTerminalStatusTrigger(
473
+ replication_config_arn=self.replication_config_arn,
474
+ waiter_delay=self.waiter_delay,
475
+ waiter_max_attempts=self.waiter_max_attempts,
476
+ aws_conn_id=self.aws_conn_id,
477
+ ),
478
+ method_name="retry_execution",
479
+ )
480
+ if provision_status not in self.TERMINAL_PROVISION_STATES: # not deprovisioned:
481
+ self.log.info("Deferring until deprovisioning completes.")
482
+ self.defer(
483
+ trigger=DmsReplicationDeprovisionedTrigger(
484
+ replication_config_arn=self.replication_config_arn,
485
+ waiter_delay=self.waiter_delay,
486
+ waiter_max_attempts=self.waiter_max_attempts,
487
+ aws_conn_id=self.aws_conn_id,
488
+ ),
489
+ method_name="retry_execution",
490
+ )
491
+
492
+ self.hook.get_waiter("replication_terminal_status").wait(
493
+ Filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}],
494
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
495
+ )
496
+ self.hook.get_waiter("replication_deprovisioned").wait(
497
+ Filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}],
498
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
499
+ )
500
+ self.hook.delete_replication_config(self.replication_config_arn)
501
+ self.handle_delete_wait()
502
+
503
+ def handle_delete_wait(self):
504
+ if self.wait_for_completion:
505
+ if self.deferrable:
506
+ self.log.info("Deferring until replication config is deleted.")
507
+ self.defer(
508
+ trigger=DmsReplicationConfigDeletedTrigger(
509
+ replication_config_arn=self.replication_config_arn,
510
+ waiter_delay=self.waiter_delay,
511
+ waiter_max_attempts=self.waiter_max_attempts,
512
+ aws_conn_id=self.aws_conn_id,
513
+ ),
514
+ method_name="execute_complete",
515
+ )
516
+ else:
517
+ self.hook.get_waiter("replication_config_deleted").wait(
518
+ Filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}],
519
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
520
+ )
521
+ self.log.info("DMS replication config(%s) deleted.", self.replication_config_arn)
522
+
523
+ def execute_complete(self, context, event=None):
524
+ self.replication_config_arn = event.get("replication_config_arn")
525
+ self.log.info("DMS replication config(%s) deleted.", self.replication_config_arn)
526
+
527
+ def retry_execution(self, context, event=None):
528
+ self.replication_config_arn = event.get("replication_config_arn")
529
+ self.log.info("Retrying replication config(%s) deletion.", self.replication_config_arn)
530
+ self.execute(context)
531
+
532
+
533
+ class DmsDescribeReplicationsOperator(AwsBaseOperator[DmsHook]):
534
+ """
535
+ Describes AWS DMS Serverless replications.
536
+
537
+ .. seealso::
538
+ For more information on how to use this operator, take a look at the guide:
539
+ :ref:`howto/operator:DmsDescribeReplicationsOperator`
540
+
541
+ :param filter: Filters block for filtering results.
542
+
543
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
544
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
545
+ running Airflow in a distributed manner and aws_conn_id is None or
546
+ empty, then default boto3 configuration would be used (and must be
547
+ """
548
+
549
+ aws_hook_class = DmsHook
550
+ template_fields: Sequence[str] = aws_template_fields("filter")
551
+ template_fields_renderer = {"filter": "json"}
552
+
553
+ def __init__(
554
+ self,
555
+ *,
556
+ filter: list[dict[str, Any]] | None = None,
557
+ aws_conn_id: str | None = "aws_default",
558
+ **kwargs,
559
+ ):
560
+ super().__init__(
561
+ aws_conn_id=aws_conn_id,
562
+ **kwargs,
563
+ )
564
+
565
+ self.filter = filter
566
+
567
+ def execute(self, context: Context) -> list[dict[str, Any]]:
568
+ """
569
+ Describe AWS DMS replications.
570
+
571
+ :return: Replications
572
+ """
573
+ return self.hook.describe_replications(self.filter)
574
+
575
+
576
+ class DmsStartReplicationOperator(AwsBaseOperator[DmsHook]):
577
+ """
578
+ Starts an AWS DMS Serverless replication.
579
+
580
+ .. seealso::
581
+ For more information on how to use this operator, take a look at the guide:
582
+ :ref:`howto/operator:DmsStartReplicationOperator`
583
+
584
+ :param replication_config_arn: ARN of the replication config
585
+ :param replication_start_type: Type of replication.
586
+ :param cdc_start_time: Start time of CDC
587
+ :param cdc_start_pos: Indicates when to start CDC.
588
+ :param cdc_stop_pos: Indicates when to stop CDC.
589
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
590
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
591
+ running Airflow in a distributed manner and aws_conn_id is None or
592
+ empty, then default boto3 configuration would be used (and must be
593
+ """
594
+
595
+ RUNNING_STATES = ["running"]
596
+ STARTABLE_STATES = ["stopped", "failed", "created"]
597
+ TERMINAL_STATES = ["failed", "stopped", "created"]
598
+ TERMINAL_PROVISION_STATES = ["deprovisioned", ""]
599
+
600
+ aws_hook_class = DmsHook
601
+ template_fields: Sequence[str] = aws_template_fields(
602
+ "replication_config_arn", "replication_start_type", "cdc_start_time", "cdc_start_pos", "cdc_stop_pos"
603
+ )
604
+
605
+ def __init__(
606
+ self,
607
+ *,
608
+ replication_config_arn: str,
609
+ replication_start_type: str,
610
+ cdc_start_time: datetime | str | None = None,
611
+ cdc_start_pos: str | None = None,
612
+ cdc_stop_pos: str | None = None,
613
+ wait_for_completion: bool = True,
614
+ waiter_delay: int = 30,
615
+ waiter_max_attempts: int = 60,
616
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
617
+ aws_conn_id: str | None = "aws_default",
618
+ **kwargs,
619
+ ):
620
+ super().__init__(
621
+ aws_conn_id=aws_conn_id,
622
+ **kwargs,
623
+ )
624
+
625
+ self.replication_config_arn = replication_config_arn
626
+ self.replication_start_type = replication_start_type
627
+ self.cdc_start_time = cdc_start_time
628
+ self.cdc_start_pos = cdc_start_pos
629
+ self.cdc_stop_pos = cdc_stop_pos
630
+ self.deferrable = deferrable
631
+ self.waiter_delay = waiter_delay
632
+ self.waiter_max_attempts = waiter_max_attempts
633
+ self.wait_for_completion = wait_for_completion
634
+
635
+ if self.cdc_start_time and self.cdc_start_pos:
636
+ raise AirflowException("Only one of cdc_start_time or cdc_start_pos should be provided.")
637
+
638
+ def execute(self, context: Context):
639
+ result = self.hook.describe_replications(
640
+ filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}]
641
+ )
642
+
643
+ current_status = result[0].get("Status", "")
644
+ provision_status = self.hook.get_provision_status(replication_config_arn=self.replication_config_arn)
645
+
646
+ if provision_status == "deprovisioning":
647
+ # wait for deprovisioning to complete before start/restart
648
+ self.log.info(
649
+ "Replication is deprovisioning. Must wait for deprovisioning before running replication"
650
+ )
651
+ if self.deferrable:
652
+ self.log.info("Deferring until deprovisioning completes.")
653
+ self.defer(
654
+ trigger=DmsReplicationDeprovisionedTrigger(
655
+ replication_config_arn=self.replication_config_arn,
656
+ waiter_delay=self.waiter_delay,
657
+ waiter_max_attempts=self.waiter_max_attempts,
658
+ aws_conn_id=self.aws_conn_id,
659
+ ),
660
+ method_name="retry_execution",
661
+ )
662
+ else:
663
+ self.hook.get_waiter("replication_deprovisioned").wait(
664
+ Filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}],
665
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
666
+ )
667
+ provision_status = self.hook.get_provision_status(
668
+ replication_config_arn=self.replication_config_arn
669
+ )
670
+ self.log.info("Replication deprovisioning complete. Provision status: %s", provision_status)
671
+
672
+ if (
673
+ current_status.lower() in self.STARTABLE_STATES
674
+ and provision_status in self.TERMINAL_PROVISION_STATES
675
+ ):
676
+ resp = self.hook.start_replication(
677
+ replication_config_arn=self.replication_config_arn,
678
+ start_replication_type=self.replication_start_type,
679
+ cdc_start_time=self.cdc_start_time,
680
+ cdc_start_pos=self.cdc_start_pos,
681
+ cdc_stop_pos=self.cdc_stop_pos,
682
+ )
683
+
684
+ current_status = resp.get("Replication", {}).get("Status", "Unknown")
685
+ self.log.info(
686
+ "Replication(%s) started with status %s.",
687
+ self.replication_config_arn,
688
+ current_status,
689
+ )
690
+
691
+ if self.wait_for_completion:
692
+ self.log.info("Waiting for %s replication to complete.", self.replication_config_arn)
693
+
694
+ if self.deferrable:
695
+ self.log.info("Deferring until %s replication completes.", self.replication_config_arn)
696
+ self.defer(
697
+ trigger=DmsReplicationCompleteTrigger(
698
+ replication_config_arn=self.replication_config_arn,
699
+ waiter_delay=self.waiter_delay,
700
+ waiter_max_attempts=self.waiter_max_attempts,
701
+ aws_conn_id=self.aws_conn_id,
702
+ ),
703
+ method_name="execute_complete",
704
+ )
705
+
706
+ self.hook.get_waiter("replication_complete").wait(
707
+ Filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}],
708
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
709
+ )
710
+ self.log.info("Replication(%s) has completed.", self.replication_config_arn)
711
+
712
+ else:
713
+ self.log.info("Replication(%s) is not in startable state.", self.replication_config_arn)
714
+ self.log.info("Status: %s Provision status: %s", current_status, provision_status)
715
+
716
+ def execute_complete(self, context, event=None):
717
+ self.replication_config_arn = event.get("replication_config_arn")
718
+ self.log.info("Replication(%s) has completed.", self.replication_config_arn)
719
+
720
+ def retry_execution(self, context, event=None):
721
+ self.replication_config_arn = event.get("replication_config_arn")
722
+ self.log.info("Retrying replication %s.", self.replication_config_arn)
723
+ self.execute(context)
724
+
725
+
726
+ class DmsStopReplicationOperator(AwsBaseOperator[DmsHook]):
727
+ """
728
+ Stops an AWS DMS Serverless replication.
729
+
730
+ .. seealso::
731
+ For more information on how to use this operator, take a look at the guide:
732
+ :ref:`howto/operator:DmsStopReplicationOperator`
733
+
734
+ :param replication_config_arn: ARN of the replication config
735
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
736
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
737
+ running Airflow in a distributed manner and aws_conn_id is None or
738
+ empty, then default boto3 configuration would be used (and must be
739
+ """
740
+
741
+ STOPPED_STATES = ["stopped"]
742
+ NON_STOPPABLE_STATES = ["stopped"]
743
+
744
+ aws_hook_class = DmsHook
745
+ template_fields: Sequence[str] = aws_template_fields("replication_config_arn")
746
+
747
+ def __init__(
748
+ self,
749
+ *,
750
+ replication_config_arn: str,
751
+ wait_for_completion: bool = True,
752
+ waiter_delay: int = 30,
753
+ waiter_max_attempts: int = 60,
754
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
755
+ aws_conn_id: str | None = "aws_default",
756
+ **kwargs,
757
+ ):
758
+ super().__init__(
759
+ aws_conn_id=aws_conn_id,
760
+ **kwargs,
761
+ )
762
+
763
+ self.replication_config_arn = replication_config_arn
764
+ self.wait_for_completion = wait_for_completion
765
+ self.waiter_delay = waiter_delay
766
+ self.waiter_max_attempts = waiter_max_attempts
767
+ self.deferrable = deferrable
768
+
769
+ def execute(self, context: Context) -> None:
770
+ results = self.hook.describe_replications(
771
+ filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}]
772
+ )
773
+
774
+ current_state = results[0].get("Status", "")
775
+ self.log.info(
776
+ "Current state of replication config(%s) is %s.", self.replication_config_arn, current_state
777
+ )
778
+
779
+ if current_state.lower() in self.STOPPED_STATES:
780
+ self.log.info("DMS replication config(%s) is already stopped.", self.replication_config_arn)
781
+ else:
782
+ resp = self.hook.stop_replication(self.replication_config_arn)
783
+ status = resp.get("Replication", {}).get("Status", "Unknown")
784
+ self.log.info(
785
+ "Stopping DMS replication config(%s). Current status: %s", self.replication_config_arn, status
786
+ )
787
+
788
+ if self.wait_for_completion:
789
+ self.log.info("Waiting for %s replication to stop.", self.replication_config_arn)
790
+
791
+ if self.deferrable:
792
+ self.log.info("Deferring until %s replication stops.", self.replication_config_arn)
793
+ self.defer(
794
+ trigger=DmsReplicationStoppedTrigger(
795
+ replication_config_arn=self.replication_config_arn,
796
+ waiter_delay=self.waiter_delay,
797
+ waiter_max_attempts=self.waiter_max_attempts,
798
+ aws_conn_id=self.aws_conn_id,
799
+ ),
800
+ method_name="execute_complete",
801
+ )
802
+ self.hook.get_waiter("replication_stopped").wait(
803
+ Filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}],
804
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
805
+ )
806
+
807
+ def execute_complete(self, context, event=None):
808
+ self.replication_config_arn = event.get("replication_config_arn")
809
+ self.log.info("Replication(%s) has stopped.", self.replication_config_arn)
@@ -17,7 +17,8 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- from typing import TYPE_CHECKING, Sequence
20
+ from collections.abc import Sequence
21
+ from typing import TYPE_CHECKING
21
22
 
22
23
  from airflow.exceptions import AirflowException
23
24
  from airflow.models import BaseOperator
@@ -18,9 +18,10 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import re
21
+ from collections.abc import Sequence
21
22
  from datetime import timedelta
22
23
  from functools import cached_property
23
- from typing import TYPE_CHECKING, Any, Sequence
24
+ from typing import TYPE_CHECKING, Any
24
25
 
25
26
  from airflow.configuration import conf
26
27
  from airflow.exceptions import AirflowException
@@ -368,7 +369,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
368
369
  If None, this is the same as the `region` parameter. If that is also None,
369
370
  this is the default AWS region based on your connection settings.
370
371
  :param awslogs_stream_prefix: the stream prefix that is used for the CloudWatch logs.
371
- This is usually based on some custom name combined with the name of the container.
372
+ This should match the prefix specified in the log configuration of the task definition.
372
373
  Only required if you want logs to be shown in the Airflow UI after your job has
373
374
  finished.
374
375
  :param awslogs_fetch_interval: the interval that the ECS task log fetcher should wait
@@ -391,6 +392,8 @@ class EcsRunTaskOperator(EcsBaseOperator):
391
392
  :param deferrable: If True, the operator will wait asynchronously for the job to complete.
392
393
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
393
394
  (default: False)
395
+ :param do_xcom_push: If True, the operator will push the ECS task ARN to XCom with key 'ecs_task_arn'.
396
+ Additionally, if logs are fetched, the last log message will be pushed to XCom with the key 'return_value'. (default: False)
394
397
  """
395
398
 
396
399
  ui_color = "#f0ede4"
@@ -481,6 +484,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
481
484
  self.awslogs_region = self.region_name
482
485
 
483
486
  self.arn: str | None = None
487
+ self.container_name: str | None = None
484
488
  self._started_by: str | None = None
485
489
 
486
490
  self.retry_args = quota_retry
@@ -597,10 +601,10 @@ class EcsRunTaskOperator(EcsBaseOperator):
597
601
 
598
602
  if self.capacity_provider_strategy:
599
603
  run_opts["capacityProviderStrategy"] = self.capacity_provider_strategy
600
- if self.volume_configurations is not None:
601
- run_opts["volumeConfigurations"] = self.volume_configurations
602
604
  elif self.launch_type:
603
605
  run_opts["launchType"] = self.launch_type
606
+ if self.volume_configurations is not None:
607
+ run_opts["volumeConfigurations"] = self.volume_configurations
604
608
  if self.platform_version is not None:
605
609
  run_opts["platformVersion"] = self.platform_version
606
610
  if self.group is not None:
@@ -624,6 +628,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
624
628
  self.log.info("ECS Task started: %s", response)
625
629
 
626
630
  self.arn = response["tasks"][0]["taskArn"]
631
+ self.container_name = response["tasks"][0]["containers"][0]["name"]
627
632
  self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))
628
633
 
629
634
  def _try_reattach_task(self, started_by: str):
@@ -659,6 +664,12 @@ class EcsRunTaskOperator(EcsBaseOperator):
659
664
  return self.awslogs_group and self.awslogs_stream_prefix
660
665
 
661
666
  def _get_logs_stream_name(self) -> str:
667
+ if (
668
+ self.awslogs_stream_prefix
669
+ and self.container_name
670
+ and not self.awslogs_stream_prefix.endswith(f"/{self.container_name}")
671
+ ):
672
+ return f"{self.awslogs_stream_prefix}/{self.container_name}/{self._get_ecs_task_id(self.arn)}"
662
673
  return f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"
663
674
 
664
675
  def _get_task_log_fetcher(self) -> AwsTaskLogFetcher: