apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. airflow/providers/google/__init__.py +3 -3
  2. airflow/providers/google/cloud/hooks/automl.py +1 -1
  3. airflow/providers/google/cloud/hooks/bigquery.py +64 -33
  4. airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
  5. airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
  6. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
  7. airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +246 -32
  9. airflow/providers/google/cloud/hooks/dataplex.py +6 -2
  10. airflow/providers/google/cloud/hooks/dlp.py +14 -14
  11. airflow/providers/google/cloud/hooks/gcs.py +6 -2
  12. airflow/providers/google/cloud/hooks/gdm.py +2 -2
  13. airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
  14. airflow/providers/google/cloud/hooks/mlengine.py +8 -4
  15. airflow/providers/google/cloud/hooks/pubsub.py +1 -1
  16. airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
  17. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
  18. airflow/providers/google/cloud/links/vertex_ai.py +2 -1
  19. airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
  20. airflow/providers/google/cloud/operators/automl.py +13 -12
  21. airflow/providers/google/cloud/operators/bigquery.py +36 -22
  22. airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
  23. airflow/providers/google/cloud/operators/bigtable.py +7 -6
  24. airflow/providers/google/cloud/operators/cloud_build.py +12 -11
  25. airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
  26. airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
  27. airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
  28. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
  29. airflow/providers/google/cloud/operators/compute.py +12 -11
  30. airflow/providers/google/cloud/operators/datacatalog.py +21 -20
  31. airflow/providers/google/cloud/operators/dataflow.py +59 -42
  32. airflow/providers/google/cloud/operators/datafusion.py +11 -10
  33. airflow/providers/google/cloud/operators/datapipeline.py +3 -2
  34. airflow/providers/google/cloud/operators/dataprep.py +5 -4
  35. airflow/providers/google/cloud/operators/dataproc.py +19 -16
  36. airflow/providers/google/cloud/operators/datastore.py +8 -7
  37. airflow/providers/google/cloud/operators/dlp.py +31 -30
  38. airflow/providers/google/cloud/operators/functions.py +4 -3
  39. airflow/providers/google/cloud/operators/gcs.py +66 -41
  40. airflow/providers/google/cloud/operators/kubernetes_engine.py +232 -12
  41. airflow/providers/google/cloud/operators/life_sciences.py +2 -1
  42. airflow/providers/google/cloud/operators/mlengine.py +11 -10
  43. airflow/providers/google/cloud/operators/pubsub.py +6 -5
  44. airflow/providers/google/cloud/operators/spanner.py +7 -6
  45. airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
  46. airflow/providers/google/cloud/operators/stackdriver.py +11 -10
  47. airflow/providers/google/cloud/operators/tasks.py +14 -13
  48. airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
  49. airflow/providers/google/cloud/operators/translate_speech.py +2 -1
  50. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
  51. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
  52. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
  53. airflow/providers/google/cloud/operators/vision.py +13 -12
  54. airflow/providers/google/cloud/operators/workflows.py +10 -9
  55. airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
  56. airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
  57. airflow/providers/google/cloud/sensors/bigtable.py +2 -1
  58. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
  59. airflow/providers/google/cloud/sensors/dataflow.py +239 -52
  60. airflow/providers/google/cloud/sensors/datafusion.py +2 -1
  61. airflow/providers/google/cloud/sensors/dataproc.py +3 -2
  62. airflow/providers/google/cloud/sensors/gcs.py +14 -12
  63. airflow/providers/google/cloud/sensors/tasks.py +2 -1
  64. airflow/providers/google/cloud/sensors/workflows.py +2 -1
  65. airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
  66. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
  67. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
  68. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
  69. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
  70. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
  71. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
  72. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
  73. airflow/providers/google/cloud/triggers/bigquery.py +14 -3
  74. airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
  75. airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
  76. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
  77. airflow/providers/google/cloud/triggers/dataflow.py +504 -4
  78. airflow/providers/google/cloud/triggers/dataproc.py +110 -26
  79. airflow/providers/google/cloud/triggers/mlengine.py +2 -1
  80. airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
  81. airflow/providers/google/common/hooks/base_google.py +45 -7
  82. airflow/providers/google/firebase/hooks/firestore.py +2 -2
  83. airflow/providers/google/firebase/operators/firestore.py +2 -1
  84. airflow/providers/google/get_provider_info.py +3 -2
  85. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/METADATA +8 -8
  86. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/RECORD +88 -89
  87. airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
  88. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/WHEEL +0 -0
  89. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -17,19 +17,23 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
+ import shlex
20
21
  from typing import TYPE_CHECKING, Sequence
21
22
 
22
23
  from google.api_core.exceptions import AlreadyExists
23
24
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
24
25
  from google.cloud.orchestration.airflow.service_v1 import ImageVersion
25
- from google.cloud.orchestration.airflow.service_v1.types import Environment
26
+ from google.cloud.orchestration.airflow.service_v1.types import Environment, ExecuteAirflowCommandResponse
26
27
 
27
28
  from airflow.configuration import conf
28
29
  from airflow.exceptions import AirflowException
29
30
  from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook
30
31
  from airflow.providers.google.cloud.links.base import BaseGoogleLink
31
32
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
32
- from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger
33
+ from airflow.providers.google.cloud.triggers.cloud_composer import (
34
+ CloudComposerAirflowCLICommandTrigger,
35
+ CloudComposerExecutionTrigger,
36
+ )
33
37
  from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
34
38
 
35
39
  if TYPE_CHECKING:
@@ -651,3 +655,144 @@ class CloudComposerListImageVersionsOperator(GoogleCloudBaseOperator):
651
655
  metadata=self.metadata,
652
656
  )
653
657
  return [ImageVersion.to_dict(image) for image in result]
658
+
659
+
660
+ class CloudComposerRunAirflowCLICommandOperator(GoogleCloudBaseOperator):
661
+ """
662
+ Run Airflow command for provided Composer environment.
663
+
664
+ :param project_id: The ID of the Google Cloud project that the service belongs to.
665
+ :param region: The ID of the Google Cloud region that the service belongs to.
666
+ :param environment_id: The ID of the Google Cloud environment that the service belongs to.
667
+ :param command: Airflow command.
668
+ :param retry: Designation of what errors, if any, should be retried.
669
+ :param timeout: The timeout for this request.
670
+ :param metadata: Strings which should be sent along with the request as metadata.
671
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
672
+ :param impersonation_chain: Optional service account to impersonate using short-term
673
+ credentials, or chained list of accounts required to get the access_token
674
+ of the last account in the list, which will be impersonated in the request.
675
+ If set as a string, the account must grant the originating account
676
+ the Service Account Token Creator IAM role.
677
+ If set as a sequence, the identities from the list must grant
678
+ Service Account Token Creator IAM role to the directly preceding identity, with first
679
+ account from the list granting this role to the originating account (templated).
680
+ :param deferrable: Run operator in the deferrable mode
681
+ :param poll_interval: Optional: Control the rate of the poll for the result of deferrable run.
682
+ By default, the trigger will poll every 10 seconds.
683
+ """
684
+
685
+ template_fields = (
686
+ "project_id",
687
+ "region",
688
+ "environment_id",
689
+ "command",
690
+ "impersonation_chain",
691
+ )
692
+
693
+ def __init__(
694
+ self,
695
+ *,
696
+ project_id: str,
697
+ region: str,
698
+ environment_id: str,
699
+ command: str,
700
+ retry: Retry | _MethodDefault = DEFAULT,
701
+ timeout: float | None = None,
702
+ metadata: Sequence[tuple[str, str]] = (),
703
+ gcp_conn_id: str = "google_cloud_default",
704
+ impersonation_chain: str | Sequence[str] | None = None,
705
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
706
+ poll_interval: int = 10,
707
+ **kwargs,
708
+ ) -> None:
709
+ super().__init__(**kwargs)
710
+ self.project_id = project_id
711
+ self.region = region
712
+ self.environment_id = environment_id
713
+ self.command = command
714
+ self.retry = retry
715
+ self.timeout = timeout
716
+ self.metadata = metadata
717
+ self.gcp_conn_id = gcp_conn_id
718
+ self.impersonation_chain = impersonation_chain
719
+ self.deferrable = deferrable
720
+ self.poll_interval = poll_interval
721
+
722
+ def execute(self, context: Context):
723
+ hook = CloudComposerHook(
724
+ gcp_conn_id=self.gcp_conn_id,
725
+ impersonation_chain=self.impersonation_chain,
726
+ )
727
+
728
+ self.log.info("Executing the command: [ airflow %s ]...", self.command)
729
+
730
+ cmd, subcommand, parameters = self._parse_cmd_to_args(self.command)
731
+ execution_cmd_info = hook.execute_airflow_command(
732
+ project_id=self.project_id,
733
+ region=self.region,
734
+ environment_id=self.environment_id,
735
+ command=cmd,
736
+ subcommand=subcommand,
737
+ parameters=parameters,
738
+ retry=self.retry,
739
+ timeout=self.timeout,
740
+ metadata=self.metadata,
741
+ )
742
+ execution_cmd_info_dict = ExecuteAirflowCommandResponse.to_dict(execution_cmd_info)
743
+
744
+ self.log.info("Command has been started. execution_id=%s", execution_cmd_info_dict["execution_id"])
745
+
746
+ if self.deferrable:
747
+ self.defer(
748
+ trigger=CloudComposerAirflowCLICommandTrigger(
749
+ project_id=self.project_id,
750
+ region=self.region,
751
+ environment_id=self.environment_id,
752
+ execution_cmd_info=execution_cmd_info_dict,
753
+ gcp_conn_id=self.gcp_conn_id,
754
+ impersonation_chain=self.impersonation_chain,
755
+ poll_interval=self.poll_interval,
756
+ ),
757
+ method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
758
+ )
759
+ return
760
+
761
+ result = hook.wait_command_execution_result(
762
+ project_id=self.project_id,
763
+ region=self.region,
764
+ environment_id=self.environment_id,
765
+ execution_cmd_info=execution_cmd_info_dict,
766
+ retry=self.retry,
767
+ timeout=self.timeout,
768
+ metadata=self.metadata,
769
+ poll_interval=self.poll_interval,
770
+ )
771
+ result_str = self._merge_cmd_output_result(result)
772
+ self.log.info("Command execution result:\n%s", result_str)
773
+ return result
774
+
775
+ def execute_complete(self, context: Context, event: dict) -> dict:
776
+ if event and event["status"] == "error":
777
+ raise AirflowException(event["message"])
778
+ result: dict = event["result"]
779
+ result_str = self._merge_cmd_output_result(result)
780
+ self.log.info("Command execution result:\n%s", result_str)
781
+ return result
782
+
783
+ def _parse_cmd_to_args(self, cmd: str) -> tuple:
784
+ """Parse user command to command, subcommand and parameters."""
785
+ cmd_dict = shlex.split(cmd)
786
+ if not cmd_dict:
787
+ raise AirflowException("The provided command is empty.")
788
+
789
+ command = cmd_dict[0] if len(cmd_dict) >= 1 else None
790
+ subcommand = cmd_dict[1] if len(cmd_dict) >= 2 else None
791
+ parameters = cmd_dict[2:] if len(cmd_dict) >= 3 else None
792
+
793
+ return command, subcommand, parameters
794
+
795
+ def _merge_cmd_output_result(self, result) -> str:
796
+ """Merge output to one string."""
797
+ result_str = "\n".join(line_dict["content"] for line_dict in result["output"])
798
+ return result_str
@@ -43,6 +43,7 @@ from airflow.providers.google.cloud.links.cloud_memorystore import (
43
43
  RedisInstanceListLink,
44
44
  )
45
45
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
46
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
46
47
 
47
48
  if TYPE_CHECKING:
48
49
  from google.api_core.retry import Retry
@@ -112,7 +113,7 @@ class CloudMemorystoreCreateInstanceOperator(GoogleCloudBaseOperator):
112
113
  location: str,
113
114
  instance_id: str,
114
115
  instance: dict | Instance,
115
- project_id: str | None = None,
116
+ project_id: str = PROVIDE_PROJECT_ID,
116
117
  retry: Retry | _MethodDefault = DEFAULT,
117
118
  timeout: float | None = None,
118
119
  metadata: Sequence[tuple[str, str]] = (),
@@ -198,7 +199,7 @@ class CloudMemorystoreDeleteInstanceOperator(GoogleCloudBaseOperator):
198
199
  *,
199
200
  location: str,
200
201
  instance: str,
201
- project_id: str | None = None,
202
+ project_id: str = PROVIDE_PROJECT_ID,
202
203
  retry: Retry | _MethodDefault = DEFAULT,
203
204
  timeout: float | None = None,
204
205
  metadata: Sequence[tuple[str, str]] = (),
@@ -283,7 +284,7 @@ class CloudMemorystoreExportInstanceOperator(GoogleCloudBaseOperator):
283
284
  location: str,
284
285
  instance: str,
285
286
  output_config: dict | OutputConfig,
286
- project_id: str | None = None,
287
+ project_id: str = PROVIDE_PROJECT_ID,
287
288
  retry: Retry | _MethodDefault = DEFAULT,
288
289
  timeout: float | None = None,
289
290
  metadata: Sequence[tuple[str, str]] = (),
@@ -376,7 +377,7 @@ class CloudMemorystoreFailoverInstanceOperator(GoogleCloudBaseOperator):
376
377
  location: str,
377
378
  instance: str,
378
379
  data_protection_mode: FailoverInstanceRequest.DataProtectionMode,
379
- project_id: str | None = None,
380
+ project_id: str = PROVIDE_PROJECT_ID,
380
381
  retry: Retry | _MethodDefault = DEFAULT,
381
382
  timeout: float | None = None,
382
383
  metadata: Sequence[tuple[str, str]] = (),
@@ -462,7 +463,7 @@ class CloudMemorystoreGetInstanceOperator(GoogleCloudBaseOperator):
462
463
  *,
463
464
  location: str,
464
465
  instance: str,
465
- project_id: str | None = None,
466
+ project_id: str = PROVIDE_PROJECT_ID,
466
467
  retry: Retry | _MethodDefault = DEFAULT,
467
468
  timeout: float | None = None,
468
469
  metadata: Sequence[tuple[str, str]] = (),
@@ -556,7 +557,7 @@ class CloudMemorystoreImportOperator(GoogleCloudBaseOperator):
556
557
  location: str,
557
558
  instance: str,
558
559
  input_config: dict | InputConfig,
559
- project_id: str | None = None,
560
+ project_id: str = PROVIDE_PROJECT_ID,
560
561
  retry: Retry | _MethodDefault = DEFAULT,
561
562
  timeout: float | None = None,
562
563
  metadata: Sequence[tuple[str, str]] = (),
@@ -646,7 +647,7 @@ class CloudMemorystoreListInstancesOperator(GoogleCloudBaseOperator):
646
647
  *,
647
648
  location: str,
648
649
  page_size: int,
649
- project_id: str | None = None,
650
+ project_id: str = PROVIDE_PROJECT_ID,
650
651
  retry: Retry | _MethodDefault = DEFAULT,
651
652
  timeout: float | None = None,
652
653
  metadata: Sequence[tuple[str, str]] = (),
@@ -749,7 +750,7 @@ class CloudMemorystoreUpdateInstanceOperator(GoogleCloudBaseOperator):
749
750
  instance: dict | Instance,
750
751
  location: str | None = None,
751
752
  instance_id: str | None = None,
752
- project_id: str | None = None,
753
+ project_id: str = PROVIDE_PROJECT_ID,
753
754
  retry: Retry | _MethodDefault = DEFAULT,
754
755
  timeout: float | None = None,
755
756
  metadata: Sequence[tuple[str, str]] = (),
@@ -842,7 +843,7 @@ class CloudMemorystoreScaleInstanceOperator(GoogleCloudBaseOperator):
842
843
  memory_size_gb: int,
843
844
  location: str | None = None,
844
845
  instance_id: str | None = None,
845
- project_id: str | None = None,
846
+ project_id: str = PROVIDE_PROJECT_ID,
846
847
  retry: Retry | _MethodDefault = DEFAULT,
847
848
  timeout: float | None = None,
848
849
  metadata: Sequence[tuple[str, str]] = (),
@@ -954,7 +955,7 @@ class CloudMemorystoreCreateInstanceAndImportOperator(GoogleCloudBaseOperator):
954
955
  instance_id: str,
955
956
  instance: dict | Instance,
956
957
  input_config: dict | InputConfig,
957
- project_id: str | None = None,
958
+ project_id: str = PROVIDE_PROJECT_ID,
958
959
  retry: Retry | _MethodDefault = DEFAULT,
959
960
  timeout: float | None = None,
960
961
  metadata: Sequence[tuple[str, str]] = (),
@@ -1061,7 +1062,7 @@ class CloudMemorystoreExportAndDeleteInstanceOperator(GoogleCloudBaseOperator):
1061
1062
  location: str,
1062
1063
  instance: str,
1063
1064
  output_config: dict | OutputConfig,
1064
- project_id: str | None = None,
1065
+ project_id: str = PROVIDE_PROJECT_ID,
1065
1066
  retry: Retry | _MethodDefault = DEFAULT,
1066
1067
  timeout: float | None = None,
1067
1068
  metadata: Sequence[tuple[str, str]] = (),
@@ -1243,7 +1244,7 @@ class CloudMemorystoreMemcachedCreateInstanceOperator(GoogleCloudBaseOperator):
1243
1244
  location: str,
1244
1245
  instance_id: str,
1245
1246
  instance: dict | cloud_memcache.Instance,
1246
- project_id: str | None = None,
1247
+ project_id: str = PROVIDE_PROJECT_ID,
1247
1248
  retry: Retry | _MethodDefault = DEFAULT,
1248
1249
  timeout: float | None = None,
1249
1250
  metadata: Sequence[tuple[str, str]] = (),
@@ -1316,7 +1317,7 @@ class CloudMemorystoreMemcachedDeleteInstanceOperator(GoogleCloudBaseOperator):
1316
1317
  self,
1317
1318
  location: str,
1318
1319
  instance: str,
1319
- project_id: str | None = None,
1320
+ project_id: str = PROVIDE_PROJECT_ID,
1320
1321
  retry: Retry | _MethodDefault = DEFAULT,
1321
1322
  timeout: float | None = None,
1322
1323
  metadata: Sequence[tuple[str, str]] = (),
@@ -1390,7 +1391,7 @@ class CloudMemorystoreMemcachedGetInstanceOperator(GoogleCloudBaseOperator):
1390
1391
  *,
1391
1392
  location: str,
1392
1393
  instance: str,
1393
- project_id: str | None = None,
1394
+ project_id: str = PROVIDE_PROJECT_ID,
1394
1395
  retry: Retry | _MethodDefault = DEFAULT,
1395
1396
  timeout: float | None = None,
1396
1397
  metadata: Sequence[tuple[str, str]] = (),
@@ -1474,7 +1475,7 @@ class CloudMemorystoreMemcachedListInstancesOperator(GoogleCloudBaseOperator):
1474
1475
  self,
1475
1476
  *,
1476
1477
  location: str,
1477
- project_id: str | None = None,
1478
+ project_id: str = PROVIDE_PROJECT_ID,
1478
1479
  retry: Retry | _MethodDefault = DEFAULT,
1479
1480
  timeout: float | None = None,
1480
1481
  metadata: Sequence[tuple[str, str]] = (),
@@ -1572,7 +1573,7 @@ class CloudMemorystoreMemcachedUpdateInstanceOperator(GoogleCloudBaseOperator):
1572
1573
  instance: dict | cloud_memcache.Instance,
1573
1574
  location: str | None = None,
1574
1575
  instance_id: str | None = None,
1575
- project_id: str | None = None,
1576
+ project_id: str = PROVIDE_PROJECT_ID,
1576
1577
  retry: Retry | _MethodDefault = DEFAULT,
1577
1578
  timeout: float | None = None,
1578
1579
  metadata: Sequence[tuple[str, str]] = (),
@@ -19,6 +19,7 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
+ from functools import cached_property
22
23
  from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence
23
24
 
24
25
  from googleapiclient.errors import HttpError
@@ -31,7 +32,7 @@ from airflow.providers.google.cloud.links.cloud_sql import CloudSQLInstanceDatab
31
32
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
32
33
  from airflow.providers.google.cloud.triggers.cloud_sql import CloudSQLExportTrigger
33
34
  from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator
34
- from airflow.providers.google.common.hooks.base_google import get_field
35
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, get_field
35
36
  from airflow.providers.google.common.links.storage import FileDetailsLink
36
37
 
37
38
  if TYPE_CHECKING:
@@ -244,7 +245,7 @@ class CloudSQLBaseOperator(GoogleCloudBaseOperator):
244
245
  self,
245
246
  *,
246
247
  instance: str,
247
- project_id: str | None = None,
248
+ project_id: str = PROVIDE_PROJECT_ID,
248
249
  gcp_conn_id: str = "google_cloud_default",
249
250
  api_version: str = "v1beta4",
250
251
  impersonation_chain: str | Sequence[str] | None = None,
@@ -337,7 +338,7 @@ class CloudSQLCreateInstanceOperator(CloudSQLBaseOperator):
337
338
  *,
338
339
  body: dict,
339
340
  instance: str,
340
- project_id: str | None = None,
341
+ project_id: str = PROVIDE_PROJECT_ID,
341
342
  gcp_conn_id: str = "google_cloud_default",
342
343
  api_version: str = "v1beta4",
343
344
  validate_body: bool = True,
@@ -440,7 +441,7 @@ class CloudSQLInstancePatchOperator(CloudSQLBaseOperator):
440
441
  *,
441
442
  body: dict,
442
443
  instance: str,
443
- project_id: str | None = None,
444
+ project_id: str = PROVIDE_PROJECT_ID,
444
445
  gcp_conn_id: str = "google_cloud_default",
445
446
  api_version: str = "v1beta4",
446
447
  impersonation_chain: str | Sequence[str] | None = None,
@@ -572,7 +573,7 @@ class CloudSQLCloneInstanceOperator(CloudSQLBaseOperator):
572
573
  instance: str,
573
574
  destination_instance_name: str,
574
575
  clone_context: dict | None = None,
575
- project_id: str | None = None,
576
+ project_id: str = PROVIDE_PROJECT_ID,
576
577
  gcp_conn_id: str = "google_cloud_default",
577
578
  api_version: str = "v1beta4",
578
579
  impersonation_chain: str | Sequence[str] | None = None,
@@ -663,7 +664,7 @@ class CloudSQLCreateInstanceDatabaseOperator(CloudSQLBaseOperator):
663
664
  *,
664
665
  instance: str,
665
666
  body: dict,
666
- project_id: str | None = None,
667
+ project_id: str = PROVIDE_PROJECT_ID,
667
668
  gcp_conn_id: str = "google_cloud_default",
668
669
  api_version: str = "v1beta4",
669
670
  validate_body: bool = True,
@@ -771,7 +772,7 @@ class CloudSQLPatchInstanceDatabaseOperator(CloudSQLBaseOperator):
771
772
  instance: str,
772
773
  database: str,
773
774
  body: dict,
774
- project_id: str | None = None,
775
+ project_id: str = PROVIDE_PROJECT_ID,
775
776
  gcp_conn_id: str = "google_cloud_default",
776
777
  api_version: str = "v1beta4",
777
778
  validate_body: bool = True,
@@ -867,7 +868,7 @@ class CloudSQLDeleteInstanceDatabaseOperator(CloudSQLBaseOperator):
867
868
  *,
868
869
  instance: str,
869
870
  database: str,
870
- project_id: str | None = None,
871
+ project_id: str = PROVIDE_PROJECT_ID,
871
872
  gcp_conn_id: str = "google_cloud_default",
872
873
  api_version: str = "v1beta4",
873
874
  impersonation_chain: str | Sequence[str] | None = None,
@@ -957,7 +958,7 @@ class CloudSQLExportInstanceOperator(CloudSQLBaseOperator):
957
958
  *,
958
959
  instance: str,
959
960
  body: dict,
960
- project_id: str | None = None,
961
+ project_id: str = PROVIDE_PROJECT_ID,
961
962
  gcp_conn_id: str = "google_cloud_default",
962
963
  api_version: str = "v1beta4",
963
964
  validate_body: bool = True,
@@ -1104,7 +1105,7 @@ class CloudSQLImportInstanceOperator(CloudSQLBaseOperator):
1104
1105
  *,
1105
1106
  instance: str,
1106
1107
  body: dict,
1107
- project_id: str | None = None,
1108
+ project_id: str = PROVIDE_PROJECT_ID,
1108
1109
  gcp_conn_id: str = "google_cloud_default",
1109
1110
  api_version: str = "v1beta4",
1110
1111
  validate_body: bool = True,
@@ -1181,10 +1182,35 @@ class CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
1181
1182
  details on how to define ``gcpcloudsql://`` connection.
1182
1183
  :param sql_proxy_binary_path: (optional) Path to the cloud-sql-proxy binary.
1183
1184
  is not specified or the binary is not present, it is automatically downloaded.
1185
+ :param ssl_cert: (optional) Path to client certificate to authenticate when SSL is used. Overrides the
1186
+ connection field ``sslcert``.
1187
+ :param ssl_key: (optional) Path to client private key to authenticate when SSL is used. Overrides the
1188
+ connection field ``sslkey``.
1189
+ :param ssl_root_cert: (optional) Path to server's certificate to authenticate when SSL is used. Overrides
1190
+ the connection field ``sslrootcert``.
1191
+ :param ssl_secret_id: (optional) ID of the secret in Google Cloud Secret Manager that stores SSL
1192
+ certificate in the format below:
1193
+
1194
+ {'sslcert': '',
1195
+ 'sslkey': '',
1196
+ 'sslrootcert': ''}
1197
+
1198
+ Overrides the connection fields ``sslcert``, ``sslkey``, ``sslrootcert``.
1199
+ Note that according to the Secret Manager requirements, the mentioned dict should be saved as a
1200
+ string, and encoded with base64.
1201
+ Note that this parameter is incompatible with parameters ``ssl_cert``, ``ssl_key``, ``ssl_root_cert``.
1184
1202
  """
1185
1203
 
1186
1204
  # [START gcp_sql_query_template_fields]
1187
- template_fields: Sequence[str] = ("sql", "gcp_cloudsql_conn_id", "gcp_conn_id")
1205
+ template_fields: Sequence[str] = (
1206
+ "sql",
1207
+ "gcp_cloudsql_conn_id",
1208
+ "gcp_conn_id",
1209
+ "ssl_server_cert",
1210
+ "ssl_client_cert",
1211
+ "ssl_client_key",
1212
+ "ssl_secret_id",
1213
+ )
1188
1214
  template_ext: Sequence[str] = (".sql",)
1189
1215
  template_fields_renderers = {"sql": "sql"}
1190
1216
  # [END gcp_sql_query_template_fields]
@@ -1199,6 +1225,10 @@ class CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
1199
1225
  gcp_conn_id: str = "google_cloud_default",
1200
1226
  gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
1201
1227
  sql_proxy_binary_path: str | None = None,
1228
+ ssl_server_cert: str | None = None,
1229
+ ssl_client_cert: str | None = None,
1230
+ ssl_client_key: str | None = None,
1231
+ ssl_secret_id: str | None = None,
1202
1232
  **kwargs,
1203
1233
  ) -> None:
1204
1234
  super().__init__(**kwargs)
@@ -1209,6 +1239,10 @@ class CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
1209
1239
  self.parameters = parameters
1210
1240
  self.gcp_connection: Connection | None = None
1211
1241
  self.sql_proxy_binary_path = sql_proxy_binary_path
1242
+ self.ssl_server_cert = ssl_server_cert
1243
+ self.ssl_client_cert = ssl_client_cert
1244
+ self.ssl_client_key = ssl_client_key
1245
+ self.ssl_secret_id = ssl_secret_id
1212
1246
 
1213
1247
  def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook | MySqlHook) -> None:
1214
1248
  cloud_sql_proxy_runner = None
@@ -1228,12 +1262,8 @@ class CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
1228
1262
 
1229
1263
  def execute(self, context: Context):
1230
1264
  self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
1231
- hook = CloudSQLDatabaseHook(
1232
- gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
1233
- gcp_conn_id=self.gcp_conn_id,
1234
- default_gcp_project_id=get_field(self.gcp_connection.extra_dejson, "project"),
1235
- sql_proxy_binary_path=self.sql_proxy_binary_path,
1236
- )
1265
+
1266
+ hook = self.hook
1237
1267
  hook.validate_ssl_certs()
1238
1268
  connection = hook.create_connection()
1239
1269
  hook.validate_socket_path_length()
@@ -1242,3 +1272,16 @@ class CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
1242
1272
  self._execute_query(hook, database_hook)
1243
1273
  finally:
1244
1274
  hook.cleanup_database_hook()
1275
+
1276
+ @cached_property
1277
+ def hook(self):
1278
+ return CloudSQLDatabaseHook(
1279
+ gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
1280
+ gcp_conn_id=self.gcp_conn_id,
1281
+ default_gcp_project_id=get_field(self.gcp_connection.extra_dejson, "project"),
1282
+ sql_proxy_binary_path=self.sql_proxy_binary_path,
1283
+ ssl_root_cert=self.ssl_server_cert,
1284
+ ssl_cert=self.ssl_client_cert,
1285
+ ssl_key=self.ssl_client_key,
1286
+ ssl_secret_id=self.ssl_secret_id,
1287
+ )
@@ -28,6 +28,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
28
28
  from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
29
29
  ACCESS_KEY_ID,
30
30
  AWS_ACCESS_KEY,
31
+ AWS_ROLE_ARN,
31
32
  AWS_S3_DATA_SOURCE,
32
33
  BUCKET_NAME,
33
34
  DAY,
@@ -62,6 +63,7 @@ from airflow.providers.google.cloud.links.cloud_storage_transfer import (
62
63
  )
63
64
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
64
65
  from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
66
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
65
67
 
66
68
  if TYPE_CHECKING:
67
69
  from airflow.utils.context import Context
@@ -78,15 +80,23 @@ class TransferJobPreprocessor:
78
80
  self.default_schedule = default_schedule
79
81
 
80
82
  def _inject_aws_credentials(self) -> None:
81
- if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC]:
82
- aws_hook = AwsBaseHook(self.aws_conn_id, resource_type="s3")
83
- aws_credentials = aws_hook.get_credentials()
84
- aws_access_key_id = aws_credentials.access_key # type: ignore[attr-defined]
85
- aws_secret_access_key = aws_credentials.secret_key # type: ignore[attr-defined]
86
- self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = {
87
- ACCESS_KEY_ID: aws_access_key_id,
88
- SECRET_ACCESS_KEY: aws_secret_access_key,
89
- }
83
+ if TRANSFER_SPEC not in self.body:
84
+ return
85
+
86
+ if AWS_S3_DATA_SOURCE not in self.body[TRANSFER_SPEC]:
87
+ return
88
+
89
+ if AWS_ROLE_ARN in self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE]:
90
+ return
91
+
92
+ aws_hook = AwsBaseHook(self.aws_conn_id, resource_type="s3")
93
+ aws_credentials = aws_hook.get_credentials()
94
+ aws_access_key_id = aws_credentials.access_key # type: ignore[attr-defined]
95
+ aws_secret_access_key = aws_credentials.secret_key # type: ignore[attr-defined]
96
+ self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = {
97
+ ACCESS_KEY_ID: aws_access_key_id,
98
+ SECRET_ACCESS_KEY: aws_secret_access_key,
99
+ }
90
100
 
91
101
  def _reformat_date(self, field_key: str) -> None:
92
102
  schedule = self.body[SCHEDULE]
@@ -234,7 +244,7 @@ class CloudDataTransferServiceCreateJobOperator(GoogleCloudBaseOperator):
234
244
  aws_conn_id: str | None = "aws_default",
235
245
  gcp_conn_id: str = "google_cloud_default",
236
246
  api_version: str = "v1",
237
- project_id: str | None = None,
247
+ project_id: str = PROVIDE_PROJECT_ID,
238
248
  google_impersonation_chain: str | Sequence[str] | None = None,
239
249
  **kwargs,
240
250
  ) -> None:
@@ -324,7 +334,7 @@ class CloudDataTransferServiceUpdateJobOperator(GoogleCloudBaseOperator):
324
334
  aws_conn_id: str | None = "aws_default",
325
335
  gcp_conn_id: str = "google_cloud_default",
326
336
  api_version: str = "v1",
327
- project_id: str | None = None,
337
+ project_id: str = PROVIDE_PROJECT_ID,
328
338
  google_impersonation_chain: str | Sequence[str] | None = None,
329
339
  **kwargs,
330
340
  ) -> None:
@@ -407,7 +417,7 @@ class CloudDataTransferServiceDeleteJobOperator(GoogleCloudBaseOperator):
407
417
  job_name: str,
408
418
  gcp_conn_id: str = "google_cloud_default",
409
419
  api_version: str = "v1",
410
- project_id: str | None = None,
420
+ project_id: str = PROVIDE_PROJECT_ID,
411
421
  google_impersonation_chain: str | Sequence[str] | None = None,
412
422
  **kwargs,
413
423
  ) -> None:
@@ -467,7 +477,7 @@ class CloudDataTransferServiceGetOperationOperator(GoogleCloudBaseOperator):
467
477
  def __init__(
468
478
  self,
469
479
  *,
470
- project_id: str | None = None,
480
+ project_id: str = PROVIDE_PROJECT_ID,
471
481
  operation_name: str,
472
482
  gcp_conn_id: str = "google_cloud_default",
473
483
  api_version: str = "v1",
@@ -541,7 +551,7 @@ class CloudDataTransferServiceListOperationsOperator(GoogleCloudBaseOperator):
541
551
  def __init__(
542
552
  self,
543
553
  request_filter: dict,
544
- project_id: str | None = None,
554
+ project_id: str = PROVIDE_PROJECT_ID,
545
555
  gcp_conn_id: str = "google_cloud_default",
546
556
  api_version: str = "v1",
547
557
  google_impersonation_chain: str | Sequence[str] | None = None,
@@ -818,6 +828,9 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator):
818
828
  account from the list granting this role to the originating account (templated).
819
829
  :param delete_job_after_completion: If True, delete the job after complete.
820
830
  If set to True, 'wait' must be set to True.
831
+ :param aws_role_arn: Optional AWS role ARN for workload identity federation. This will
832
+ override the `aws_conn_id` for authentication between GCP and AWS; see
833
+ https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec#AwsS3Data
821
834
  """
822
835
 
823
836
  template_fields: Sequence[str] = (
@@ -829,6 +842,7 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator):
829
842
  "description",
830
843
  "object_conditions",
831
844
  "google_impersonation_chain",
845
+ "aws_role_arn",
832
846
  )
833
847
  ui_color = "#e09411"
834
848
 
@@ -839,7 +853,7 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator):
839
853
  gcs_bucket: str,
840
854
  s3_path: str | None = None,
841
855
  gcs_path: str | None = None,
842
- project_id: str | None = None,
856
+ project_id: str = PROVIDE_PROJECT_ID,
843
857
  aws_conn_id: str | None = "aws_default",
844
858
  gcp_conn_id: str = "google_cloud_default",
845
859
  description: str | None = None,
@@ -850,6 +864,7 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator):
850
864
  timeout: float | None = None,
851
865
  google_impersonation_chain: str | Sequence[str] | None = None,
852
866
  delete_job_after_completion: bool = False,
867
+ aws_role_arn: str | None = None,
853
868
  **kwargs,
854
869
  ) -> None:
855
870
  super().__init__(**kwargs)
@@ -868,6 +883,7 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator):
868
883
  self.timeout = timeout
869
884
  self.google_impersonation_chain = google_impersonation_chain
870
885
  self.delete_job_after_completion = delete_job_after_completion
886
+ self.aws_role_arn = aws_role_arn
871
887
  self._validate_inputs()
872
888
 
873
889
  def _validate_inputs(self) -> None:
@@ -918,6 +934,9 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator):
918
934
  if self.transfer_options is not None:
919
935
  body[TRANSFER_SPEC][TRANSFER_OPTIONS] = self.transfer_options # type: ignore[index]
920
936
 
937
+ if self.aws_role_arn is not None:
938
+ body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ROLE_ARN] = self.aws_role_arn # type: ignore[index]
939
+
921
940
  return body
922
941
 
923
942
 
@@ -1007,7 +1026,7 @@ class CloudDataTransferServiceGCSToGCSOperator(GoogleCloudBaseOperator):
1007
1026
  destination_bucket: str,
1008
1027
  source_path: str | None = None,
1009
1028
  destination_path: str | None = None,
1010
- project_id: str | None = None,
1029
+ project_id: str = PROVIDE_PROJECT_ID,
1011
1030
  gcp_conn_id: str = "google_cloud_default",
1012
1031
  description: str | None = None,
1013
1032
  schedule: dict | None = None,