apache-airflow-providers-google 16.0.0a1__py3-none-any.whl → 16.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +43 -5
  3. airflow/providers/google/ads/operators/ads.py +1 -1
  4. airflow/providers/google/ads/transfers/ads_to_gcs.py +1 -1
  5. airflow/providers/google/cloud/hooks/bigquery.py +63 -77
  6. airflow/providers/google/cloud/hooks/cloud_sql.py +8 -4
  7. airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +2 -2
  9. airflow/providers/google/cloud/hooks/dataplex.py +1 -1
  10. airflow/providers/google/cloud/hooks/dataprep.py +4 -1
  11. airflow/providers/google/cloud/hooks/gcs.py +5 -5
  12. airflow/providers/google/cloud/hooks/looker.py +10 -1
  13. airflow/providers/google/cloud/hooks/mlengine.py +2 -1
  14. airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
  15. airflow/providers/google/cloud/hooks/spanner.py +2 -2
  16. airflow/providers/google/cloud/hooks/translate.py +1 -1
  17. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -36
  18. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
  19. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +44 -80
  20. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +11 -2
  21. airflow/providers/google/cloud/hooks/vision.py +2 -2
  22. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  23. airflow/providers/google/cloud/links/base.py +75 -11
  24. airflow/providers/google/cloud/links/bigquery.py +0 -47
  25. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  26. airflow/providers/google/cloud/links/bigtable.py +0 -48
  27. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  28. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  29. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  30. airflow/providers/google/cloud/links/cloud_run.py +27 -0
  31. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  32. airflow/providers/google/cloud/links/cloud_storage_transfer.py +16 -43
  33. airflow/providers/google/cloud/links/cloud_tasks.py +6 -25
  34. airflow/providers/google/cloud/links/compute.py +0 -58
  35. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  36. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  37. airflow/providers/google/cloud/links/dataflow.py +0 -34
  38. airflow/providers/google/cloud/links/dataform.py +0 -64
  39. airflow/providers/google/cloud/links/datafusion.py +1 -96
  40. airflow/providers/google/cloud/links/dataplex.py +0 -154
  41. airflow/providers/google/cloud/links/dataprep.py +0 -24
  42. airflow/providers/google/cloud/links/dataproc.py +14 -90
  43. airflow/providers/google/cloud/links/datastore.py +0 -31
  44. airflow/providers/google/cloud/links/kubernetes_engine.py +5 -59
  45. airflow/providers/google/cloud/links/life_sciences.py +0 -19
  46. airflow/providers/google/cloud/links/managed_kafka.py +0 -70
  47. airflow/providers/google/cloud/links/mlengine.py +0 -70
  48. airflow/providers/google/cloud/links/pubsub.py +0 -32
  49. airflow/providers/google/cloud/links/spanner.py +0 -33
  50. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  51. airflow/providers/google/cloud/links/translate.py +16 -186
  52. airflow/providers/google/cloud/links/vertex_ai.py +8 -224
  53. airflow/providers/google/cloud/links/workflows.py +0 -52
  54. airflow/providers/google/cloud/log/gcs_task_handler.py +4 -4
  55. airflow/providers/google/cloud/operators/alloy_db.py +69 -54
  56. airflow/providers/google/cloud/operators/automl.py +16 -14
  57. airflow/providers/google/cloud/operators/bigquery.py +49 -25
  58. airflow/providers/google/cloud/operators/bigquery_dts.py +2 -4
  59. airflow/providers/google/cloud/operators/bigtable.py +35 -6
  60. airflow/providers/google/cloud/operators/cloud_base.py +21 -1
  61. airflow/providers/google/cloud/operators/cloud_build.py +74 -31
  62. airflow/providers/google/cloud/operators/cloud_composer.py +34 -35
  63. airflow/providers/google/cloud/operators/cloud_memorystore.py +68 -42
  64. airflow/providers/google/cloud/operators/cloud_run.py +9 -1
  65. airflow/providers/google/cloud/operators/cloud_sql.py +11 -15
  66. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +0 -2
  67. airflow/providers/google/cloud/operators/compute.py +7 -39
  68. airflow/providers/google/cloud/operators/datacatalog.py +156 -20
  69. airflow/providers/google/cloud/operators/dataflow.py +37 -14
  70. airflow/providers/google/cloud/operators/dataform.py +14 -4
  71. airflow/providers/google/cloud/operators/datafusion.py +4 -12
  72. airflow/providers/google/cloud/operators/dataplex.py +180 -96
  73. airflow/providers/google/cloud/operators/dataprep.py +0 -4
  74. airflow/providers/google/cloud/operators/dataproc.py +10 -16
  75. airflow/providers/google/cloud/operators/dataproc_metastore.py +95 -87
  76. airflow/providers/google/cloud/operators/datastore.py +21 -5
  77. airflow/providers/google/cloud/operators/dlp.py +3 -26
  78. airflow/providers/google/cloud/operators/functions.py +15 -6
  79. airflow/providers/google/cloud/operators/gcs.py +1 -7
  80. airflow/providers/google/cloud/operators/kubernetes_engine.py +53 -92
  81. airflow/providers/google/cloud/operators/life_sciences.py +0 -1
  82. airflow/providers/google/cloud/operators/managed_kafka.py +106 -51
  83. airflow/providers/google/cloud/operators/mlengine.py +0 -1
  84. airflow/providers/google/cloud/operators/pubsub.py +4 -5
  85. airflow/providers/google/cloud/operators/spanner.py +0 -4
  86. airflow/providers/google/cloud/operators/speech_to_text.py +0 -1
  87. airflow/providers/google/cloud/operators/stackdriver.py +0 -8
  88. airflow/providers/google/cloud/operators/tasks.py +0 -11
  89. airflow/providers/google/cloud/operators/text_to_speech.py +0 -1
  90. airflow/providers/google/cloud/operators/translate.py +37 -13
  91. airflow/providers/google/cloud/operators/translate_speech.py +0 -1
  92. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +31 -18
  93. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +28 -8
  94. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +38 -25
  95. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +69 -7
  96. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +42 -8
  97. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +531 -0
  98. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +93 -117
  99. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +10 -8
  100. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +56 -10
  101. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +29 -6
  102. airflow/providers/google/cloud/operators/vertex_ai/ray.py +9 -6
  103. airflow/providers/google/cloud/operators/workflows.py +1 -9
  104. airflow/providers/google/cloud/sensors/bigquery.py +1 -1
  105. airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -1
  106. airflow/providers/google/cloud/sensors/bigtable.py +15 -3
  107. airflow/providers/google/cloud/sensors/cloud_composer.py +6 -1
  108. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -1
  109. airflow/providers/google/cloud/sensors/dataflow.py +3 -3
  110. airflow/providers/google/cloud/sensors/dataform.py +6 -1
  111. airflow/providers/google/cloud/sensors/datafusion.py +6 -1
  112. airflow/providers/google/cloud/sensors/dataplex.py +6 -1
  113. airflow/providers/google/cloud/sensors/dataprep.py +6 -1
  114. airflow/providers/google/cloud/sensors/dataproc.py +6 -1
  115. airflow/providers/google/cloud/sensors/dataproc_metastore.py +6 -1
  116. airflow/providers/google/cloud/sensors/gcs.py +9 -3
  117. airflow/providers/google/cloud/sensors/looker.py +6 -1
  118. airflow/providers/google/cloud/sensors/pubsub.py +8 -3
  119. airflow/providers/google/cloud/sensors/tasks.py +6 -1
  120. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +6 -1
  121. airflow/providers/google/cloud/sensors/workflows.py +6 -1
  122. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +1 -1
  123. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +1 -1
  124. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +10 -7
  125. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -2
  126. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +0 -1
  127. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -1
  128. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  129. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +1 -1
  130. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +2 -2
  131. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
  132. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -1
  133. airflow/providers/google/cloud/transfers/gcs_to_local.py +1 -1
  134. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
  135. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +5 -1
  136. airflow/providers/google/cloud/transfers/gdrive_to_local.py +1 -1
  137. airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
  138. airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -1
  139. airflow/providers/google/cloud/transfers/s3_to_gcs.py +11 -5
  140. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +1 -1
  141. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
  142. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +2 -2
  143. airflow/providers/google/cloud/transfers/sql_to_gcs.py +1 -1
  144. airflow/providers/google/cloud/triggers/bigquery.py +32 -5
  145. airflow/providers/google/cloud/triggers/dataproc.py +62 -10
  146. airflow/providers/google/cloud/utils/field_validator.py +1 -2
  147. airflow/providers/google/common/auth_backend/google_openid.py +2 -1
  148. airflow/providers/google/common/deprecated.py +2 -1
  149. airflow/providers/google/common/hooks/base_google.py +7 -3
  150. airflow/providers/google/common/links/storage.py +0 -22
  151. airflow/providers/google/firebase/operators/firestore.py +1 -1
  152. airflow/providers/google/get_provider_info.py +14 -16
  153. airflow/providers/google/leveldb/hooks/leveldb.py +30 -1
  154. airflow/providers/google/leveldb/operators/leveldb.py +1 -1
  155. airflow/providers/google/marketing_platform/links/analytics_admin.py +3 -6
  156. airflow/providers/google/marketing_platform/operators/analytics_admin.py +0 -1
  157. airflow/providers/google/marketing_platform/operators/campaign_manager.py +4 -4
  158. airflow/providers/google/marketing_platform/operators/display_video.py +6 -6
  159. airflow/providers/google/marketing_platform/operators/search_ads.py +1 -1
  160. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +6 -1
  161. airflow/providers/google/marketing_platform/sensors/display_video.py +6 -1
  162. airflow/providers/google/suite/operators/sheets.py +3 -3
  163. airflow/providers/google/suite/sensors/drive.py +6 -1
  164. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +1 -1
  165. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  166. airflow/providers/google/suite/transfers/local_to_drive.py +1 -1
  167. airflow/providers/google/version_compat.py +28 -0
  168. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/METADATA +35 -35
  169. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/RECORD +171 -170
  170. airflow/providers/google/cloud/links/automl.py +0 -193
  171. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/WHEEL +0 -0
  172. {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/entry_points.txt +0 -0
@@ -37,6 +37,7 @@ from airflow.providers.google.cloud.links.translate import (
37
37
  TranslationNativeDatasetLink,
38
38
  )
39
39
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
40
+ from airflow.providers.google.cloud.operators.vertex_ai.dataset import DatasetImportDataResultsCheckHelper
40
41
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
41
42
 
42
43
  if TYPE_CHECKING:
@@ -394,7 +395,6 @@ class TranslateTextBatchOperator(GoogleCloudBaseOperator):
394
395
  self.log.info("Translate text batch job started.")
395
396
  TranslateTextBatchLink.persist(
396
397
  context=context,
397
- task_instance=self,
398
398
  project_id=self.project_id or hook.project_id,
399
399
  output_config=self.output_config,
400
400
  )
@@ -480,15 +480,15 @@ class TranslateCreateDatasetOperator(GoogleCloudBaseOperator):
480
480
  result = hook.wait_for_operation_result(result_operation)
481
481
  result = type(result).to_dict(result)
482
482
  dataset_id = hook.extract_object_id(result)
483
- self.xcom_push(context, key="dataset_id", value=dataset_id)
483
+ context["ti"].xcom_push(key="dataset_id", value=dataset_id)
484
484
  self.log.info("Dataset creation complete. The dataset_id: %s.", dataset_id)
485
485
 
486
486
  project_id = self.project_id or hook.project_id
487
487
  TranslationNativeDatasetLink.persist(
488
488
  context=context,
489
- task_instance=self,
490
489
  dataset_id=dataset_id,
491
490
  project_id=project_id,
491
+ location=self.location,
492
492
  )
493
493
  return result
494
494
 
@@ -556,7 +556,6 @@ class TranslateDatasetsListOperator(GoogleCloudBaseOperator):
556
556
  project_id = self.project_id or hook.project_id
557
557
  TranslationDatasetsListLink.persist(
558
558
  context=context,
559
- task_instance=self,
560
559
  project_id=project_id,
561
560
  )
562
561
  self.log.info("Requesting datasets list")
@@ -577,7 +576,7 @@ class TranslateDatasetsListOperator(GoogleCloudBaseOperator):
577
576
  return result_ids
578
577
 
579
578
 
580
- class TranslateImportDataOperator(GoogleCloudBaseOperator):
579
+ class TranslateImportDataOperator(GoogleCloudBaseOperator, DatasetImportDataResultsCheckHelper):
581
580
  """
582
581
  Import data to the translation dataset.
583
582
 
@@ -604,6 +603,7 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
604
603
  If set as a sequence, the identities from the list must grant
605
604
  Service Account Token Creator IAM role to the directly preceding identity, with first
606
605
  account from the list granting this role to the originating account (templated).
606
+ :param raise_for_empty_result: Raise an error if no additional data has been populated after the import.
607
607
  """
608
608
 
609
609
  template_fields: Sequence[str] = (
@@ -629,6 +629,7 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
629
629
  retry: Retry | _MethodDefault = DEFAULT,
630
630
  gcp_conn_id: str = "google_cloud_default",
631
631
  impersonation_chain: str | Sequence[str] | None = None,
632
+ raise_for_empty_result: bool = False,
632
633
  **kwargs,
633
634
  ) -> None:
634
635
  super().__init__(**kwargs)
@@ -641,9 +642,21 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
641
642
  self.retry = retry
642
643
  self.gcp_conn_id = gcp_conn_id
643
644
  self.impersonation_chain = impersonation_chain
645
+ self.raise_for_empty_result = raise_for_empty_result
644
646
 
645
647
  def execute(self, context: Context):
646
648
  hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
649
+ initial_dataset_size = self._get_number_of_ds_items(
650
+ dataset=hook.get_dataset(
651
+ dataset_id=self.dataset_id,
652
+ project_id=self.project_id,
653
+ location=self.location,
654
+ retry=self.retry,
655
+ timeout=self.timeout,
656
+ metadata=self.metadata,
657
+ ),
658
+ total_key_name="example_count",
659
+ )
647
660
  self.log.info("Importing data to dataset...")
648
661
  operation = hook.import_dataset_data(
649
662
  dataset_id=self.dataset_id,
@@ -657,12 +670,27 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
657
670
  project_id = self.project_id or hook.project_id
658
671
  TranslationNativeDatasetLink.persist(
659
672
  context=context,
660
- task_instance=self,
661
673
  dataset_id=self.dataset_id,
662
674
  project_id=project_id,
675
+ location=self.location,
663
676
  )
664
677
  hook.wait_for_operation_done(operation=operation, timeout=self.timeout)
678
+
679
+ result_dataset_size = self._get_number_of_ds_items(
680
+ dataset=hook.get_dataset(
681
+ dataset_id=self.dataset_id,
682
+ project_id=self.project_id,
683
+ location=self.location,
684
+ retry=self.retry,
685
+ timeout=self.timeout,
686
+ metadata=self.metadata,
687
+ ),
688
+ total_key_name="example_count",
689
+ )
690
+ if self.raise_for_empty_result:
691
+ self._raise_for_empty_import_result(self.dataset_id, initial_dataset_size, result_dataset_size)
665
692
  self.log.info("Importing data finished!")
693
+ return {"total_imported": int(result_dataset_size) - int(initial_dataset_size)}
666
694
 
667
695
 
668
696
  class TranslateDeleteDatasetOperator(GoogleCloudBaseOperator):
@@ -821,16 +849,16 @@ class TranslateCreateModelOperator(GoogleCloudBaseOperator):
821
849
  result = hook.wait_for_operation_result(operation=result_operation)
822
850
  result = type(result).to_dict(result)
823
851
  model_id = hook.extract_object_id(result)
824
- self.xcom_push(context, key="model_id", value=model_id)
852
+ context["ti"].xcom_push(key="model_id", value=model_id)
825
853
  self.log.info("Model creation complete. The model_id: %s.", model_id)
826
854
 
827
855
  project_id = self.project_id or hook.project_id
828
856
  TranslationModelLink.persist(
829
857
  context=context,
830
- task_instance=self,
831
858
  dataset_id=self.dataset_id,
832
859
  model_id=model_id,
833
860
  project_id=project_id,
861
+ location=self.location,
834
862
  )
835
863
  return result
836
864
 
@@ -898,7 +926,6 @@ class TranslateModelsListOperator(GoogleCloudBaseOperator):
898
926
  project_id = self.project_id or hook.project_id
899
927
  TranslationModelsListLink.persist(
900
928
  context=context,
901
- task_instance=self,
902
929
  project_id=project_id,
903
930
  )
904
931
  self.log.info("Requesting models list")
@@ -1141,7 +1168,6 @@ class TranslateDocumentOperator(GoogleCloudBaseOperator):
1141
1168
  if self.document_output_config:
1142
1169
  TranslateResultByOutputConfigLink.persist(
1143
1170
  context=context,
1144
- task_instance=self,
1145
1171
  project_id=self.project_id or hook.project_id,
1146
1172
  output_config=self.document_output_config,
1147
1173
  )
@@ -1304,7 +1330,6 @@ class TranslateDocumentBatchOperator(GoogleCloudBaseOperator):
1304
1330
  self.log.info("Batch document translation job started.")
1305
1331
  TranslateResultByOutputConfigLink.persist(
1306
1332
  context=context,
1307
- task_instance=self,
1308
1333
  project_id=self.project_id or hook.project_id,
1309
1334
  output_config=self.output_config,
1310
1335
  )
@@ -1411,7 +1436,7 @@ class TranslateCreateGlossaryOperator(GoogleCloudBaseOperator):
1411
1436
  result = type(result).to_dict(result)
1412
1437
 
1413
1438
  glossary_id = hook.extract_object_id(result)
1414
- self.xcom_push(context, key="glossary_id", value=glossary_id)
1439
+ context["ti"].xcom_push(key="glossary_id", value=glossary_id)
1415
1440
  self.log.info("Glossary creation complete. The glossary_id: %s.", glossary_id)
1416
1441
  return result
1417
1442
 
@@ -1610,7 +1635,6 @@ class TranslateListGlossariesOperator(GoogleCloudBaseOperator):
1610
1635
  project_id = self.project_id or hook.project_id
1611
1636
  TranslationGlossariesListLink.persist(
1612
1637
  context=context,
1613
- task_instance=self,
1614
1638
  project_id=project_id,
1615
1639
  )
1616
1640
  self.log.info("Requesting glossaries list")
@@ -173,7 +173,6 @@ class CloudTranslateSpeechOperator(GoogleCloudBaseOperator):
173
173
  if self.audio.uri:
174
174
  FileDetailsLink.persist(
175
175
  context=context,
176
- task_instance=self,
177
176
  # Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
178
177
  uri=self.audio.uri[5:],
179
178
  project_id=self.project_id or translate_hook.project_id,
@@ -21,7 +21,7 @@
21
21
  from __future__ import annotations
22
22
 
23
23
  from collections.abc import Sequence
24
- from typing import TYPE_CHECKING
24
+ from typing import TYPE_CHECKING, Any
25
25
 
26
26
  from google.api_core.exceptions import NotFound
27
27
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -91,6 +91,13 @@ class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
91
91
  self.impersonation_chain = impersonation_chain
92
92
  self.hook: AutoMLHook | None = None
93
93
 
94
+ @property
95
+ def extra_links_params(self) -> dict[str, Any]:
96
+ return {
97
+ "region": self.region,
98
+ "project_id": self.project_id,
99
+ }
100
+
94
101
  def on_kill(self) -> None:
95
102
  """Act as a callback called when the operator is killed; cancel any running job."""
96
103
  if self.hook:
@@ -242,12 +249,12 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
242
249
  if model:
243
250
  result = Model.to_dict(model)
244
251
  model_id = self.hook.extract_model_id(result)
245
- self.xcom_push(context, key="model_id", value=model_id)
246
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
252
+ context["ti"].xcom_push(key="model_id", value=model_id)
253
+ VertexAIModelLink.persist(context=context, model_id=model_id)
247
254
  else:
248
255
  result = model # type: ignore
249
- self.xcom_push(context, key="training_id", value=training_id)
250
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
256
+ context["ti"].xcom_push(key="training_id", value=training_id)
257
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
251
258
  return result
252
259
 
253
260
 
@@ -334,12 +341,12 @@ class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
334
341
  if model:
335
342
  result = Model.to_dict(model)
336
343
  model_id = self.hook.extract_model_id(result)
337
- self.xcom_push(context, key="model_id", value=model_id)
338
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
344
+ context["ti"].xcom_push(key="model_id", value=model_id)
345
+ VertexAIModelLink.persist(context=context, model_id=model_id)
339
346
  else:
340
347
  result = model # type: ignore
341
- self.xcom_push(context, key="training_id", value=training_id)
342
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
348
+ context["ti"].xcom_push(key="training_id", value=training_id)
349
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
343
350
  return result
344
351
 
345
352
 
@@ -457,12 +464,12 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
457
464
  if model:
458
465
  result = Model.to_dict(model)
459
466
  model_id = self.hook.extract_model_id(result)
460
- self.xcom_push(context, key="model_id", value=model_id)
461
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
467
+ context["ti"].xcom_push(key="model_id", value=model_id)
468
+ VertexAIModelLink.persist(context=context, model_id=model_id)
462
469
  else:
463
470
  result = model # type: ignore
464
- self.xcom_push(context, key="training_id", value=training_id)
465
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
471
+ context["ti"].xcom_push(key="training_id", value=training_id)
472
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
466
473
  return result
467
474
 
468
475
 
@@ -531,12 +538,12 @@ class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
531
538
  if model:
532
539
  result = Model.to_dict(model)
533
540
  model_id = self.hook.extract_model_id(result)
534
- self.xcom_push(context, key="model_id", value=model_id)
535
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
541
+ context["ti"].xcom_push(key="model_id", value=model_id)
542
+ VertexAIModelLink.persist(context=context, model_id=model_id)
536
543
  else:
537
544
  result = model # type: ignore
538
- self.xcom_push(context, key="training_id", value=training_id)
539
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
545
+ context["ti"].xcom_push(key="training_id", value=training_id)
546
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
540
547
  return result
541
548
 
542
549
 
@@ -640,6 +647,12 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
640
647
  self.gcp_conn_id = gcp_conn_id
641
648
  self.impersonation_chain = impersonation_chain
642
649
 
650
+ @property
651
+ def extra_links_params(self) -> dict[str, Any]:
652
+ return {
653
+ "project_id": self.project_id,
654
+ }
655
+
643
656
  def execute(self, context: Context):
644
657
  hook = AutoMLHook(
645
658
  gcp_conn_id=self.gcp_conn_id,
@@ -656,5 +669,5 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
656
669
  timeout=self.timeout,
657
670
  metadata=self.metadata,
658
671
  )
659
- VertexAITrainingPipelinesLink.persist(context=context, task_instance=self)
672
+ VertexAITrainingPipelinesLink.persist(context=context)
660
673
  return [TrainingPipeline.to_dict(result) for result in results]
@@ -231,6 +231,13 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
231
231
  impersonation_chain=self.impersonation_chain,
232
232
  )
233
233
 
234
+ @property
235
+ def extra_links_params(self) -> dict[str, Any]:
236
+ return {
237
+ "region": self.region,
238
+ "project_id": self.project_id,
239
+ }
240
+
234
241
  def execute(self, context: Context):
235
242
  self.log.info("Creating Batch prediction job")
236
243
  batch_prediction_job: BatchPredictionJobObject = self.hook.submit_batch_prediction_job(
@@ -262,9 +269,10 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
262
269
  batch_prediction_job_id = batch_prediction_job.name
263
270
  self.log.info("Batch prediction job was created. Job id: %s", batch_prediction_job_id)
264
271
 
265
- self.xcom_push(context, key="batch_prediction_job_id", value=batch_prediction_job_id)
272
+ context["ti"].xcom_push(key="batch_prediction_job_id", value=batch_prediction_job_id)
266
273
  VertexAIBatchPredictionJobLink.persist(
267
- context=context, task_instance=self, batch_prediction_job_id=batch_prediction_job_id
274
+ context=context,
275
+ batch_prediction_job_id=batch_prediction_job_id,
268
276
  )
269
277
 
270
278
  if self.deferrable:
@@ -295,13 +303,11 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
295
303
  job: dict[str, Any] = event["job"]
296
304
  self.log.info("Batch prediction job %s created and completed successfully.", job["name"])
297
305
  job_id = self.hook.extract_batch_prediction_job_id(job)
298
- self.xcom_push(
299
- context,
306
+ context["ti"].xcom_push(
300
307
  key="batch_prediction_job_id",
301
308
  value=job_id,
302
309
  )
303
- self.xcom_push(
304
- context,
310
+ context["ti"].xcom_push(
305
311
  key="training_conf",
306
312
  value={
307
313
  "training_conf_id": job_id,
@@ -427,6 +433,13 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
427
433
  self.gcp_conn_id = gcp_conn_id
428
434
  self.impersonation_chain = impersonation_chain
429
435
 
436
+ @property
437
+ def extra_links_params(self) -> dict[str, Any]:
438
+ return {
439
+ "region": self.region,
440
+ "project_id": self.project_id,
441
+ }
442
+
430
443
  def execute(self, context: Context):
431
444
  hook = BatchPredictionJobHook(
432
445
  gcp_conn_id=self.gcp_conn_id,
@@ -445,7 +458,8 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
445
458
  )
446
459
  self.log.info("Batch prediction job was gotten.")
447
460
  VertexAIBatchPredictionJobLink.persist(
448
- context=context, task_instance=self, batch_prediction_job_id=self.batch_prediction_job
461
+ context=context,
462
+ batch_prediction_job_id=self.batch_prediction_job,
449
463
  )
450
464
  return BatchPredictionJob.to_dict(result)
451
465
  except NotFound:
@@ -517,6 +531,12 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
517
531
  self.gcp_conn_id = gcp_conn_id
518
532
  self.impersonation_chain = impersonation_chain
519
533
 
534
+ @property
535
+ def extra_links_params(self) -> dict[str, Any]:
536
+ return {
537
+ "project_id": self.project_id,
538
+ }
539
+
520
540
  def execute(self, context: Context):
521
541
  hook = BatchPredictionJobHook(
522
542
  gcp_conn_id=self.gcp_conn_id,
@@ -533,5 +553,5 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
533
553
  timeout=self.timeout,
534
554
  metadata=self.metadata,
535
555
  )
536
- VertexAIBatchPredictionJobListLink.persist(context=context, task_instance=self)
556
+ VertexAIBatchPredictionJobListLink.persist(context=context)
537
557
  return [BatchPredictionJob.to_dict(result) for result in results]
@@ -170,17 +170,24 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
170
170
  self.gcp_conn_id = gcp_conn_id
171
171
  self.impersonation_chain = impersonation_chain
172
172
 
173
+ @property
174
+ def extra_links_params(self) -> dict[str, Any]:
175
+ return {
176
+ "region": self.region,
177
+ "project_id": self.project_id,
178
+ }
179
+
173
180
  def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
174
181
  if event["status"] == "error":
175
182
  raise AirflowException(event["message"])
176
183
  training_pipeline = event["job"]
177
184
  custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(training_pipeline)
178
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
185
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
179
186
  try:
180
187
  model = training_pipeline["model_to_upload"]
181
188
  model_id = self.hook.extract_model_id(model)
182
- self.xcom_push(context, key="model_id", value=model_id)
183
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
189
+ context["ti"].xcom_push(key="model_id", value=model_id)
190
+ VertexAIModelLink.persist(context=context, model_id=model_id)
184
191
  return model
185
192
  except KeyError:
186
193
  self.log.warning(
@@ -584,13 +591,13 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
584
591
  if model:
585
592
  result = Model.to_dict(model)
586
593
  model_id = self.hook.extract_model_id(result)
587
- self.xcom_push(context, key="model_id", value=model_id)
588
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
594
+ context["ti"].xcom_push(key="model_id", value=model_id)
595
+ VertexAIModelLink.persist(context=context, model_id=model_id)
589
596
  else:
590
597
  result = model # type: ignore
591
- self.xcom_push(context, key="training_id", value=training_id)
592
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
593
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
598
+ context["ti"].xcom_push(key="training_id", value=training_id)
599
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
600
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
594
601
  return result
595
602
 
596
603
  def invoke_defer(self, context: Context) -> None:
@@ -648,8 +655,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
648
655
  )
649
656
  custom_container_training_job_obj.wait_for_resource_creation()
650
657
  training_pipeline_id: str = custom_container_training_job_obj.name
651
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
652
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
658
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
659
+ VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
653
660
  self.defer(
654
661
  trigger=CustomContainerTrainingJobTrigger(
655
662
  conn_id=self.gcp_conn_id,
@@ -1041,13 +1048,13 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
1041
1048
  if model:
1042
1049
  result = Model.to_dict(model)
1043
1050
  model_id = self.hook.extract_model_id(result)
1044
- self.xcom_push(context, key="model_id", value=model_id)
1045
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1051
+ context["ti"].xcom_push(key="model_id", value=model_id)
1052
+ VertexAIModelLink.persist(context=context, model_id=model_id)
1046
1053
  else:
1047
1054
  result = model # type: ignore
1048
- self.xcom_push(context, key="training_id", value=training_id)
1049
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
1050
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
1055
+ context["ti"].xcom_push(key="training_id", value=training_id)
1056
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
1057
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
1051
1058
  return result
1052
1059
 
1053
1060
  def invoke_defer(self, context: Context) -> None:
@@ -1106,8 +1113,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
1106
1113
  )
1107
1114
  custom_python_training_job_obj.wait_for_resource_creation()
1108
1115
  training_pipeline_id: str = custom_python_training_job_obj.name
1109
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
1110
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
1116
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
1117
+ VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
1111
1118
  self.defer(
1112
1119
  trigger=CustomPythonPackageTrainingJobTrigger(
1113
1120
  conn_id=self.gcp_conn_id,
@@ -1504,13 +1511,13 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1504
1511
  if model:
1505
1512
  result = Model.to_dict(model)
1506
1513
  model_id = self.hook.extract_model_id(result)
1507
- self.xcom_push(context, key="model_id", value=model_id)
1508
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1514
+ context["ti"].xcom_push(key="model_id", value=model_id)
1515
+ VertexAIModelLink.persist(context=context, model_id=model_id)
1509
1516
  else:
1510
1517
  result = model # type: ignore
1511
- self.xcom_push(context, key="training_id", value=training_id)
1512
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
1513
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
1518
+ context["ti"].xcom_push(key="training_id", value=training_id)
1519
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
1520
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
1514
1521
  return result
1515
1522
 
1516
1523
  def invoke_defer(self, context: Context) -> None:
@@ -1569,8 +1576,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1569
1576
  )
1570
1577
  custom_training_job_obj.wait_for_resource_creation()
1571
1578
  training_pipeline_id: str = custom_training_job_obj.name
1572
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
1573
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
1579
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
1580
+ VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
1574
1581
  self.defer(
1575
1582
  trigger=CustomTrainingJobTrigger(
1576
1583
  conn_id=self.gcp_conn_id,
@@ -1748,6 +1755,12 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
1748
1755
  self.gcp_conn_id = gcp_conn_id
1749
1756
  self.impersonation_chain = impersonation_chain
1750
1757
 
1758
+ @property
1759
+ def extra_links_params(self) -> dict[str, Any]:
1760
+ return {
1761
+ "project_id": self.project_id,
1762
+ }
1763
+
1751
1764
  def execute(self, context: Context):
1752
1765
  hook = CustomJobHook(
1753
1766
  gcp_conn_id=self.gcp_conn_id,
@@ -1764,5 +1777,5 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
1764
1777
  timeout=self.timeout,
1765
1778
  metadata=self.metadata,
1766
1779
  )
1767
- VertexAITrainingPipelinesLink.persist(context=context, task_instance=self)
1780
+ VertexAITrainingPipelinesLink.persist(context=context)
1768
1781
  return [TrainingPipeline.to_dict(result) for result in results]