apache-airflow-providers-google 10.19.0rc1__py3-none-any.whl → 10.20.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/LICENSE +4 -4
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +4 -4
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +26 -0
- airflow/providers/google/cloud/hooks/dataflow.py +132 -1
- airflow/providers/google/cloud/hooks/datapipeline.py +22 -73
- airflow/providers/google/cloud/hooks/gcs.py +21 -0
- airflow/providers/google/cloud/hooks/pubsub.py +10 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -0
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +15 -3
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
- airflow/providers/google/cloud/links/dataflow.py +25 -0
- airflow/providers/google/cloud/openlineage/mixins.py +271 -0
- airflow/providers/google/cloud/openlineage/utils.py +5 -218
- airflow/providers/google/cloud/operators/bigquery.py +74 -20
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +76 -0
- airflow/providers/google/cloud/operators/dataflow.py +235 -1
- airflow/providers/google/cloud/operators/datapipeline.py +29 -121
- airflow/providers/google/cloud/operators/dataplex.py +1 -1
- airflow/providers/google/cloud/operators/dataproc_metastore.py +17 -6
- airflow/providers/google/cloud/operators/kubernetes_engine.py +9 -6
- airflow/providers/google/cloud/operators/pubsub.py +18 -0
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +6 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +16 -0
- airflow/providers/google/cloud/sensors/cloud_composer.py +171 -2
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +13 -0
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +56 -1
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +6 -12
- airflow/providers/google/cloud/triggers/cloud_composer.py +115 -0
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -0
- airflow/providers/google/cloud/utils/credentials_provider.py +81 -6
- airflow/providers/google/cloud/utils/external_token_supplier.py +175 -0
- airflow/providers/google/common/hooks/base_google.py +35 -1
- airflow/providers/google/common/utils/id_token_credentials.py +1 -1
- airflow/providers/google/get_provider_info.py +19 -14
- {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/METADATA +41 -35
- {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/RECORD +39 -37
- {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,138 +19,46 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
-
from
|
23
|
-
|
24
|
-
from airflow.exceptions import
|
25
|
-
from airflow.providers.google.cloud.hooks.
|
26
|
-
from airflow.providers.google.cloud.operators.
|
22
|
+
from deprecated import deprecated
|
23
|
+
|
24
|
+
from airflow.exceptions import AirflowProviderDeprecationWarning
|
25
|
+
from airflow.providers.google.cloud.hooks.dataflow import DEFAULT_DATAFLOW_LOCATION
|
26
|
+
from airflow.providers.google.cloud.operators.dataflow import (
|
27
|
+
DataflowCreatePipelineOperator,
|
28
|
+
DataflowRunPipelineOperator,
|
29
|
+
)
|
27
30
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
28
31
|
|
29
|
-
if TYPE_CHECKING:
|
30
|
-
from airflow.utils.context import Context
|
31
|
-
|
32
32
|
|
33
|
-
|
34
|
-
""
|
35
|
-
|
33
|
+
@deprecated(
|
34
|
+
reason="This operator is deprecated and will be removed after 01.12.2024. "
|
35
|
+
"Please use `DataflowCreatePipelineOperator`.",
|
36
|
+
category=AirflowProviderDeprecationWarning,
|
37
|
+
)
|
38
|
+
class CreateDataPipelineOperator(DataflowCreatePipelineOperator):
|
39
|
+
"""Creates a new Data Pipelines instance from the Data Pipelines API."""
|
36
40
|
|
37
|
-
:param body: The request body (contains instance of Pipeline). See:
|
38
|
-
https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/create#request-body
|
39
|
-
:param project_id: The ID of the GCP project that owns the job.
|
40
|
-
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
|
41
|
-
:param gcp_conn_id: The connection ID to connect to the Google Cloud
|
42
|
-
Platform.
|
43
|
-
:param impersonation_chain: Optional service account to impersonate using short-term
|
44
|
-
credentials, or chained list of accounts required to get the access_token
|
45
|
-
of the last account in the list, which will be impersonated in the request.
|
46
|
-
If set as a string, the account must grant the originating account
|
47
|
-
the Service Account Token Creator IAM role.
|
48
|
-
If set as a sequence, the identities from the list must grant
|
49
|
-
Service Account Token Creator IAM role to the directly preceding identity, with first
|
50
|
-
account from the list granting this role to the originating account (templated).
|
51
41
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
self,
|
60
|
-
*,
|
61
|
-
body: dict,
|
62
|
-
project_id: str = PROVIDE_PROJECT_ID,
|
63
|
-
location: str = DEFAULT_DATAPIPELINE_LOCATION,
|
64
|
-
gcp_conn_id: str = "google_cloud_default",
|
65
|
-
impersonation_chain: str | Sequence[str] | None = None,
|
66
|
-
**kwargs,
|
67
|
-
) -> None:
|
68
|
-
super().__init__(**kwargs)
|
69
|
-
|
70
|
-
self.body = body
|
71
|
-
self.project_id = project_id
|
72
|
-
self.location = location
|
73
|
-
self.gcp_conn_id = gcp_conn_id
|
74
|
-
self.impersonation_chain = impersonation_chain
|
75
|
-
self.datapipeline_hook: DataPipelineHook | None = None
|
76
|
-
self.body["pipelineSources"] = {"airflow": "airflow"}
|
77
|
-
|
78
|
-
def execute(self, context: Context):
|
79
|
-
if self.body is None:
|
80
|
-
raise AirflowException(
|
81
|
-
"Request Body not given; cannot create a Data Pipeline without the Request Body."
|
82
|
-
)
|
83
|
-
if self.project_id is None:
|
84
|
-
raise AirflowException(
|
85
|
-
"Project ID not given; cannot create a Data Pipeline without the Project ID."
|
86
|
-
)
|
87
|
-
if self.location is None:
|
88
|
-
raise AirflowException("location not given; cannot create a Data Pipeline without the location.")
|
89
|
-
|
90
|
-
self.datapipeline_hook = DataPipelineHook(
|
91
|
-
gcp_conn_id=self.gcp_conn_id,
|
92
|
-
impersonation_chain=self.impersonation_chain,
|
93
|
-
)
|
94
|
-
|
95
|
-
self.data_pipeline = self.datapipeline_hook.create_data_pipeline(
|
96
|
-
project_id=self.project_id,
|
97
|
-
body=self.body,
|
98
|
-
location=self.location,
|
99
|
-
)
|
100
|
-
if self.data_pipeline:
|
101
|
-
if "error" in self.data_pipeline:
|
102
|
-
raise AirflowException(self.data_pipeline.get("error").get("message"))
|
103
|
-
|
104
|
-
return self.data_pipeline
|
105
|
-
|
106
|
-
|
107
|
-
class RunDataPipelineOperator(GoogleCloudBaseOperator):
|
108
|
-
"""
|
109
|
-
Runs a Data Pipelines Instance using the Data Pipelines API.
|
110
|
-
|
111
|
-
:param data_pipeline_name: The display name of the pipeline. In example
|
112
|
-
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
|
113
|
-
:param project_id: The ID of the GCP project that owns the job.
|
114
|
-
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
|
115
|
-
:param gcp_conn_id: The connection ID to connect to the Google Cloud
|
116
|
-
Platform.
|
117
|
-
|
118
|
-
Returns the created Job in JSON representation.
|
119
|
-
"""
|
42
|
+
@deprecated(
|
43
|
+
reason="This operator is deprecated and will be removed after 01.12.2024. "
|
44
|
+
"Please use `DataflowRunPipelineOperator`.",
|
45
|
+
category=AirflowProviderDeprecationWarning,
|
46
|
+
)
|
47
|
+
class RunDataPipelineOperator(DataflowRunPipelineOperator):
|
48
|
+
"""Runs a Data Pipelines Instance using the Data Pipelines API."""
|
120
49
|
|
121
50
|
def __init__(
|
122
51
|
self,
|
123
52
|
data_pipeline_name: str,
|
124
53
|
project_id: str = PROVIDE_PROJECT_ID,
|
125
|
-
location: str =
|
54
|
+
location: str = DEFAULT_DATAFLOW_LOCATION,
|
126
55
|
gcp_conn_id: str = "google_cloud_default",
|
127
56
|
**kwargs,
|
128
57
|
) -> None:
|
129
|
-
super().__init__(
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
def execute(self, context: Context):
|
137
|
-
self.data_pipeline_hook = DataPipelineHook(gcp_conn_id=self.gcp_conn_id)
|
138
|
-
|
139
|
-
if self.data_pipeline_name is None:
|
140
|
-
raise AirflowException("Data Pipeline name not given; cannot run unspecified pipeline.")
|
141
|
-
if self.project_id is None:
|
142
|
-
raise AirflowException("Data Pipeline Project ID not given; cannot run pipeline.")
|
143
|
-
if self.location is None:
|
144
|
-
raise AirflowException("Data Pipeline location not given; cannot run pipeline.")
|
145
|
-
|
146
|
-
self.response = self.data_pipeline_hook.run_data_pipeline(
|
147
|
-
data_pipeline_name=self.data_pipeline_name,
|
148
|
-
project_id=self.project_id,
|
149
|
-
location=self.location,
|
58
|
+
super().__init__(
|
59
|
+
pipeline_name=data_pipeline_name,
|
60
|
+
project_id=project_id,
|
61
|
+
location=location,
|
62
|
+
gcp_conn_id=gcp_conn_id,
|
63
|
+
**kwargs,
|
150
64
|
)
|
151
|
-
|
152
|
-
if self.response:
|
153
|
-
if "error" in self.response:
|
154
|
-
raise AirflowException(self.response.get("error").get("message"))
|
155
|
-
|
156
|
-
return self.response
|
@@ -1067,7 +1067,7 @@ class DataplexGetDataQualityScanResultOperator(GoogleCloudBaseOperator):
|
|
1067
1067
|
is available.
|
1068
1068
|
"""
|
1069
1069
|
|
1070
|
-
template_fields = ("project_id", "data_scan_id", "impersonation_chain")
|
1070
|
+
template_fields = ("project_id", "data_scan_id", "impersonation_chain", "job_id")
|
1071
1071
|
|
1072
1072
|
def __init__(
|
1073
1073
|
self,
|
@@ -431,7 +431,7 @@ class DataprocMetastoreCreateServiceOperator(GoogleCloudBaseOperator):
|
|
431
431
|
hook = DataprocMetastoreHook(
|
432
432
|
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
|
433
433
|
)
|
434
|
-
self.log.info("Creating Dataproc Metastore service: %s", self.
|
434
|
+
self.log.info("Creating Dataproc Metastore service: %s", self.service_id)
|
435
435
|
try:
|
436
436
|
operation = hook.create_service(
|
437
437
|
region=self.region,
|
@@ -548,13 +548,24 @@ class DataprocMetastoreDeleteBackupOperator(GoogleCloudBaseOperator):
|
|
548
548
|
class DataprocMetastoreDeleteServiceOperator(GoogleCloudBaseOperator):
|
549
549
|
"""Delete a single service.
|
550
550
|
|
551
|
-
:param
|
552
|
-
[DataprocMetastore.DeleteService][google.cloud.metastore.v1.DataprocMetastore.DeleteService].
|
551
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
553
552
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
553
|
+
:param service_id: Required. The ID of the metastore service, which is used as the final component of
|
554
|
+
the metastore service's name. This value must be between 2 and 63 characters long inclusive, begin
|
555
|
+
with a letter, end with a letter or number, and consist of alphanumeric ASCII characters or
|
556
|
+
hyphens.
|
554
557
|
:param retry: Designation of what errors, if any, should be retried.
|
555
558
|
:param timeout: The timeout for this request.
|
556
559
|
:param metadata: Strings which should be sent along with the request as metadata.
|
557
|
-
:param gcp_conn_id:
|
560
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
561
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
562
|
+
credentials, or chained list of accounts required to get the access_token
|
563
|
+
of the last account in the list, which will be impersonated in the request.
|
564
|
+
If set as a string, the account must grant the originating account
|
565
|
+
the Service Account Token Creator IAM role.
|
566
|
+
If set as a sequence, the identities from the list must grant
|
567
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
568
|
+
account from the list granting this role to the originating account (templated).
|
558
569
|
"""
|
559
570
|
|
560
571
|
template_fields: Sequence[str] = (
|
@@ -589,7 +600,7 @@ class DataprocMetastoreDeleteServiceOperator(GoogleCloudBaseOperator):
|
|
589
600
|
hook = DataprocMetastoreHook(
|
590
601
|
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
|
591
602
|
)
|
592
|
-
self.log.info("Deleting Dataproc Metastore service: %s", self.
|
603
|
+
self.log.info("Deleting Dataproc Metastore service: %s", self.service_id)
|
593
604
|
operation = hook.delete_service(
|
594
605
|
region=self.region,
|
595
606
|
project_id=self.project_id,
|
@@ -599,7 +610,7 @@ class DataprocMetastoreDeleteServiceOperator(GoogleCloudBaseOperator):
|
|
599
610
|
metadata=self.metadata,
|
600
611
|
)
|
601
612
|
hook.wait_for_operation(self.timeout, operation)
|
602
|
-
self.log.info("Service %s deleted successfully", self.
|
613
|
+
self.log.info("Service %s deleted successfully", self.service_id)
|
603
614
|
|
604
615
|
|
605
616
|
class DataprocMetastoreExportMetadataOperator(GoogleCloudBaseOperator):
|
@@ -73,6 +73,7 @@ except ImportError:
|
|
73
73
|
|
74
74
|
if TYPE_CHECKING:
|
75
75
|
from kubernetes.client.models import V1Job, V1Pod
|
76
|
+
from pendulum import DateTime
|
76
77
|
|
77
78
|
from airflow.utils.context import Context
|
78
79
|
|
@@ -773,16 +774,16 @@ class GKEStartPodOperator(KubernetesPodOperator):
|
|
773
774
|
self._ssl_ca_cert = cluster.master_auth.cluster_ca_certificate
|
774
775
|
return self._cluster_url, self._ssl_ca_cert
|
775
776
|
|
776
|
-
def invoke_defer_method(self):
|
777
|
+
def invoke_defer_method(self, last_log_time: DateTime | None = None):
|
777
778
|
"""Redefine triggers which are being used in child classes."""
|
778
779
|
trigger_start_time = utcnow()
|
779
780
|
self.defer(
|
780
781
|
trigger=GKEStartPodTrigger(
|
781
|
-
pod_name=self.pod.metadata.name,
|
782
|
-
pod_namespace=self.pod.metadata.namespace,
|
782
|
+
pod_name=self.pod.metadata.name, # type: ignore[union-attr]
|
783
|
+
pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr]
|
783
784
|
trigger_start_time=trigger_start_time,
|
784
|
-
cluster_url=self._cluster_url,
|
785
|
-
ssl_ca_cert=self._ssl_ca_cert,
|
785
|
+
cluster_url=self._cluster_url, # type: ignore[arg-type]
|
786
|
+
ssl_ca_cert=self._ssl_ca_cert, # type: ignore[arg-type]
|
786
787
|
get_logs=self.get_logs,
|
787
788
|
startup_timeout=self.startup_timeout_seconds,
|
788
789
|
cluster_context=self.cluster_context,
|
@@ -792,6 +793,8 @@ class GKEStartPodOperator(KubernetesPodOperator):
|
|
792
793
|
on_finish_action=self.on_finish_action,
|
793
794
|
gcp_conn_id=self.gcp_conn_id,
|
794
795
|
impersonation_chain=self.impersonation_chain,
|
796
|
+
logging_interval=self.logging_interval,
|
797
|
+
last_log_time=last_log_time,
|
795
798
|
),
|
796
799
|
method_name="execute_complete",
|
797
800
|
kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": self._ssl_ca_cert},
|
@@ -802,7 +805,7 @@ class GKEStartPodOperator(KubernetesPodOperator):
|
|
802
805
|
self._cluster_url = kwargs["cluster_url"]
|
803
806
|
self._ssl_ca_cert = kwargs["ssl_ca_cert"]
|
804
807
|
|
805
|
-
return super().
|
808
|
+
return super().trigger_reentry(context, event)
|
806
809
|
|
807
810
|
|
808
811
|
class GKEStartJobOperator(KubernetesJobOperator):
|
@@ -604,6 +604,7 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
|
|
604
604
|
m1 = {"data": b"Hello, World!", "attributes": {"type": "greeting"}}
|
605
605
|
m2 = {"data": b"Knock, knock"}
|
606
606
|
m3 = {"attributes": {"foo": ""}}
|
607
|
+
m4 = {"data": b"Who's there?", "attributes": {"ordering_key": "knock_knock"}}
|
607
608
|
|
608
609
|
t1 = PubSubPublishMessageOperator(
|
609
610
|
project_id="my-project",
|
@@ -613,6 +614,15 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
|
|
613
614
|
dag=dag,
|
614
615
|
)
|
615
616
|
|
617
|
+
t2 = PubSubPublishMessageOperator(
|
618
|
+
project_id="my-project",
|
619
|
+
topic="my_topic",
|
620
|
+
messages=[m4],
|
621
|
+
create_topic=True,
|
622
|
+
enable_message_ordering=True,
|
623
|
+
dag=dag,
|
624
|
+
)
|
625
|
+
|
616
626
|
``project_id``, ``topic``, and ``messages`` are templated so you can use Jinja templating
|
617
627
|
in their values.
|
618
628
|
|
@@ -632,6 +642,10 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
|
|
632
642
|
https://cloud.google.com/pubsub/docs/reference/rest/v1/PubsubMessage
|
633
643
|
:param gcp_conn_id: The connection ID to use connecting to
|
634
644
|
Google Cloud.
|
645
|
+
:param enable_message_ordering: If true, messages published with the same
|
646
|
+
ordering_key in PubsubMessage will be delivered to the subscribers in the order
|
647
|
+
in which they are received by the Pub/Sub system. Otherwise, they may be
|
648
|
+
delivered in any order. Default is False.
|
635
649
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
636
650
|
credentials, or chained list of accounts required to get the access_token
|
637
651
|
of the last account in the list, which will be impersonated in the request.
|
@@ -646,6 +660,7 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
|
|
646
660
|
"project_id",
|
647
661
|
"topic",
|
648
662
|
"messages",
|
663
|
+
"enable_message_ordering",
|
649
664
|
"impersonation_chain",
|
650
665
|
)
|
651
666
|
ui_color = "#0273d4"
|
@@ -657,6 +672,7 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
|
|
657
672
|
messages: list,
|
658
673
|
project_id: str = PROVIDE_PROJECT_ID,
|
659
674
|
gcp_conn_id: str = "google_cloud_default",
|
675
|
+
enable_message_ordering: bool = False,
|
660
676
|
impersonation_chain: str | Sequence[str] | None = None,
|
661
677
|
**kwargs,
|
662
678
|
) -> None:
|
@@ -665,12 +681,14 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
|
|
665
681
|
self.topic = topic
|
666
682
|
self.messages = messages
|
667
683
|
self.gcp_conn_id = gcp_conn_id
|
684
|
+
self.enable_message_ordering = enable_message_ordering
|
668
685
|
self.impersonation_chain = impersonation_chain
|
669
686
|
|
670
687
|
def execute(self, context: Context) -> None:
|
671
688
|
hook = PubSubHook(
|
672
689
|
gcp_conn_id=self.gcp_conn_id,
|
673
690
|
impersonation_chain=self.impersonation_chain,
|
691
|
+
enable_message_ordering=self.enable_message_ordering,
|
674
692
|
)
|
675
693
|
|
676
694
|
self.log.info("Publishing to topic %s", self.topic)
|
@@ -138,6 +138,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
138
138
|
region: str,
|
139
139
|
impersonation_chain: str | Sequence[str] | None = None,
|
140
140
|
parent_model: str | None = None,
|
141
|
+
window_stride_length: int | None = None,
|
142
|
+
window_max_count: int | None = None,
|
141
143
|
**kwargs,
|
142
144
|
) -> None:
|
143
145
|
super().__init__(
|
@@ -170,6 +172,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
170
172
|
self.quantiles = quantiles
|
171
173
|
self.validation_options = validation_options
|
172
174
|
self.budget_milli_node_hours = budget_milli_node_hours
|
175
|
+
self.window_stride_length = window_stride_length
|
176
|
+
self.window_max_count = window_max_count
|
173
177
|
|
174
178
|
def execute(self, context: Context):
|
175
179
|
self.hook = AutoMLHook(
|
@@ -220,6 +224,8 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
220
224
|
model_display_name=self.model_display_name,
|
221
225
|
model_labels=self.model_labels,
|
222
226
|
sync=self.sync,
|
227
|
+
window_stride_length=self.window_stride_length,
|
228
|
+
window_max_count=self.window_max_count,
|
223
229
|
)
|
224
230
|
|
225
231
|
if model:
|
@@ -187,6 +187,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
|
|
187
187
|
service belongs to (templated).
|
188
188
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
189
189
|
to the Multi-modal model, in order to elicit a specific response (templated).
|
190
|
+
:param generation_config: Optional. Generation configuration settings.
|
191
|
+
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
190
192
|
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
|
191
193
|
supporting prompts with text-only input, including natural language
|
192
194
|
tasks, multi-turn text and code chat, and code generation. It can
|
@@ -210,6 +212,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
|
|
210
212
|
project_id: str,
|
211
213
|
location: str,
|
212
214
|
prompt: str,
|
215
|
+
generation_config: dict | None = None,
|
216
|
+
safety_settings: dict | None = None,
|
213
217
|
pretrained_model: str = "gemini-pro",
|
214
218
|
gcp_conn_id: str = "google_cloud_default",
|
215
219
|
impersonation_chain: str | Sequence[str] | None = None,
|
@@ -219,6 +223,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
|
|
219
223
|
self.project_id = project_id
|
220
224
|
self.location = location
|
221
225
|
self.prompt = prompt
|
226
|
+
self.generation_config = generation_config
|
227
|
+
self.safety_settings = safety_settings
|
222
228
|
self.pretrained_model = pretrained_model
|
223
229
|
self.gcp_conn_id = gcp_conn_id
|
224
230
|
self.impersonation_chain = impersonation_chain
|
@@ -232,6 +238,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
|
|
232
238
|
project_id=self.project_id,
|
233
239
|
location=self.location,
|
234
240
|
prompt=self.prompt,
|
241
|
+
generation_config=self.generation_config,
|
242
|
+
safety_settings=self.safety_settings,
|
235
243
|
pretrained_model=self.pretrained_model,
|
236
244
|
)
|
237
245
|
|
@@ -251,6 +259,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
|
|
251
259
|
service belongs to (templated).
|
252
260
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
253
261
|
to the Multi-modal model, in order to elicit a specific response (templated).
|
262
|
+
:param generation_config: Optional. Generation configuration settings.
|
263
|
+
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
254
264
|
:param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
|
255
265
|
supporting prompts with text-only input, including natural language
|
256
266
|
tasks, multi-turn text and code chat, and code generation. It can
|
@@ -279,6 +289,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
|
|
279
289
|
prompt: str,
|
280
290
|
media_gcs_path: str,
|
281
291
|
mime_type: str,
|
292
|
+
generation_config: dict | None = None,
|
293
|
+
safety_settings: dict | None = None,
|
282
294
|
pretrained_model: str = "gemini-pro-vision",
|
283
295
|
gcp_conn_id: str = "google_cloud_default",
|
284
296
|
impersonation_chain: str | Sequence[str] | None = None,
|
@@ -288,6 +300,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
|
|
288
300
|
self.project_id = project_id
|
289
301
|
self.location = location
|
290
302
|
self.prompt = prompt
|
303
|
+
self.generation_config = generation_config
|
304
|
+
self.safety_settings = safety_settings
|
291
305
|
self.pretrained_model = pretrained_model
|
292
306
|
self.media_gcs_path = media_gcs_path
|
293
307
|
self.mime_type = mime_type
|
@@ -303,6 +317,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
|
|
303
317
|
project_id=self.project_id,
|
304
318
|
location=self.location,
|
305
319
|
prompt=self.prompt,
|
320
|
+
generation_config=self.generation_config,
|
321
|
+
safety_settings=self.safety_settings,
|
306
322
|
pretrained_model=self.pretrained_model,
|
307
323
|
media_gcs_path=self.media_gcs_path,
|
308
324
|
mime_type=self.mime_type,
|
@@ -19,13 +19,24 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
-
|
22
|
+
import json
|
23
|
+
from datetime import datetime, timedelta
|
24
|
+
from typing import TYPE_CHECKING, Any, Iterable, Sequence
|
23
25
|
|
26
|
+
from dateutil import parser
|
24
27
|
from deprecated import deprecated
|
28
|
+
from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse
|
25
29
|
|
30
|
+
from airflow.configuration import conf
|
26
31
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
|
27
|
-
from airflow.providers.google.cloud.
|
32
|
+
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook
|
33
|
+
from airflow.providers.google.cloud.triggers.cloud_composer import (
|
34
|
+
CloudComposerDAGRunTrigger,
|
35
|
+
CloudComposerExecutionTrigger,
|
36
|
+
)
|
37
|
+
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
|
28
38
|
from airflow.sensors.base import BaseSensorOperator
|
39
|
+
from airflow.utils.state import TaskInstanceState
|
29
40
|
|
30
41
|
if TYPE_CHECKING:
|
31
42
|
from airflow.utils.context import Context
|
@@ -117,3 +128,161 @@ class CloudComposerEnvironmentSensor(BaseSensorOperator):
|
|
117
128
|
if self.soft_fail:
|
118
129
|
raise AirflowSkipException(message)
|
119
130
|
raise AirflowException(message)
|
131
|
+
|
132
|
+
|
133
|
+
class CloudComposerDAGRunSensor(BaseSensorOperator):
|
134
|
+
"""
|
135
|
+
Check if a DAG run has completed.
|
136
|
+
|
137
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
138
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
139
|
+
:param environment_id: The name of the Composer environment.
|
140
|
+
:param composer_dag_id: The ID of executable DAG.
|
141
|
+
:param allowed_states: Iterable of allowed states, default is ``['success']``.
|
142
|
+
:param execution_range: execution DAGs time range. Sensor checks DAGs states only for DAGs which were
|
143
|
+
started in this time range. For yesterday, use [positive!] datetime.timedelta(days=1).
|
144
|
+
For future, use [negative!] datetime.timedelta(days=-1). For specific time, use list of
|
145
|
+
datetimes [datetime(2024,3,22,11,0,0), datetime(2024,3,22,12,0,0)].
|
146
|
+
Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for states from specific time in the
|
147
|
+
past till current time execution.
|
148
|
+
Default value datetime.timedelta(days=1).
|
149
|
+
:param gcp_conn_id: The connection ID to use when fetching connection info.
|
150
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
151
|
+
credentials, or chained list of accounts required to get the access_token
|
152
|
+
of the last account in the list, which will be impersonated in the request.
|
153
|
+
If set as a string, the account must grant the originating account
|
154
|
+
the Service Account Token Creator IAM role.
|
155
|
+
If set as a sequence, the identities from the list must grant
|
156
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
157
|
+
account from the list granting this role to the originating account (templated).
|
158
|
+
:param poll_interval: Optional: Control the rate of the poll for the result of deferrable run.
|
159
|
+
:param deferrable: Run sensor in deferrable mode.
|
160
|
+
"""
|
161
|
+
|
162
|
+
template_fields = (
|
163
|
+
"project_id",
|
164
|
+
"region",
|
165
|
+
"environment_id",
|
166
|
+
"composer_dag_id",
|
167
|
+
"impersonation_chain",
|
168
|
+
)
|
169
|
+
|
170
|
+
def __init__(
|
171
|
+
self,
|
172
|
+
*,
|
173
|
+
project_id: str,
|
174
|
+
region: str,
|
175
|
+
environment_id: str,
|
176
|
+
composer_dag_id: str,
|
177
|
+
allowed_states: Iterable[str] | None = None,
|
178
|
+
execution_range: timedelta | list[datetime] | None = None,
|
179
|
+
gcp_conn_id: str = "google_cloud_default",
|
180
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
181
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
182
|
+
poll_interval: int = 10,
|
183
|
+
**kwargs,
|
184
|
+
) -> None:
|
185
|
+
super().__init__(**kwargs)
|
186
|
+
self.project_id = project_id
|
187
|
+
self.region = region
|
188
|
+
self.environment_id = environment_id
|
189
|
+
self.composer_dag_id = composer_dag_id
|
190
|
+
self.allowed_states = list(allowed_states) if allowed_states else [TaskInstanceState.SUCCESS.value]
|
191
|
+
self.execution_range = execution_range
|
192
|
+
self.gcp_conn_id = gcp_conn_id
|
193
|
+
self.impersonation_chain = impersonation_chain
|
194
|
+
self.deferrable = deferrable
|
195
|
+
self.poll_interval = poll_interval
|
196
|
+
|
197
|
+
def _get_execution_dates(self, context) -> tuple[datetime, datetime]:
|
198
|
+
if isinstance(self.execution_range, timedelta):
|
199
|
+
if self.execution_range < timedelta(0):
|
200
|
+
return context["logical_date"], context["logical_date"] - self.execution_range
|
201
|
+
else:
|
202
|
+
return context["logical_date"] - self.execution_range, context["logical_date"]
|
203
|
+
elif isinstance(self.execution_range, list) and len(self.execution_range) > 0:
|
204
|
+
return self.execution_range[0], self.execution_range[1] if len(
|
205
|
+
self.execution_range
|
206
|
+
) > 1 else context["logical_date"]
|
207
|
+
else:
|
208
|
+
return context["logical_date"] - timedelta(1), context["logical_date"]
|
209
|
+
|
210
|
+
def poke(self, context: Context) -> bool:
|
211
|
+
start_date, end_date = self._get_execution_dates(context)
|
212
|
+
|
213
|
+
if datetime.now(end_date.tzinfo) < end_date:
|
214
|
+
return False
|
215
|
+
|
216
|
+
dag_runs = self._pull_dag_runs()
|
217
|
+
|
218
|
+
self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
|
219
|
+
allowed_states_status = self._check_dag_runs_states(
|
220
|
+
dag_runs=dag_runs,
|
221
|
+
start_date=start_date,
|
222
|
+
end_date=end_date,
|
223
|
+
)
|
224
|
+
|
225
|
+
return allowed_states_status
|
226
|
+
|
227
|
+
def _pull_dag_runs(self) -> list[dict]:
|
228
|
+
"""Pull the list of dag runs."""
|
229
|
+
hook = CloudComposerHook(
|
230
|
+
gcp_conn_id=self.gcp_conn_id,
|
231
|
+
impersonation_chain=self.impersonation_chain,
|
232
|
+
)
|
233
|
+
dag_runs_cmd = hook.execute_airflow_command(
|
234
|
+
project_id=self.project_id,
|
235
|
+
region=self.region,
|
236
|
+
environment_id=self.environment_id,
|
237
|
+
command="dags",
|
238
|
+
subcommand="list-runs",
|
239
|
+
parameters=["-d", self.composer_dag_id, "-o", "json"],
|
240
|
+
)
|
241
|
+
cmd_result = hook.wait_command_execution_result(
|
242
|
+
project_id=self.project_id,
|
243
|
+
region=self.region,
|
244
|
+
environment_id=self.environment_id,
|
245
|
+
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
|
246
|
+
)
|
247
|
+
dag_runs = json.loads(cmd_result["output"][0]["content"])
|
248
|
+
return dag_runs
|
249
|
+
|
250
|
+
def _check_dag_runs_states(
|
251
|
+
self,
|
252
|
+
dag_runs: list[dict],
|
253
|
+
start_date: datetime,
|
254
|
+
end_date: datetime,
|
255
|
+
) -> bool:
|
256
|
+
for dag_run in dag_runs:
|
257
|
+
if (
|
258
|
+
start_date.timestamp()
|
259
|
+
< parser.parse(dag_run["execution_date"]).timestamp()
|
260
|
+
< end_date.timestamp()
|
261
|
+
) and dag_run["state"] not in self.allowed_states:
|
262
|
+
return False
|
263
|
+
return True
|
264
|
+
|
265
|
+
def execute(self, context: Context) -> None:
|
266
|
+
if self.deferrable:
|
267
|
+
start_date, end_date = self._get_execution_dates(context)
|
268
|
+
self.defer(
|
269
|
+
trigger=CloudComposerDAGRunTrigger(
|
270
|
+
project_id=self.project_id,
|
271
|
+
region=self.region,
|
272
|
+
environment_id=self.environment_id,
|
273
|
+
composer_dag_id=self.composer_dag_id,
|
274
|
+
start_date=start_date,
|
275
|
+
end_date=end_date,
|
276
|
+
allowed_states=self.allowed_states,
|
277
|
+
gcp_conn_id=self.gcp_conn_id,
|
278
|
+
impersonation_chain=self.impersonation_chain,
|
279
|
+
poll_interval=self.poll_interval,
|
280
|
+
),
|
281
|
+
method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
|
282
|
+
)
|
283
|
+
super().execute(context)
|
284
|
+
|
285
|
+
def execute_complete(self, context: Context, event: dict):
|
286
|
+
if event and event["status"] == "error":
|
287
|
+
raise AirflowException(event["message"])
|
288
|
+
self.log.info("DAG %s has executed successfully.", self.composer_dag_id)
|
@@ -122,3 +122,16 @@ class AzureBlobStorageToGCSOperator(BaseOperator):
|
|
122
122
|
self.bucket_name,
|
123
123
|
)
|
124
124
|
return f"gs://{self.bucket_name}/{self.object_name}"
|
125
|
+
|
126
|
+
def get_openlineage_facets_on_start(self):
|
127
|
+
from openlineage.client.run import Dataset
|
128
|
+
|
129
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
130
|
+
|
131
|
+
wasb_hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
|
132
|
+
account_name = wasb_hook.get_conn().account_name
|
133
|
+
|
134
|
+
return OperatorLineage(
|
135
|
+
inputs=[Dataset(namespace=f"wasbs://{self.container_name}@{account_name}", name=self.blob_name)],
|
136
|
+
outputs=[Dataset(namespace=f"gs://{self.bucket_name}", name=self.object_name)],
|
137
|
+
)
|