apache-airflow-providers-google 18.0.0__py3-none-any.whl → 18.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-google might be problematic. Click here for more details.

Files changed (72) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +5 -5
  3. airflow/providers/google/assets/gcs.py +1 -11
  4. airflow/providers/google/cloud/bundles/__init__.py +16 -0
  5. airflow/providers/google/cloud/bundles/gcs.py +161 -0
  6. airflow/providers/google/cloud/hooks/bigquery.py +45 -42
  7. airflow/providers/google/cloud/hooks/cloud_composer.py +131 -1
  8. airflow/providers/google/cloud/hooks/cloud_sql.py +88 -13
  9. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +16 -0
  10. airflow/providers/google/cloud/hooks/dataflow.py +1 -1
  11. airflow/providers/google/cloud/hooks/dataprep.py +1 -1
  12. airflow/providers/google/cloud/hooks/dataproc.py +3 -0
  13. airflow/providers/google/cloud/hooks/gcs.py +107 -3
  14. airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
  15. airflow/providers/google/cloud/hooks/looker.py +1 -1
  16. airflow/providers/google/cloud/hooks/spanner.py +45 -0
  17. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +30 -0
  18. airflow/providers/google/cloud/links/base.py +11 -11
  19. airflow/providers/google/cloud/links/dataproc.py +2 -10
  20. airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
  21. airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
  22. airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
  23. airflow/providers/google/cloud/openlineage/facets.py +102 -1
  24. airflow/providers/google/cloud/openlineage/mixins.py +3 -1
  25. airflow/providers/google/cloud/operators/bigquery.py +2 -9
  26. airflow/providers/google/cloud/operators/cloud_run.py +2 -1
  27. airflow/providers/google/cloud/operators/cloud_sql.py +1 -1
  28. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +89 -6
  29. airflow/providers/google/cloud/operators/datafusion.py +36 -7
  30. airflow/providers/google/cloud/operators/gen_ai.py +389 -0
  31. airflow/providers/google/cloud/operators/spanner.py +22 -6
  32. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +7 -0
  33. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +30 -0
  34. airflow/providers/google/cloud/operators/workflows.py +17 -6
  35. airflow/providers/google/cloud/sensors/bigquery.py +1 -1
  36. airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -6
  37. airflow/providers/google/cloud/sensors/bigtable.py +1 -6
  38. airflow/providers/google/cloud/sensors/cloud_composer.py +65 -31
  39. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +1 -6
  40. airflow/providers/google/cloud/sensors/dataflow.py +1 -1
  41. airflow/providers/google/cloud/sensors/dataform.py +1 -6
  42. airflow/providers/google/cloud/sensors/datafusion.py +1 -6
  43. airflow/providers/google/cloud/sensors/dataplex.py +1 -6
  44. airflow/providers/google/cloud/sensors/dataprep.py +1 -6
  45. airflow/providers/google/cloud/sensors/dataproc.py +1 -6
  46. airflow/providers/google/cloud/sensors/dataproc_metastore.py +1 -6
  47. airflow/providers/google/cloud/sensors/gcs.py +1 -7
  48. airflow/providers/google/cloud/sensors/looker.py +1 -6
  49. airflow/providers/google/cloud/sensors/pubsub.py +1 -6
  50. airflow/providers/google/cloud/sensors/tasks.py +1 -6
  51. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +1 -6
  52. airflow/providers/google/cloud/sensors/workflows.py +1 -6
  53. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
  54. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +2 -1
  55. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +11 -2
  56. airflow/providers/google/cloud/triggers/bigquery.py +15 -3
  57. airflow/providers/google/cloud/triggers/cloud_composer.py +51 -21
  58. airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
  59. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +90 -0
  60. airflow/providers/google/cloud/triggers/pubsub.py +14 -18
  61. airflow/providers/google/common/hooks/base_google.py +1 -1
  62. airflow/providers/google/get_provider_info.py +15 -0
  63. airflow/providers/google/leveldb/hooks/leveldb.py +1 -1
  64. airflow/providers/google/marketing_platform/links/analytics_admin.py +2 -8
  65. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +1 -6
  66. airflow/providers/google/marketing_platform/sensors/display_video.py +1 -6
  67. airflow/providers/google/suite/sensors/drive.py +1 -6
  68. airflow/providers/google/version_compat.py +0 -20
  69. {apache_airflow_providers_google-18.0.0.dist-info → apache_airflow_providers_google-18.1.0rc1.dist-info}/METADATA +15 -15
  70. {apache_airflow_providers_google-18.0.0.dist-info → apache_airflow_providers_google-18.1.0rc1.dist-info}/RECORD +72 -65
  71. {apache_airflow_providers_google-18.0.0.dist-info → apache_airflow_providers_google-18.1.0rc1.dist-info}/WHEEL +0 -0
  72. {apache_airflow_providers_google-18.0.0.dist-info → apache_airflow_providers_google-18.1.0rc1.dist-info}/entry_points.txt +0 -0
@@ -40,6 +40,7 @@ from airflow.providers.google.cloud.utils.helpers import resource_path_to_dict
40
40
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
41
41
 
42
42
  if TYPE_CHECKING:
43
+ from airflow.providers.openlineage.extractors import OperatorLineage
43
44
  from airflow.utils.context import Context
44
45
 
45
46
 
@@ -777,6 +778,7 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
777
778
  self.pipeline_timeout = pipeline_timeout
778
779
  self.deferrable = deferrable
779
780
  self.poll_interval = poll_interval
781
+ self.pipeline_id: str | None = None
780
782
 
781
783
  if success_states:
782
784
  self.success_states = success_states
@@ -796,14 +798,14 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
796
798
  project_id=self.project_id,
797
799
  )
798
800
  api_url = instance["apiEndpoint"]
799
- pipeline_id = hook.start_pipeline(
801
+ self.pipeline_id = hook.start_pipeline(
800
802
  pipeline_name=self.pipeline_name,
801
803
  pipeline_type=self.pipeline_type,
802
804
  instance_url=api_url,
803
805
  namespace=self.namespace,
804
806
  runtime_args=self.runtime_args,
805
807
  )
806
- self.log.info("Pipeline %s submitted successfully.", pipeline_id)
808
+ self.log.info("Pipeline %s submitted successfully.", self.pipeline_id)
807
809
 
808
810
  DataFusionPipelineLink.persist(
809
811
  context=context,
@@ -824,7 +826,7 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
824
826
  namespace=self.namespace,
825
827
  pipeline_name=self.pipeline_name,
826
828
  pipeline_type=self.pipeline_type.value,
827
- pipeline_id=pipeline_id,
829
+ pipeline_id=self.pipeline_id,
828
830
  poll_interval=self.poll_interval,
829
831
  gcp_conn_id=self.gcp_conn_id,
830
832
  impersonation_chain=self.impersonation_chain,
@@ -834,19 +836,21 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
834
836
  else:
835
837
  if not self.asynchronous:
836
838
  # when NOT using asynchronous mode it will just wait for pipeline to finish and print message
837
- self.log.info("Waiting when pipeline %s will be in one of the success states", pipeline_id)
839
+ self.log.info(
840
+ "Waiting when pipeline %s will be in one of the success states", self.pipeline_id
841
+ )
838
842
  hook.wait_for_pipeline_state(
839
843
  success_states=self.success_states,
840
- pipeline_id=pipeline_id,
844
+ pipeline_id=self.pipeline_id,
841
845
  pipeline_name=self.pipeline_name,
842
846
  pipeline_type=self.pipeline_type,
843
847
  namespace=self.namespace,
844
848
  instance_url=api_url,
845
849
  timeout=self.pipeline_timeout,
846
850
  )
847
- self.log.info("Pipeline %s discovered success state.", pipeline_id)
851
+ self.log.info("Pipeline %s discovered success state.", self.pipeline_id)
848
852
  # otherwise, return pipeline_id so that sensor can use it later to check the pipeline state
849
- return pipeline_id
853
+ return self.pipeline_id
850
854
 
851
855
  def execute_complete(self, context: Context, event: dict[str, Any]):
852
856
  """
@@ -863,6 +867,31 @@ class CloudDataFusionStartPipelineOperator(GoogleCloudBaseOperator):
863
867
  )
864
868
  return event["pipeline_id"]
865
869
 
870
+ def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
871
+ """Build and return OpenLineage facets and datasets for the completed pipeline start."""
872
+ from airflow.providers.common.compat.openlineage.facet import Dataset
873
+ from airflow.providers.google.cloud.openlineage.facets import DataFusionRunFacet
874
+ from airflow.providers.openlineage.extractors import OperatorLineage
875
+
876
+ pipeline_resource = f"{self.project_id}:{self.location}:{self.instance_name}:{self.pipeline_name}"
877
+
878
+ inputs = [Dataset(namespace="datafusion", name=pipeline_resource)]
879
+
880
+ if self.pipeline_id:
881
+ output_name = f"{pipeline_resource}:{self.pipeline_id}"
882
+ else:
883
+ output_name = f"{pipeline_resource}:unknown"
884
+ outputs = [Dataset(namespace="datafusion", name=output_name)]
885
+
886
+ run_facets = {
887
+ "dataFusionRun": DataFusionRunFacet(
888
+ runId=self.pipeline_id,
889
+ runtimeArgs=self.runtime_args,
890
+ )
891
+ }
892
+
893
+ return OperatorLineage(inputs=inputs, outputs=outputs, run_facets=run_facets, job_facets={})
894
+
866
895
 
867
896
  class CloudDataFusionStopPipelineOperator(GoogleCloudBaseOperator):
868
897
  """
@@ -0,0 +1,389 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ """This module contains Google Gen AI operators."""
19
+
20
+ from __future__ import annotations
21
+
22
+ from collections.abc import Sequence
23
+ from typing import TYPE_CHECKING, Any
24
+
25
+ from airflow.providers.google.cloud.hooks.gen_ai import (
26
+ GenAIGenerativeModelHook,
27
+ )
28
+ from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
29
+
30
+ if TYPE_CHECKING:
31
+ from google.genai.types import (
32
+ ContentListUnion,
33
+ ContentListUnionDict,
34
+ CountTokensConfigOrDict,
35
+ CreateCachedContentConfigOrDict,
36
+ CreateTuningJobConfigOrDict,
37
+ EmbedContentConfigOrDict,
38
+ GenerateContentConfig,
39
+ TuningDatasetOrDict,
40
+ )
41
+
42
+ from airflow.utils.context import Context
43
+
44
+
45
+ class GenAIGenerateEmbeddingsOperator(GoogleCloudBaseOperator):
46
+ """
47
+ Uses the Gemini AI Embeddings API to generate embeddings for words, phrases, sentences, and code.
48
+
49
+ :param project_id: Required. The ID of the Google Cloud project that the
50
+ service belongs to (templated).
51
+ :param location: Required. The ID of the Google Cloud location that the
52
+ service belongs to (templated).
53
+ :param model: Required. The name of the model to use for content generation,
54
+ which can be a text-only or multimodal model. For example, `gemini-pro` or
55
+ `gemini-pro-vision`.
56
+ :param contents: Optional. The contents to use for embedding.
57
+ :param config: Optional. Configuration for embeddings.
58
+ :param gcp_conn_id: Optional. The connection ID to use connecting to Google Cloud.
59
+ :param impersonation_chain: Optional. Service account to impersonate using short-term
60
+ credentials, or chained list of accounts required to get the access_token
61
+ of the last account in the list, which will be impersonated in the request.
62
+ If set as a string, the account must grant the originating account
63
+ the Service Account Token Creator IAM role.
64
+ If set as a sequence, the identities from the list must grant
65
+ Service Account Token Creator IAM role to the directly preceding identity, with first
66
+ account from the list granting this role to the originating account (templated).
67
+ """
68
+
69
+ template_fields = ("location", "project_id", "impersonation_chain", "contents", "model", "config")
70
+
71
+ def __init__(
72
+ self,
73
+ *,
74
+ project_id: str,
75
+ location: str,
76
+ model: str,
77
+ contents: ContentListUnion | ContentListUnionDict | list[str],
78
+ config: EmbedContentConfigOrDict | None = None,
79
+ gcp_conn_id: str = "google_cloud_default",
80
+ impersonation_chain: str | Sequence[str] | None = None,
81
+ **kwargs,
82
+ ) -> None:
83
+ super().__init__(**kwargs)
84
+ self.project_id = project_id
85
+ self.location = location
86
+ self.contents = contents
87
+ self.config = config
88
+ self.model = model
89
+ self.gcp_conn_id = gcp_conn_id
90
+ self.impersonation_chain = impersonation_chain
91
+
92
+ def execute(self, context: Context):
93
+ self.hook = GenAIGenerativeModelHook(
94
+ gcp_conn_id=self.gcp_conn_id,
95
+ impersonation_chain=self.impersonation_chain,
96
+ )
97
+
98
+ self.log.info("Generating text embeddings...")
99
+ response = self.hook.embed_content(
100
+ project_id=self.project_id,
101
+ location=self.location,
102
+ contents=self.contents,
103
+ model=self.model,
104
+ config=self.config,
105
+ )
106
+
107
+ self.log.info("Model response: %s", response)
108
+ context["ti"].xcom_push(key="model_response", value=response)
109
+
110
+ return response
111
+
112
+
113
+ class GenAIGenerateContentOperator(GoogleCloudBaseOperator):
114
+ """
115
+ Generate a model response based on given configuration. Input capabilities differ between models, including tuned models.
116
+
117
+ :param project_id: Required. The ID of the Google Cloud project that the
118
+ service belongs to (templated).
119
+ :param location: Required. The ID of the Google Cloud location that the
120
+ service belongs to (templated).
121
+ :param model: Required. The name of the model to use for content generation,
122
+ which can be a text-only or multimodal model. For example, `gemini-pro` or
123
+ `gemini-pro-vision`.
124
+ :param contents: Required. The multi-part content of a message that a user or a program
125
+ gives to the generative model, in order to elicit a specific response.
126
+ :param generation_config: Optional. Generation configuration settings.
127
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
128
+ :param impersonation_chain: Optional service account to impersonate using short-term
129
+ credentials, or chained list of accounts required to get the access_token
130
+ of the last account in the list, which will be impersonated in the request.
131
+ If set as a string, the account must grant the originating account
132
+ the Service Account Token Creator IAM role.
133
+ If set as a sequence, the identities from the list must grant
134
+ Service Account Token Creator IAM role to the directly preceding identity, with first
135
+ account from the list granting this role to the originating account (templated).
136
+ """
137
+
138
+ template_fields = (
139
+ "generation_config",
140
+ "location",
141
+ "project_id",
142
+ "impersonation_chain",
143
+ "contents",
144
+ "model",
145
+ )
146
+
147
+ def __init__(
148
+ self,
149
+ *,
150
+ project_id: str,
151
+ location: str,
152
+ contents: ContentListUnionDict,
153
+ model: str,
154
+ generation_config: GenerateContentConfig | dict[str, Any] | None = None,
155
+ gcp_conn_id: str = "google_cloud_default",
156
+ impersonation_chain: str | Sequence[str] | None = None,
157
+ **kwargs,
158
+ ) -> None:
159
+ super().__init__(**kwargs)
160
+ self.project_id = project_id
161
+ self.location = location
162
+ self.contents = contents
163
+ self.generation_config = generation_config
164
+ self.model = model
165
+ self.gcp_conn_id = gcp_conn_id
166
+ self.impersonation_chain = impersonation_chain
167
+
168
+ def execute(self, context: Context):
169
+ self.hook = GenAIGenerativeModelHook(
170
+ gcp_conn_id=self.gcp_conn_id,
171
+ impersonation_chain=self.impersonation_chain,
172
+ )
173
+ response = self.hook.generate_content(
174
+ project_id=self.project_id,
175
+ location=self.location,
176
+ model=self.model,
177
+ contents=self.contents,
178
+ generation_config=self.generation_config,
179
+ )
180
+
181
+ self.log.info("Created Content: %s", response)
182
+ context["ti"].xcom_push(key="model_response", value=response)
183
+
184
+ return response
185
+
186
+
187
+ class GenAISupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
188
+ """
189
+ Create a tuning job to adapt model behavior with a labeled dataset.
190
+
191
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
192
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
193
+ :param source_model: Required. A pre-trained model optimized for performing natural
194
+ language tasks such as classification, summarization, extraction, content
195
+ creation, and ideation.
196
+ :param training_dataset: Required. Cloud Storage URI of your training dataset. The dataset
197
+ must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
198
+ :param tuning_job_config: Optional. Configuration of the Tuning job to be created.
199
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
200
+ :param impersonation_chain: Optional service account to impersonate using short-term
201
+ credentials, or chained list of accounts required to get the access_token
202
+ of the last account in the list, which will be impersonated in the request.
203
+ If set as a string, the account must grant the originating account
204
+ the Service Account Token Creator IAM role.
205
+ If set as a sequence, the identities from the list must grant
206
+ Service Account Token Creator IAM role to the directly preceding identity, with first
207
+ account from the list granting this role to the originating account (templated).
208
+ """
209
+
210
+ template_fields = (
211
+ "location",
212
+ "project_id",
213
+ "impersonation_chain",
214
+ "training_dataset",
215
+ "tuning_job_config",
216
+ "source_model",
217
+ )
218
+
219
+ def __init__(
220
+ self,
221
+ *,
222
+ project_id: str,
223
+ location: str,
224
+ source_model: str,
225
+ training_dataset: TuningDatasetOrDict,
226
+ tuning_job_config: CreateTuningJobConfigOrDict | dict[str, Any] | None = None,
227
+ gcp_conn_id: str = "google_cloud_default",
228
+ impersonation_chain: str | Sequence[str] | None = None,
229
+ **kwargs,
230
+ ) -> None:
231
+ super().__init__(**kwargs)
232
+ self.project_id = project_id
233
+ self.location = location
234
+ self.source_model = source_model
235
+ self.training_dataset = training_dataset
236
+ self.tuning_job_config = tuning_job_config
237
+ self.gcp_conn_id = gcp_conn_id
238
+ self.impersonation_chain = impersonation_chain
239
+
240
+ def execute(self, context: Context):
241
+ self.hook = GenAIGenerativeModelHook(
242
+ gcp_conn_id=self.gcp_conn_id,
243
+ impersonation_chain=self.impersonation_chain,
244
+ )
245
+ response = self.hook.supervised_fine_tuning_train(
246
+ project_id=self.project_id,
247
+ location=self.location,
248
+ source_model=self.source_model,
249
+ training_dataset=self.training_dataset,
250
+ tuning_job_config=self.tuning_job_config,
251
+ )
252
+
253
+ self.log.info("Tuned Model Name: %s", response.tuned_model.model) # type: ignore[union-attr,arg-type]
254
+ self.log.info("Tuned Model EndpointName: %s", response.tuned_model.endpoint) # type: ignore[union-attr,arg-type]
255
+
256
+ context["ti"].xcom_push(key="tuned_model_name", value=response.tuned_model.model) # type: ignore[union-attr,arg-type]
257
+ context["ti"].xcom_push(key="tuned_model_endpoint_name", value=response.tuned_model.endpoint) # type: ignore[union-attr,arg-type]
258
+
259
+ result = {
260
+ "tuned_model_name": response.tuned_model.model, # type: ignore[union-attr,arg-type]
261
+ "tuned_model_endpoint_name": response.tuned_model.endpoint, # type: ignore[union-attr,arg-type]
262
+ }
263
+
264
+ return result
265
+
266
+
267
+ class GenAICountTokensOperator(GoogleCloudBaseOperator):
268
+ """
269
+ Use Count Tokens API to calculate the number of input tokens before sending a request to Gemini API.
270
+
271
+ :param project_id: Required. The ID of the Google Cloud project that the
272
+ service belongs to (templated).
273
+ :param location: Required. The ID of the Google Cloud location that the
274
+ service belongs to (templated).
275
+ :param contents: Required. The multi-part content of a message that a user or a program
276
+ gives to the generative model, in order to elicit a specific response.
277
+ :param model: Required. Model, supporting prompts with text-only input,
278
+ including natural language tasks, multi-turn text and code chat,
279
+ and code generation. It can output text and code.
280
+ :param config: Optional. Configuration for Count Tokens.
281
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
282
+ :param impersonation_chain: Optional service account to impersonate using short-term
283
+ credentials, or chained list of accounts required to get the access_token
284
+ of the last account in the list, which will be impersonated in the request.
285
+ If set as a string, the account must grant the originating account
286
+ the Service Account Token Creator IAM role.
287
+ If set as a sequence, the identities from the list must grant
288
+ Service Account Token Creator IAM role to the directly preceding identity, with first
289
+ account from the list granting this role to the originating account (templated).
290
+ """
291
+
292
+ template_fields = ("location", "project_id", "impersonation_chain", "contents", "model", "config")
293
+
294
+ def __init__(
295
+ self,
296
+ *,
297
+ project_id: str,
298
+ location: str,
299
+ contents: ContentListUnion | ContentListUnionDict,
300
+ model: str,
301
+ config: CountTokensConfigOrDict | None = None,
302
+ gcp_conn_id: str = "google_cloud_default",
303
+ impersonation_chain: str | Sequence[str] | None = None,
304
+ **kwargs,
305
+ ) -> None:
306
+ super().__init__(**kwargs)
307
+ self.project_id = project_id
308
+ self.location = location
309
+ self.contents = contents
310
+ self.model = model
311
+ self.config = config
312
+ self.gcp_conn_id = gcp_conn_id
313
+ self.impersonation_chain = impersonation_chain
314
+
315
+ def execute(self, context: Context):
316
+ self.hook = GenAIGenerativeModelHook(
317
+ gcp_conn_id=self.gcp_conn_id,
318
+ impersonation_chain=self.impersonation_chain,
319
+ )
320
+ response = self.hook.count_tokens(
321
+ project_id=self.project_id,
322
+ location=self.location,
323
+ contents=self.contents,
324
+ model=self.model,
325
+ config=self.config,
326
+ )
327
+
328
+ self.log.info("Total tokens: %s", response.total_tokens)
329
+ context["ti"].xcom_push(key="total_tokens", value=response.total_tokens)
330
+
331
+
332
+ class GenAICreateCachedContentOperator(GoogleCloudBaseOperator):
333
+ """
334
+ Create CachedContent resource to reduce the cost of requests that contain repeat content with high input token counts.
335
+
336
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
337
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
338
+ :param model: Required. The name of the publisher model to use for cached content.
339
+ :param cached_content_config: Optional. Configuration of the Cached Content.
340
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
341
+ :param impersonation_chain: Optional service account to impersonate using short-term
342
+ credentials, or chained list of accounts required to get the access_token
343
+ of the last account in the list, which will be impersonated in the request.
344
+ If set as a string, the account must grant the originating account
345
+ the Service Account Token Creator IAM role.
346
+ If set as a sequence, the identities from the list must grant
347
+ Service Account Token Creator IAM role to the directly preceding identity, with first
348
+ account from the list granting this role to the originating account (templated).
349
+ """
350
+
351
+ template_fields = ("location", "project_id", "impersonation_chain", "model", "cached_content_config")
352
+
353
+ def __init__(
354
+ self,
355
+ *,
356
+ project_id: str,
357
+ location: str,
358
+ model: str,
359
+ cached_content_config: CreateCachedContentConfigOrDict | None = None,
360
+ gcp_conn_id: str = "google_cloud_default",
361
+ impersonation_chain: str | Sequence[str] | None = None,
362
+ **kwargs,
363
+ ) -> None:
364
+ super().__init__(**kwargs)
365
+
366
+ self.project_id = project_id
367
+ self.location = location
368
+ self.model = model
369
+ self.cached_content_config = cached_content_config
370
+ self.gcp_conn_id = gcp_conn_id
371
+ self.impersonation_chain = impersonation_chain
372
+
373
+ def execute(self, context: Context):
374
+ self.hook = GenAIGenerativeModelHook(
375
+ gcp_conn_id=self.gcp_conn_id,
376
+ impersonation_chain=self.impersonation_chain,
377
+ )
378
+
379
+ cached_content_name = self.hook.create_cached_content(
380
+ project_id=self.project_id,
381
+ location=self.location,
382
+ model=self.model,
383
+ cached_content_config=self.cached_content_config,
384
+ )
385
+
386
+ self.log.info("Cached Content Name: %s", cached_content_name)
387
+ context["ti"].xcom_push(key="cached_content", value=cached_content_name)
388
+
389
+ return cached_content_name
@@ -20,6 +20,7 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  from collections.abc import Sequence
23
+ from functools import cached_property
23
24
  from typing import TYPE_CHECKING
24
25
 
25
26
  from airflow.exceptions import AirflowException
@@ -29,6 +30,7 @@ from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseO
29
30
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
30
31
 
31
32
  if TYPE_CHECKING:
33
+ from airflow.providers.openlineage.extractors import OperatorLineage
32
34
  from airflow.utils.context import Context
33
35
 
34
36
 
@@ -254,6 +256,13 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
254
256
  self.impersonation_chain = impersonation_chain
255
257
  super().__init__(**kwargs)
256
258
 
259
+ @cached_property
260
+ def hook(self) -> SpannerHook:
261
+ return SpannerHook(
262
+ gcp_conn_id=self.gcp_conn_id,
263
+ impersonation_chain=self.impersonation_chain,
264
+ )
265
+
257
266
  def _validate_inputs(self) -> None:
258
267
  if self.project_id == "":
259
268
  raise AirflowException("The required parameter 'project_id' is empty")
@@ -265,10 +274,6 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
265
274
  raise AirflowException("The required parameter 'query' is empty")
266
275
 
267
276
  def execute(self, context: Context):
268
- hook = SpannerHook(
269
- gcp_conn_id=self.gcp_conn_id,
270
- impersonation_chain=self.impersonation_chain,
271
- )
272
277
  if isinstance(self.query, str):
273
278
  queries = [x.strip() for x in self.query.split(";")]
274
279
  self.sanitize_queries(queries)
@@ -281,7 +286,7 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
281
286
  self.database_id,
282
287
  )
283
288
  self.log.info("Executing queries: %s", queries)
284
- result_rows_count_per_query = hook.execute_dml(
289
+ result_rows_count_per_query = self.hook.execute_dml(
285
290
  project_id=self.project_id,
286
291
  instance_id=self.instance_id,
287
292
  database_id=self.database_id,
@@ -291,7 +296,7 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
291
296
  context=context,
292
297
  instance_id=self.instance_id,
293
298
  database_id=self.database_id,
294
- project_id=self.project_id or hook.project_id,
299
+ project_id=self.project_id or self.hook.project_id,
295
300
  )
296
301
  return result_rows_count_per_query
297
302
 
@@ -305,6 +310,17 @@ class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
305
310
  if queries and queries[-1] == "":
306
311
  queries.pop()
307
312
 
313
+ def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
314
+ """Build a generic OpenLineage facet, aligned with SQL-based operators."""
315
+ from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
316
+
317
+ return get_openlineage_facets_with_sql(
318
+ hook=self.hook,
319
+ sql=self.query,
320
+ conn_id=self.gcp_conn_id,
321
+ database=self.database_id,
322
+ )
323
+
308
324
 
309
325
  class SpannerDeployDatabaseInstanceOperator(GoogleCloudBaseOperator):
310
326
  """
@@ -29,6 +29,7 @@ from google.cloud.aiplatform import datasets
29
29
  from google.cloud.aiplatform.models import Model
30
30
  from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
31
31
 
32
+ from airflow.exceptions import AirflowProviderDeprecationWarning
32
33
  from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
33
34
  from airflow.providers.google.cloud.links.vertex_ai import (
34
35
  VertexAIModelLink,
@@ -36,6 +37,7 @@ from airflow.providers.google.cloud.links.vertex_ai import (
36
37
  VertexAITrainingPipelinesLink,
37
38
  )
38
39
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
40
+ from airflow.providers.google.common.deprecated import deprecated
39
41
 
40
42
  if TYPE_CHECKING:
41
43
  from google.api_core.retry import Retry
@@ -473,6 +475,11 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
473
475
  return result
474
476
 
475
477
 
478
+ @deprecated(
479
+ planned_removal_date="March 24, 2026",
480
+ use_instead="airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator",
481
+ category=AirflowProviderDeprecationWarning,
482
+ )
476
483
  class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
477
484
  """Create Auto ML Video Training job."""
478
485
 
@@ -36,6 +36,11 @@ if TYPE_CHECKING:
36
36
  from airflow.utils.context import Context
37
37
 
38
38
 
39
+ @deprecated(
40
+ planned_removal_date="January 3, 2026",
41
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateEmbeddingsOperator",
42
+ category=AirflowProviderDeprecationWarning,
43
+ )
39
44
  class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
40
45
  """
41
46
  Uses the Vertex AI Embeddings API to generate embeddings based on prompt.
@@ -99,6 +104,11 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
99
104
  return response
100
105
 
101
106
 
107
+ @deprecated(
108
+ planned_removal_date="January 3, 2026",
109
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
110
+ category=AirflowProviderDeprecationWarning,
111
+ )
102
112
  class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
103
113
  """
104
114
  Use the Vertex AI Gemini Pro foundation model to generate content.
@@ -178,6 +188,11 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
178
188
  return response
179
189
 
180
190
 
191
+ @deprecated(
192
+ planned_removal_date="January 3, 2026",
193
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAISupervisedFineTuningTrainOperator",
194
+ category=AirflowProviderDeprecationWarning,
195
+ )
181
196
  class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
182
197
  """
183
198
  Use the Supervised Fine Tuning API to create a tuning job.
@@ -280,6 +295,11 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
280
295
  return result
281
296
 
282
297
 
298
+ @deprecated(
299
+ planned_removal_date="January 3, 2026",
300
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICountTokensOperator",
301
+ category=AirflowProviderDeprecationWarning,
302
+ )
283
303
  class CountTokensOperator(GoogleCloudBaseOperator):
284
304
  """
285
305
  Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
@@ -443,6 +463,11 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
443
463
  return response.summary_metrics
444
464
 
445
465
 
466
+ @deprecated(
467
+ planned_removal_date="January 3, 2026",
468
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICreateCachedContentOperator",
469
+ category=AirflowProviderDeprecationWarning,
470
+ )
446
471
  class CreateCachedContentOperator(GoogleCloudBaseOperator):
447
472
  """
448
473
  Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
@@ -522,6 +547,11 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
522
547
  return cached_content_name
523
548
 
524
549
 
550
+ @deprecated(
551
+ planned_removal_date="January 3, 2026",
552
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
553
+ category=AirflowProviderDeprecationWarning,
554
+ )
525
555
  class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
526
556
  """
527
557
  Generate a response from CachedContent.