apache-airflow-providers-google 16.0.0rc1__py3-none-any.whl → 16.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +9 -5
  3. airflow/providers/google/ads/operators/ads.py +1 -1
  4. airflow/providers/google/ads/transfers/ads_to_gcs.py +1 -1
  5. airflow/providers/google/cloud/hooks/bigquery.py +2 -3
  6. airflow/providers/google/cloud/hooks/cloud_sql.py +8 -4
  7. airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +2 -2
  9. airflow/providers/google/cloud/hooks/dataplex.py +1 -1
  10. airflow/providers/google/cloud/hooks/dataprep.py +4 -1
  11. airflow/providers/google/cloud/hooks/gcs.py +2 -2
  12. airflow/providers/google/cloud/hooks/looker.py +5 -1
  13. airflow/providers/google/cloud/hooks/mlengine.py +2 -1
  14. airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
  15. airflow/providers/google/cloud/hooks/spanner.py +2 -2
  16. airflow/providers/google/cloud/hooks/translate.py +1 -1
  17. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
  18. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +43 -14
  19. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +11 -2
  20. airflow/providers/google/cloud/hooks/vision.py +2 -2
  21. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  22. airflow/providers/google/cloud/links/base.py +75 -11
  23. airflow/providers/google/cloud/links/bigquery.py +0 -47
  24. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  25. airflow/providers/google/cloud/links/bigtable.py +0 -48
  26. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  27. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  28. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  29. airflow/providers/google/cloud/links/cloud_run.py +1 -33
  30. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  31. airflow/providers/google/cloud/links/cloud_storage_transfer.py +16 -43
  32. airflow/providers/google/cloud/links/cloud_tasks.py +6 -25
  33. airflow/providers/google/cloud/links/compute.py +0 -58
  34. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  35. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  36. airflow/providers/google/cloud/links/dataflow.py +0 -34
  37. airflow/providers/google/cloud/links/dataform.py +0 -64
  38. airflow/providers/google/cloud/links/datafusion.py +1 -96
  39. airflow/providers/google/cloud/links/dataplex.py +0 -154
  40. airflow/providers/google/cloud/links/dataprep.py +0 -24
  41. airflow/providers/google/cloud/links/dataproc.py +14 -90
  42. airflow/providers/google/cloud/links/datastore.py +0 -31
  43. airflow/providers/google/cloud/links/kubernetes_engine.py +5 -59
  44. airflow/providers/google/cloud/links/life_sciences.py +0 -19
  45. airflow/providers/google/cloud/links/managed_kafka.py +0 -70
  46. airflow/providers/google/cloud/links/mlengine.py +0 -70
  47. airflow/providers/google/cloud/links/pubsub.py +0 -32
  48. airflow/providers/google/cloud/links/spanner.py +0 -33
  49. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  50. airflow/providers/google/cloud/links/translate.py +16 -186
  51. airflow/providers/google/cloud/links/vertex_ai.py +8 -224
  52. airflow/providers/google/cloud/links/workflows.py +0 -52
  53. airflow/providers/google/cloud/operators/alloy_db.py +69 -54
  54. airflow/providers/google/cloud/operators/automl.py +16 -14
  55. airflow/providers/google/cloud/operators/bigquery.py +0 -15
  56. airflow/providers/google/cloud/operators/bigquery_dts.py +2 -4
  57. airflow/providers/google/cloud/operators/bigtable.py +35 -6
  58. airflow/providers/google/cloud/operators/cloud_base.py +21 -1
  59. airflow/providers/google/cloud/operators/cloud_build.py +74 -31
  60. airflow/providers/google/cloud/operators/cloud_composer.py +34 -35
  61. airflow/providers/google/cloud/operators/cloud_memorystore.py +68 -42
  62. airflow/providers/google/cloud/operators/cloud_run.py +0 -1
  63. airflow/providers/google/cloud/operators/cloud_sql.py +11 -15
  64. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +0 -2
  65. airflow/providers/google/cloud/operators/compute.py +7 -39
  66. airflow/providers/google/cloud/operators/datacatalog.py +156 -20
  67. airflow/providers/google/cloud/operators/dataflow.py +37 -14
  68. airflow/providers/google/cloud/operators/dataform.py +14 -4
  69. airflow/providers/google/cloud/operators/datafusion.py +4 -12
  70. airflow/providers/google/cloud/operators/dataplex.py +180 -96
  71. airflow/providers/google/cloud/operators/dataprep.py +0 -4
  72. airflow/providers/google/cloud/operators/dataproc.py +10 -16
  73. airflow/providers/google/cloud/operators/dataproc_metastore.py +95 -87
  74. airflow/providers/google/cloud/operators/datastore.py +21 -5
  75. airflow/providers/google/cloud/operators/dlp.py +3 -26
  76. airflow/providers/google/cloud/operators/functions.py +15 -6
  77. airflow/providers/google/cloud/operators/gcs.py +0 -7
  78. airflow/providers/google/cloud/operators/kubernetes_engine.py +50 -7
  79. airflow/providers/google/cloud/operators/life_sciences.py +0 -1
  80. airflow/providers/google/cloud/operators/managed_kafka.py +106 -51
  81. airflow/providers/google/cloud/operators/mlengine.py +0 -1
  82. airflow/providers/google/cloud/operators/pubsub.py +2 -4
  83. airflow/providers/google/cloud/operators/spanner.py +0 -4
  84. airflow/providers/google/cloud/operators/speech_to_text.py +0 -1
  85. airflow/providers/google/cloud/operators/stackdriver.py +0 -8
  86. airflow/providers/google/cloud/operators/tasks.py +0 -11
  87. airflow/providers/google/cloud/operators/text_to_speech.py +0 -1
  88. airflow/providers/google/cloud/operators/translate.py +37 -13
  89. airflow/providers/google/cloud/operators/translate_speech.py +0 -1
  90. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +31 -18
  91. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +28 -8
  92. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +38 -25
  93. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +69 -7
  94. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +42 -8
  95. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +531 -0
  96. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +93 -25
  97. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +10 -8
  98. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +56 -10
  99. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +25 -6
  100. airflow/providers/google/cloud/operators/vertex_ai/ray.py +9 -6
  101. airflow/providers/google/cloud/operators/workflows.py +1 -9
  102. airflow/providers/google/cloud/sensors/bigquery.py +1 -1
  103. airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -1
  104. airflow/providers/google/cloud/sensors/bigtable.py +15 -3
  105. airflow/providers/google/cloud/sensors/cloud_composer.py +6 -1
  106. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -1
  107. airflow/providers/google/cloud/sensors/dataflow.py +3 -3
  108. airflow/providers/google/cloud/sensors/dataform.py +6 -1
  109. airflow/providers/google/cloud/sensors/datafusion.py +6 -1
  110. airflow/providers/google/cloud/sensors/dataplex.py +6 -1
  111. airflow/providers/google/cloud/sensors/dataprep.py +6 -1
  112. airflow/providers/google/cloud/sensors/dataproc.py +6 -1
  113. airflow/providers/google/cloud/sensors/dataproc_metastore.py +6 -1
  114. airflow/providers/google/cloud/sensors/gcs.py +9 -3
  115. airflow/providers/google/cloud/sensors/looker.py +6 -1
  116. airflow/providers/google/cloud/sensors/pubsub.py +8 -3
  117. airflow/providers/google/cloud/sensors/tasks.py +6 -1
  118. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +6 -1
  119. airflow/providers/google/cloud/sensors/workflows.py +6 -1
  120. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +1 -1
  121. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +1 -1
  122. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +1 -2
  123. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -2
  124. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +0 -1
  125. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -1
  126. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  127. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +1 -1
  128. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -1
  129. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
  130. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -1
  131. airflow/providers/google/cloud/transfers/gcs_to_local.py +1 -1
  132. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
  133. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +5 -1
  134. airflow/providers/google/cloud/transfers/gdrive_to_local.py +1 -1
  135. airflow/providers/google/cloud/transfers/http_to_gcs.py +1 -1
  136. airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -1
  137. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +1 -1
  138. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
  139. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +2 -2
  140. airflow/providers/google/cloud/transfers/sql_to_gcs.py +1 -1
  141. airflow/providers/google/cloud/utils/field_validator.py +1 -2
  142. airflow/providers/google/common/auth_backend/google_openid.py +2 -1
  143. airflow/providers/google/common/deprecated.py +2 -1
  144. airflow/providers/google/common/hooks/base_google.py +7 -3
  145. airflow/providers/google/common/links/storage.py +0 -22
  146. airflow/providers/google/firebase/operators/firestore.py +1 -1
  147. airflow/providers/google/get_provider_info.py +0 -11
  148. airflow/providers/google/leveldb/hooks/leveldb.py +5 -1
  149. airflow/providers/google/leveldb/operators/leveldb.py +1 -1
  150. airflow/providers/google/marketing_platform/links/analytics_admin.py +3 -6
  151. airflow/providers/google/marketing_platform/operators/analytics_admin.py +0 -1
  152. airflow/providers/google/marketing_platform/operators/campaign_manager.py +4 -4
  153. airflow/providers/google/marketing_platform/operators/display_video.py +6 -6
  154. airflow/providers/google/marketing_platform/operators/search_ads.py +1 -1
  155. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +6 -1
  156. airflow/providers/google/marketing_platform/sensors/display_video.py +6 -1
  157. airflow/providers/google/suite/operators/sheets.py +3 -3
  158. airflow/providers/google/suite/sensors/drive.py +6 -1
  159. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +1 -1
  160. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  161. airflow/providers/google/suite/transfers/local_to_drive.py +1 -1
  162. airflow/providers/google/version_compat.py +28 -0
  163. {apache_airflow_providers_google-16.0.0rc1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/METADATA +19 -20
  164. {apache_airflow_providers_google-16.0.0rc1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/RECORD +166 -166
  165. {apache_airflow_providers_google-16.0.0rc1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/WHEEL +0 -0
  166. {apache_airflow_providers_google-16.0.0rc1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/entry_points.txt +0 -0
@@ -20,9 +20,15 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  from collections.abc import Sequence
23
- from typing import TYPE_CHECKING
23
+ from typing import TYPE_CHECKING, Any, Literal
24
24
 
25
- from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
25
+ from google.api_core import exceptions
26
+
27
+ from airflow.exceptions import AirflowException
28
+ from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
29
+ ExperimentRunHook,
30
+ GenerativeModelHook,
31
+ )
26
32
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
27
33
 
28
34
  if TYPE_CHECKING:
@@ -38,9 +44,8 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
38
44
  :param location: Required. The ID of the Google Cloud location that the
39
45
  service belongs to (templated).
40
46
  :param prompt: Required. Inputs or queries that a user or a program gives
41
- to the Vertex AI PaLM API, in order to elicit a specific response (templated).
42
- :param pretrained_model: By default uses the pre-trained model `textembedding-gecko`,
43
- optimized for performing text embeddings.
47
+ to the Vertex AI Generative Model API, in order to elicit a specific response (templated).
48
+ :param pretrained_model: Required. Model, optimized for performing text embeddings.
44
49
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
45
50
  :param impersonation_chain: Optional service account to impersonate using short-term
46
51
  credentials, or chained list of accounts required to get the access_token
@@ -60,7 +65,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
60
65
  project_id: str,
61
66
  location: str,
62
67
  prompt: str,
63
- pretrained_model: str = "textembedding-gecko",
68
+ pretrained_model: str,
64
69
  gcp_conn_id: str = "google_cloud_default",
65
70
  impersonation_chain: str | Sequence[str] | None = None,
66
71
  **kwargs,
@@ -88,7 +93,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
88
93
  )
89
94
 
90
95
  self.log.info("Model response: %s", response)
91
- self.xcom_push(context, key="model_response", value=response)
96
+ context["ti"].xcom_push(key="model_response", value=response)
92
97
 
93
98
  return response
94
99
 
@@ -107,10 +112,9 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
107
112
  :param safety_settings: Optional. Per request settings for blocking unsafe content.
108
113
  :param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
109
114
  :param system_instruction: Optional. An instruction given to the model to guide its behavior.
110
- :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
111
- supporting prompts with text-only input, including natural language
112
- tasks, multi-turn text and code chat, and code generation. It can
113
- output text and code.
115
+ :param pretrained_model: Required. The name of the model to use for content generation,
116
+ which can be a text-only or multimodal model. For example, `gemini-pro` or
117
+ `gemini-pro-vision`.
114
118
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
115
119
  :param impersonation_chain: Optional service account to impersonate using short-term
116
120
  credentials, or chained list of accounts required to get the access_token
@@ -134,7 +138,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
134
138
  generation_config: dict | None = None,
135
139
  safety_settings: dict | None = None,
136
140
  system_instruction: str | None = None,
137
- pretrained_model: str = "gemini-pro",
141
+ pretrained_model: str,
138
142
  gcp_conn_id: str = "google_cloud_default",
139
143
  impersonation_chain: str | Sequence[str] | None = None,
140
144
  **kwargs,
@@ -168,7 +172,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
168
172
  )
169
173
 
170
174
  self.log.info("Model response: %s", response)
171
- self.xcom_push(context, key="model_response", value=response)
175
+ context["ti"].xcom_push(key="model_response", value=response)
172
176
 
173
177
  return response
174
178
 
@@ -218,7 +222,7 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
218
222
  tuned_model_display_name: str | None = None,
219
223
  validation_dataset: str | None = None,
220
224
  epochs: int | None = None,
221
- adapter_size: int | None = None,
225
+ adapter_size: Literal[1, 4, 8, 16] | None = None,
222
226
  learning_rate_multiplier: float | None = None,
223
227
  gcp_conn_id: str = "google_cloud_default",
224
228
  impersonation_chain: str | Sequence[str] | None = None,
@@ -257,8 +261,8 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
257
261
  self.log.info("Tuned Model Name: %s", response.tuned_model_name)
258
262
  self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name)
259
263
 
260
- self.xcom_push(context, key="tuned_model_name", value=response.tuned_model_name)
261
- self.xcom_push(context, key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
264
+ context["ti"].xcom_push(key="tuned_model_name", value=response.tuned_model_name)
265
+ context["ti"].xcom_push(key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
262
266
 
263
267
  result = {
264
268
  "tuned_model_name": response.tuned_model_name,
@@ -278,10 +282,9 @@ class CountTokensOperator(GoogleCloudBaseOperator):
278
282
  service belongs to (templated).
279
283
  :param contents: Required. The multi-part content of a message that a user or a program
280
284
  gives to the generative model, in order to elicit a specific response.
281
- :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
282
- supporting prompts with text-only input, including natural language
283
- tasks, multi-turn text and code chat, and code generation. It can
284
- output text and code.
285
+ :param pretrained_model: Required. Model, supporting prompts with text-only input,
286
+ including natural language tasks, multi-turn text and code chat,
287
+ and code generation. It can output text and code.
285
288
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
286
289
  :param impersonation_chain: Optional service account to impersonate using short-term
287
290
  credentials, or chained list of accounts required to get the access_token
@@ -301,7 +304,7 @@ class CountTokensOperator(GoogleCloudBaseOperator):
301
304
  project_id: str,
302
305
  location: str,
303
306
  contents: list,
304
- pretrained_model: str = "gemini-pro",
307
+ pretrained_model: str,
305
308
  gcp_conn_id: str = "google_cloud_default",
306
309
  impersonation_chain: str | Sequence[str] | None = None,
307
310
  **kwargs,
@@ -329,8 +332,8 @@ class CountTokensOperator(GoogleCloudBaseOperator):
329
332
  self.log.info("Total tokens: %s", response.total_tokens)
330
333
  self.log.info("Total billable characters: %s", response.total_billable_characters)
331
334
 
332
- self.xcom_push(context, key="total_tokens", value=response.total_tokens)
333
- self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters)
335
+ context["ti"].xcom_push(key="total_tokens", value=response.total_tokens)
336
+ context["ti"].xcom_push(key="total_billable_characters", value=response.total_billable_characters)
334
337
 
335
338
 
336
339
  class RunEvaluationOperator(GoogleCloudBaseOperator):
@@ -470,8 +473,8 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
470
473
  project_id: str,
471
474
  location: str,
472
475
  model_name: str,
473
- system_instruction: str | None = None,
474
- contents: list | None = None,
476
+ system_instruction: Any | None = None,
477
+ contents: list[Any] | None = None,
475
478
  ttl_hours: float = 1,
476
479
  display_name: str | None = None,
477
480
  gcp_conn_id: str = "google_cloud_default",
@@ -582,3 +585,68 @@ class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
582
585
  self.log.info("Cached Content Response: %s", cached_content_text)
583
586
 
584
587
  return cached_content_text
588
+
589
+
590
+ class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
591
+ """
592
+ Use the Rapid Evaluation API to evaluate a model.
593
+
594
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
595
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
596
+ :param experiment_name: Required. The name of the evaluation experiment.
597
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
598
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
599
+ :param impersonation_chain: Optional service account to impersonate using short-term
600
+ credentials, or chained list of accounts required to get the access_token
601
+ of the last account in the list, which will be impersonated in the request.
602
+ If set as a string, the account must grant the originating account
603
+ the Service Account Token Creator IAM role.
604
+ If set as a sequence, the identities from the list must grant
605
+ Service Account Token Creator IAM role to the directly preceding identity, with first
606
+ account from the list granting this role to the originating account (templated).
607
+ """
608
+
609
+ template_fields = (
610
+ "location",
611
+ "project_id",
612
+ "impersonation_chain",
613
+ "experiment_name",
614
+ "experiment_run_name",
615
+ )
616
+
617
+ def __init__(
618
+ self,
619
+ *,
620
+ project_id: str,
621
+ location: str,
622
+ experiment_name: str,
623
+ experiment_run_name: str,
624
+ gcp_conn_id: str = "google_cloud_default",
625
+ impersonation_chain: str | Sequence[str] | None = None,
626
+ **kwargs,
627
+ ) -> None:
628
+ super().__init__(**kwargs)
629
+ self.project_id = project_id
630
+ self.location = location
631
+ self.experiment_name = experiment_name
632
+ self.experiment_run_name = experiment_run_name
633
+ self.gcp_conn_id = gcp_conn_id
634
+ self.impersonation_chain = impersonation_chain
635
+
636
+ def execute(self, context: Context) -> None:
637
+ self.hook = ExperimentRunHook(
638
+ gcp_conn_id=self.gcp_conn_id,
639
+ impersonation_chain=self.impersonation_chain,
640
+ )
641
+
642
+ try:
643
+ self.hook.delete_experiment_run(
644
+ project_id=self.project_id,
645
+ location=self.location,
646
+ experiment_name=self.experiment_name,
647
+ experiment_run_name=self.experiment_run_name,
648
+ )
649
+ except exceptions.NotFound:
650
+ raise AirflowException(f"Experiment Run with name {self.experiment_run_name} not found")
651
+
652
+ self.log.info("Deleted experiment run: %s", self.experiment_run_name)
@@ -257,10 +257,8 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
257
257
  hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
258
258
  self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id)
259
259
 
260
- self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
261
- VertexAITrainingLink.persist(
262
- context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
263
- )
260
+ context["ti"].xcom_push(key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
261
+ VertexAITrainingLink.persist(context=context, training_id=hyperparameter_tuning_job_id)
264
262
 
265
263
  if self.deferrable:
266
264
  self.defer(
@@ -355,9 +353,7 @@ class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
355
353
  timeout=self.timeout,
356
354
  metadata=self.metadata,
357
355
  )
358
- VertexAITrainingLink.persist(
359
- context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id
360
- )
356
+ VertexAITrainingLink.persist(context=context, training_id=self.hyperparameter_tuning_job_id)
361
357
  self.log.info("Hyperparameter tuning job was gotten.")
362
358
  return types.HyperparameterTuningJob.to_dict(result)
363
359
  except NotFound:
@@ -487,6 +483,12 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
487
483
  self.gcp_conn_id = gcp_conn_id
488
484
  self.impersonation_chain = impersonation_chain
489
485
 
486
+ @property
487
+ def extra_links_params(self) -> dict[str, Any]:
488
+ return {
489
+ "project_id": self.project_id,
490
+ }
491
+
490
492
  def execute(self, context: Context):
491
493
  hook = HyperparameterTuningJobHook(
492
494
  gcp_conn_id=self.gcp_conn_id,
@@ -503,5 +505,5 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
503
505
  timeout=self.timeout,
504
506
  metadata=self.metadata,
505
507
  )
506
- VertexAIHyperparameterTuningJobListLink.persist(context=context, task_instance=self)
508
+ VertexAIHyperparameterTuningJobListLink.persist(context=context)
507
509
  return [types.HyperparameterTuningJob.to_dict(result) for result in results]
@@ -20,7 +20,7 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  from collections.abc import Sequence
23
- from typing import TYPE_CHECKING
23
+ from typing import TYPE_CHECKING, Any
24
24
 
25
25
  from google.api_core.exceptions import NotFound
26
26
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -161,6 +161,13 @@ class GetModelOperator(GoogleCloudBaseOperator):
161
161
  self.gcp_conn_id = gcp_conn_id
162
162
  self.impersonation_chain = impersonation_chain
163
163
 
164
+ @property
165
+ def extra_links_params(self) -> dict[str, Any]:
166
+ return {
167
+ "region": self.region,
168
+ "project_id": self.project_id,
169
+ }
170
+
164
171
  def execute(self, context: Context):
165
172
  hook = ModelServiceHook(
166
173
  gcp_conn_id=self.gcp_conn_id,
@@ -179,8 +186,8 @@ class GetModelOperator(GoogleCloudBaseOperator):
179
186
  )
180
187
  self.log.info("Model found. Model ID: %s", self.model_id)
181
188
 
182
- self.xcom_push(context, key="model_id", value=self.model_id)
183
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
189
+ context["ti"].xcom_push(key="model_id", value=self.model_id)
190
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
184
191
  return Model.to_dict(model)
185
192
  except NotFound:
186
193
  self.log.info("The Model ID %s does not exist.", self.model_id)
@@ -257,7 +264,12 @@ class ExportModelOperator(GoogleCloudBaseOperator):
257
264
  metadata=self.metadata,
258
265
  )
259
266
  hook.wait_for_operation(timeout=self.timeout, operation=operation)
260
- VertexAIModelExportLink.persist(context=context, task_instance=self)
267
+ VertexAIModelExportLink.persist(
268
+ context=context,
269
+ output_config=self.output_config,
270
+ model_id=self.model_id,
271
+ project_id=self.project_id,
272
+ )
261
273
  self.log.info("Model was exported.")
262
274
  except NotFound:
263
275
  self.log.info("The Model ID %s does not exist.", self.model_id)
@@ -335,6 +347,12 @@ class ListModelsOperator(GoogleCloudBaseOperator):
335
347
  self.gcp_conn_id = gcp_conn_id
336
348
  self.impersonation_chain = impersonation_chain
337
349
 
350
+ @property
351
+ def extra_links_params(self) -> dict[str, Any]:
352
+ return {
353
+ "project_id": self.project_id,
354
+ }
355
+
338
356
  def execute(self, context: Context):
339
357
  hook = ModelServiceHook(
340
358
  gcp_conn_id=self.gcp_conn_id,
@@ -352,7 +370,7 @@ class ListModelsOperator(GoogleCloudBaseOperator):
352
370
  timeout=self.timeout,
353
371
  metadata=self.metadata,
354
372
  )
355
- VertexAIModelListLink.persist(context=context, task_instance=self)
373
+ VertexAIModelListLink.persist(context=context)
356
374
  return [Model.to_dict(result) for result in results]
357
375
 
358
376
 
@@ -407,6 +425,13 @@ class UploadModelOperator(GoogleCloudBaseOperator):
407
425
  self.gcp_conn_id = gcp_conn_id
408
426
  self.impersonation_chain = impersonation_chain
409
427
 
428
+ @property
429
+ def extra_links_params(self) -> dict[str, Any]:
430
+ return {
431
+ "region": self.region,
432
+ "project_id": self.project_id,
433
+ }
434
+
410
435
  def execute(self, context: Context):
411
436
  hook = ModelServiceHook(
412
437
  gcp_conn_id=self.gcp_conn_id,
@@ -428,8 +453,8 @@ class UploadModelOperator(GoogleCloudBaseOperator):
428
453
  model_id = hook.extract_model_id(model_resp)
429
454
  self.log.info("Model was uploaded. Model ID: %s", model_id)
430
455
 
431
- self.xcom_push(context, key="model_id", value=model_id)
432
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
456
+ context["ti"].xcom_push(key="model_id", value=model_id)
457
+ VertexAIModelLink.persist(context=context, model_id=model_id)
433
458
  return model_resp
434
459
 
435
460
 
@@ -553,6 +578,13 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
553
578
  self.gcp_conn_id = gcp_conn_id
554
579
  self.impersonation_chain = impersonation_chain
555
580
 
581
+ @property
582
+ def extra_links_params(self) -> dict[str, Any]:
583
+ return {
584
+ "region": self.region,
585
+ "project_id": self.project_id,
586
+ }
587
+
556
588
  def execute(self, context: Context):
557
589
  hook = ModelServiceHook(
558
590
  gcp_conn_id=self.gcp_conn_id,
@@ -571,7 +603,7 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
571
603
  timeout=self.timeout,
572
604
  metadata=self.metadata,
573
605
  )
574
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
606
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
575
607
  return Model.to_dict(updated_model)
576
608
 
577
609
 
@@ -627,6 +659,13 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
627
659
  self.gcp_conn_id = gcp_conn_id
628
660
  self.impersonation_chain = impersonation_chain
629
661
 
662
+ @property
663
+ def extra_links_params(self) -> dict[str, Any]:
664
+ return {
665
+ "region": self.region,
666
+ "project_id": self.project_id,
667
+ }
668
+
630
669
  def execute(self, context: Context):
631
670
  hook = ModelServiceHook(
632
671
  gcp_conn_id=self.gcp_conn_id,
@@ -645,7 +684,7 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
645
684
  timeout=self.timeout,
646
685
  metadata=self.metadata,
647
686
  )
648
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
687
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
649
688
  return Model.to_dict(updated_model)
650
689
 
651
690
 
@@ -701,6 +740,13 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
701
740
  self.gcp_conn_id = gcp_conn_id
702
741
  self.impersonation_chain = impersonation_chain
703
742
 
743
+ @property
744
+ def extra_links_params(self) -> dict[str, Any]:
745
+ return {
746
+ "region": self.region,
747
+ "project_id": self.project_id,
748
+ }
749
+
704
750
  def execute(self, context: Context):
705
751
  hook = ModelServiceHook(
706
752
  gcp_conn_id=self.gcp_conn_id,
@@ -721,7 +767,7 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
721
767
  timeout=self.timeout,
722
768
  metadata=self.metadata,
723
769
  )
724
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
770
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
725
771
  return Model.to_dict(updated_model)
726
772
 
727
773
 
@@ -166,6 +166,13 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
166
166
  self.deferrable = deferrable
167
167
  self.poll_interval = poll_interval
168
168
 
169
+ @property
170
+ def extra_links_params(self) -> dict[str, Any]:
171
+ return {
172
+ "region": self.region,
173
+ "project_id": self.project_id,
174
+ }
175
+
169
176
  def execute(self, context: Context):
170
177
  self.log.info("Running Pipeline job")
171
178
  pipeline_job_obj: PipelineJob = self.hook.submit_pipeline_job(
@@ -188,8 +195,8 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
188
195
  )
189
196
  pipeline_job_id = pipeline_job_obj.job_id
190
197
  self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id)
191
- self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id)
192
- VertexAIPipelineJobLink.persist(context=context, task_instance=self, pipeline_id=pipeline_job_id)
198
+ context["ti"].xcom_push(key="pipeline_job_id", value=pipeline_job_id)
199
+ VertexAIPipelineJobLink.persist(context=context, pipeline_id=pipeline_job_id)
193
200
 
194
201
  if self.deferrable:
195
202
  pipeline_job_obj.wait_for_resource_creation()
@@ -280,6 +287,13 @@ class GetPipelineJobOperator(GoogleCloudBaseOperator):
280
287
  self.gcp_conn_id = gcp_conn_id
281
288
  self.impersonation_chain = impersonation_chain
282
289
 
290
+ @property
291
+ def extra_links_params(self) -> dict[str, Any]:
292
+ return {
293
+ "region": self.region,
294
+ "project_id": self.project_id,
295
+ }
296
+
283
297
  def execute(self, context: Context):
284
298
  hook = PipelineJobHook(
285
299
  gcp_conn_id=self.gcp_conn_id,
@@ -296,9 +310,7 @@ class GetPipelineJobOperator(GoogleCloudBaseOperator):
296
310
  timeout=self.timeout,
297
311
  metadata=self.metadata,
298
312
  )
299
- VertexAIPipelineJobLink.persist(
300
- context=context, task_instance=self, pipeline_id=self.pipeline_job_id
301
- )
313
+ VertexAIPipelineJobLink.persist(context=context, pipeline_id=self.pipeline_job_id)
302
314
  self.log.info("Pipeline job was gotten.")
303
315
  return types.PipelineJob.to_dict(result)
304
316
  except NotFound:
@@ -412,6 +424,13 @@ class ListPipelineJobOperator(GoogleCloudBaseOperator):
412
424
  self.gcp_conn_id = gcp_conn_id
413
425
  self.impersonation_chain = impersonation_chain
414
426
 
427
+ @property
428
+ def extra_links_params(self) -> dict[str, Any]:
429
+ return {
430
+ "region": self.region,
431
+ "project_id": self.project_id,
432
+ }
433
+
415
434
  def execute(self, context: Context):
416
435
  hook = PipelineJobHook(
417
436
  gcp_conn_id=self.gcp_conn_id,
@@ -428,7 +447,7 @@ class ListPipelineJobOperator(GoogleCloudBaseOperator):
428
447
  timeout=self.timeout,
429
448
  metadata=self.metadata,
430
449
  )
431
- VertexAIPipelineJobListLink.persist(context=context, task_instance=self)
450
+ VertexAIPipelineJobListLink.persist(context=context)
432
451
  return [types.PipelineJob.to_dict(result) for result in results]
433
452
 
434
453
 
@@ -188,12 +188,13 @@ class CreateRayClusterOperator(RayBaseOperator):
188
188
  labels=self.labels,
189
189
  )
190
190
  cluster_id = self.hook.extract_cluster_id(cluster_path)
191
- self.xcom_push(
192
- context=context,
191
+ context["ti"].xcom_push(
193
192
  key="cluster_id",
194
193
  value=cluster_id,
195
194
  )
196
- VertexAIRayClusterLink.persist(context=context, task_instance=self, cluster_id=cluster_id)
195
+ VertexAIRayClusterLink.persist(
196
+ context=context, location=self.location, cluster_id=cluster_id, project_id=self.project_id
197
+ )
197
198
  self.log.info("Ray cluster was created.")
198
199
  except Exception as error:
199
200
  raise AirflowException(error)
@@ -220,7 +221,7 @@ class ListRayClustersOperator(RayBaseOperator):
220
221
  operator_extra_links = (VertexAIRayClusterListLink(),)
221
222
 
222
223
  def execute(self, context: Context):
223
- VertexAIRayClusterListLink.persist(context=context, task_instance=self)
224
+ VertexAIRayClusterListLink.persist(context=context, project_id=self.project_id)
224
225
  self.log.info("Listing Clusters from location %s.", self.location)
225
226
  try:
226
227
  ray_cluster_list = self.hook.list_ray_clusters(
@@ -268,8 +269,9 @@ class GetRayClusterOperator(RayBaseOperator):
268
269
  def execute(self, context: Context):
269
270
  VertexAIRayClusterLink.persist(
270
271
  context=context,
271
- task_instance=self,
272
+ location=self.location,
272
273
  cluster_id=self.cluster_id,
274
+ project_id=self.project_id,
273
275
  )
274
276
  self.log.info("Getting Cluster: %s", self.cluster_id)
275
277
  try:
@@ -325,8 +327,9 @@ class UpdateRayClusterOperator(RayBaseOperator):
325
327
  def execute(self, context: Context):
326
328
  VertexAIRayClusterLink.persist(
327
329
  context=context,
328
- task_instance=self,
330
+ location=self.location,
329
331
  cluster_id=self.cluster_id,
332
+ project_id=self.project_id,
330
333
  )
331
334
  self.log.info("Updating a Ray cluster.")
332
335
  try:
@@ -147,7 +147,6 @@ class WorkflowsCreateWorkflowOperator(GoogleCloudBaseOperator):
147
147
 
148
148
  WorkflowsWorkflowDetailsLink.persist(
149
149
  context=context,
150
- task_instance=self,
151
150
  location_id=self.location,
152
151
  workflow_id=self.workflow_id,
153
152
  project_id=self.project_id or hook.project_id,
@@ -235,7 +234,6 @@ class WorkflowsUpdateWorkflowOperator(GoogleCloudBaseOperator):
235
234
 
236
235
  WorkflowsWorkflowDetailsLink.persist(
237
236
  context=context,
238
- task_instance=self,
239
237
  location_id=self.location,
240
238
  workflow_id=self.workflow_id,
241
239
  project_id=self.project_id or hook.project_id,
@@ -368,7 +366,6 @@ class WorkflowsListWorkflowsOperator(GoogleCloudBaseOperator):
368
366
 
369
367
  WorkflowsListOfWorkflowsLink.persist(
370
368
  context=context,
371
- task_instance=self,
372
369
  project_id=self.project_id or hook.project_id,
373
370
  )
374
371
 
@@ -434,7 +431,6 @@ class WorkflowsGetWorkflowOperator(GoogleCloudBaseOperator):
434
431
 
435
432
  WorkflowsWorkflowDetailsLink.persist(
436
433
  context=context,
437
- task_instance=self,
438
434
  location_id=self.location,
439
435
  workflow_id=self.workflow_id,
440
436
  project_id=self.project_id or hook.project_id,
@@ -505,11 +501,10 @@ class WorkflowsCreateExecutionOperator(GoogleCloudBaseOperator):
505
501
  metadata=self.metadata,
506
502
  )
507
503
  execution_id = execution.name.split("/")[-1]
508
- self.xcom_push(context, key="execution_id", value=execution_id)
504
+ context["task_instance"].xcom_push(key="execution_id", value=execution_id)
509
505
 
510
506
  WorkflowsExecutionLink.persist(
511
507
  context=context,
512
- task_instance=self,
513
508
  location_id=self.location,
514
509
  workflow_id=self.workflow_id,
515
510
  execution_id=execution_id,
@@ -582,7 +577,6 @@ class WorkflowsCancelExecutionOperator(GoogleCloudBaseOperator):
582
577
 
583
578
  WorkflowsExecutionLink.persist(
584
579
  context=context,
585
- task_instance=self,
586
580
  location_id=self.location,
587
581
  workflow_id=self.workflow_id,
588
582
  execution_id=self.execution_id,
@@ -661,7 +655,6 @@ class WorkflowsListExecutionsOperator(GoogleCloudBaseOperator):
661
655
 
662
656
  WorkflowsWorkflowDetailsLink.persist(
663
657
  context=context,
664
- task_instance=self,
665
658
  location_id=self.location,
666
659
  workflow_id=self.workflow_id,
667
660
  project_id=self.project_id or hook.project_id,
@@ -737,7 +730,6 @@ class WorkflowsGetExecutionOperator(GoogleCloudBaseOperator):
737
730
 
738
731
  WorkflowsExecutionLink.persist(
739
732
  context=context,
740
- task_instance=self,
741
733
  location_id=self.location,
742
734
  workflow_id=self.workflow_id,
743
735
  execution_id=self.execution_id,
@@ -31,7 +31,7 @@ from airflow.providers.google.cloud.triggers.bigquery import (
31
31
  BigQueryTableExistenceTrigger,
32
32
  BigQueryTablePartitionExistenceTrigger,
33
33
  )
34
- from airflow.sensors.base import BaseSensorOperator
34
+ from airflow.providers.google.version_compat import BaseSensorOperator
35
35
 
36
36
  if TYPE_CHECKING:
37
37
  from airflow.utils.context import Context
@@ -28,7 +28,12 @@ from google.cloud.bigquery_datatransfer_v1 import TransferState
28
28
  from airflow.exceptions import AirflowException
29
29
  from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook
30
30
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
31
- from airflow.sensors.base import BaseSensorOperator
31
+ from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
32
+
33
+ if AIRFLOW_V_3_0_PLUS:
34
+ from airflow.sdk import BaseSensorOperator
35
+ else:
36
+ from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef]
32
37
 
33
38
  if TYPE_CHECKING:
34
39
  from google.api_core.retry import Retry
@@ -20,7 +20,7 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  from collections.abc import Sequence
23
- from typing import TYPE_CHECKING
23
+ from typing import TYPE_CHECKING, Any
24
24
 
25
25
  import google.api_core.exceptions
26
26
  from google.cloud.bigtable import enums
@@ -30,7 +30,12 @@ from airflow.providers.google.cloud.hooks.bigtable import BigtableHook
30
30
  from airflow.providers.google.cloud.links.bigtable import BigtableTablesLink
31
31
  from airflow.providers.google.cloud.operators.bigtable import BigtableValidationMixin
32
32
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
33
- from airflow.sensors.base import BaseSensorOperator
33
+ from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
34
+
35
+ if AIRFLOW_V_3_0_PLUS:
36
+ from airflow.sdk import BaseSensorOperator
37
+ else:
38
+ from airflow.sensors.base import BaseSensorOperator # type: ignore[no-redef]
34
39
 
35
40
  if TYPE_CHECKING:
36
41
  from airflow.utils.context import Context
@@ -89,6 +94,13 @@ class BigtableTableReplicationCompletedSensor(BaseSensorOperator, BigtableValida
89
94
  self.impersonation_chain = impersonation_chain
90
95
  super().__init__(**kwargs)
91
96
 
97
+ @property
98
+ def extra_links_params(self) -> dict[str, Any]:
99
+ return {
100
+ "instance_id": self.instance_id,
101
+ "project_id": self.project_id,
102
+ }
103
+
92
104
  def poke(self, context: Context) -> bool:
93
105
  hook = BigtableHook(
94
106
  gcp_conn_id=self.gcp_conn_id,
@@ -119,5 +131,5 @@ class BigtableTableReplicationCompletedSensor(BaseSensorOperator, BigtableValida
119
131
  return False
120
132
 
121
133
  self.log.info("Table '%s' is replicated.", self.table_id)
122
- BigtableTablesLink.persist(context=context, task_instance=self)
134
+ BigtableTablesLink.persist(context=context)
123
135
  return True