apache-airflow-providers-google 18.0.0rc1__py3-none-any.whl → 18.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of apache-airflow-providers-google might be problematic. Click here for more details.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +5 -5
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/bundles/__init__.py +16 -0
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/bigquery.py +45 -42
- airflow/providers/google/cloud/hooks/cloud_composer.py +131 -1
- airflow/providers/google/cloud/hooks/cloud_sql.py +88 -13
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +16 -0
- airflow/providers/google/cloud/hooks/dataflow.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dataproc.py +3 -0
- airflow/providers/google/cloud/hooks/gcs.py +107 -3
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/looker.py +1 -1
- airflow/providers/google/cloud/hooks/spanner.py +45 -0
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +30 -0
- airflow/providers/google/cloud/links/base.py +11 -11
- airflow/providers/google/cloud/links/dataproc.py +2 -10
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +102 -1
- airflow/providers/google/cloud/openlineage/mixins.py +3 -1
- airflow/providers/google/cloud/operators/bigquery.py +2 -9
- airflow/providers/google/cloud/operators/cloud_run.py +2 -1
- airflow/providers/google/cloud/operators/cloud_sql.py +1 -1
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +89 -6
- airflow/providers/google/cloud/operators/datafusion.py +36 -7
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/spanner.py +22 -6
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +7 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +30 -0
- airflow/providers/google/cloud/operators/workflows.py +17 -6
- airflow/providers/google/cloud/sensors/bigquery.py +1 -1
- airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -6
- airflow/providers/google/cloud/sensors/bigtable.py +1 -6
- airflow/providers/google/cloud/sensors/cloud_composer.py +65 -31
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +1 -6
- airflow/providers/google/cloud/sensors/dataflow.py +1 -1
- airflow/providers/google/cloud/sensors/dataform.py +1 -6
- airflow/providers/google/cloud/sensors/datafusion.py +1 -6
- airflow/providers/google/cloud/sensors/dataplex.py +1 -6
- airflow/providers/google/cloud/sensors/dataprep.py +1 -6
- airflow/providers/google/cloud/sensors/dataproc.py +1 -6
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +1 -6
- airflow/providers/google/cloud/sensors/gcs.py +1 -7
- airflow/providers/google/cloud/sensors/looker.py +1 -6
- airflow/providers/google/cloud/sensors/pubsub.py +1 -6
- airflow/providers/google/cloud/sensors/tasks.py +1 -6
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +1 -6
- airflow/providers/google/cloud/sensors/workflows.py +1 -6
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +2 -1
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +11 -2
- airflow/providers/google/cloud/triggers/bigquery.py +15 -3
- airflow/providers/google/cloud/triggers/cloud_composer.py +51 -21
- airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +90 -0
- airflow/providers/google/cloud/triggers/pubsub.py +14 -18
- airflow/providers/google/common/hooks/base_google.py +1 -1
- airflow/providers/google/get_provider_info.py +15 -0
- airflow/providers/google/leveldb/hooks/leveldb.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +2 -8
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +1 -6
- airflow/providers/google/marketing_platform/sensors/display_video.py +1 -6
- airflow/providers/google/suite/sensors/drive.py +1 -6
- airflow/providers/google/version_compat.py +0 -20
- {apache_airflow_providers_google-18.0.0rc1.dist-info → apache_airflow_providers_google-18.1.0.dist-info}/METADATA +15 -15
- {apache_airflow_providers_google-18.0.0rc1.dist-info → apache_airflow_providers_google-18.1.0.dist-info}/RECORD +72 -65
- {apache_airflow_providers_google-18.0.0rc1.dist-info → apache_airflow_providers_google-18.1.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-18.0.0rc1.dist-info → apache_airflow_providers_google-18.1.0.dist-info}/entry_points.txt +0 -0
|
@@ -40,6 +40,7 @@ from airflow.providers.google.cloud.utils.helpers import resource_path_to_dict
|
|
|
40
40
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
|
41
41
|
|
|
42
42
|
if TYPE_CHECKING:
|
|
43
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
|
43
44
|
from airflow.utils.context import Context
|
|
44
45
|
|
|
45
46
|
|
|
@@ -777,6 +778,7 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
|
|
|
777
778
|
self.pipeline_timeout = pipeline_timeout
|
|
778
779
|
self.deferrable = deferrable
|
|
779
780
|
self.poll_interval = poll_interval
|
|
781
|
+
self.pipeline_id: str | None = None
|
|
780
782
|
|
|
781
783
|
if success_states:
|
|
782
784
|
self.success_states = success_states
|
|
@@ -796,14 +798,14 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
|
|
|
796
798
|
project_id=self.project_id,
|
|
797
799
|
)
|
|
798
800
|
api_url = instance["apiEndpoint"]
|
|
799
|
-
pipeline_id = hook.start_pipeline(
|
|
801
|
+
self.pipeline_id = hook.start_pipeline(
|
|
800
802
|
pipeline_name=self.pipeline_name,
|
|
801
803
|
pipeline_type=self.pipeline_type,
|
|
802
804
|
instance_url=api_url,
|
|
803
805
|
namespace=self.namespace,
|
|
804
806
|
runtime_args=self.runtime_args,
|
|
805
807
|
)
|
|
806
|
-
self.log.info("Pipeline %s submitted successfully.", pipeline_id)
|
|
808
|
+
self.log.info("Pipeline %s submitted successfully.", self.pipeline_id)
|
|
807
809
|
|
|
808
810
|
DataFusionPipelineLink.persist(
|
|
809
811
|
context=context,
|
|
@@ -824,7 +826,7 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
|
|
|
824
826
|
namespace=self.namespace,
|
|
825
827
|
pipeline_name=self.pipeline_name,
|
|
826
828
|
pipeline_type=self.pipeline_type.value,
|
|
827
|
-
pipeline_id=pipeline_id,
|
|
829
|
+
pipeline_id=self.pipeline_id,
|
|
828
830
|
poll_interval=self.poll_interval,
|
|
829
831
|
gcp_conn_id=self.gcp_conn_id,
|
|
830
832
|
impersonation_chain=self.impersonation_chain,
|
|
@@ -834,19 +836,21 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
|
|
|
834
836
|
else:
|
|
835
837
|
if not self.asynchronous:
|
|
836
838
|
# when NOT using asynchronous mode it will just wait for pipeline to finish and print message
|
|
837
|
-
self.log.info(
|
|
839
|
+
self.log.info(
|
|
840
|
+
"Waiting when pipeline %s will be in one of the success states", self.pipeline_id
|
|
841
|
+
)
|
|
838
842
|
hook.wait_for_pipeline_state(
|
|
839
843
|
success_states=self.success_states,
|
|
840
|
-
pipeline_id=pipeline_id,
|
|
844
|
+
pipeline_id=self.pipeline_id,
|
|
841
845
|
pipeline_name=self.pipeline_name,
|
|
842
846
|
pipeline_type=self.pipeline_type,
|
|
843
847
|
namespace=self.namespace,
|
|
844
848
|
instance_url=api_url,
|
|
845
849
|
timeout=self.pipeline_timeout,
|
|
846
850
|
)
|
|
847
|
-
self.log.info("Pipeline %s discovered success state.", pipeline_id)
|
|
851
|
+
self.log.info("Pipeline %s discovered success state.", self.pipeline_id)
|
|
848
852
|
# otherwise, return pipeline_id so that sensor can use it later to check the pipeline state
|
|
849
|
-
return pipeline_id
|
|
853
|
+
return self.pipeline_id
|
|
850
854
|
|
|
851
855
|
def execute_complete(self, context: Context, event: dict[str, Any]):
|
|
852
856
|
"""
|
|
@@ -863,6 +867,31 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
|
|
|
863
867
|
)
|
|
864
868
|
return event["pipeline_id"]
|
|
865
869
|
|
|
870
|
+
def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
|
|
871
|
+
"""Build and return OpenLineage facets and datasets for the completed pipeline start."""
|
|
872
|
+
from airflow.providers.common.compat.openlineage.facet import Dataset
|
|
873
|
+
from airflow.providers.google.cloud.openlineage.facets import DataFusionRunFacet
|
|
874
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
|
875
|
+
|
|
876
|
+
pipeline_resource = f"{self.project_id}:{self.location}:{self.instance_name}:{self.pipeline_name}"
|
|
877
|
+
|
|
878
|
+
inputs = [Dataset(namespace="datafusion", name=pipeline_resource)]
|
|
879
|
+
|
|
880
|
+
if self.pipeline_id:
|
|
881
|
+
output_name = f"{pipeline_resource}:{self.pipeline_id}"
|
|
882
|
+
else:
|
|
883
|
+
output_name = f"{pipeline_resource}:unknown"
|
|
884
|
+
outputs = [Dataset(namespace="datafusion", name=output_name)]
|
|
885
|
+
|
|
886
|
+
run_facets = {
|
|
887
|
+
"dataFusionRun": DataFusionRunFacet(
|
|
888
|
+
runId=self.pipeline_id,
|
|
889
|
+
runtimeArgs=self.runtime_args,
|
|
890
|
+
)
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
return OperatorLineage(inputs=inputs, outputs=outputs, run_facets=run_facets, job_facets={})
|
|
894
|
+
|
|
866
895
|
|
|
867
896
|
class CloudDataFusionStopPipelineOperator(GoogleCloudBaseOperator):
|
|
868
897
|
"""
|
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
|
4
|
+
# distributed with this work for additional information
|
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
|
7
|
+
# "License"); you may not use this file except in compliance
|
|
8
|
+
# with the License. You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
|
13
|
+
# software distributed under the License is distributed on an
|
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
15
|
+
# KIND, either express or implied. See the License for the
|
|
16
|
+
# specific language governing permissions and limitations
|
|
17
|
+
# under the License.
|
|
18
|
+
"""This module contains Google Gen AI operators."""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from collections.abc import Sequence
|
|
23
|
+
from typing import TYPE_CHECKING, Any
|
|
24
|
+
|
|
25
|
+
from airflow.providers.google.cloud.hooks.gen_ai import (
|
|
26
|
+
GenAIGenerativeModelHook,
|
|
27
|
+
)
|
|
28
|
+
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from google.genai.types import (
|
|
32
|
+
ContentListUnion,
|
|
33
|
+
ContentListUnionDict,
|
|
34
|
+
CountTokensConfigOrDict,
|
|
35
|
+
CreateCachedContentConfigOrDict,
|
|
36
|
+
CreateTuningJobConfigOrDict,
|
|
37
|
+
EmbedContentConfigOrDict,
|
|
38
|
+
GenerateContentConfig,
|
|
39
|
+
TuningDatasetOrDict,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
from airflow.utils.context import Context
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GenAIGenerateEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
46
|
+
"""
|
|
47
|
+
Uses the Gemini AI Embeddings API to generate embeddings for words, phrases, sentences, and code.
|
|
48
|
+
|
|
49
|
+
:param project_id: Required. The ID of the Google Cloud project that the
|
|
50
|
+
service belongs to (templated).
|
|
51
|
+
:param location: Required. The ID of the Google Cloud location that the
|
|
52
|
+
service belongs to (templated).
|
|
53
|
+
:param model: Required. The name of the model to use for content generation,
|
|
54
|
+
which can be a text-only or multimodal model. For example, `gemini-pro` or
|
|
55
|
+
`gemini-pro-vision`.
|
|
56
|
+
:param contents: Optional. The contents to use for embedding.
|
|
57
|
+
:param config: Optional. Configuration for embeddings.
|
|
58
|
+
:param gcp_conn_id: Optional. The connection ID to use connecting to Google Cloud.
|
|
59
|
+
:param impersonation_chain: Optional. Service account to impersonate using short-term
|
|
60
|
+
credentials, or chained list of accounts required to get the access_token
|
|
61
|
+
of the last account in the list, which will be impersonated in the request.
|
|
62
|
+
If set as a string, the account must grant the originating account
|
|
63
|
+
the Service Account Token Creator IAM role.
|
|
64
|
+
If set as a sequence, the identities from the list must grant
|
|
65
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
66
|
+
account from the list granting this role to the originating account (templated).
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
template_fields = ("location", "project_id", "impersonation_chain", "contents", "model", "config")
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
*,
|
|
74
|
+
project_id: str,
|
|
75
|
+
location: str,
|
|
76
|
+
model: str,
|
|
77
|
+
contents: ContentListUnion | ContentListUnionDict | list[str],
|
|
78
|
+
config: EmbedContentConfigOrDict | None = None,
|
|
79
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
80
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
81
|
+
**kwargs,
|
|
82
|
+
) -> None:
|
|
83
|
+
super().__init__(**kwargs)
|
|
84
|
+
self.project_id = project_id
|
|
85
|
+
self.location = location
|
|
86
|
+
self.contents = contents
|
|
87
|
+
self.config = config
|
|
88
|
+
self.model = model
|
|
89
|
+
self.gcp_conn_id = gcp_conn_id
|
|
90
|
+
self.impersonation_chain = impersonation_chain
|
|
91
|
+
|
|
92
|
+
def execute(self, context: Context):
|
|
93
|
+
self.hook = GenAIGenerativeModelHook(
|
|
94
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
95
|
+
impersonation_chain=self.impersonation_chain,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
self.log.info("Generating text embeddings...")
|
|
99
|
+
response = self.hook.embed_content(
|
|
100
|
+
project_id=self.project_id,
|
|
101
|
+
location=self.location,
|
|
102
|
+
contents=self.contents,
|
|
103
|
+
model=self.model,
|
|
104
|
+
config=self.config,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
self.log.info("Model response: %s", response)
|
|
108
|
+
context["ti"].xcom_push(key="model_response", value=response)
|
|
109
|
+
|
|
110
|
+
return response
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class GenAIGenerateContentOperator(GoogleCloudBaseOperator):
|
|
114
|
+
"""
|
|
115
|
+
Generate a model response based on given configuration. Input capabilities differ between models, including tuned models.
|
|
116
|
+
|
|
117
|
+
:param project_id: Required. The ID of the Google Cloud project that the
|
|
118
|
+
service belongs to (templated).
|
|
119
|
+
:param location: Required. The ID of the Google Cloud location that the
|
|
120
|
+
service belongs to (templated).
|
|
121
|
+
:param model: Required. The name of the model to use for content generation,
|
|
122
|
+
which can be a text-only or multimodal model. For example, `gemini-pro` or
|
|
123
|
+
`gemini-pro-vision`.
|
|
124
|
+
:param contents: Required. The multi-part content of a message that a user or a program
|
|
125
|
+
gives to the generative model, in order to elicit a specific response.
|
|
126
|
+
:param generation_config: Optional. Generation configuration settings.
|
|
127
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
128
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
129
|
+
credentials, or chained list of accounts required to get the access_token
|
|
130
|
+
of the last account in the list, which will be impersonated in the request.
|
|
131
|
+
If set as a string, the account must grant the originating account
|
|
132
|
+
the Service Account Token Creator IAM role.
|
|
133
|
+
If set as a sequence, the identities from the list must grant
|
|
134
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
135
|
+
account from the list granting this role to the originating account (templated).
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
template_fields = (
|
|
139
|
+
"generation_config",
|
|
140
|
+
"location",
|
|
141
|
+
"project_id",
|
|
142
|
+
"impersonation_chain",
|
|
143
|
+
"contents",
|
|
144
|
+
"model",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
*,
|
|
150
|
+
project_id: str,
|
|
151
|
+
location: str,
|
|
152
|
+
contents: ContentListUnionDict,
|
|
153
|
+
model: str,
|
|
154
|
+
generation_config: GenerateContentConfig | dict[str, Any] | None = None,
|
|
155
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
156
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
157
|
+
**kwargs,
|
|
158
|
+
) -> None:
|
|
159
|
+
super().__init__(**kwargs)
|
|
160
|
+
self.project_id = project_id
|
|
161
|
+
self.location = location
|
|
162
|
+
self.contents = contents
|
|
163
|
+
self.generation_config = generation_config
|
|
164
|
+
self.model = model
|
|
165
|
+
self.gcp_conn_id = gcp_conn_id
|
|
166
|
+
self.impersonation_chain = impersonation_chain
|
|
167
|
+
|
|
168
|
+
def execute(self, context: Context):
|
|
169
|
+
self.hook = GenAIGenerativeModelHook(
|
|
170
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
171
|
+
impersonation_chain=self.impersonation_chain,
|
|
172
|
+
)
|
|
173
|
+
response = self.hook.generate_content(
|
|
174
|
+
project_id=self.project_id,
|
|
175
|
+
location=self.location,
|
|
176
|
+
model=self.model,
|
|
177
|
+
contents=self.contents,
|
|
178
|
+
generation_config=self.generation_config,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
self.log.info("Created Content: %s", response)
|
|
182
|
+
context["ti"].xcom_push(key="model_response", value=response)
|
|
183
|
+
|
|
184
|
+
return response
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class GenAISupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
188
|
+
"""
|
|
189
|
+
Create a tuning job to adapt model behavior with a labeled dataset.
|
|
190
|
+
|
|
191
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
192
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
193
|
+
:param source_model: Required. A pre-trained model optimized for performing natural
|
|
194
|
+
language tasks such as classification, summarization, extraction, content
|
|
195
|
+
creation, and ideation.
|
|
196
|
+
:param training_dataset: Required. Cloud Storage URI of your training dataset. The dataset
|
|
197
|
+
must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
|
|
198
|
+
:param tuning_job_config: Optional. Configuration of the Tuning job to be created.
|
|
199
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
200
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
201
|
+
credentials, or chained list of accounts required to get the access_token
|
|
202
|
+
of the last account in the list, which will be impersonated in the request.
|
|
203
|
+
If set as a string, the account must grant the originating account
|
|
204
|
+
the Service Account Token Creator IAM role.
|
|
205
|
+
If set as a sequence, the identities from the list must grant
|
|
206
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
207
|
+
account from the list granting this role to the originating account (templated).
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
template_fields = (
|
|
211
|
+
"location",
|
|
212
|
+
"project_id",
|
|
213
|
+
"impersonation_chain",
|
|
214
|
+
"training_dataset",
|
|
215
|
+
"tuning_job_config",
|
|
216
|
+
"source_model",
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
*,
|
|
222
|
+
project_id: str,
|
|
223
|
+
location: str,
|
|
224
|
+
source_model: str,
|
|
225
|
+
training_dataset: TuningDatasetOrDict,
|
|
226
|
+
tuning_job_config: CreateTuningJobConfigOrDict | dict[str, Any] | None = None,
|
|
227
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
228
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
229
|
+
**kwargs,
|
|
230
|
+
) -> None:
|
|
231
|
+
super().__init__(**kwargs)
|
|
232
|
+
self.project_id = project_id
|
|
233
|
+
self.location = location
|
|
234
|
+
self.source_model = source_model
|
|
235
|
+
self.training_dataset = training_dataset
|
|
236
|
+
self.tuning_job_config = tuning_job_config
|
|
237
|
+
self.gcp_conn_id = gcp_conn_id
|
|
238
|
+
self.impersonation_chain = impersonation_chain
|
|
239
|
+
|
|
240
|
+
def execute(self, context: Context):
|
|
241
|
+
self.hook = GenAIGenerativeModelHook(
|
|
242
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
243
|
+
impersonation_chain=self.impersonation_chain,
|
|
244
|
+
)
|
|
245
|
+
response = self.hook.supervised_fine_tuning_train(
|
|
246
|
+
project_id=self.project_id,
|
|
247
|
+
location=self.location,
|
|
248
|
+
source_model=self.source_model,
|
|
249
|
+
training_dataset=self.training_dataset,
|
|
250
|
+
tuning_job_config=self.tuning_job_config,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
self.log.info("Tuned Model Name: %s", response.tuned_model.model) # type: ignore[union-attr,arg-type]
|
|
254
|
+
self.log.info("Tuned Model EndpointName: %s", response.tuned_model.endpoint) # type: ignore[union-attr,arg-type]
|
|
255
|
+
|
|
256
|
+
context["ti"].xcom_push(key="tuned_model_name", value=response.tuned_model.model) # type: ignore[union-attr,arg-type]
|
|
257
|
+
context["ti"].xcom_push(key="tuned_model_endpoint_name", value=response.tuned_model.endpoint) # type: ignore[union-attr,arg-type]
|
|
258
|
+
|
|
259
|
+
result = {
|
|
260
|
+
"tuned_model_name": response.tuned_model.model, # type: ignore[union-attr,arg-type]
|
|
261
|
+
"tuned_model_endpoint_name": response.tuned_model.endpoint, # type: ignore[union-attr,arg-type]
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
return result
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class GenAICountTokensOperator(GoogleCloudBaseOperator):
|
|
268
|
+
"""
|
|
269
|
+
Use Count Tokens API to calculate the number of input tokens before sending a request to Gemini API.
|
|
270
|
+
|
|
271
|
+
:param project_id: Required. The ID of the Google Cloud project that the
|
|
272
|
+
service belongs to (templated).
|
|
273
|
+
:param location: Required. The ID of the Google Cloud location that the
|
|
274
|
+
service belongs to (templated).
|
|
275
|
+
:param contents: Required. The multi-part content of a message that a user or a program
|
|
276
|
+
gives to the generative model, in order to elicit a specific response.
|
|
277
|
+
:param model: Required. Model, supporting prompts with text-only input,
|
|
278
|
+
including natural language tasks, multi-turn text and code chat,
|
|
279
|
+
and code generation. It can output text and code.
|
|
280
|
+
:param config: Optional. Configuration for Count Tokens.
|
|
281
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
282
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
283
|
+
credentials, or chained list of accounts required to get the access_token
|
|
284
|
+
of the last account in the list, which will be impersonated in the request.
|
|
285
|
+
If set as a string, the account must grant the originating account
|
|
286
|
+
the Service Account Token Creator IAM role.
|
|
287
|
+
If set as a sequence, the identities from the list must grant
|
|
288
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
289
|
+
account from the list granting this role to the originating account (templated).
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
template_fields = ("location", "project_id", "impersonation_chain", "contents", "model", "config")
|
|
293
|
+
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
*,
|
|
297
|
+
project_id: str,
|
|
298
|
+
location: str,
|
|
299
|
+
contents: ContentListUnion | ContentListUnionDict,
|
|
300
|
+
model: str,
|
|
301
|
+
config: CountTokensConfigOrDict | None = None,
|
|
302
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
303
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
304
|
+
**kwargs,
|
|
305
|
+
) -> None:
|
|
306
|
+
super().__init__(**kwargs)
|
|
307
|
+
self.project_id = project_id
|
|
308
|
+
self.location = location
|
|
309
|
+
self.contents = contents
|
|
310
|
+
self.model = model
|
|
311
|
+
self.config = config
|
|
312
|
+
self.gcp_conn_id = gcp_conn_id
|
|
313
|
+
self.impersonation_chain = impersonation_chain
|
|
314
|
+
|
|
315
|
+
def execute(self, context: Context):
|
|
316
|
+
self.hook = GenAIGenerativeModelHook(
|
|
317
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
318
|
+
impersonation_chain=self.impersonation_chain,
|
|
319
|
+
)
|
|
320
|
+
response = self.hook.count_tokens(
|
|
321
|
+
project_id=self.project_id,
|
|
322
|
+
location=self.location,
|
|
323
|
+
contents=self.contents,
|
|
324
|
+
model=self.model,
|
|
325
|
+
config=self.config,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
self.log.info("Total tokens: %s", response.total_tokens)
|
|
329
|
+
context["ti"].xcom_push(key="total_tokens", value=response.total_tokens)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class GenAICreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
333
|
+
"""
|
|
334
|
+
Create CachedContent resource to reduce the cost of requests that contain repeat content with high input token counts.
|
|
335
|
+
|
|
336
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
337
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
338
|
+
:param model: Required. The name of the publisher model to use for cached content.
|
|
339
|
+
:param cached_content_config: Optional. Configuration of the Cached Content.
|
|
340
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
341
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
342
|
+
credentials, or chained list of accounts required to get the access_token
|
|
343
|
+
of the last account in the list, which will be impersonated in the request.
|
|
344
|
+
If set as a string, the account must grant the originating account
|
|
345
|
+
the Service Account Token Creator IAM role.
|
|
346
|
+
If set as a sequence, the identities from the list must grant
|
|
347
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
348
|
+
account from the list granting this role to the originating account (templated).
|
|
349
|
+
"""
|
|
350
|
+
|
|
351
|
+
template_fields = ("location", "project_id", "impersonation_chain", "model", "cached_content_config")
|
|
352
|
+
|
|
353
|
+
def __init__(
|
|
354
|
+
self,
|
|
355
|
+
*,
|
|
356
|
+
project_id: str,
|
|
357
|
+
location: str,
|
|
358
|
+
model: str,
|
|
359
|
+
cached_content_config: CreateCachedContentConfigOrDict | None = None,
|
|
360
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
361
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
362
|
+
**kwargs,
|
|
363
|
+
) -> None:
|
|
364
|
+
super().__init__(**kwargs)
|
|
365
|
+
|
|
366
|
+
self.project_id = project_id
|
|
367
|
+
self.location = location
|
|
368
|
+
self.model = model
|
|
369
|
+
self.cached_content_config = cached_content_config
|
|
370
|
+
self.gcp_conn_id = gcp_conn_id
|
|
371
|
+
self.impersonation_chain = impersonation_chain
|
|
372
|
+
|
|
373
|
+
def execute(self, context: Context):
|
|
374
|
+
self.hook = GenAIGenerativeModelHook(
|
|
375
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
376
|
+
impersonation_chain=self.impersonation_chain,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
cached_content_name = self.hook.create_cached_content(
|
|
380
|
+
project_id=self.project_id,
|
|
381
|
+
location=self.location,
|
|
382
|
+
model=self.model,
|
|
383
|
+
cached_content_config=self.cached_content_config,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
self.log.info("Cached Content Name: %s", cached_content_name)
|
|
387
|
+
context["ti"].xcom_push(key="cached_content", value=cached_content_name)
|
|
388
|
+
|
|
389
|
+
return cached_content_name
|
|
@@ -20,6 +20,7 @@
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
|
+
from functools import cached_property
|
|
23
24
|
from typing import TYPE_CHECKING
|
|
24
25
|
|
|
25
26
|
from airflow.exceptions import AirflowException
|
|
@@ -29,6 +30,7 @@ from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseO
|
|
|
29
30
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
|
30
31
|
|
|
31
32
|
if TYPE_CHECKING:
|
|
33
|
+
from airflow.providers.openlineage.extractors import OperatorLineage
|
|
32
34
|
from airflow.utils.context import Context
|
|
33
35
|
|
|
34
36
|
|
|
@@ -254,6 +256,13 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
|
|
|
254
256
|
self.impersonation_chain = impersonation_chain
|
|
255
257
|
super().__init__(**kwargs)
|
|
256
258
|
|
|
259
|
+
@cached_property
|
|
260
|
+
def hook(self) -> SpannerHook:
|
|
261
|
+
return SpannerHook(
|
|
262
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
263
|
+
impersonation_chain=self.impersonation_chain,
|
|
264
|
+
)
|
|
265
|
+
|
|
257
266
|
def _validate_inputs(self) -> None:
|
|
258
267
|
if self.project_id == "":
|
|
259
268
|
raise AirflowException("The required parameter 'project_id' is empty")
|
|
@@ -265,10 +274,6 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
|
|
|
265
274
|
raise AirflowException("The required parameter 'query' is empty")
|
|
266
275
|
|
|
267
276
|
def execute(self, context: Context):
|
|
268
|
-
hook = SpannerHook(
|
|
269
|
-
gcp_conn_id=self.gcp_conn_id,
|
|
270
|
-
impersonation_chain=self.impersonation_chain,
|
|
271
|
-
)
|
|
272
277
|
if isinstance(self.query, str):
|
|
273
278
|
queries = [x.strip() for x in self.query.split(";")]
|
|
274
279
|
self.sanitize_queries(queries)
|
|
@@ -281,7 +286,7 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
|
|
|
281
286
|
self.database_id,
|
|
282
287
|
)
|
|
283
288
|
self.log.info("Executing queries: %s", queries)
|
|
284
|
-
result_rows_count_per_query = hook.execute_dml(
|
|
289
|
+
result_rows_count_per_query = self.hook.execute_dml(
|
|
285
290
|
project_id=self.project_id,
|
|
286
291
|
instance_id=self.instance_id,
|
|
287
292
|
database_id=self.database_id,
|
|
@@ -291,7 +296,7 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
|
|
|
291
296
|
context=context,
|
|
292
297
|
instance_id=self.instance_id,
|
|
293
298
|
database_id=self.database_id,
|
|
294
|
-
project_id=self.project_id or hook.project_id,
|
|
299
|
+
project_id=self.project_id or self.hook.project_id,
|
|
295
300
|
)
|
|
296
301
|
return result_rows_count_per_query
|
|
297
302
|
|
|
@@ -305,6 +310,17 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
|
|
|
305
310
|
if queries and queries[-1] == "":
|
|
306
311
|
queries.pop()
|
|
307
312
|
|
|
313
|
+
def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
|
|
314
|
+
"""Build a generic OpenLineage facet, aligned with SQL-based operators."""
|
|
315
|
+
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
|
|
316
|
+
|
|
317
|
+
return get_openlineage_facets_with_sql(
|
|
318
|
+
hook=self.hook,
|
|
319
|
+
sql=self.query,
|
|
320
|
+
conn_id=self.gcp_conn_id,
|
|
321
|
+
database=self.database_id,
|
|
322
|
+
)
|
|
323
|
+
|
|
308
324
|
|
|
309
325
|
class SpannerDeployDatabaseInstanceOperator(GoogleCloudBaseOperator):
|
|
310
326
|
"""
|
|
@@ -29,6 +29,7 @@ from google.cloud.aiplatform import datasets
|
|
|
29
29
|
from google.cloud.aiplatform.models import Model
|
|
30
30
|
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
|
|
31
31
|
|
|
32
|
+
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
32
33
|
from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
|
|
33
34
|
from airflow.providers.google.cloud.links.vertex_ai import (
|
|
34
35
|
VertexAIModelLink,
|
|
@@ -36,6 +37,7 @@ from airflow.providers.google.cloud.links.vertex_ai import (
|
|
|
36
37
|
VertexAITrainingPipelinesLink,
|
|
37
38
|
)
|
|
38
39
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
40
|
+
from airflow.providers.google.common.deprecated import deprecated
|
|
39
41
|
|
|
40
42
|
if TYPE_CHECKING:
|
|
41
43
|
from google.api_core.retry import Retry
|
|
@@ -473,6 +475,11 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
|
473
475
|
return result
|
|
474
476
|
|
|
475
477
|
|
|
478
|
+
@deprecated(
|
|
479
|
+
planned_removal_date="March 24, 2026",
|
|
480
|
+
use_instead="airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator",
|
|
481
|
+
category=AirflowProviderDeprecationWarning,
|
|
482
|
+
)
|
|
476
483
|
class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
477
484
|
"""Create Auto ML Video Training job."""
|
|
478
485
|
|
|
@@ -36,6 +36,11 @@ if TYPE_CHECKING:
|
|
|
36
36
|
from airflow.utils.context import Context
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
@deprecated(
|
|
40
|
+
planned_removal_date="January 3, 2026",
|
|
41
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateEmbeddingsOperator",
|
|
42
|
+
category=AirflowProviderDeprecationWarning,
|
|
43
|
+
)
|
|
39
44
|
class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
40
45
|
"""
|
|
41
46
|
Uses the Vertex AI Embeddings API to generate embeddings based on prompt.
|
|
@@ -99,6 +104,11 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
99
104
|
return response
|
|
100
105
|
|
|
101
106
|
|
|
107
|
+
@deprecated(
|
|
108
|
+
planned_removal_date="January 3, 2026",
|
|
109
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
|
|
110
|
+
category=AirflowProviderDeprecationWarning,
|
|
111
|
+
)
|
|
102
112
|
class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
103
113
|
"""
|
|
104
114
|
Use the Vertex AI Gemini Pro foundation model to generate content.
|
|
@@ -178,6 +188,11 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
|
178
188
|
return response
|
|
179
189
|
|
|
180
190
|
|
|
191
|
+
@deprecated(
|
|
192
|
+
planned_removal_date="January 3, 2026",
|
|
193
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAISupervisedFineTuningTrainOperator",
|
|
194
|
+
category=AirflowProviderDeprecationWarning,
|
|
195
|
+
)
|
|
181
196
|
class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
182
197
|
"""
|
|
183
198
|
Use the Supervised Fine Tuning API to create a tuning job.
|
|
@@ -280,6 +295,11 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
280
295
|
return result
|
|
281
296
|
|
|
282
297
|
|
|
298
|
+
@deprecated(
|
|
299
|
+
planned_removal_date="January 3, 2026",
|
|
300
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICountTokensOperator",
|
|
301
|
+
category=AirflowProviderDeprecationWarning,
|
|
302
|
+
)
|
|
283
303
|
class CountTokensOperator(GoogleCloudBaseOperator):
|
|
284
304
|
"""
|
|
285
305
|
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
|
|
@@ -443,6 +463,11 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
|
|
|
443
463
|
return response.summary_metrics
|
|
444
464
|
|
|
445
465
|
|
|
466
|
+
@deprecated(
|
|
467
|
+
planned_removal_date="January 3, 2026",
|
|
468
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICreateCachedContentOperator",
|
|
469
|
+
category=AirflowProviderDeprecationWarning,
|
|
470
|
+
)
|
|
446
471
|
class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
447
472
|
"""
|
|
448
473
|
Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
|
|
@@ -522,6 +547,11 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
|
522
547
|
return cached_content_name
|
|
523
548
|
|
|
524
549
|
|
|
550
|
+
@deprecated(
|
|
551
|
+
planned_removal_date="January 3, 2026",
|
|
552
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
|
|
553
|
+
category=AirflowProviderDeprecationWarning,
|
|
554
|
+
)
|
|
525
555
|
class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
|
|
526
556
|
"""
|
|
527
557
|
Generate a response from CachedContent.
|