apache-airflow-providers-google 16.0.0a1__py3-none-any.whl → 16.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +43 -5
  3. airflow/providers/google/ads/operators/ads.py +1 -1
  4. airflow/providers/google/ads/transfers/ads_to_gcs.py +1 -1
  5. airflow/providers/google/cloud/hooks/bigquery.py +63 -77
  6. airflow/providers/google/cloud/hooks/cloud_sql.py +8 -4
  7. airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +2 -2
  9. airflow/providers/google/cloud/hooks/dataplex.py +1 -1
  10. airflow/providers/google/cloud/hooks/dataprep.py +4 -1
  11. airflow/providers/google/cloud/hooks/gcs.py +5 -5
  12. airflow/providers/google/cloud/hooks/looker.py +10 -1
  13. airflow/providers/google/cloud/hooks/mlengine.py +2 -1
  14. airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
  15. airflow/providers/google/cloud/hooks/spanner.py +2 -2
  16. airflow/providers/google/cloud/hooks/translate.py +1 -1
  17. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -36
  18. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
  19. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +44 -80
  20. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +11 -2
  21. airflow/providers/google/cloud/hooks/vision.py +2 -2
  22. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  23. airflow/providers/google/cloud/links/base.py +75 -11
  24. airflow/providers/google/cloud/links/bigquery.py +0 -47
  25. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  26. airflow/providers/google/cloud/links/bigtable.py +0 -48
  27. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  28. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  29. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  30. airflow/providers/google/cloud/links/cloud_run.py +27 -0
  31. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  32. airflow/providers/google/cloud/links/cloud_storage_transfer.py +16 -43
  33. airflow/providers/google/cloud/links/cloud_tasks.py +6 -25
  34. airflow/providers/google/cloud/links/compute.py +0 -58
  35. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  36. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  37. airflow/providers/google/cloud/links/dataflow.py +0 -34
  38. airflow/providers/google/cloud/links/dataform.py +0 -64
  39. airflow/providers/google/cloud/links/datafusion.py +1 -96
  40. airflow/providers/google/cloud/links/dataplex.py +0 -154
  41. airflow/providers/google/cloud/links/dataprep.py +0 -24
  42. airflow/providers/google/cloud/links/dataproc.py +14 -90
  43. airflow/providers/google/cloud/links/datastore.py +0 -31
  44. airflow/providers/google/cloud/links/kubernetes_engine.py +5 -59
  45. airflow/providers/google/cloud/links/life_sciences.py +0 -19
  46. airflow/providers/google/cloud/links/managed_kafka.py +0 -70
  47. airflow/providers/google/cloud/links/mlengine.py +0 -70
  48. airflow/providers/google/cloud/links/pubsub.py +0 -32
  49. airflow/providers/google/cloud/links/spanner.py +0 -33
  50. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  51. airflow/providers/google/cloud/links/translate.py +16 -186
  52. airflow/providers/google/cloud/links/vertex_ai.py +8 -224
  53. airflow/providers/google/cloud/links/workflows.py +0 -52
  54. airflow/providers/google/cloud/log/gcs_task_handler.py +4 -4
  55. airflow/providers/google/cloud/operators/alloy_db.py +69 -54
  56. airflow/providers/google/cloud/operators/automl.py +16 -14
  57. airflow/providers/google/cloud/operators/bigquery.py +49 -25
  58. airflow/providers/google/cloud/operators/bigquery_dts.py +2 -4
  59. airflow/providers/google/cloud/operators/bigtable.py +35 -6
  60. airflow/providers/google/cloud/operators/cloud_base.py +21 -1
  61. airflow/providers/google/cloud/operators/cloud_build.py +74 -31
  62. airflow/providers/google/cloud/operators/cloud_composer.py +34 -35
  63. airflow/providers/google/cloud/operators/cloud_memorystore.py +68 -42
  64. airflow/providers/google/cloud/operators/cloud_run.py +9 -1
  65. airflow/providers/google/cloud/operators/cloud_sql.py +11 -15
  66. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +0 -2
  67. airflow/providers/google/cloud/operators/compute.py +7 -39
  68. airflow/providers/google/cloud/operators/datacatalog.py +156 -20
  69. airflow/providers/google/cloud/operators/dataflow.py +37 -14
  70. airflow/providers/google/cloud/operators/dataform.py +14 -4
  71. airflow/providers/google/cloud/operators/datafusion.py +4 -12
  72. airflow/providers/google/cloud/operators/dataplex.py +180 -96
  73. airflow/providers/google/cloud/operators/dataprep.py +0 -4
  74. airflow/providers/google/cloud/operators/dataproc.py +10 -16
  75. airflow/providers/google/cloud/operators/dataproc_metastore.py +95 -87
  76. airflow/providers/google/cloud/operators/datastore.py +21 -5
  77. airflow/providers/google/cloud/operators/dlp.py +3 -26
  78. airflow/providers/google/cloud/operators/functions.py +15 -6
  79. airflow/providers/google/cloud/operators/gcs.py +1 -7
  80. airflow/providers/google/cloud/operators/kubernetes_engine.py +53 -92
  81. airflow/providers/google/cloud/operators/life_sciences.py +0 -1
  82. airflow/providers/google/cloud/operators/managed_kafka.py +106 -51
  83. airflow/providers/google/cloud/operators/mlengine.py +0 -1
  84. airflow/providers/google/cloud/operators/pubsub.py +4 -5
  85. airflow/providers/google/cloud/operators/spanner.py +0 -4
  86. airflow/providers/google/cloud/operators/speech_to_text.py +0 -1
  87. airflow/providers/google/cloud/operators/stackdriver.py +0 -8
  88. airflow/providers/google/cloud/operators/tasks.py +0 -11
  89. airflow/providers/google/cloud/operators/text_to_speech.py +0 -1
  90. airflow/providers/google/cloud/operators/translate.py +37 -13
  91. airflow/providers/google/cloud/operators/translate_speech.py +0 -1
  92. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +31 -18
  93. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +28 -8
  94. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +38 -25
  95. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +69 -7
  96. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +42 -8
  97. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +531 -0
  98. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +93 -117
  99. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +10 -8
  100. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +56 -10
  101. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +29 -6
  102. airflow/providers/google/cloud/operators/vertex_ai/ray.py +9 -6
  103. airflow/providers/google/cloud/operators/workflows.py +1 -9
  104. airflow/providers/google/cloud/sensors/bigquery.py +1 -1
  105. airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -1
  106. airflow/providers/google/cloud/sensors/bigtable.py +15 -3
  107. airflow/providers/google/cloud/sensors/cloud_composer.py +6 -1
  108. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -1
  109. airflow/providers/google/cloud/sensors/dataflow.py +3 -3
  110. airflow/providers/google/cloud/sensors/dataform.py +6 -1
  111. airflow/providers/google/cloud/sensors/datafusion.py +6 -1
  112. airflow/providers/google/cloud/sensors/dataplex.py +6 -1
  113. airflow/providers/google/cloud/sensors/dataprep.py +6 -1
  114. airflow/providers/google/cloud/sensors/dataproc.py +6 -1
  115. airflow/providers/google/cloud/sensors/dataproc_metastore.py +6 -1
  116. airflow/providers/google/cloud/sensors/gcs.py +9 -3
  117. airflow/providers/google/cloud/sensors/looker.py +6 -1
  118. airflow/providers/google/cloud/sensors/pubsub.py +8 -3
  119. airflow/providers/google/cloud/sensors/tasks.py +6 -1
  120. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +6 -1
  121. airflow/providers/google/cloud/sensors/workflows.py +6 -1
  122. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +1 -1
  123. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +1 -1
  124. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +10 -7
  125. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -2
  126. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +0 -1
  127. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -1
  128. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  129. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +1 -1
  130. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +2 -2
  131. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
  132. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -1
  133. airflow/providers/google/cloud/transfers/gcs_to_local.py +1 -1
  134. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
  135. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +5 -1
  136. airflow/providers/google/cloud/transfers/gdrive_to_local.py +1 -1
  137. airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
  138. airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -1
  139. airflow/providers/google/cloud/transfers/s3_to_gcs.py +11 -5
  140. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +1 -1
  141. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
  142. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +2 -2
  143. airflow/providers/google/cloud/transfers/sql_to_gcs.py +1 -1
  144. airflow/providers/google/cloud/triggers/bigquery.py +32 -5
  145. airflow/providers/google/cloud/triggers/dataproc.py +62 -10
  146. airflow/providers/google/cloud/utils/field_validator.py +1 -2
  147. airflow/providers/google/common/auth_backend/google_openid.py +2 -1
  148. airflow/providers/google/common/deprecated.py +2 -1
  149. airflow/providers/google/common/hooks/base_google.py +7 -3
  150. airflow/providers/google/common/links/storage.py +0 -22
  151. airflow/providers/google/firebase/operators/firestore.py +1 -1
  152. airflow/providers/google/get_provider_info.py +14 -16
  153. airflow/providers/google/leveldb/hooks/leveldb.py +30 -1
  154. airflow/providers/google/leveldb/operators/leveldb.py +1 -1
  155. airflow/providers/google/marketing_platform/links/analytics_admin.py +3 -6
  156. airflow/providers/google/marketing_platform/operators/analytics_admin.py +0 -1
  157. airflow/providers/google/marketing_platform/operators/campaign_manager.py +4 -4
  158. airflow/providers/google/marketing_platform/operators/display_video.py +6 -6
  159. airflow/providers/google/marketing_platform/operators/search_ads.py +1 -1
  160. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +6 -1
  161. airflow/providers/google/marketing_platform/sensors/display_video.py +6 -1
  162. airflow/providers/google/suite/operators/sheets.py +3 -3
  163. airflow/providers/google/suite/sensors/drive.py +6 -1
  164. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +1 -1
  165. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  166. airflow/providers/google/suite/transfers/local_to_drive.py +1 -1
  167. airflow/providers/google/version_compat.py +28 -0
  168. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/METADATA +35 -35
  169. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/RECORD +171 -170
  170. airflow/providers/google/cloud/links/automl.py +0 -193
  171. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/WHEEL +0 -0
  172. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/entry_points.txt +0 -0
@@ -21,38 +21,26 @@ from __future__ import annotations
21
21
 
22
22
  import time
23
23
  from datetime import timedelta
24
- from typing import TYPE_CHECKING
24
+ from typing import TYPE_CHECKING, Any, Literal
25
25
 
26
26
  import vertexai
27
+ from google.cloud import aiplatform
27
28
  from vertexai.generative_models import GenerativeModel
28
- from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
29
+ from vertexai.language_models import TextEmbeddingModel
30
+ from vertexai.preview import generative_models as preview_generative_model
29
31
  from vertexai.preview.caching import CachedContent
30
32
  from vertexai.preview.evaluation import EvalResult, EvalTask
31
- from vertexai.preview.generative_models import GenerativeModel as preview_generative_model
32
33
  from vertexai.preview.tuning import sft
33
34
 
34
- from airflow.exceptions import AirflowProviderDeprecationWarning
35
- from airflow.providers.google.common.deprecated import deprecated
36
35
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
37
36
 
38
37
  if TYPE_CHECKING:
39
- from google.cloud.aiplatform_v1 import types as types_v1
40
38
  from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
41
39
 
42
40
 
43
41
  class GenerativeModelHook(GoogleBaseHook):
44
42
  """Hook for Google Cloud Vertex AI Generative Model APIs."""
45
43
 
46
- @deprecated(
47
- planned_removal_date="April 09, 2025",
48
- use_instead="GenerativeModelHook.get_generative_model",
49
- category=AirflowProviderDeprecationWarning,
50
- )
51
- def get_text_generation_model(self, pretrained_model: str):
52
- """Return a Model Garden Model object based on Text Generation."""
53
- model = TextGenerationModel.from_pretrained(pretrained_model)
54
- return model
55
-
56
44
  def get_text_embedding_model(self, pretrained_model: str):
57
45
  """Return a Model Garden Model object based on Text Embedding."""
58
46
  model = TextEmbeddingModel.from_pretrained(pretrained_model)
@@ -61,7 +49,7 @@ class GenerativeModelHook(GoogleBaseHook):
61
49
  def get_generative_model(
62
50
  self,
63
51
  pretrained_model: str,
64
- system_instruction: str | None = None,
52
+ system_instruction: Any | None = None,
65
53
  generation_config: dict | None = None,
66
54
  safety_settings: dict | None = None,
67
55
  tools: list | None = None,
@@ -93,66 +81,13 @@ class GenerativeModelHook(GoogleBaseHook):
93
81
  def get_cached_context_model(
94
82
  self,
95
83
  cached_content_name: str,
96
- ) -> preview_generative_model:
84
+ ) -> Any:
97
85
  """Return a Generative Model with Cached Context."""
98
86
  cached_content = CachedContent(cached_content_name=cached_content_name)
99
87
 
100
- cached_context_model = preview_generative_model.from_cached_content(cached_content)
88
+ cached_context_model = preview_generative_model.GenerativeModel.from_cached_content(cached_content)
101
89
  return cached_context_model
102
90
 
103
- @deprecated(
104
- planned_removal_date="April 09, 2025",
105
- use_instead="GenerativeModelHook.generative_model_generate_content",
106
- category=AirflowProviderDeprecationWarning,
107
- )
108
- @GoogleBaseHook.fallback_to_default_project_id
109
- def text_generation_model_predict(
110
- self,
111
- prompt: str,
112
- pretrained_model: str,
113
- temperature: float,
114
- max_output_tokens: int,
115
- top_p: float,
116
- top_k: int,
117
- location: str,
118
- project_id: str = PROVIDE_PROJECT_ID,
119
- ) -> str:
120
- """
121
- Use the Vertex AI PaLM API to generate natural language text.
122
-
123
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
124
- :param location: Required. The ID of the Google Cloud location that the service belongs to.
125
- :param prompt: Required. Inputs or queries that a user or a program gives
126
- to the Vertex AI PaLM API, in order to elicit a specific response.
127
- :param pretrained_model: A pre-trained model optimized for performing natural
128
- language tasks such as classification, summarization, extraction, content
129
- creation, and ideation.
130
- :param temperature: Temperature controls the degree of randomness in token
131
- selection.
132
- :param max_output_tokens: Token limit determines the maximum amount of text
133
- output.
134
- :param top_p: Tokens are selected from most probable to least until the sum
135
- of their probabilities equals the top_p value. Defaults to 0.8.
136
- :param top_k: A top_k of 1 means the selected token is the most probable
137
- among all tokens.
138
- """
139
- vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
140
-
141
- parameters = {
142
- "temperature": temperature,
143
- "max_output_tokens": max_output_tokens,
144
- "top_p": top_p,
145
- "top_k": top_k,
146
- }
147
-
148
- model = self.get_text_generation_model(pretrained_model)
149
-
150
- response = model.predict(
151
- prompt=prompt,
152
- **parameters,
153
- )
154
- return response.text
155
-
156
91
  @GoogleBaseHook.fallback_to_default_project_id
157
92
  def text_embedding_model_get_embeddings(
158
93
  self,
@@ -182,11 +117,11 @@ class GenerativeModelHook(GoogleBaseHook):
182
117
  self,
183
118
  contents: list,
184
119
  location: str,
120
+ pretrained_model: str,
185
121
  tools: list | None = None,
186
122
  generation_config: dict | None = None,
187
123
  safety_settings: dict | None = None,
188
124
  system_instruction: str | None = None,
189
- pretrained_model: str = "gemini-pro",
190
125
  project_id: str = PROVIDE_PROJECT_ID,
191
126
  ) -> str:
192
127
  """
@@ -200,7 +135,7 @@ class GenerativeModelHook(GoogleBaseHook):
200
135
  :param safety_settings: Optional. Per request settings for blocking unsafe content.
201
136
  :param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
202
137
  :param system_instruction: Optional. An instruction given to the model to guide its behavior.
203
- :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
138
+ :param pretrained_model: Required. Model,
204
139
  supporting prompts with text-only input, including natural language
205
140
  tasks, multi-turn text and code chat, and code generation. It can
206
141
  output text and code.
@@ -228,10 +163,10 @@ class GenerativeModelHook(GoogleBaseHook):
228
163
  tuned_model_display_name: str | None = None,
229
164
  validation_dataset: str | None = None,
230
165
  epochs: int | None = None,
231
- adapter_size: int | None = None,
166
+ adapter_size: Literal[1, 4, 8, 16] | None = None,
232
167
  learning_rate_multiplier: float | None = None,
233
168
  project_id: str = PROVIDE_PROJECT_ID,
234
- ) -> types_v1.TuningJob:
169
+ ) -> Any:
235
170
  """
236
171
  Use the Supervised Fine Tuning API to create a tuning job.
237
172
 
@@ -277,7 +212,7 @@ class GenerativeModelHook(GoogleBaseHook):
277
212
  self,
278
213
  contents: list,
279
214
  location: str,
280
- pretrained_model: str = "gemini-pro",
215
+ pretrained_model: str,
281
216
  project_id: str = PROVIDE_PROJECT_ID,
282
217
  ) -> types_v1beta1.CountTokensResponse:
283
218
  """
@@ -287,7 +222,7 @@ class GenerativeModelHook(GoogleBaseHook):
287
222
  :param location: Required. The ID of the Google Cloud location that the service belongs to.
288
223
  :param contents: Required. The multi-part content of a message that a user or a program
289
224
  gives to the generative model, in order to elicit a specific response.
290
- :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
225
+ :param pretrained_model: Required. Model,
291
226
  supporting prompts with text-only input, including natural language
292
227
  tasks, multi-turn text and code chat, and code generation. It can
293
228
  output text and code.
@@ -364,8 +299,8 @@ class GenerativeModelHook(GoogleBaseHook):
364
299
  model_name: str,
365
300
  location: str,
366
301
  ttl_hours: float = 1,
367
- system_instruction: str | None = None,
368
- contents: list | None = None,
302
+ system_instruction: Any | None = None,
303
+ contents: list[Any] | None = None,
369
304
  display_name: str | None = None,
370
305
  project_id: str = PROVIDE_PROJECT_ID,
371
306
  ) -> str:
@@ -424,3 +359,32 @@ class GenerativeModelHook(GoogleBaseHook):
424
359
  )
425
360
 
426
361
  return response.text
362
+
363
+
364
+ class ExperimentRunHook(GoogleBaseHook):
365
+ """Use the Vertex AI SDK for Python to create and manage your experiment runs."""
366
+
367
+ @GoogleBaseHook.fallback_to_default_project_id
368
+ def delete_experiment_run(
369
+ self,
370
+ experiment_run_name: str,
371
+ experiment_name: str,
372
+ location: str,
373
+ project_id: str = PROVIDE_PROJECT_ID,
374
+ delete_backing_tensorboard_run: bool = False,
375
+ ) -> None:
376
+ """
377
+ Delete experiment run from the experiment.
378
+
379
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
380
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
381
+ :param experiment_name: Required. The name of the evaluation experiment.
382
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
383
+ :param delete_backing_tensorboard_run: Whether to delete the backing Vertex AI TensorBoard run
384
+ that stores time series metrics for this run.
385
+ """
386
+ self.log.info("Next experiment run will be deleted: %s", experiment_run_name)
387
+ experiment_run = aiplatform.ExperimentRun(
388
+ run_name=experiment_run_name, experiment=experiment_name, project=project_id, location=location
389
+ )
390
+ experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run)
@@ -22,8 +22,17 @@ from __future__ import annotations
22
22
  import dataclasses
23
23
  from typing import Any
24
24
 
25
- import vertex_ray
26
- from google._upb._message import ScalarMapContainer
25
+ from airflow.exceptions import AirflowOptionalProviderFeatureException
26
+
27
+ try:
28
+ import vertex_ray
29
+ from google._upb._message import ScalarMapContainer # type: ignore[attr-defined]
30
+ except ImportError:
31
+ # Fallback for environments where the upb module is not available.
32
+ raise AirflowOptionalProviderFeatureException(
33
+ "google._upb._message.ScalarMapContainer is not available. "
34
+ "Please install the ray package to use this feature."
35
+ )
27
36
  from google.cloud import aiplatform
28
37
  from google.cloud.aiplatform.vertex_ray.util import resources
29
38
  from google.cloud.aiplatform_v1 import (
@@ -19,10 +19,10 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from collections.abc import Sequence
22
+ from collections.abc import Callable, Sequence
23
23
  from copy import deepcopy
24
24
  from functools import cached_property
25
- from typing import TYPE_CHECKING, Any, Callable
25
+ from typing import TYPE_CHECKING, Any
26
26
 
27
27
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
28
28
  from google.cloud.vision_v1 import (
@@ -19,14 +19,8 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from typing import TYPE_CHECKING
23
-
24
22
  from airflow.providers.google.cloud.links.base import BaseGoogleLink
25
23
 
26
- if TYPE_CHECKING:
27
- from airflow.models import BaseOperator
28
- from airflow.utils.context import Context
29
-
30
24
  ALLOY_DB_BASE_LINK = "/alloydb"
31
25
  ALLOY_DB_CLUSTER_LINK = (
32
26
  ALLOY_DB_BASE_LINK + "/locations/{location_id}/clusters/{cluster_id}?project={project_id}"
@@ -44,20 +38,6 @@ class AlloyDBClusterLink(BaseGoogleLink):
44
38
  key = "alloy_db_cluster"
45
39
  format_str = ALLOY_DB_CLUSTER_LINK
46
40
 
47
- @staticmethod
48
- def persist(
49
- context: Context,
50
- task_instance: BaseOperator,
51
- location_id: str,
52
- cluster_id: str,
53
- project_id: str | None,
54
- ):
55
- task_instance.xcom_push(
56
- context,
57
- key=AlloyDBClusterLink.key,
58
- value={"location_id": location_id, "cluster_id": cluster_id, "project_id": project_id},
59
- )
60
-
61
41
 
62
42
  class AlloyDBUsersLink(BaseGoogleLink):
63
43
  """Helper class for constructing AlloyDB users Link."""
@@ -66,20 +46,6 @@ class AlloyDBUsersLink(BaseGoogleLink):
66
46
  key = "alloy_db_users"
67
47
  format_str = ALLOY_DB_USERS_LINK
68
48
 
69
- @staticmethod
70
- def persist(
71
- context: Context,
72
- task_instance: BaseOperator,
73
- location_id: str,
74
- cluster_id: str,
75
- project_id: str | None,
76
- ):
77
- task_instance.xcom_push(
78
- context,
79
- key=AlloyDBUsersLink.key,
80
- value={"location_id": location_id, "cluster_id": cluster_id, "project_id": project_id},
81
- )
82
-
83
49
 
84
50
  class AlloyDBBackupsLink(BaseGoogleLink):
85
51
  """Helper class for constructing AlloyDB backups Link."""
@@ -87,15 +53,3 @@ class AlloyDBBackupsLink(BaseGoogleLink):
87
53
  name = "AlloyDB Backups"
88
54
  key = "alloy_db_backups"
89
55
  format_str = ALLOY_DB_BACKUPS_LINK
90
-
91
- @staticmethod
92
- def persist(
93
- context: Context,
94
- task_instance: BaseOperator,
95
- project_id: str | None,
96
- ):
97
- task_instance.xcom_push(
98
- context,
99
- key=AlloyDBBackupsLink.key,
100
- value={"project_id": project_id},
101
- )
@@ -19,19 +19,23 @@ from __future__ import annotations
19
19
 
20
20
  from typing import TYPE_CHECKING, ClassVar
21
21
 
22
- from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
23
-
24
- if TYPE_CHECKING:
25
- from airflow.models import BaseOperator
26
- from airflow.models.taskinstancekey import TaskInstanceKey
22
+ from airflow.providers.google.version_compat import (
23
+ AIRFLOW_V_3_0_PLUS,
24
+ BaseOperator,
25
+ BaseOperatorLink,
26
+ BaseSensorOperator,
27
+ )
27
28
 
28
29
  if AIRFLOW_V_3_0_PLUS:
29
- from airflow.sdk import BaseOperatorLink
30
30
  from airflow.sdk.execution_time.xcom import XCom
31
31
  else:
32
- from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
33
32
  from airflow.models.xcom import XCom # type: ignore[no-redef]
34
33
 
34
+ if TYPE_CHECKING:
35
+ from airflow.models.taskinstancekey import TaskInstanceKey
36
+ from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
37
+ from airflow.utils.context import Context
38
+
35
39
  BASE_LINK = "https://console.cloud.google.com"
36
40
 
37
41
 
@@ -39,6 +43,12 @@ class BaseGoogleLink(BaseOperatorLink):
39
43
  """
40
44
  Base class for all Google links.
41
45
 
46
+ When you inherit this class in a Link class;
47
+ - You can call the persist method to push data to the XCom to use it later in the get_link method.
48
+ - If you have an operator which inherit the GoogleCloudBaseOperator or BaseSensorOperator
49
+ You can define extra_links_params method in the operator to pass the operator properties
50
+ to the get_link method.
51
+
42
52
  :meta private:
43
53
  """
44
54
 
@@ -46,15 +56,69 @@ class BaseGoogleLink(BaseOperatorLink):
46
56
  key: ClassVar[str]
47
57
  format_str: ClassVar[str]
48
58
 
59
+ @property
60
+ def xcom_key(self) -> str:
61
+ # NOTE: in Airflow 3 we need to have xcom_key property in the Link class.
62
+ # Since we have the key property already, this is just a proxy property method to use same
63
+ # key as in Airflow 2.
64
+ return self.key
65
+
66
+ @classmethod
67
+ def persist(cls, context: Context, **value):
68
+ """
69
+ Push arguments to the XCom to use later for link formatting at the `get_link` method.
70
+
71
+ Note: for Airflow 2 we need to call this function with context variable only
72
+ where we have the extra_links_params property method defined
73
+ """
74
+ params = {}
75
+ # TODO: remove after Airflow v2 support dropped
76
+ if not AIRFLOW_V_3_0_PLUS:
77
+ common_params = getattr(context["task"], "extra_links_params", None)
78
+ if common_params:
79
+ params.update(common_params)
80
+
81
+ context["ti"].xcom_push(
82
+ key=cls.key,
83
+ value={
84
+ **params,
85
+ **value,
86
+ },
87
+ )
88
+
89
+ def get_config(self, operator, ti_key):
90
+ conf = {}
91
+ conf.update(getattr(operator, "extra_links_params", {}))
92
+ conf.update(XCom.get_value(key=self.key, ti_key=ti_key) or {})
93
+
94
+ # if the config did not define, return None to stop URL formatting
95
+ if not conf:
96
+ return None
97
+
98
+ # Add a default value for the 'namespace' parameter for backward compatibility.
99
+ # This is for datafusion
100
+ conf.setdefault("namespace", "default")
101
+ return conf
102
+
49
103
  def get_link(
50
104
  self,
51
105
  operator: BaseOperator,
52
106
  *,
53
107
  ti_key: TaskInstanceKey,
54
108
  ) -> str:
55
- conf = XCom.get_value(key=self.key, ti_key=ti_key)
109
+ if TYPE_CHECKING:
110
+ assert isinstance(operator, (GoogleCloudBaseOperator, BaseSensorOperator))
111
+
112
+ conf = self.get_config(operator, ti_key)
56
113
  if not conf:
57
114
  return ""
58
- if self.format_str.startswith("http"):
59
- return self.format_str.format(**conf)
60
- return BASE_LINK + self.format_str.format(**conf)
115
+ return self._format_link(**conf)
116
+
117
+ def _format_link(self, **kwargs):
118
+ try:
119
+ formatted_str = self.format_str.format(**kwargs)
120
+ if formatted_str.startswith("http"):
121
+ return formatted_str
122
+ return BASE_LINK + formatted_str
123
+ except KeyError:
124
+ return ""
@@ -19,14 +19,8 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from typing import TYPE_CHECKING
23
-
24
22
  from airflow.providers.google.cloud.links.base import BaseGoogleLink
25
23
 
26
- if TYPE_CHECKING:
27
- from airflow.models import BaseOperator
28
- from airflow.utils.context import Context
29
-
30
24
  BIGQUERY_BASE_LINK = "/bigquery"
31
25
  BIGQUERY_DATASET_LINK = (
32
26
  BIGQUERY_BASE_LINK + "?referrer=search&project={project_id}&d={dataset_id}&p={project_id}&page=dataset"
@@ -47,19 +41,6 @@ class BigQueryDatasetLink(BaseGoogleLink):
47
41
  key = "bigquery_dataset"
48
42
  format_str = BIGQUERY_DATASET_LINK
49
43
 
50
- @staticmethod
51
- def persist(
52
- context: Context,
53
- task_instance: BaseOperator,
54
- dataset_id: str,
55
- project_id: str,
56
- ):
57
- task_instance.xcom_push(
58
- context,
59
- key=BigQueryDatasetLink.key,
60
- value={"dataset_id": dataset_id, "project_id": project_id},
61
- )
62
-
63
44
 
64
45
  class BigQueryTableLink(BaseGoogleLink):
65
46
  """Helper class for constructing BigQuery Table Link."""
@@ -68,20 +49,6 @@ class BigQueryTableLink(BaseGoogleLink):
68
49
  key = "bigquery_table"
69
50
  format_str = BIGQUERY_TABLE_LINK
70
51
 
71
- @staticmethod
72
- def persist(
73
- context: Context,
74
- task_instance: BaseOperator,
75
- project_id: str,
76
- table_id: str,
77
- dataset_id: str | None = None,
78
- ):
79
- task_instance.xcom_push(
80
- context,
81
- key=BigQueryTableLink.key,
82
- value={"dataset_id": dataset_id, "project_id": project_id, "table_id": table_id},
83
- )
84
-
85
52
 
86
53
  class BigQueryJobDetailLink(BaseGoogleLink):
87
54
  """Helper class for constructing BigQuery Job Detail Link."""
@@ -89,17 +56,3 @@ class BigQueryJobDetailLink(BaseGoogleLink):
89
56
  name = "BigQuery Job Detail"
90
57
  key = "bigquery_job_detail"
91
58
  format_str = BIGQUERY_JOB_DETAIL_LINK
92
-
93
- @staticmethod
94
- def persist(
95
- context: Context,
96
- task_instance: BaseOperator,
97
- project_id: str,
98
- location: str,
99
- job_id: str,
100
- ):
101
- task_instance.xcom_push(
102
- context,
103
- key=BigQueryJobDetailLink.key,
104
- value={"project_id": project_id, "location": location, "job_id": job_id},
105
- )
@@ -19,14 +19,8 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from typing import TYPE_CHECKING
23
-
24
22
  from airflow.providers.google.cloud.links.base import BaseGoogleLink
25
23
 
26
- if TYPE_CHECKING:
27
- from airflow.models import BaseOperator
28
- from airflow.utils.context import Context
29
-
30
24
  BIGQUERY_BASE_LINK = "/bigquery/transfers"
31
25
  BIGQUERY_DTS_LINK = BIGQUERY_BASE_LINK + "/locations/{region}/configs/{config_id}/runs?project={project_id}"
32
26
 
@@ -37,17 +31,3 @@ class BigQueryDataTransferConfigLink(BaseGoogleLink):
37
31
  name = "BigQuery Data Transfer Config"
38
32
  key = "bigquery_dts_config"
39
33
  format_str = BIGQUERY_DTS_LINK
40
-
41
- @staticmethod
42
- def persist(
43
- context: Context,
44
- task_instance: BaseOperator,
45
- region: str,
46
- config_id: str,
47
- project_id: str,
48
- ):
49
- task_instance.xcom_push(
50
- context,
51
- key=BigQueryDataTransferConfigLink.key,
52
- value={"project_id": project_id, "region": region, "config_id": config_id},
53
- )
@@ -16,13 +16,8 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
- from typing import TYPE_CHECKING
20
-
21
19
  from airflow.providers.google.cloud.links.base import BaseGoogleLink
22
20
 
23
- if TYPE_CHECKING:
24
- from airflow.utils.context import Context
25
-
26
21
  BIGTABLE_BASE_LINK = "/bigtable"
27
22
  BIGTABLE_INSTANCE_LINK = BIGTABLE_BASE_LINK + "/instances/{instance_id}/overview?project={project_id}"
28
23
  BIGTABLE_CLUSTER_LINK = (
@@ -38,20 +33,6 @@ class BigtableInstanceLink(BaseGoogleLink):
38
33
  key = "instance_key"
39
34
  format_str = BIGTABLE_INSTANCE_LINK
40
35
 
41
- @staticmethod
42
- def persist(
43
- context: Context,
44
- task_instance,
45
- ):
46
- task_instance.xcom_push(
47
- context=context,
48
- key=BigtableInstanceLink.key,
49
- value={
50
- "instance_id": task_instance.instance_id,
51
- "project_id": task_instance.project_id,
52
- },
53
- )
54
-
55
36
 
56
37
  class BigtableClusterLink(BaseGoogleLink):
57
38
  """Helper class for constructing Bigtable Cluster link."""
@@ -60,21 +41,6 @@ class BigtableClusterLink(BaseGoogleLink):
60
41
  key = "cluster_key"
61
42
  format_str = BIGTABLE_CLUSTER_LINK
62
43
 
63
- @staticmethod
64
- def persist(
65
- context: Context,
66
- task_instance,
67
- ):
68
- task_instance.xcom_push(
69
- context=context,
70
- key=BigtableClusterLink.key,
71
- value={
72
- "instance_id": task_instance.instance_id,
73
- "cluster_id": task_instance.cluster_id,
74
- "project_id": task_instance.project_id,
75
- },
76
- )
77
-
78
44
 
79
45
  class BigtableTablesLink(BaseGoogleLink):
80
46
  """Helper class for constructing Bigtable Tables link."""
@@ -82,17 +48,3 @@ class BigtableTablesLink(BaseGoogleLink):
82
48
  name = "Bigtable Tables"
83
49
  key = "tables_key"
84
50
  format_str = BIGTABLE_TABLES_LINK
85
-
86
- @staticmethod
87
- def persist(
88
- context: Context,
89
- task_instance,
90
- ):
91
- task_instance.xcom_push(
92
- context=context,
93
- key=BigtableTablesLink.key,
94
- value={
95
- "instance_id": task_instance.instance_id,
96
- "project_id": task_instance.project_id,
97
- },
98
- )