apache-airflow-providers-google 16.0.0a1__py3-none-any.whl → 16.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +43 -5
  3. airflow/providers/google/ads/operators/ads.py +1 -1
  4. airflow/providers/google/ads/transfers/ads_to_gcs.py +1 -1
  5. airflow/providers/google/cloud/hooks/bigquery.py +63 -77
  6. airflow/providers/google/cloud/hooks/cloud_sql.py +8 -4
  7. airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +2 -2
  9. airflow/providers/google/cloud/hooks/dataplex.py +1 -1
  10. airflow/providers/google/cloud/hooks/dataprep.py +4 -1
  11. airflow/providers/google/cloud/hooks/gcs.py +5 -5
  12. airflow/providers/google/cloud/hooks/looker.py +10 -1
  13. airflow/providers/google/cloud/hooks/mlengine.py +2 -1
  14. airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
  15. airflow/providers/google/cloud/hooks/spanner.py +2 -2
  16. airflow/providers/google/cloud/hooks/translate.py +1 -1
  17. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -36
  18. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
  19. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +44 -80
  20. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +11 -2
  21. airflow/providers/google/cloud/hooks/vision.py +2 -2
  22. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  23. airflow/providers/google/cloud/links/base.py +75 -11
  24. airflow/providers/google/cloud/links/bigquery.py +0 -47
  25. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  26. airflow/providers/google/cloud/links/bigtable.py +0 -48
  27. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  28. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  29. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  30. airflow/providers/google/cloud/links/cloud_run.py +27 -0
  31. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  32. airflow/providers/google/cloud/links/cloud_storage_transfer.py +16 -43
  33. airflow/providers/google/cloud/links/cloud_tasks.py +6 -25
  34. airflow/providers/google/cloud/links/compute.py +0 -58
  35. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  36. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  37. airflow/providers/google/cloud/links/dataflow.py +0 -34
  38. airflow/providers/google/cloud/links/dataform.py +0 -64
  39. airflow/providers/google/cloud/links/datafusion.py +1 -96
  40. airflow/providers/google/cloud/links/dataplex.py +0 -154
  41. airflow/providers/google/cloud/links/dataprep.py +0 -24
  42. airflow/providers/google/cloud/links/dataproc.py +14 -90
  43. airflow/providers/google/cloud/links/datastore.py +0 -31
  44. airflow/providers/google/cloud/links/kubernetes_engine.py +5 -59
  45. airflow/providers/google/cloud/links/life_sciences.py +0 -19
  46. airflow/providers/google/cloud/links/managed_kafka.py +0 -70
  47. airflow/providers/google/cloud/links/mlengine.py +0 -70
  48. airflow/providers/google/cloud/links/pubsub.py +0 -32
  49. airflow/providers/google/cloud/links/spanner.py +0 -33
  50. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  51. airflow/providers/google/cloud/links/translate.py +16 -186
  52. airflow/providers/google/cloud/links/vertex_ai.py +8 -224
  53. airflow/providers/google/cloud/links/workflows.py +0 -52
  54. airflow/providers/google/cloud/log/gcs_task_handler.py +4 -4
  55. airflow/providers/google/cloud/operators/alloy_db.py +69 -54
  56. airflow/providers/google/cloud/operators/automl.py +16 -14
  57. airflow/providers/google/cloud/operators/bigquery.py +49 -25
  58. airflow/providers/google/cloud/operators/bigquery_dts.py +2 -4
  59. airflow/providers/google/cloud/operators/bigtable.py +35 -6
  60. airflow/providers/google/cloud/operators/cloud_base.py +21 -1
  61. airflow/providers/google/cloud/operators/cloud_build.py +74 -31
  62. airflow/providers/google/cloud/operators/cloud_composer.py +34 -35
  63. airflow/providers/google/cloud/operators/cloud_memorystore.py +68 -42
  64. airflow/providers/google/cloud/operators/cloud_run.py +9 -1
  65. airflow/providers/google/cloud/operators/cloud_sql.py +11 -15
  66. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +0 -2
  67. airflow/providers/google/cloud/operators/compute.py +7 -39
  68. airflow/providers/google/cloud/operators/datacatalog.py +156 -20
  69. airflow/providers/google/cloud/operators/dataflow.py +37 -14
  70. airflow/providers/google/cloud/operators/dataform.py +14 -4
  71. airflow/providers/google/cloud/operators/datafusion.py +4 -12
  72. airflow/providers/google/cloud/operators/dataplex.py +180 -96
  73. airflow/providers/google/cloud/operators/dataprep.py +0 -4
  74. airflow/providers/google/cloud/operators/dataproc.py +10 -16
  75. airflow/providers/google/cloud/operators/dataproc_metastore.py +95 -87
  76. airflow/providers/google/cloud/operators/datastore.py +21 -5
  77. airflow/providers/google/cloud/operators/dlp.py +3 -26
  78. airflow/providers/google/cloud/operators/functions.py +15 -6
  79. airflow/providers/google/cloud/operators/gcs.py +1 -7
  80. airflow/providers/google/cloud/operators/kubernetes_engine.py +53 -92
  81. airflow/providers/google/cloud/operators/life_sciences.py +0 -1
  82. airflow/providers/google/cloud/operators/managed_kafka.py +106 -51
  83. airflow/providers/google/cloud/operators/mlengine.py +0 -1
  84. airflow/providers/google/cloud/operators/pubsub.py +4 -5
  85. airflow/providers/google/cloud/operators/spanner.py +0 -4
  86. airflow/providers/google/cloud/operators/speech_to_text.py +0 -1
  87. airflow/providers/google/cloud/operators/stackdriver.py +0 -8
  88. airflow/providers/google/cloud/operators/tasks.py +0 -11
  89. airflow/providers/google/cloud/operators/text_to_speech.py +0 -1
  90. airflow/providers/google/cloud/operators/translate.py +37 -13
  91. airflow/providers/google/cloud/operators/translate_speech.py +0 -1
  92. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +31 -18
  93. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +28 -8
  94. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +38 -25
  95. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +69 -7
  96. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +42 -8
  97. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +531 -0
  98. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +93 -117
  99. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +10 -8
  100. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +56 -10
  101. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +29 -6
  102. airflow/providers/google/cloud/operators/vertex_ai/ray.py +9 -6
  103. airflow/providers/google/cloud/operators/workflows.py +1 -9
  104. airflow/providers/google/cloud/sensors/bigquery.py +1 -1
  105. airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -1
  106. airflow/providers/google/cloud/sensors/bigtable.py +15 -3
  107. airflow/providers/google/cloud/sensors/cloud_composer.py +6 -1
  108. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -1
  109. airflow/providers/google/cloud/sensors/dataflow.py +3 -3
  110. airflow/providers/google/cloud/sensors/dataform.py +6 -1
  111. airflow/providers/google/cloud/sensors/datafusion.py +6 -1
  112. airflow/providers/google/cloud/sensors/dataplex.py +6 -1
  113. airflow/providers/google/cloud/sensors/dataprep.py +6 -1
  114. airflow/providers/google/cloud/sensors/dataproc.py +6 -1
  115. airflow/providers/google/cloud/sensors/dataproc_metastore.py +6 -1
  116. airflow/providers/google/cloud/sensors/gcs.py +9 -3
  117. airflow/providers/google/cloud/sensors/looker.py +6 -1
  118. airflow/providers/google/cloud/sensors/pubsub.py +8 -3
  119. airflow/providers/google/cloud/sensors/tasks.py +6 -1
  120. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +6 -1
  121. airflow/providers/google/cloud/sensors/workflows.py +6 -1
  122. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +1 -1
  123. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +1 -1
  124. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +10 -7
  125. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -2
  126. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +0 -1
  127. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -1
  128. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  129. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +1 -1
  130. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +2 -2
  131. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
  132. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -1
  133. airflow/providers/google/cloud/transfers/gcs_to_local.py +1 -1
  134. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
  135. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +5 -1
  136. airflow/providers/google/cloud/transfers/gdrive_to_local.py +1 -1
  137. airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
  138. airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -1
  139. airflow/providers/google/cloud/transfers/s3_to_gcs.py +11 -5
  140. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +1 -1
  141. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
  142. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +2 -2
  143. airflow/providers/google/cloud/transfers/sql_to_gcs.py +1 -1
  144. airflow/providers/google/cloud/triggers/bigquery.py +32 -5
  145. airflow/providers/google/cloud/triggers/dataproc.py +62 -10
  146. airflow/providers/google/cloud/utils/field_validator.py +1 -2
  147. airflow/providers/google/common/auth_backend/google_openid.py +2 -1
  148. airflow/providers/google/common/deprecated.py +2 -1
  149. airflow/providers/google/common/hooks/base_google.py +7 -3
  150. airflow/providers/google/common/links/storage.py +0 -22
  151. airflow/providers/google/firebase/operators/firestore.py +1 -1
  152. airflow/providers/google/get_provider_info.py +14 -16
  153. airflow/providers/google/leveldb/hooks/leveldb.py +30 -1
  154. airflow/providers/google/leveldb/operators/leveldb.py +1 -1
  155. airflow/providers/google/marketing_platform/links/analytics_admin.py +3 -6
  156. airflow/providers/google/marketing_platform/operators/analytics_admin.py +0 -1
  157. airflow/providers/google/marketing_platform/operators/campaign_manager.py +4 -4
  158. airflow/providers/google/marketing_platform/operators/display_video.py +6 -6
  159. airflow/providers/google/marketing_platform/operators/search_ads.py +1 -1
  160. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +6 -1
  161. airflow/providers/google/marketing_platform/sensors/display_video.py +6 -1
  162. airflow/providers/google/suite/operators/sheets.py +3 -3
  163. airflow/providers/google/suite/sensors/drive.py +6 -1
  164. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +1 -1
  165. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  166. airflow/providers/google/suite/transfers/local_to_drive.py +1 -1
  167. airflow/providers/google/version_compat.py +28 -0
  168. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/METADATA +35 -35
  169. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/RECORD +171 -170
  170. airflow/providers/google/cloud/links/automl.py +0 -193
  171. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/WHEEL +0 -0
  172. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/entry_points.txt +0 -0
@@ -20,107 +20,21 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  from collections.abc import Sequence
23
- from typing import TYPE_CHECKING
23
+ from typing import TYPE_CHECKING, Any, Literal
24
24
 
25
- from airflow.exceptions import AirflowProviderDeprecationWarning
26
- from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
25
+ from google.api_core import exceptions
26
+
27
+ from airflow.exceptions import AirflowException
28
+ from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
29
+ ExperimentRunHook,
30
+ GenerativeModelHook,
31
+ )
27
32
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
28
- from airflow.providers.google.common.deprecated import deprecated
29
33
 
30
34
  if TYPE_CHECKING:
31
35
  from airflow.utils.context import Context
32
36
 
33
37
 
34
- @deprecated(
35
- planned_removal_date="April 09, 2025",
36
- use_instead="GenerativeModelGenerateContentOperator",
37
- category=AirflowProviderDeprecationWarning,
38
- )
39
- class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
40
- """
41
- Uses the Vertex AI PaLM API to generate natural language text.
42
-
43
- :param project_id: Required. The ID of the Google Cloud project that the
44
- service belongs to (templated).
45
- :param location: Required. The ID of the Google Cloud location that the
46
- service belongs to (templated).
47
- :param prompt: Required. Inputs or queries that a user or a program gives
48
- to the Vertex AI PaLM API, in order to elicit a specific response (templated).
49
- :param pretrained_model: By default uses the pre-trained model `text-bison`,
50
- optimized for performing natural language tasks such as classification,
51
- summarization, extraction, content creation, and ideation.
52
- :param temperature: Temperature controls the degree of randomness in token
53
- selection. Defaults to 0.0.
54
- :param max_output_tokens: Token limit determines the maximum amount of text
55
- output. Defaults to 256.
56
- :param top_p: Tokens are selected from most probable to least until the sum
57
- of their probabilities equals the top_p value. Defaults to 0.8.
58
- :param top_k: A top_k of 1 means the selected token is the most probable
59
- among all tokens. Defaults to 0.4.
60
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
61
- :param impersonation_chain: Optional service account to impersonate using short-term
62
- credentials, or chained list of accounts required to get the access_token
63
- of the last account in the list, which will be impersonated in the request.
64
- If set as a string, the account must grant the originating account
65
- the Service Account Token Creator IAM role.
66
- If set as a sequence, the identities from the list must grant
67
- Service Account Token Creator IAM role to the directly preceding identity, with first
68
- account from the list granting this role to the originating account (templated).
69
- """
70
-
71
- template_fields = ("location", "project_id", "impersonation_chain", "prompt")
72
-
73
- def __init__(
74
- self,
75
- *,
76
- project_id: str,
77
- location: str,
78
- prompt: str,
79
- pretrained_model: str = "text-bison",
80
- temperature: float = 0.0,
81
- max_output_tokens: int = 256,
82
- top_p: float = 0.8,
83
- top_k: int = 40,
84
- gcp_conn_id: str = "google_cloud_default",
85
- impersonation_chain: str | Sequence[str] | None = None,
86
- **kwargs,
87
- ) -> None:
88
- super().__init__(**kwargs)
89
- self.project_id = project_id
90
- self.location = location
91
- self.prompt = prompt
92
- self.pretrained_model = pretrained_model
93
- self.temperature = temperature
94
- self.max_output_tokens = max_output_tokens
95
- self.top_p = top_p
96
- self.top_k = top_k
97
- self.gcp_conn_id = gcp_conn_id
98
- self.impersonation_chain = impersonation_chain
99
-
100
- def execute(self, context: Context):
101
- self.hook = GenerativeModelHook(
102
- gcp_conn_id=self.gcp_conn_id,
103
- impersonation_chain=self.impersonation_chain,
104
- )
105
-
106
- self.log.info("Submitting prompt")
107
- response = self.hook.text_generation_model_predict(
108
- project_id=self.project_id,
109
- location=self.location,
110
- prompt=self.prompt,
111
- pretrained_model=self.pretrained_model,
112
- temperature=self.temperature,
113
- max_output_tokens=self.max_output_tokens,
114
- top_p=self.top_p,
115
- top_k=self.top_k,
116
- )
117
-
118
- self.log.info("Model response: %s", response)
119
- self.xcom_push(context, key="model_response", value=response)
120
-
121
- return response
122
-
123
-
124
38
  class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
125
39
  """
126
40
  Uses the Vertex AI Embeddings API to generate embeddings based on prompt.
@@ -130,9 +44,8 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
130
44
  :param location: Required. The ID of the Google Cloud location that the
131
45
  service belongs to (templated).
132
46
  :param prompt: Required. Inputs or queries that a user or a program gives
133
- to the Vertex AI PaLM API, in order to elicit a specific response (templated).
134
- :param pretrained_model: By default uses the pre-trained model `textembedding-gecko`,
135
- optimized for performing text embeddings.
47
+ to the Vertex AI Generative Model API, in order to elicit a specific response (templated).
48
+ :param pretrained_model: Required. Model, optimized for performing text embeddings.
136
49
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
137
50
  :param impersonation_chain: Optional service account to impersonate using short-term
138
51
  credentials, or chained list of accounts required to get the access_token
@@ -152,7 +65,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
152
65
  project_id: str,
153
66
  location: str,
154
67
  prompt: str,
155
- pretrained_model: str = "textembedding-gecko",
68
+ pretrained_model: str,
156
69
  gcp_conn_id: str = "google_cloud_default",
157
70
  impersonation_chain: str | Sequence[str] | None = None,
158
71
  **kwargs,
@@ -180,7 +93,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
180
93
  )
181
94
 
182
95
  self.log.info("Model response: %s", response)
183
- self.xcom_push(context, key="model_response", value=response)
96
+ context["ti"].xcom_push(key="model_response", value=response)
184
97
 
185
98
  return response
186
99
 
@@ -199,10 +112,9 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
199
112
  :param safety_settings: Optional. Per request settings for blocking unsafe content.
200
113
  :param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
201
114
  :param system_instruction: Optional. An instruction given to the model to guide its behavior.
202
- :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
203
- supporting prompts with text-only input, including natural language
204
- tasks, multi-turn text and code chat, and code generation. It can
205
- output text and code.
115
+ :param pretrained_model: Required. The name of the model to use for content generation,
116
+ which can be a text-only or multimodal model. For example, `gemini-pro` or
117
+ `gemini-pro-vision`.
206
118
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
207
119
  :param impersonation_chain: Optional service account to impersonate using short-term
208
120
  credentials, or chained list of accounts required to get the access_token
@@ -226,7 +138,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
226
138
  generation_config: dict | None = None,
227
139
  safety_settings: dict | None = None,
228
140
  system_instruction: str | None = None,
229
- pretrained_model: str = "gemini-pro",
141
+ pretrained_model: str,
230
142
  gcp_conn_id: str = "google_cloud_default",
231
143
  impersonation_chain: str | Sequence[str] | None = None,
232
144
  **kwargs,
@@ -260,7 +172,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
260
172
  )
261
173
 
262
174
  self.log.info("Model response: %s", response)
263
- self.xcom_push(context, key="model_response", value=response)
175
+ context["ti"].xcom_push(key="model_response", value=response)
264
176
 
265
177
  return response
266
178
 
@@ -310,7 +222,7 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
310
222
  tuned_model_display_name: str | None = None,
311
223
  validation_dataset: str | None = None,
312
224
  epochs: int | None = None,
313
- adapter_size: int | None = None,
225
+ adapter_size: Literal[1, 4, 8, 16] | None = None,
314
226
  learning_rate_multiplier: float | None = None,
315
227
  gcp_conn_id: str = "google_cloud_default",
316
228
  impersonation_chain: str | Sequence[str] | None = None,
@@ -349,8 +261,8 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
349
261
  self.log.info("Tuned Model Name: %s", response.tuned_model_name)
350
262
  self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name)
351
263
 
352
- self.xcom_push(context, key="tuned_model_name", value=response.tuned_model_name)
353
- self.xcom_push(context, key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
264
+ context["ti"].xcom_push(key="tuned_model_name", value=response.tuned_model_name)
265
+ context["ti"].xcom_push(key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
354
266
 
355
267
  result = {
356
268
  "tuned_model_name": response.tuned_model_name,
@@ -370,10 +282,9 @@ class CountTokensOperator(GoogleCloudBaseOperator):
370
282
  service belongs to (templated).
371
283
  :param contents: Required. The multi-part content of a message that a user or a program
372
284
  gives to the generative model, in order to elicit a specific response.
373
- :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
374
- supporting prompts with text-only input, including natural language
375
- tasks, multi-turn text and code chat, and code generation. It can
376
- output text and code.
285
+ :param pretrained_model: Required. Model, supporting prompts with text-only input,
286
+ including natural language tasks, multi-turn text and code chat,
287
+ and code generation. It can output text and code.
377
288
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
378
289
  :param impersonation_chain: Optional service account to impersonate using short-term
379
290
  credentials, or chained list of accounts required to get the access_token
@@ -393,7 +304,7 @@ class CountTokensOperator(GoogleCloudBaseOperator):
393
304
  project_id: str,
394
305
  location: str,
395
306
  contents: list,
396
- pretrained_model: str = "gemini-pro",
307
+ pretrained_model: str,
397
308
  gcp_conn_id: str = "google_cloud_default",
398
309
  impersonation_chain: str | Sequence[str] | None = None,
399
310
  **kwargs,
@@ -421,8 +332,8 @@ class CountTokensOperator(GoogleCloudBaseOperator):
421
332
  self.log.info("Total tokens: %s", response.total_tokens)
422
333
  self.log.info("Total billable characters: %s", response.total_billable_characters)
423
334
 
424
- self.xcom_push(context, key="total_tokens", value=response.total_tokens)
425
- self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters)
335
+ context["ti"].xcom_push(key="total_tokens", value=response.total_tokens)
336
+ context["ti"].xcom_push(key="total_billable_characters", value=response.total_billable_characters)
426
337
 
427
338
 
428
339
  class RunEvaluationOperator(GoogleCloudBaseOperator):
@@ -562,8 +473,8 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
562
473
  project_id: str,
563
474
  location: str,
564
475
  model_name: str,
565
- system_instruction: str | None = None,
566
- contents: list | None = None,
476
+ system_instruction: Any | None = None,
477
+ contents: list[Any] | None = None,
567
478
  ttl_hours: float = 1,
568
479
  display_name: str | None = None,
569
480
  gcp_conn_id: str = "google_cloud_default",
@@ -674,3 +585,68 @@ class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
674
585
  self.log.info("Cached Content Response: %s", cached_content_text)
675
586
 
676
587
  return cached_content_text
588
+
589
+
590
+ class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
591
+ """
592
+ Use the Rapid Evaluation API to evaluate a model.
593
+
594
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
595
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
596
+ :param experiment_name: Required. The name of the evaluation experiment.
597
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
598
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
599
+ :param impersonation_chain: Optional service account to impersonate using short-term
600
+ credentials, or chained list of accounts required to get the access_token
601
+ of the last account in the list, which will be impersonated in the request.
602
+ If set as a string, the account must grant the originating account
603
+ the Service Account Token Creator IAM role.
604
+ If set as a sequence, the identities from the list must grant
605
+ Service Account Token Creator IAM role to the directly preceding identity, with first
606
+ account from the list granting this role to the originating account (templated).
607
+ """
608
+
609
+ template_fields = (
610
+ "location",
611
+ "project_id",
612
+ "impersonation_chain",
613
+ "experiment_name",
614
+ "experiment_run_name",
615
+ )
616
+
617
+ def __init__(
618
+ self,
619
+ *,
620
+ project_id: str,
621
+ location: str,
622
+ experiment_name: str,
623
+ experiment_run_name: str,
624
+ gcp_conn_id: str = "google_cloud_default",
625
+ impersonation_chain: str | Sequence[str] | None = None,
626
+ **kwargs,
627
+ ) -> None:
628
+ super().__init__(**kwargs)
629
+ self.project_id = project_id
630
+ self.location = location
631
+ self.experiment_name = experiment_name
632
+ self.experiment_run_name = experiment_run_name
633
+ self.gcp_conn_id = gcp_conn_id
634
+ self.impersonation_chain = impersonation_chain
635
+
636
+ def execute(self, context: Context) -> None:
637
+ self.hook = ExperimentRunHook(
638
+ gcp_conn_id=self.gcp_conn_id,
639
+ impersonation_chain=self.impersonation_chain,
640
+ )
641
+
642
+ try:
643
+ self.hook.delete_experiment_run(
644
+ project_id=self.project_id,
645
+ location=self.location,
646
+ experiment_name=self.experiment_name,
647
+ experiment_run_name=self.experiment_run_name,
648
+ )
649
+ except exceptions.NotFound:
650
+ raise AirflowException(f"Experiment Run with name {self.experiment_run_name} not found")
651
+
652
+ self.log.info("Deleted experiment run: %s", self.experiment_run_name)
@@ -257,10 +257,8 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
257
257
  hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
258
258
  self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id)
259
259
 
260
- self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
261
- VertexAITrainingLink.persist(
262
- context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
263
- )
260
+ context["ti"].xcom_push(key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
261
+ VertexAITrainingLink.persist(context=context, training_id=hyperparameter_tuning_job_id)
264
262
 
265
263
  if self.deferrable:
266
264
  self.defer(
@@ -355,9 +353,7 @@ class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
355
353
  timeout=self.timeout,
356
354
  metadata=self.metadata,
357
355
  )
358
- VertexAITrainingLink.persist(
359
- context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id
360
- )
356
+ VertexAITrainingLink.persist(context=context, training_id=self.hyperparameter_tuning_job_id)
361
357
  self.log.info("Hyperparameter tuning job was gotten.")
362
358
  return types.HyperparameterTuningJob.to_dict(result)
363
359
  except NotFound:
@@ -487,6 +483,12 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
487
483
  self.gcp_conn_id = gcp_conn_id
488
484
  self.impersonation_chain = impersonation_chain
489
485
 
486
+ @property
487
+ def extra_links_params(self) -> dict[str, Any]:
488
+ return {
489
+ "project_id": self.project_id,
490
+ }
491
+
490
492
  def execute(self, context: Context):
491
493
  hook = HyperparameterTuningJobHook(
492
494
  gcp_conn_id=self.gcp_conn_id,
@@ -503,5 +505,5 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
503
505
  timeout=self.timeout,
504
506
  metadata=self.metadata,
505
507
  )
506
- VertexAIHyperparameterTuningJobListLink.persist(context=context, task_instance=self)
508
+ VertexAIHyperparameterTuningJobListLink.persist(context=context)
507
509
  return [types.HyperparameterTuningJob.to_dict(result) for result in results]
@@ -20,7 +20,7 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  from collections.abc import Sequence
23
- from typing import TYPE_CHECKING
23
+ from typing import TYPE_CHECKING, Any
24
24
 
25
25
  from google.api_core.exceptions import NotFound
26
26
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -161,6 +161,13 @@ class GetModelOperator(GoogleCloudBaseOperator):
161
161
  self.gcp_conn_id = gcp_conn_id
162
162
  self.impersonation_chain = impersonation_chain
163
163
 
164
+ @property
165
+ def extra_links_params(self) -> dict[str, Any]:
166
+ return {
167
+ "region": self.region,
168
+ "project_id": self.project_id,
169
+ }
170
+
164
171
  def execute(self, context: Context):
165
172
  hook = ModelServiceHook(
166
173
  gcp_conn_id=self.gcp_conn_id,
@@ -179,8 +186,8 @@ class GetModelOperator(GoogleCloudBaseOperator):
179
186
  )
180
187
  self.log.info("Model found. Model ID: %s", self.model_id)
181
188
 
182
- self.xcom_push(context, key="model_id", value=self.model_id)
183
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
189
+ context["ti"].xcom_push(key="model_id", value=self.model_id)
190
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
184
191
  return Model.to_dict(model)
185
192
  except NotFound:
186
193
  self.log.info("The Model ID %s does not exist.", self.model_id)
@@ -257,7 +264,12 @@ class ExportModelOperator(GoogleCloudBaseOperator):
257
264
  metadata=self.metadata,
258
265
  )
259
266
  hook.wait_for_operation(timeout=self.timeout, operation=operation)
260
- VertexAIModelExportLink.persist(context=context, task_instance=self)
267
+ VertexAIModelExportLink.persist(
268
+ context=context,
269
+ output_config=self.output_config,
270
+ model_id=self.model_id,
271
+ project_id=self.project_id,
272
+ )
261
273
  self.log.info("Model was exported.")
262
274
  except NotFound:
263
275
  self.log.info("The Model ID %s does not exist.", self.model_id)
@@ -335,6 +347,12 @@ class ListModelsOperator(GoogleCloudBaseOperator):
335
347
  self.gcp_conn_id = gcp_conn_id
336
348
  self.impersonation_chain = impersonation_chain
337
349
 
350
+ @property
351
+ def extra_links_params(self) -> dict[str, Any]:
352
+ return {
353
+ "project_id": self.project_id,
354
+ }
355
+
338
356
  def execute(self, context: Context):
339
357
  hook = ModelServiceHook(
340
358
  gcp_conn_id=self.gcp_conn_id,
@@ -352,7 +370,7 @@ class ListModelsOperator(GoogleCloudBaseOperator):
352
370
  timeout=self.timeout,
353
371
  metadata=self.metadata,
354
372
  )
355
- VertexAIModelListLink.persist(context=context, task_instance=self)
373
+ VertexAIModelListLink.persist(context=context)
356
374
  return [Model.to_dict(result) for result in results]
357
375
 
358
376
 
@@ -407,6 +425,13 @@ class UploadModelOperator(GoogleCloudBaseOperator):
407
425
  self.gcp_conn_id = gcp_conn_id
408
426
  self.impersonation_chain = impersonation_chain
409
427
 
428
+ @property
429
+ def extra_links_params(self) -> dict[str, Any]:
430
+ return {
431
+ "region": self.region,
432
+ "project_id": self.project_id,
433
+ }
434
+
410
435
  def execute(self, context: Context):
411
436
  hook = ModelServiceHook(
412
437
  gcp_conn_id=self.gcp_conn_id,
@@ -428,8 +453,8 @@ class UploadModelOperator(GoogleCloudBaseOperator):
428
453
  model_id = hook.extract_model_id(model_resp)
429
454
  self.log.info("Model was uploaded. Model ID: %s", model_id)
430
455
 
431
- self.xcom_push(context, key="model_id", value=model_id)
432
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
456
+ context["ti"].xcom_push(key="model_id", value=model_id)
457
+ VertexAIModelLink.persist(context=context, model_id=model_id)
433
458
  return model_resp
434
459
 
435
460
 
@@ -553,6 +578,13 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
553
578
  self.gcp_conn_id = gcp_conn_id
554
579
  self.impersonation_chain = impersonation_chain
555
580
 
581
+ @property
582
+ def extra_links_params(self) -> dict[str, Any]:
583
+ return {
584
+ "region": self.region,
585
+ "project_id": self.project_id,
586
+ }
587
+
556
588
  def execute(self, context: Context):
557
589
  hook = ModelServiceHook(
558
590
  gcp_conn_id=self.gcp_conn_id,
@@ -571,7 +603,7 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
571
603
  timeout=self.timeout,
572
604
  metadata=self.metadata,
573
605
  )
574
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
606
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
575
607
  return Model.to_dict(updated_model)
576
608
 
577
609
 
@@ -627,6 +659,13 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
627
659
  self.gcp_conn_id = gcp_conn_id
628
660
  self.impersonation_chain = impersonation_chain
629
661
 
662
+ @property
663
+ def extra_links_params(self) -> dict[str, Any]:
664
+ return {
665
+ "region": self.region,
666
+ "project_id": self.project_id,
667
+ }
668
+
630
669
  def execute(self, context: Context):
631
670
  hook = ModelServiceHook(
632
671
  gcp_conn_id=self.gcp_conn_id,
@@ -645,7 +684,7 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
645
684
  timeout=self.timeout,
646
685
  metadata=self.metadata,
647
686
  )
648
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
687
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
649
688
  return Model.to_dict(updated_model)
650
689
 
651
690
 
@@ -701,6 +740,13 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
701
740
  self.gcp_conn_id = gcp_conn_id
702
741
  self.impersonation_chain = impersonation_chain
703
742
 
743
+ @property
744
+ def extra_links_params(self) -> dict[str, Any]:
745
+ return {
746
+ "region": self.region,
747
+ "project_id": self.project_id,
748
+ }
749
+
704
750
  def execute(self, context: Context):
705
751
  hook = ModelServiceHook(
706
752
  gcp_conn_id=self.gcp_conn_id,
@@ -721,7 +767,7 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
721
767
  timeout=self.timeout,
722
768
  metadata=self.metadata,
723
769
  )
724
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
770
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
725
771
  return Model.to_dict(updated_model)
726
772
 
727
773
 
@@ -112,6 +112,10 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
112
112
  "project_id",
113
113
  "input_artifacts",
114
114
  "impersonation_chain",
115
+ "template_path",
116
+ "pipeline_root",
117
+ "parameter_values",
118
+ "service_account",
115
119
  ]
116
120
  operator_extra_links = (VertexAIPipelineJobLink(),)
117
121
 
@@ -162,6 +166,13 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
162
166
  self.deferrable = deferrable
163
167
  self.poll_interval = poll_interval
164
168
 
169
+ @property
170
+ def extra_links_params(self) -> dict[str, Any]:
171
+ return {
172
+ "region": self.region,
173
+ "project_id": self.project_id,
174
+ }
175
+
165
176
  def execute(self, context: Context):
166
177
  self.log.info("Running Pipeline job")
167
178
  pipeline_job_obj: PipelineJob = self.hook.submit_pipeline_job(
@@ -184,8 +195,8 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
184
195
  )
185
196
  pipeline_job_id = pipeline_job_obj.job_id
186
197
  self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id)
187
- self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id)
188
- VertexAIPipelineJobLink.persist(context=context, task_instance=self, pipeline_id=pipeline_job_id)
198
+ context["ti"].xcom_push(key="pipeline_job_id", value=pipeline_job_id)
199
+ VertexAIPipelineJobLink.persist(context=context, pipeline_id=pipeline_job_id)
189
200
 
190
201
  if self.deferrable:
191
202
  pipeline_job_obj.wait_for_resource_creation()
@@ -276,6 +287,13 @@ class GetPipelineJobOperator(GoogleCloudBaseOperator):
276
287
  self.gcp_conn_id = gcp_conn_id
277
288
  self.impersonation_chain = impersonation_chain
278
289
 
290
+ @property
291
+ def extra_links_params(self) -> dict[str, Any]:
292
+ return {
293
+ "region": self.region,
294
+ "project_id": self.project_id,
295
+ }
296
+
279
297
  def execute(self, context: Context):
280
298
  hook = PipelineJobHook(
281
299
  gcp_conn_id=self.gcp_conn_id,
@@ -292,9 +310,7 @@ class GetPipelineJobOperator(GoogleCloudBaseOperator):
292
310
  timeout=self.timeout,
293
311
  metadata=self.metadata,
294
312
  )
295
- VertexAIPipelineJobLink.persist(
296
- context=context, task_instance=self, pipeline_id=self.pipeline_job_id
297
- )
313
+ VertexAIPipelineJobLink.persist(context=context, pipeline_id=self.pipeline_job_id)
298
314
  self.log.info("Pipeline job was gotten.")
299
315
  return types.PipelineJob.to_dict(result)
300
316
  except NotFound:
@@ -408,6 +424,13 @@ class ListPipelineJobOperator(GoogleCloudBaseOperator):
408
424
  self.gcp_conn_id = gcp_conn_id
409
425
  self.impersonation_chain = impersonation_chain
410
426
 
427
+ @property
428
+ def extra_links_params(self) -> dict[str, Any]:
429
+ return {
430
+ "region": self.region,
431
+ "project_id": self.project_id,
432
+ }
433
+
411
434
  def execute(self, context: Context):
412
435
  hook = PipelineJobHook(
413
436
  gcp_conn_id=self.gcp_conn_id,
@@ -424,7 +447,7 @@ class ListPipelineJobOperator(GoogleCloudBaseOperator):
424
447
  timeout=self.timeout,
425
448
  metadata=self.metadata,
426
449
  )
427
- VertexAIPipelineJobListLink.persist(context=context, task_instance=self)
450
+ VertexAIPipelineJobListLink.persist(context=context)
428
451
  return [types.PipelineJob.to_dict(result) for result in results]
429
452
 
430
453
 
@@ -188,12 +188,13 @@ class CreateRayClusterOperator(RayBaseOperator):
188
188
  labels=self.labels,
189
189
  )
190
190
  cluster_id = self.hook.extract_cluster_id(cluster_path)
191
- self.xcom_push(
192
- context=context,
191
+ context["ti"].xcom_push(
193
192
  key="cluster_id",
194
193
  value=cluster_id,
195
194
  )
196
- VertexAIRayClusterLink.persist(context=context, task_instance=self, cluster_id=cluster_id)
195
+ VertexAIRayClusterLink.persist(
196
+ context=context, location=self.location, cluster_id=cluster_id, project_id=self.project_id
197
+ )
197
198
  self.log.info("Ray cluster was created.")
198
199
  except Exception as error:
199
200
  raise AirflowException(error)
@@ -220,7 +221,7 @@ class ListRayClustersOperator(RayBaseOperator):
220
221
  operator_extra_links = (VertexAIRayClusterListLink(),)
221
222
 
222
223
  def execute(self, context: Context):
223
- VertexAIRayClusterListLink.persist(context=context, task_instance=self)
224
+ VertexAIRayClusterListLink.persist(context=context, project_id=self.project_id)
224
225
  self.log.info("Listing Clusters from location %s.", self.location)
225
226
  try:
226
227
  ray_cluster_list = self.hook.list_ray_clusters(
@@ -268,8 +269,9 @@ class GetRayClusterOperator(RayBaseOperator):
268
269
  def execute(self, context: Context):
269
270
  VertexAIRayClusterLink.persist(
270
271
  context=context,
271
- task_instance=self,
272
+ location=self.location,
272
273
  cluster_id=self.cluster_id,
274
+ project_id=self.project_id,
273
275
  )
274
276
  self.log.info("Getting Cluster: %s", self.cluster_id)
275
277
  try:
@@ -325,8 +327,9 @@ class UpdateRayClusterOperator(RayBaseOperator):
325
327
  def execute(self, context: Context):
326
328
  VertexAIRayClusterLink.persist(
327
329
  context=context,
328
- task_instance=self,
330
+ location=self.location,
329
331
  cluster_id=self.cluster_id,
332
+ project_id=self.project_id,
330
333
  )
331
334
  self.log.info("Updating a Ray cluster.")
332
335
  try: