apache-airflow-providers-google 10.19.0rc1__py3-none-any.whl → 10.20.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. airflow/providers/google/LICENSE +4 -4
  2. airflow/providers/google/__init__.py +1 -1
  3. airflow/providers/google/ads/hooks/ads.py +4 -4
  4. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +26 -0
  5. airflow/providers/google/cloud/hooks/dataflow.py +132 -1
  6. airflow/providers/google/cloud/hooks/datapipeline.py +22 -73
  7. airflow/providers/google/cloud/hooks/gcs.py +21 -0
  8. airflow/providers/google/cloud/hooks/pubsub.py +10 -1
  9. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -0
  10. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +15 -3
  11. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
  12. airflow/providers/google/cloud/links/dataflow.py +25 -0
  13. airflow/providers/google/cloud/openlineage/mixins.py +271 -0
  14. airflow/providers/google/cloud/openlineage/utils.py +5 -218
  15. airflow/providers/google/cloud/operators/bigquery.py +74 -20
  16. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +76 -0
  17. airflow/providers/google/cloud/operators/dataflow.py +235 -1
  18. airflow/providers/google/cloud/operators/datapipeline.py +29 -121
  19. airflow/providers/google/cloud/operators/dataplex.py +1 -1
  20. airflow/providers/google/cloud/operators/dataproc_metastore.py +17 -6
  21. airflow/providers/google/cloud/operators/kubernetes_engine.py +9 -6
  22. airflow/providers/google/cloud/operators/pubsub.py +18 -0
  23. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +6 -0
  24. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +16 -0
  25. airflow/providers/google/cloud/sensors/cloud_composer.py +171 -2
  26. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +13 -0
  27. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +56 -1
  28. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +6 -12
  29. airflow/providers/google/cloud/triggers/cloud_composer.py +115 -0
  30. airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -0
  31. airflow/providers/google/cloud/utils/credentials_provider.py +81 -6
  32. airflow/providers/google/cloud/utils/external_token_supplier.py +175 -0
  33. airflow/providers/google/common/hooks/base_google.py +35 -1
  34. airflow/providers/google/common/utils/id_token_credentials.py +1 -1
  35. airflow/providers/google/get_provider_info.py +19 -14
  36. {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/METADATA +41 -35
  37. {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/RECORD +39 -37
  38. {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/WHEEL +0 -0
  39. {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,138 +19,46 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from typing import TYPE_CHECKING, Sequence
23
-
24
- from airflow.exceptions import AirflowException
25
- from airflow.providers.google.cloud.hooks.datapipeline import DEFAULT_DATAPIPELINE_LOCATION, DataPipelineHook
26
- from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
22
+ from deprecated import deprecated
23
+
24
+ from airflow.exceptions import AirflowProviderDeprecationWarning
25
+ from airflow.providers.google.cloud.hooks.dataflow import DEFAULT_DATAFLOW_LOCATION
26
+ from airflow.providers.google.cloud.operators.dataflow import (
27
+ DataflowCreatePipelineOperator,
28
+ DataflowRunPipelineOperator,
29
+ )
27
30
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
28
31
 
29
- if TYPE_CHECKING:
30
- from airflow.utils.context import Context
31
-
32
32
 
33
- class CreateDataPipelineOperator(GoogleCloudBaseOperator):
34
- """
35
- Creates a new Data Pipelines instance from the Data Pipelines API.
33
+ @deprecated(
34
+ reason="This operator is deprecated and will be removed after 01.12.2024. "
35
+ "Please use `DataflowCreatePipelineOperator`.",
36
+ category=AirflowProviderDeprecationWarning,
37
+ )
38
+ class CreateDataPipelineOperator(DataflowCreatePipelineOperator):
39
+ """Creates a new Data Pipelines instance from the Data Pipelines API."""
36
40
 
37
- :param body: The request body (contains instance of Pipeline). See:
38
- https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/create#request-body
39
- :param project_id: The ID of the GCP project that owns the job.
40
- :param location: The location to direct the Data Pipelines instance to (for example us-central1).
41
- :param gcp_conn_id: The connection ID to connect to the Google Cloud
42
- Platform.
43
- :param impersonation_chain: Optional service account to impersonate using short-term
44
- credentials, or chained list of accounts required to get the access_token
45
- of the last account in the list, which will be impersonated in the request.
46
- If set as a string, the account must grant the originating account
47
- the Service Account Token Creator IAM role.
48
- If set as a sequence, the identities from the list must grant
49
- Service Account Token Creator IAM role to the directly preceding identity, with first
50
- account from the list granting this role to the originating account (templated).
51
41
 
52
- .. warning::
53
- This option requires Apache Beam 2.39.0 or newer.
54
-
55
- Returns the created Data Pipelines instance in JSON representation.
56
- """
57
-
58
- def __init__(
59
- self,
60
- *,
61
- body: dict,
62
- project_id: str = PROVIDE_PROJECT_ID,
63
- location: str = DEFAULT_DATAPIPELINE_LOCATION,
64
- gcp_conn_id: str = "google_cloud_default",
65
- impersonation_chain: str | Sequence[str] | None = None,
66
- **kwargs,
67
- ) -> None:
68
- super().__init__(**kwargs)
69
-
70
- self.body = body
71
- self.project_id = project_id
72
- self.location = location
73
- self.gcp_conn_id = gcp_conn_id
74
- self.impersonation_chain = impersonation_chain
75
- self.datapipeline_hook: DataPipelineHook | None = None
76
- self.body["pipelineSources"] = {"airflow": "airflow"}
77
-
78
- def execute(self, context: Context):
79
- if self.body is None:
80
- raise AirflowException(
81
- "Request Body not given; cannot create a Data Pipeline without the Request Body."
82
- )
83
- if self.project_id is None:
84
- raise AirflowException(
85
- "Project ID not given; cannot create a Data Pipeline without the Project ID."
86
- )
87
- if self.location is None:
88
- raise AirflowException("location not given; cannot create a Data Pipeline without the location.")
89
-
90
- self.datapipeline_hook = DataPipelineHook(
91
- gcp_conn_id=self.gcp_conn_id,
92
- impersonation_chain=self.impersonation_chain,
93
- )
94
-
95
- self.data_pipeline = self.datapipeline_hook.create_data_pipeline(
96
- project_id=self.project_id,
97
- body=self.body,
98
- location=self.location,
99
- )
100
- if self.data_pipeline:
101
- if "error" in self.data_pipeline:
102
- raise AirflowException(self.data_pipeline.get("error").get("message"))
103
-
104
- return self.data_pipeline
105
-
106
-
107
- class RunDataPipelineOperator(GoogleCloudBaseOperator):
108
- """
109
- Runs a Data Pipelines Instance using the Data Pipelines API.
110
-
111
- :param data_pipeline_name: The display name of the pipeline. In example
112
- projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
113
- :param project_id: The ID of the GCP project that owns the job.
114
- :param location: The location to direct the Data Pipelines instance to (for example us-central1).
115
- :param gcp_conn_id: The connection ID to connect to the Google Cloud
116
- Platform.
117
-
118
- Returns the created Job in JSON representation.
119
- """
42
+ @deprecated(
43
+ reason="This operator is deprecated and will be removed after 01.12.2024. "
44
+ "Please use `DataflowRunPipelineOperator`.",
45
+ category=AirflowProviderDeprecationWarning,
46
+ )
47
+ class RunDataPipelineOperator(DataflowRunPipelineOperator):
48
+ """Runs a Data Pipelines Instance using the Data Pipelines API."""
120
49
 
121
50
  def __init__(
122
51
  self,
123
52
  data_pipeline_name: str,
124
53
  project_id: str = PROVIDE_PROJECT_ID,
125
- location: str = DEFAULT_DATAPIPELINE_LOCATION,
54
+ location: str = DEFAULT_DATAFLOW_LOCATION,
126
55
  gcp_conn_id: str = "google_cloud_default",
127
56
  **kwargs,
128
57
  ) -> None:
129
- super().__init__(**kwargs)
130
-
131
- self.data_pipeline_name = data_pipeline_name
132
- self.project_id = project_id
133
- self.location = location
134
- self.gcp_conn_id = gcp_conn_id
135
-
136
- def execute(self, context: Context):
137
- self.data_pipeline_hook = DataPipelineHook(gcp_conn_id=self.gcp_conn_id)
138
-
139
- if self.data_pipeline_name is None:
140
- raise AirflowException("Data Pipeline name not given; cannot run unspecified pipeline.")
141
- if self.project_id is None:
142
- raise AirflowException("Data Pipeline Project ID not given; cannot run pipeline.")
143
- if self.location is None:
144
- raise AirflowException("Data Pipeline location not given; cannot run pipeline.")
145
-
146
- self.response = self.data_pipeline_hook.run_data_pipeline(
147
- data_pipeline_name=self.data_pipeline_name,
148
- project_id=self.project_id,
149
- location=self.location,
58
+ super().__init__(
59
+ pipeline_name=data_pipeline_name,
60
+ project_id=project_id,
61
+ location=location,
62
+ gcp_conn_id=gcp_conn_id,
63
+ **kwargs,
150
64
  )
151
-
152
- if self.response:
153
- if "error" in self.response:
154
- raise AirflowException(self.response.get("error").get("message"))
155
-
156
- return self.response
@@ -1067,7 +1067,7 @@ class DataplexGetDataQualityScanResultOperator(GoogleCloudBaseOperator):
1067
1067
  is available.
1068
1068
  """
1069
1069
 
1070
- template_fields = ("project_id", "data_scan_id", "impersonation_chain")
1070
+ template_fields = ("project_id", "data_scan_id", "impersonation_chain", "job_id")
1071
1071
 
1072
1072
  def __init__(
1073
1073
  self,
@@ -431,7 +431,7 @@ class DataprocMetastoreCreateServiceOperator(GoogleCloudBaseOperator):
431
431
  hook = DataprocMetastoreHook(
432
432
  gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
433
433
  )
434
- self.log.info("Creating Dataproc Metastore service: %s", self.project_id)
434
+ self.log.info("Creating Dataproc Metastore service: %s", self.service_id)
435
435
  try:
436
436
  operation = hook.create_service(
437
437
  region=self.region,
@@ -548,13 +548,24 @@ class DataprocMetastoreDeleteBackupOperator(GoogleCloudBaseOperator):
548
548
  class DataprocMetastoreDeleteServiceOperator(GoogleCloudBaseOperator):
549
549
  """Delete a single service.
550
550
 
551
- :param request: The request object. Request message for
552
- [DataprocMetastore.DeleteService][google.cloud.metastore.v1.DataprocMetastore.DeleteService].
551
+ :param region: Required. The ID of the Google Cloud region that the service belongs to.
553
552
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
553
+ :param service_id: Required. The ID of the metastore service, which is used as the final component of
554
+ the metastore service's name. This value must be between 2 and 63 characters long inclusive, begin
555
+ with a letter, end with a letter or number, and consist of alphanumeric ASCII characters or
556
+ hyphens.
554
557
  :param retry: Designation of what errors, if any, should be retried.
555
558
  :param timeout: The timeout for this request.
556
559
  :param metadata: Strings which should be sent along with the request as metadata.
557
- :param gcp_conn_id:
560
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
561
+ :param impersonation_chain: Optional service account to impersonate using short-term
562
+ credentials, or chained list of accounts required to get the access_token
563
+ of the last account in the list, which will be impersonated in the request.
564
+ If set as a string, the account must grant the originating account
565
+ the Service Account Token Creator IAM role.
566
+ If set as a sequence, the identities from the list must grant
567
+ Service Account Token Creator IAM role to the directly preceding identity, with first
568
+ account from the list granting this role to the originating account (templated).
558
569
  """
559
570
 
560
571
  template_fields: Sequence[str] = (
@@ -589,7 +600,7 @@ class DataprocMetastoreDeleteServiceOperator(GoogleCloudBaseOperator):
589
600
  hook = DataprocMetastoreHook(
590
601
  gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
591
602
  )
592
- self.log.info("Deleting Dataproc Metastore service: %s", self.project_id)
603
+ self.log.info("Deleting Dataproc Metastore service: %s", self.service_id)
593
604
  operation = hook.delete_service(
594
605
  region=self.region,
595
606
  project_id=self.project_id,
@@ -599,7 +610,7 @@ class DataprocMetastoreDeleteServiceOperator(GoogleCloudBaseOperator):
599
610
  metadata=self.metadata,
600
611
  )
601
612
  hook.wait_for_operation(self.timeout, operation)
602
- self.log.info("Service %s deleted successfully", self.project_id)
613
+ self.log.info("Service %s deleted successfully", self.service_id)
603
614
 
604
615
 
605
616
  class DataprocMetastoreExportMetadataOperator(GoogleCloudBaseOperator):
@@ -73,6 +73,7 @@ except ImportError:
73
73
 
74
74
  if TYPE_CHECKING:
75
75
  from kubernetes.client.models import V1Job, V1Pod
76
+ from pendulum import DateTime
76
77
 
77
78
  from airflow.utils.context import Context
78
79
 
@@ -773,16 +774,16 @@ class GKEStartPodOperator(KubernetesPodOperator):
773
774
  self._ssl_ca_cert = cluster.master_auth.cluster_ca_certificate
774
775
  return self._cluster_url, self._ssl_ca_cert
775
776
 
776
- def invoke_defer_method(self):
777
+ def invoke_defer_method(self, last_log_time: DateTime | None = None):
777
778
  """Redefine triggers which are being used in child classes."""
778
779
  trigger_start_time = utcnow()
779
780
  self.defer(
780
781
  trigger=GKEStartPodTrigger(
781
- pod_name=self.pod.metadata.name,
782
- pod_namespace=self.pod.metadata.namespace,
782
+ pod_name=self.pod.metadata.name, # type: ignore[union-attr]
783
+ pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr]
783
784
  trigger_start_time=trigger_start_time,
784
- cluster_url=self._cluster_url,
785
- ssl_ca_cert=self._ssl_ca_cert,
785
+ cluster_url=self._cluster_url, # type: ignore[arg-type]
786
+ ssl_ca_cert=self._ssl_ca_cert, # type: ignore[arg-type]
786
787
  get_logs=self.get_logs,
787
788
  startup_timeout=self.startup_timeout_seconds,
788
789
  cluster_context=self.cluster_context,
@@ -792,6 +793,8 @@ class GKEStartPodOperator(KubernetesPodOperator):
792
793
  on_finish_action=self.on_finish_action,
793
794
  gcp_conn_id=self.gcp_conn_id,
794
795
  impersonation_chain=self.impersonation_chain,
796
+ logging_interval=self.logging_interval,
797
+ last_log_time=last_log_time,
795
798
  ),
796
799
  method_name="execute_complete",
797
800
  kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": self._ssl_ca_cert},
@@ -802,7 +805,7 @@ class GKEStartPodOperator(KubernetesPodOperator):
802
805
  self._cluster_url = kwargs["cluster_url"]
803
806
  self._ssl_ca_cert = kwargs["ssl_ca_cert"]
804
807
 
805
- return super().execute_complete(context, event, **kwargs)
808
+ return super().trigger_reentry(context, event)
806
809
 
807
810
 
808
811
  class GKEStartJobOperator(KubernetesJobOperator):
@@ -604,6 +604,7 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
604
604
  m1 = {"data": b"Hello, World!", "attributes": {"type": "greeting"}}
605
605
  m2 = {"data": b"Knock, knock"}
606
606
  m3 = {"attributes": {"foo": ""}}
607
+ m4 = {"data": b"Who's there?", "attributes": {"ordering_key": "knock_knock"}}
607
608
 
608
609
  t1 = PubSubPublishMessageOperator(
609
610
  project_id="my-project",
@@ -613,6 +614,15 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
613
614
  dag=dag,
614
615
  )
615
616
 
617
+ t2 = PubSubPublishMessageOperator(
618
+ project_id="my-project",
619
+ topic="my_topic",
620
+ messages=[m4],
621
+ create_topic=True,
622
+ enable_message_ordering=True,
623
+ dag=dag,
624
+ )
625
+
616
626
  ``project_id``, ``topic``, and ``messages`` are templated so you can use Jinja templating
617
627
  in their values.
618
628
 
@@ -632,6 +642,10 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
632
642
  https://cloud.google.com/pubsub/docs/reference/rest/v1/PubsubMessage
633
643
  :param gcp_conn_id: The connection ID to use connecting to
634
644
  Google Cloud.
645
+ :param enable_message_ordering: If true, messages published with the same
646
+ ordering_key in PubsubMessage will be delivered to the subscribers in the order
647
+ in which they are received by the Pub/Sub system. Otherwise, they may be
648
+ delivered in any order. Default is False.
635
649
  :param impersonation_chain: Optional service account to impersonate using short-term
636
650
  credentials, or chained list of accounts required to get the access_token
637
651
  of the last account in the list, which will be impersonated in the request.
@@ -646,6 +660,7 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
646
660
  "project_id",
647
661
  "topic",
648
662
  "messages",
663
+ "enable_message_ordering",
649
664
  "impersonation_chain",
650
665
  )
651
666
  ui_color = "#0273d4"
@@ -657,6 +672,7 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
657
672
  messages: list,
658
673
  project_id: str = PROVIDE_PROJECT_ID,
659
674
  gcp_conn_id: str = "google_cloud_default",
675
+ enable_message_ordering: bool = False,
660
676
  impersonation_chain: str | Sequence[str] | None = None,
661
677
  **kwargs,
662
678
  ) -> None:
@@ -665,12 +681,14 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
665
681
  self.topic = topic
666
682
  self.messages = messages
667
683
  self.gcp_conn_id = gcp_conn_id
684
+ self.enable_message_ordering = enable_message_ordering
668
685
  self.impersonation_chain = impersonation_chain
669
686
 
670
687
  def execute(self, context: Context) -> None:
671
688
  hook = PubSubHook(
672
689
  gcp_conn_id=self.gcp_conn_id,
673
690
  impersonation_chain=self.impersonation_chain,
691
+ enable_message_ordering=self.enable_message_ordering,
674
692
  )
675
693
 
676
694
  self.log.info("Publishing to topic %s", self.topic)
@@ -138,6 +138,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
138
138
  region: str,
139
139
  impersonation_chain: str | Sequence[str] | None = None,
140
140
  parent_model: str | None = None,
141
+ window_stride_length: int | None = None,
142
+ window_max_count: int | None = None,
141
143
  **kwargs,
142
144
  ) -> None:
143
145
  super().__init__(
@@ -170,6 +172,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
170
172
  self.quantiles = quantiles
171
173
  self.validation_options = validation_options
172
174
  self.budget_milli_node_hours = budget_milli_node_hours
175
+ self.window_stride_length = window_stride_length
176
+ self.window_max_count = window_max_count
173
177
 
174
178
  def execute(self, context: Context):
175
179
  self.hook = AutoMLHook(
@@ -220,6 +224,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
220
224
  model_display_name=self.model_display_name,
221
225
  model_labels=self.model_labels,
222
226
  sync=self.sync,
227
+ window_stride_length=self.window_stride_length,
228
+ window_max_count=self.window_max_count,
223
229
  )
224
230
 
225
231
  if model:
@@ -187,6 +187,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
187
187
  service belongs to (templated).
188
188
  :param prompt: Required. Inputs or queries that a user or a program gives
189
189
  to the Multi-modal model, in order to elicit a specific response (templated).
190
+ :param generation_config: Optional. Generation configuration settings.
191
+ :param safety_settings: Optional. Per request settings for blocking unsafe content.
190
192
  :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
191
193
  supporting prompts with text-only input, including natural language
192
194
  tasks, multi-turn text and code chat, and code generation. It can
@@ -210,6 +212,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
210
212
  project_id: str,
211
213
  location: str,
212
214
  prompt: str,
215
+ generation_config: dict | None = None,
216
+ safety_settings: dict | None = None,
213
217
  pretrained_model: str = "gemini-pro",
214
218
  gcp_conn_id: str = "google_cloud_default",
215
219
  impersonation_chain: str | Sequence[str] | None = None,
@@ -219,6 +223,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
219
223
  self.project_id = project_id
220
224
  self.location = location
221
225
  self.prompt = prompt
226
+ self.generation_config = generation_config
227
+ self.safety_settings = safety_settings
222
228
  self.pretrained_model = pretrained_model
223
229
  self.gcp_conn_id = gcp_conn_id
224
230
  self.impersonation_chain = impersonation_chain
@@ -232,6 +238,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
232
238
  project_id=self.project_id,
233
239
  location=self.location,
234
240
  prompt=self.prompt,
241
+ generation_config=self.generation_config,
242
+ safety_settings=self.safety_settings,
235
243
  pretrained_model=self.pretrained_model,
236
244
  )
237
245
 
@@ -251,6 +259,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
251
259
  service belongs to (templated).
252
260
  :param prompt: Required. Inputs or queries that a user or a program gives
253
261
  to the Multi-modal model, in order to elicit a specific response (templated).
262
+ :param generation_config: Optional. Generation configuration settings.
263
+ :param safety_settings: Optional. Per request settings for blocking unsafe content.
254
264
  :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
255
265
  supporting prompts with text-only input, including natural language
256
266
  tasks, multi-turn text and code chat, and code generation. It can
@@ -279,6 +289,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
279
289
  prompt: str,
280
290
  media_gcs_path: str,
281
291
  mime_type: str,
292
+ generation_config: dict | None = None,
293
+ safety_settings: dict | None = None,
282
294
  pretrained_model: str = "gemini-pro-vision",
283
295
  gcp_conn_id: str = "google_cloud_default",
284
296
  impersonation_chain: str | Sequence[str] | None = None,
@@ -288,6 +300,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
288
300
  self.project_id = project_id
289
301
  self.location = location
290
302
  self.prompt = prompt
303
+ self.generation_config = generation_config
304
+ self.safety_settings = safety_settings
291
305
  self.pretrained_model = pretrained_model
292
306
  self.media_gcs_path = media_gcs_path
293
307
  self.mime_type = mime_type
@@ -303,6 +317,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
303
317
  project_id=self.project_id,
304
318
  location=self.location,
305
319
  prompt=self.prompt,
320
+ generation_config=self.generation_config,
321
+ safety_settings=self.safety_settings,
306
322
  pretrained_model=self.pretrained_model,
307
323
  media_gcs_path=self.media_gcs_path,
308
324
  mime_type=self.mime_type,
@@ -19,13 +19,24 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from typing import TYPE_CHECKING, Any, Sequence
22
+ import json
23
+ from datetime import datetime, timedelta
24
+ from typing import TYPE_CHECKING, Any, Iterable, Sequence
23
25
 
26
+ from dateutil import parser
24
27
  from deprecated import deprecated
28
+ from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse
25
29
 
30
+ from airflow.configuration import conf
26
31
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
27
- from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger
32
+ from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook
33
+ from airflow.providers.google.cloud.triggers.cloud_composer import (
34
+ CloudComposerDAGRunTrigger,
35
+ CloudComposerExecutionTrigger,
36
+ )
37
+ from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
28
38
  from airflow.sensors.base import BaseSensorOperator
39
+ from airflow.utils.state import TaskInstanceState
29
40
 
30
41
  if TYPE_CHECKING:
31
42
  from airflow.utils.context import Context
@@ -117,3 +128,161 @@ class CloudComposerEnvironmentSensor(BaseSensorOperator):
117
128
  if self.soft_fail:
118
129
  raise AirflowSkipException(message)
119
130
  raise AirflowException(message)
131
+
132
+
133
+ class CloudComposerDAGRunSensor(BaseSensorOperator):
134
+ """
135
+ Check if a DAG run has completed.
136
+
137
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
138
+ :param region: Required. The ID of the Google Cloud region that the service belongs to.
139
+ :param environment_id: The name of the Composer environment.
140
+ :param composer_dag_id: The ID of executable DAG.
141
+ :param allowed_states: Iterable of allowed states, default is ``['success']``.
142
+ :param execution_range: execution DAGs time range. Sensor checks DAGs states only for DAGs which were
143
+ started in this time range. For yesterday, use [positive!] datetime.timedelta(days=1).
144
+ For future, use [negative!] datetime.timedelta(days=-1). For specific time, use list of
145
+ datetimes [datetime(2024,3,22,11,0,0), datetime(2024,3,22,12,0,0)].
146
+ Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for states from specific time in the
147
+ past till current time execution.
148
+ Default value datetime.timedelta(days=1).
149
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
150
+ :param impersonation_chain: Optional service account to impersonate using short-term
151
+ credentials, or chained list of accounts required to get the access_token
152
+ of the last account in the list, which will be impersonated in the request.
153
+ If set as a string, the account must grant the originating account
154
+ the Service Account Token Creator IAM role.
155
+ If set as a sequence, the identities from the list must grant
156
+ Service Account Token Creator IAM role to the directly preceding identity, with first
157
+ account from the list granting this role to the originating account (templated).
158
+ :param poll_interval: Optional: Control the rate of the poll for the result of deferrable run.
159
+ :param deferrable: Run sensor in deferrable mode.
160
+ """
161
+
162
+ template_fields = (
163
+ "project_id",
164
+ "region",
165
+ "environment_id",
166
+ "composer_dag_id",
167
+ "impersonation_chain",
168
+ )
169
+
170
+ def __init__(
171
+ self,
172
+ *,
173
+ project_id: str,
174
+ region: str,
175
+ environment_id: str,
176
+ composer_dag_id: str,
177
+ allowed_states: Iterable[str] | None = None,
178
+ execution_range: timedelta | list[datetime] | None = None,
179
+ gcp_conn_id: str = "google_cloud_default",
180
+ impersonation_chain: str | Sequence[str] | None = None,
181
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
182
+ poll_interval: int = 10,
183
+ **kwargs,
184
+ ) -> None:
185
+ super().__init__(**kwargs)
186
+ self.project_id = project_id
187
+ self.region = region
188
+ self.environment_id = environment_id
189
+ self.composer_dag_id = composer_dag_id
190
+ self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value]
191
+ self.execution_range = execution_range
192
+ self.gcp_conn_id = gcp_conn_id
193
+ self.impersonation_chain = impersonation_chain
194
+ self.deferrable = deferrable
195
+ self.poll_interval = poll_interval
196
+
197
+ def _get_execution_dates(self, context) -> tuple[datetime, datetime]:
198
+ if isinstance(self.execution_range, timedelta):
199
+ if self.execution_range < timedelta(0):
200
+ return context["logical_date"], context["logical_date"] - self.execution_range
201
+ else:
202
+ return context["logical_date"] - self.execution_range, context["logical_date"]
203
+ elif isinstance(self.execution_range, list) and len(self.execution_range) > 0:
204
+ return self.execution_range[0], self.execution_range[1] if len(
205
+ self.execution_range
206
+ ) > 1 else context["logical_date"]
207
+ else:
208
+ return context["logical_date"] - timedelta(1), context["logical_date"]
209
+
210
+ def poke(self, context: Context) -> bool:
211
+ start_date, end_date = self._get_execution_dates(context)
212
+
213
+ if datetime.now(end_date.tzinfo) < end_date:
214
+ return False
215
+
216
+ dag_runs = self._pull_dag_runs()
217
+
218
+ self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
219
+ allowed_states_status = self._check_dag_runs_states(
220
+ dag_runs=dag_runs,
221
+ start_date=start_date,
222
+ end_date=end_date,
223
+ )
224
+
225
+ return allowed_states_status
226
+
227
+ def _pull_dag_runs(self) -> list[dict]:
228
+ """Pull the list of dag runs."""
229
+ hook = CloudComposerHook(
230
+ gcp_conn_id=self.gcp_conn_id,
231
+ impersonation_chain=self.impersonation_chain,
232
+ )
233
+ dag_runs_cmd = hook.execute_airflow_command(
234
+ project_id=self.project_id,
235
+ region=self.region,
236
+ environment_id=self.environment_id,
237
+ command="dags",
238
+ subcommand="list-runs",
239
+ parameters=["-d", self.composer_dag_id, "-o", "json"],
240
+ )
241
+ cmd_result = hook.wait_command_execution_result(
242
+ project_id=self.project_id,
243
+ region=self.region,
244
+ environment_id=self.environment_id,
245
+ execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
246
+ )
247
+ dag_runs = json.loads(cmd_result["output"][0]["content"])
248
+ return dag_runs
249
+
250
+ def _check_dag_runs_states(
251
+ self,
252
+ dag_runs: list[dict],
253
+ start_date: datetime,
254
+ end_date: datetime,
255
+ ) -> bool:
256
+ for dag_run in dag_runs:
257
+ if (
258
+ start_date.timestamp()
259
+ < parser.parse(dag_run["execution_date"]).timestamp()
260
+ < end_date.timestamp()
261
+ ) and dag_run["state"] not in self.allowed_states:
262
+ return False
263
+ return True
264
+
265
+ def execute(self, context: Context) -> None:
266
+ if self.deferrable:
267
+ start_date, end_date = self._get_execution_dates(context)
268
+ self.defer(
269
+ trigger=CloudComposerDAGRunTrigger(
270
+ project_id=self.project_id,
271
+ region=self.region,
272
+ environment_id=self.environment_id,
273
+ composer_dag_id=self.composer_dag_id,
274
+ start_date=start_date,
275
+ end_date=end_date,
276
+ allowed_states=self.allowed_states,
277
+ gcp_conn_id=self.gcp_conn_id,
278
+ impersonation_chain=self.impersonation_chain,
279
+ poll_interval=self.poll_interval,
280
+ ),
281
+ method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
282
+ )
283
+ super().execute(context)
284
+
285
+ def execute_complete(self, context: Context, event: dict):
286
+ if event and event["status"] == "error":
287
+ raise AirflowException(event["message"])
288
+ self.log.info("DAG %s has executed successfully.", self.composer_dag_id)
@@ -122,3 +122,16 @@ class AzureBlobStorageToGCSOperator(BaseOperator):
122
122
  self.bucket_name,
123
123
  )
124
124
  return f"gs://{self.bucket_name}/{self.object_name}"
125
+
126
+ def get_openlineage_facets_on_start(self):
127
+ from openlineage.client.run import Dataset
128
+
129
+ from airflow.providers.openlineage.extractors import OperatorLineage
130
+
131
+ wasb_hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
132
+ account_name = wasb_hook.get_conn().account_name
133
+
134
+ return OperatorLineage(
135
+ inputs=[Dataset(namespace=f"wasbs://{self.container_name}@{account_name}", name=self.blob_name)],
136
+ outputs=[Dataset(namespace=f"gs://{self.bucket_name}", name=self.object_name)],
137
+ )