apache-airflow-providers-google 16.0.0rc1__py3-none-any.whl → 16.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +9 -5
  3. airflow/providers/google/ads/operators/ads.py +1 -1
  4. airflow/providers/google/ads/transfers/ads_to_gcs.py +1 -1
  5. airflow/providers/google/cloud/hooks/bigquery.py +2 -3
  6. airflow/providers/google/cloud/hooks/cloud_sql.py +8 -4
  7. airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +2 -2
  9. airflow/providers/google/cloud/hooks/dataplex.py +1 -1
  10. airflow/providers/google/cloud/hooks/dataprep.py +4 -1
  11. airflow/providers/google/cloud/hooks/gcs.py +2 -2
  12. airflow/providers/google/cloud/hooks/looker.py +5 -1
  13. airflow/providers/google/cloud/hooks/mlengine.py +2 -1
  14. airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
  15. airflow/providers/google/cloud/hooks/spanner.py +2 -2
  16. airflow/providers/google/cloud/hooks/translate.py +1 -1
  17. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
  18. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +43 -14
  19. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +11 -2
  20. airflow/providers/google/cloud/hooks/vision.py +2 -2
  21. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  22. airflow/providers/google/cloud/links/base.py +75 -11
  23. airflow/providers/google/cloud/links/bigquery.py +0 -47
  24. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  25. airflow/providers/google/cloud/links/bigtable.py +0 -48
  26. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  27. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  28. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  29. airflow/providers/google/cloud/links/cloud_run.py +1 -33
  30. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  31. airflow/providers/google/cloud/links/cloud_storage_transfer.py +16 -43
  32. airflow/providers/google/cloud/links/cloud_tasks.py +6 -25
  33. airflow/providers/google/cloud/links/compute.py +0 -58
  34. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  35. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  36. airflow/providers/google/cloud/links/dataflow.py +0 -34
  37. airflow/providers/google/cloud/links/dataform.py +0 -64
  38. airflow/providers/google/cloud/links/datafusion.py +1 -96
  39. airflow/providers/google/cloud/links/dataplex.py +0 -154
  40. airflow/providers/google/cloud/links/dataprep.py +0 -24
  41. airflow/providers/google/cloud/links/dataproc.py +14 -90
  42. airflow/providers/google/cloud/links/datastore.py +0 -31
  43. airflow/providers/google/cloud/links/kubernetes_engine.py +5 -59
  44. airflow/providers/google/cloud/links/life_sciences.py +0 -19
  45. airflow/providers/google/cloud/links/managed_kafka.py +0 -70
  46. airflow/providers/google/cloud/links/mlengine.py +0 -70
  47. airflow/providers/google/cloud/links/pubsub.py +0 -32
  48. airflow/providers/google/cloud/links/spanner.py +0 -33
  49. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  50. airflow/providers/google/cloud/links/translate.py +16 -186
  51. airflow/providers/google/cloud/links/vertex_ai.py +8 -224
  52. airflow/providers/google/cloud/links/workflows.py +0 -52
  53. airflow/providers/google/cloud/operators/alloy_db.py +69 -54
  54. airflow/providers/google/cloud/operators/automl.py +16 -14
  55. airflow/providers/google/cloud/operators/bigquery.py +0 -15
  56. airflow/providers/google/cloud/operators/bigquery_dts.py +2 -4
  57. airflow/providers/google/cloud/operators/bigtable.py +35 -6
  58. airflow/providers/google/cloud/operators/cloud_base.py +21 -1
  59. airflow/providers/google/cloud/operators/cloud_build.py +74 -31
  60. airflow/providers/google/cloud/operators/cloud_composer.py +34 -35
  61. airflow/providers/google/cloud/operators/cloud_memorystore.py +68 -42
  62. airflow/providers/google/cloud/operators/cloud_run.py +0 -1
  63. airflow/providers/google/cloud/operators/cloud_sql.py +11 -15
  64. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +0 -2
  65. airflow/providers/google/cloud/operators/compute.py +7 -39
  66. airflow/providers/google/cloud/operators/datacatalog.py +156 -20
  67. airflow/providers/google/cloud/operators/dataflow.py +37 -14
  68. airflow/providers/google/cloud/operators/dataform.py +14 -4
  69. airflow/providers/google/cloud/operators/datafusion.py +4 -12
  70. airflow/providers/google/cloud/operators/dataplex.py +180 -96
  71. airflow/providers/google/cloud/operators/dataprep.py +0 -4
  72. airflow/providers/google/cloud/operators/dataproc.py +10 -16
  73. airflow/providers/google/cloud/operators/dataproc_metastore.py +95 -87
  74. airflow/providers/google/cloud/operators/datastore.py +21 -5
  75. airflow/providers/google/cloud/operators/dlp.py +3 -26
  76. airflow/providers/google/cloud/operators/functions.py +15 -6
  77. airflow/providers/google/cloud/operators/gcs.py +0 -7
  78. airflow/providers/google/cloud/operators/kubernetes_engine.py +50 -7
  79. airflow/providers/google/cloud/operators/life_sciences.py +0 -1
  80. airflow/providers/google/cloud/operators/managed_kafka.py +106 -51
  81. airflow/providers/google/cloud/operators/mlengine.py +0 -1
  82. airflow/providers/google/cloud/operators/pubsub.py +2 -4
  83. airflow/providers/google/cloud/operators/spanner.py +0 -4
  84. airflow/providers/google/cloud/operators/speech_to_text.py +0 -1
  85. airflow/providers/google/cloud/operators/stackdriver.py +0 -8
  86. airflow/providers/google/cloud/operators/tasks.py +0 -11
  87. airflow/providers/google/cloud/operators/text_to_speech.py +0 -1
  88. airflow/providers/google/cloud/operators/translate.py +37 -13
  89. airflow/providers/google/cloud/operators/translate_speech.py +0 -1
  90. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +31 -18
  91. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +28 -8
  92. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +38 -25
  93. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +69 -7
  94. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +42 -8
  95. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +531 -0
  96. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +93 -25
  97. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +10 -8
  98. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +56 -10
  99. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +25 -6
  100. airflow/providers/google/cloud/operators/vertex_ai/ray.py +9 -6
  101. airflow/providers/google/cloud/operators/workflows.py +1 -9
  102. airflow/providers/google/cloud/sensors/bigquery.py +1 -1
  103. airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -1
  104. airflow/providers/google/cloud/sensors/bigtable.py +15 -3
  105. airflow/providers/google/cloud/sensors/cloud_composer.py +6 -1
  106. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -1
  107. airflow/providers/google/cloud/sensors/dataflow.py +3 -3
  108. airflow/providers/google/cloud/sensors/dataform.py +6 -1
  109. airflow/providers/google/cloud/sensors/datafusion.py +6 -1
  110. airflow/providers/google/cloud/sensors/dataplex.py +6 -1
  111. airflow/providers/google/cloud/sensors/dataprep.py +6 -1
  112. airflow/providers/google/cloud/sensors/dataproc.py +6 -1
  113. airflow/providers/google/cloud/sensors/dataproc_metastore.py +6 -1
  114. airflow/providers/google/cloud/sensors/gcs.py +9 -3
  115. airflow/providers/google/cloud/sensors/looker.py +6 -1
  116. airflow/providers/google/cloud/sensors/pubsub.py +8 -3
  117. airflow/providers/google/cloud/sensors/tasks.py +6 -1
  118. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +6 -1
  119. airflow/providers/google/cloud/sensors/workflows.py +6 -1
  120. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +1 -1
  121. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +1 -1
  122. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +1 -2
  123. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -2
  124. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +0 -1
  125. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -1
  126. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  127. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +1 -1
  128. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -1
  129. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
  130. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -1
  131. airflow/providers/google/cloud/transfers/gcs_to_local.py +1 -1
  132. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
  133. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +5 -1
  134. airflow/providers/google/cloud/transfers/gdrive_to_local.py +1 -1
  135. airflow/providers/google/cloud/transfers/http_to_gcs.py +1 -1
  136. airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -1
  137. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +1 -1
  138. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
  139. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +2 -2
  140. airflow/providers/google/cloud/transfers/sql_to_gcs.py +1 -1
  141. airflow/providers/google/cloud/utils/field_validator.py +1 -2
  142. airflow/providers/google/common/auth_backend/google_openid.py +2 -1
  143. airflow/providers/google/common/deprecated.py +2 -1
  144. airflow/providers/google/common/hooks/base_google.py +7 -3
  145. airflow/providers/google/common/links/storage.py +0 -22
  146. airflow/providers/google/firebase/operators/firestore.py +1 -1
  147. airflow/providers/google/get_provider_info.py +0 -11
  148. airflow/providers/google/leveldb/hooks/leveldb.py +5 -1
  149. airflow/providers/google/leveldb/operators/leveldb.py +1 -1
  150. airflow/providers/google/marketing_platform/links/analytics_admin.py +3 -6
  151. airflow/providers/google/marketing_platform/operators/analytics_admin.py +0 -1
  152. airflow/providers/google/marketing_platform/operators/campaign_manager.py +4 -4
  153. airflow/providers/google/marketing_platform/operators/display_video.py +6 -6
  154. airflow/providers/google/marketing_platform/operators/search_ads.py +1 -1
  155. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +6 -1
  156. airflow/providers/google/marketing_platform/sensors/display_video.py +6 -1
  157. airflow/providers/google/suite/operators/sheets.py +3 -3
  158. airflow/providers/google/suite/sensors/drive.py +6 -1
  159. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +1 -1
  160. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  161. airflow/providers/google/suite/transfers/local_to_drive.py +1 -1
  162. airflow/providers/google/version_compat.py +28 -0
  163. {apache_airflow_providers_google-16.0.0rc1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/METADATA +19 -20
  164. {apache_airflow_providers_google-16.0.0rc1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/RECORD +166 -166
  165. {apache_airflow_providers_google-16.0.0rc1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/WHEEL +0 -0
  166. {apache_airflow_providers_google-16.0.0rc1.dist-info → apache_airflow_providers_google-16.1.0rc1.dist-info}/entry_points.txt +0 -0
@@ -21,7 +21,7 @@
21
21
  from __future__ import annotations
22
22
 
23
23
  from collections.abc import Sequence
24
- from typing import TYPE_CHECKING
24
+ from typing import TYPE_CHECKING, Any
25
25
 
26
26
  from google.api_core.exceptions import NotFound
27
27
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -91,6 +91,13 @@ class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
91
91
  self.impersonation_chain = impersonation_chain
92
92
  self.hook: AutoMLHook | None = None
93
93
 
94
+ @property
95
+ def extra_links_params(self) -> dict[str, Any]:
96
+ return {
97
+ "region": self.region,
98
+ "project_id": self.project_id,
99
+ }
100
+
94
101
  def on_kill(self) -> None:
95
102
  """Act as a callback called when the operator is killed; cancel any running job."""
96
103
  if self.hook:
@@ -242,12 +249,12 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
242
249
  if model:
243
250
  result = Model.to_dict(model)
244
251
  model_id = self.hook.extract_model_id(result)
245
- self.xcom_push(context, key="model_id", value=model_id)
246
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
252
+ context["ti"].xcom_push(key="model_id", value=model_id)
253
+ VertexAIModelLink.persist(context=context, model_id=model_id)
247
254
  else:
248
255
  result = model # type: ignore
249
- self.xcom_push(context, key="training_id", value=training_id)
250
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
256
+ context["ti"].xcom_push(key="training_id", value=training_id)
257
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
251
258
  return result
252
259
 
253
260
 
@@ -334,12 +341,12 @@ class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
334
341
  if model:
335
342
  result = Model.to_dict(model)
336
343
  model_id = self.hook.extract_model_id(result)
337
- self.xcom_push(context, key="model_id", value=model_id)
338
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
344
+ context["ti"].xcom_push(key="model_id", value=model_id)
345
+ VertexAIModelLink.persist(context=context, model_id=model_id)
339
346
  else:
340
347
  result = model # type: ignore
341
- self.xcom_push(context, key="training_id", value=training_id)
342
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
348
+ context["ti"].xcom_push(key="training_id", value=training_id)
349
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
343
350
  return result
344
351
 
345
352
 
@@ -457,12 +464,12 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
457
464
  if model:
458
465
  result = Model.to_dict(model)
459
466
  model_id = self.hook.extract_model_id(result)
460
- self.xcom_push(context, key="model_id", value=model_id)
461
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
467
+ context["ti"].xcom_push(key="model_id", value=model_id)
468
+ VertexAIModelLink.persist(context=context, model_id=model_id)
462
469
  else:
463
470
  result = model # type: ignore
464
- self.xcom_push(context, key="training_id", value=training_id)
465
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
471
+ context["ti"].xcom_push(key="training_id", value=training_id)
472
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
466
473
  return result
467
474
 
468
475
 
@@ -531,12 +538,12 @@ class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
531
538
  if model:
532
539
  result = Model.to_dict(model)
533
540
  model_id = self.hook.extract_model_id(result)
534
- self.xcom_push(context, key="model_id", value=model_id)
535
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
541
+ context["ti"].xcom_push(key="model_id", value=model_id)
542
+ VertexAIModelLink.persist(context=context, model_id=model_id)
536
543
  else:
537
544
  result = model # type: ignore
538
- self.xcom_push(context, key="training_id", value=training_id)
539
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
545
+ context["ti"].xcom_push(key="training_id", value=training_id)
546
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
540
547
  return result
541
548
 
542
549
 
@@ -640,6 +647,12 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
640
647
  self.gcp_conn_id = gcp_conn_id
641
648
  self.impersonation_chain = impersonation_chain
642
649
 
650
+ @property
651
+ def extra_links_params(self) -> dict[str, Any]:
652
+ return {
653
+ "project_id": self.project_id,
654
+ }
655
+
643
656
  def execute(self, context: Context):
644
657
  hook = AutoMLHook(
645
658
  gcp_conn_id=self.gcp_conn_id,
@@ -656,5 +669,5 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
656
669
  timeout=self.timeout,
657
670
  metadata=self.metadata,
658
671
  )
659
- VertexAITrainingPipelinesLink.persist(context=context, task_instance=self)
672
+ VertexAITrainingPipelinesLink.persist(context=context)
660
673
  return [TrainingPipeline.to_dict(result) for result in results]
@@ -231,6 +231,13 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
231
231
  impersonation_chain=self.impersonation_chain,
232
232
  )
233
233
 
234
+ @property
235
+ def extra_links_params(self) -> dict[str, Any]:
236
+ return {
237
+ "region": self.region,
238
+ "project_id": self.project_id,
239
+ }
240
+
234
241
  def execute(self, context: Context):
235
242
  self.log.info("Creating Batch prediction job")
236
243
  batch_prediction_job: BatchPredictionJobObject = self.hook.submit_batch_prediction_job(
@@ -262,9 +269,10 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
262
269
  batch_prediction_job_id = batch_prediction_job.name
263
270
  self.log.info("Batch prediction job was created. Job id: %s", batch_prediction_job_id)
264
271
 
265
- self.xcom_push(context, key="batch_prediction_job_id", value=batch_prediction_job_id)
272
+ context["ti"].xcom_push(key="batch_prediction_job_id", value=batch_prediction_job_id)
266
273
  VertexAIBatchPredictionJobLink.persist(
267
- context=context, task_instance=self, batch_prediction_job_id=batch_prediction_job_id
274
+ context=context,
275
+ batch_prediction_job_id=batch_prediction_job_id,
268
276
  )
269
277
 
270
278
  if self.deferrable:
@@ -295,13 +303,11 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
295
303
  job: dict[str, Any] = event["job"]
296
304
  self.log.info("Batch prediction job %s created and completed successfully.", job["name"])
297
305
  job_id = self.hook.extract_batch_prediction_job_id(job)
298
- self.xcom_push(
299
- context,
306
+ context["ti"].xcom_push(
300
307
  key="batch_prediction_job_id",
301
308
  value=job_id,
302
309
  )
303
- self.xcom_push(
304
- context,
310
+ context["ti"].xcom_push(
305
311
  key="training_conf",
306
312
  value={
307
313
  "training_conf_id": job_id,
@@ -427,6 +433,13 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
427
433
  self.gcp_conn_id = gcp_conn_id
428
434
  self.impersonation_chain = impersonation_chain
429
435
 
436
+ @property
437
+ def extra_links_params(self) -> dict[str, Any]:
438
+ return {
439
+ "region": self.region,
440
+ "project_id": self.project_id,
441
+ }
442
+
430
443
  def execute(self, context: Context):
431
444
  hook = BatchPredictionJobHook(
432
445
  gcp_conn_id=self.gcp_conn_id,
@@ -445,7 +458,8 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
445
458
  )
446
459
  self.log.info("Batch prediction job was gotten.")
447
460
  VertexAIBatchPredictionJobLink.persist(
448
- context=context, task_instance=self, batch_prediction_job_id=self.batch_prediction_job
461
+ context=context,
462
+ batch_prediction_job_id=self.batch_prediction_job,
449
463
  )
450
464
  return BatchPredictionJob.to_dict(result)
451
465
  except NotFound:
@@ -517,6 +531,12 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
517
531
  self.gcp_conn_id = gcp_conn_id
518
532
  self.impersonation_chain = impersonation_chain
519
533
 
534
+ @property
535
+ def extra_links_params(self) -> dict[str, Any]:
536
+ return {
537
+ "project_id": self.project_id,
538
+ }
539
+
520
540
  def execute(self, context: Context):
521
541
  hook = BatchPredictionJobHook(
522
542
  gcp_conn_id=self.gcp_conn_id,
@@ -533,5 +553,5 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
533
553
  timeout=self.timeout,
534
554
  metadata=self.metadata,
535
555
  )
536
- VertexAIBatchPredictionJobListLink.persist(context=context, task_instance=self)
556
+ VertexAIBatchPredictionJobListLink.persist(context=context)
537
557
  return [BatchPredictionJob.to_dict(result) for result in results]
@@ -170,17 +170,24 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
170
170
  self.gcp_conn_id = gcp_conn_id
171
171
  self.impersonation_chain = impersonation_chain
172
172
 
173
+ @property
174
+ def extra_links_params(self) -> dict[str, Any]:
175
+ return {
176
+ "region": self.region,
177
+ "project_id": self.project_id,
178
+ }
179
+
173
180
  def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
174
181
  if event["status"] == "error":
175
182
  raise AirflowException(event["message"])
176
183
  training_pipeline = event["job"]
177
184
  custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(training_pipeline)
178
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
185
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
179
186
  try:
180
187
  model = training_pipeline["model_to_upload"]
181
188
  model_id = self.hook.extract_model_id(model)
182
- self.xcom_push(context, key="model_id", value=model_id)
183
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
189
+ context["ti"].xcom_push(key="model_id", value=model_id)
190
+ VertexAIModelLink.persist(context=context, model_id=model_id)
184
191
  return model
185
192
  except KeyError:
186
193
  self.log.warning(
@@ -584,13 +591,13 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
584
591
  if model:
585
592
  result = Model.to_dict(model)
586
593
  model_id = self.hook.extract_model_id(result)
587
- self.xcom_push(context, key="model_id", value=model_id)
588
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
594
+ context["ti"].xcom_push(key="model_id", value=model_id)
595
+ VertexAIModelLink.persist(context=context, model_id=model_id)
589
596
  else:
590
597
  result = model # type: ignore
591
- self.xcom_push(context, key="training_id", value=training_id)
592
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
593
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
598
+ context["ti"].xcom_push(key="training_id", value=training_id)
599
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
600
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
594
601
  return result
595
602
 
596
603
  def invoke_defer(self, context: Context) -> None:
@@ -648,8 +655,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
648
655
  )
649
656
  custom_container_training_job_obj.wait_for_resource_creation()
650
657
  training_pipeline_id: str = custom_container_training_job_obj.name
651
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
652
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
658
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
659
+ VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
653
660
  self.defer(
654
661
  trigger=CustomContainerTrainingJobTrigger(
655
662
  conn_id=self.gcp_conn_id,
@@ -1041,13 +1048,13 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
1041
1048
  if model:
1042
1049
  result = Model.to_dict(model)
1043
1050
  model_id = self.hook.extract_model_id(result)
1044
- self.xcom_push(context, key="model_id", value=model_id)
1045
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1051
+ context["ti"].xcom_push(key="model_id", value=model_id)
1052
+ VertexAIModelLink.persist(context=context, model_id=model_id)
1046
1053
  else:
1047
1054
  result = model # type: ignore
1048
- self.xcom_push(context, key="training_id", value=training_id)
1049
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
1050
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
1055
+ context["ti"].xcom_push(key="training_id", value=training_id)
1056
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
1057
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
1051
1058
  return result
1052
1059
 
1053
1060
  def invoke_defer(self, context: Context) -> None:
@@ -1106,8 +1113,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
1106
1113
  )
1107
1114
  custom_python_training_job_obj.wait_for_resource_creation()
1108
1115
  training_pipeline_id: str = custom_python_training_job_obj.name
1109
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
1110
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
1116
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
1117
+ VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
1111
1118
  self.defer(
1112
1119
  trigger=CustomPythonPackageTrainingJobTrigger(
1113
1120
  conn_id=self.gcp_conn_id,
@@ -1504,13 +1511,13 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1504
1511
  if model:
1505
1512
  result = Model.to_dict(model)
1506
1513
  model_id = self.hook.extract_model_id(result)
1507
- self.xcom_push(context, key="model_id", value=model_id)
1508
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1514
+ context["ti"].xcom_push(key="model_id", value=model_id)
1515
+ VertexAIModelLink.persist(context=context, model_id=model_id)
1509
1516
  else:
1510
1517
  result = model # type: ignore
1511
- self.xcom_push(context, key="training_id", value=training_id)
1512
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
1513
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
1518
+ context["ti"].xcom_push(key="training_id", value=training_id)
1519
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
1520
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
1514
1521
  return result
1515
1522
 
1516
1523
  def invoke_defer(self, context: Context) -> None:
@@ -1569,8 +1576,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1569
1576
  )
1570
1577
  custom_training_job_obj.wait_for_resource_creation()
1571
1578
  training_pipeline_id: str = custom_training_job_obj.name
1572
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
1573
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
1579
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
1580
+ VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
1574
1581
  self.defer(
1575
1582
  trigger=CustomTrainingJobTrigger(
1576
1583
  conn_id=self.gcp_conn_id,
@@ -1748,6 +1755,12 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
1748
1755
  self.gcp_conn_id = gcp_conn_id
1749
1756
  self.impersonation_chain = impersonation_chain
1750
1757
 
1758
+ @property
1759
+ def extra_links_params(self) -> dict[str, Any]:
1760
+ return {
1761
+ "project_id": self.project_id,
1762
+ }
1763
+
1751
1764
  def execute(self, context: Context):
1752
1765
  hook = CustomJobHook(
1753
1766
  gcp_conn_id=self.gcp_conn_id,
@@ -1764,5 +1777,5 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
1764
1777
  timeout=self.timeout,
1765
1778
  metadata=self.metadata,
1766
1779
  )
1767
- VertexAITrainingPipelinesLink.persist(context=context, task_instance=self)
1780
+ VertexAITrainingPipelinesLink.persist(context=context)
1768
1781
  return [TrainingPipeline.to_dict(result) for result in results]
@@ -20,12 +20,13 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  from collections.abc import Sequence
23
- from typing import TYPE_CHECKING
23
+ from typing import TYPE_CHECKING, Any
24
24
 
25
25
  from google.api_core.exceptions import NotFound
26
26
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
27
27
  from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, ImportDataConfig
28
28
 
29
+ from airflow.exceptions import AirflowException
29
30
  from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook
30
31
  from airflow.providers.google.cloud.links.vertex_ai import VertexAIDatasetLink, VertexAIDatasetListLink
31
32
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
@@ -85,6 +86,13 @@ class CreateDatasetOperator(GoogleCloudBaseOperator):
85
86
  self.gcp_conn_id = gcp_conn_id
86
87
  self.impersonation_chain = impersonation_chain
87
88
 
89
+ @property
90
+ def extra_links_params(self) -> dict[str, Any]:
91
+ return {
92
+ "region": self.region,
93
+ "project_id": self.project_id,
94
+ }
95
+
88
96
  def execute(self, context: Context):
89
97
  hook = DatasetHook(
90
98
  gcp_conn_id=self.gcp_conn_id,
@@ -106,8 +114,8 @@ class CreateDatasetOperator(GoogleCloudBaseOperator):
106
114
  dataset_id = hook.extract_dataset_id(dataset)
107
115
  self.log.info("Dataset was created. Dataset id: %s", dataset_id)
108
116
 
109
- self.xcom_push(context, key="dataset_id", value=dataset_id)
110
- VertexAIDatasetLink.persist(context=context, task_instance=self, dataset_id=dataset_id)
117
+ context["ti"].xcom_push(key="dataset_id", value=dataset_id)
118
+ VertexAIDatasetLink.persist(context=context, dataset_id=dataset_id)
111
119
  return dataset
112
120
 
113
121
 
@@ -160,6 +168,13 @@ class GetDatasetOperator(GoogleCloudBaseOperator):
160
168
  self.gcp_conn_id = gcp_conn_id
161
169
  self.impersonation_chain = impersonation_chain
162
170
 
171
+ @property
172
+ def extra_links_params(self) -> dict[str, Any]:
173
+ return {
174
+ "region": self.region,
175
+ "project_id": self.project_id,
176
+ }
177
+
163
178
  def execute(self, context: Context):
164
179
  hook = DatasetHook(
165
180
  gcp_conn_id=self.gcp_conn_id,
@@ -177,7 +192,7 @@ class GetDatasetOperator(GoogleCloudBaseOperator):
177
192
  timeout=self.timeout,
178
193
  metadata=self.metadata,
179
194
  )
180
- VertexAIDatasetLink.persist(context=context, task_instance=self, dataset_id=self.dataset_id)
195
+ VertexAIDatasetLink.persist(context=context, dataset_id=self.dataset_id)
181
196
  self.log.info("Dataset was gotten.")
182
197
  return Dataset.to_dict(dataset_obj)
183
198
  except NotFound:
@@ -321,7 +336,21 @@ class ExportDataOperator(GoogleCloudBaseOperator):
321
336
  self.log.info("Export was done successfully")
322
337
 
323
338
 
324
- class ImportDataOperator(GoogleCloudBaseOperator):
339
+ class DatasetImportDataResultsCheckHelper:
340
+ """Helper utils to verify import dataset data results."""
341
+
342
+ @staticmethod
343
+ def _get_number_of_ds_items(dataset, total_key_name):
344
+ number_of_items = type(dataset).to_dict(dataset).get(total_key_name, 0)
345
+ return number_of_items
346
+
347
+ @staticmethod
348
+ def _raise_for_empty_import_result(dataset_id, initial_size, size_after_import):
349
+ if int(size_after_import) - int(initial_size) <= 0:
350
+ raise AirflowException(f"Empty results of data import for the dataset_id {dataset_id}.")
351
+
352
+
353
+ class ImportDataOperator(GoogleCloudBaseOperator, DatasetImportDataResultsCheckHelper):
325
354
  """
326
355
  Imports data into a Dataset.
327
356
 
@@ -342,6 +371,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
342
371
  If set as a sequence, the identities from the list must grant
343
372
  Service Account Token Creator IAM role to the directly preceding identity, with first
344
373
  account from the list granting this role to the originating account (templated).
374
+ :param raise_for_empty_result: Raise an error if no additional data has been populated after the import.
345
375
  """
346
376
 
347
377
  template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")
@@ -358,6 +388,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
358
388
  metadata: Sequence[tuple[str, str]] = (),
359
389
  gcp_conn_id: str = "google_cloud_default",
360
390
  impersonation_chain: str | Sequence[str] | None = None,
391
+ raise_for_empty_result: bool = False,
361
392
  **kwargs,
362
393
  ) -> None:
363
394
  super().__init__(**kwargs)
@@ -370,13 +401,24 @@ class ImportDataOperator(GoogleCloudBaseOperator):
370
401
  self.metadata = metadata
371
402
  self.gcp_conn_id = gcp_conn_id
372
403
  self.impersonation_chain = impersonation_chain
404
+ self.raise_for_empty_result = raise_for_empty_result
373
405
 
374
406
  def execute(self, context: Context):
375
407
  hook = DatasetHook(
376
408
  gcp_conn_id=self.gcp_conn_id,
377
409
  impersonation_chain=self.impersonation_chain,
378
410
  )
379
-
411
+ initial_dataset_size = self._get_number_of_ds_items(
412
+ dataset=hook.get_dataset(
413
+ dataset_id=self.dataset_id,
414
+ project_id=self.project_id,
415
+ region=self.region,
416
+ retry=self.retry,
417
+ timeout=self.timeout,
418
+ metadata=self.metadata,
419
+ ),
420
+ total_key_name="data_item_count",
421
+ )
380
422
  self.log.info("Importing data: %s", self.dataset_id)
381
423
  operation = hook.import_data(
382
424
  project_id=self.project_id,
@@ -388,7 +430,21 @@ class ImportDataOperator(GoogleCloudBaseOperator):
388
430
  metadata=self.metadata,
389
431
  )
390
432
  hook.wait_for_operation(timeout=self.timeout, operation=operation)
433
+ result_dataset_size = self._get_number_of_ds_items(
434
+ dataset=hook.get_dataset(
435
+ dataset_id=self.dataset_id,
436
+ project_id=self.project_id,
437
+ region=self.region,
438
+ retry=self.retry,
439
+ timeout=self.timeout,
440
+ metadata=self.metadata,
441
+ ),
442
+ total_key_name="data_item_count",
443
+ )
444
+ if self.raise_for_empty_result:
445
+ self._raise_for_empty_import_result(self.dataset_id, initial_dataset_size, result_dataset_size)
391
446
  self.log.info("Import was done successfully")
447
+ return {"total_data_items_imported": int(result_dataset_size) - int(initial_dataset_size)}
392
448
 
393
449
 
394
450
  class ListDatasetsOperator(GoogleCloudBaseOperator):
@@ -451,6 +507,12 @@ class ListDatasetsOperator(GoogleCloudBaseOperator):
451
507
  self.gcp_conn_id = gcp_conn_id
452
508
  self.impersonation_chain = impersonation_chain
453
509
 
510
+ @property
511
+ def extra_links_params(self) -> dict[str, Any]:
512
+ return {
513
+ "project_id": self.project_id,
514
+ }
515
+
454
516
  def execute(self, context: Context):
455
517
  hook = DatasetHook(
456
518
  gcp_conn_id=self.gcp_conn_id,
@@ -468,7 +530,7 @@ class ListDatasetsOperator(GoogleCloudBaseOperator):
468
530
  timeout=self.timeout,
469
531
  metadata=self.metadata,
470
532
  )
471
- VertexAIDatasetListLink.persist(context=context, task_instance=self)
533
+ VertexAIDatasetListLink.persist(context=context)
472
534
  return [Dataset.to_dict(result) for result in results]
473
535
 
474
536
 
@@ -21,7 +21,7 @@
21
21
  from __future__ import annotations
22
22
 
23
23
  from collections.abc import Sequence
24
- from typing import TYPE_CHECKING
24
+ from typing import TYPE_CHECKING, Any
25
25
 
26
26
  from google.api_core.exceptions import NotFound
27
27
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -93,6 +93,13 @@ class CreateEndpointOperator(GoogleCloudBaseOperator):
93
93
  self.gcp_conn_id = gcp_conn_id
94
94
  self.impersonation_chain = impersonation_chain
95
95
 
96
+ @property
97
+ def extra_links_params(self) -> dict[str, Any]:
98
+ return {
99
+ "region": self.region,
100
+ "project_id": self.project_id,
101
+ }
102
+
96
103
  def execute(self, context: Context):
97
104
  hook = EndpointServiceHook(
98
105
  gcp_conn_id=self.gcp_conn_id,
@@ -115,8 +122,8 @@ class CreateEndpointOperator(GoogleCloudBaseOperator):
115
122
  endpoint_id = hook.extract_endpoint_id(endpoint)
116
123
  self.log.info("Endpoint was created. Endpoint ID: %s", endpoint_id)
117
124
 
118
- self.xcom_push(context, key="endpoint_id", value=endpoint_id)
119
- VertexAIEndpointLink.persist(context=context, task_instance=self, endpoint_id=endpoint_id)
125
+ context["ti"].xcom_push(key="endpoint_id", value=endpoint_id)
126
+ VertexAIEndpointLink.persist(context=context, endpoint_id=endpoint_id)
120
127
  return endpoint
121
128
 
122
129
 
@@ -255,6 +262,13 @@ class DeployModelOperator(GoogleCloudBaseOperator):
255
262
  self.gcp_conn_id = gcp_conn_id
256
263
  self.impersonation_chain = impersonation_chain
257
264
 
265
+ @property
266
+ def extra_links_params(self) -> dict[str, Any]:
267
+ return {
268
+ "region": self.region,
269
+ "project_id": self.project_id,
270
+ }
271
+
258
272
  def execute(self, context: Context):
259
273
  hook = EndpointServiceHook(
260
274
  gcp_conn_id=self.gcp_conn_id,
@@ -278,8 +292,8 @@ class DeployModelOperator(GoogleCloudBaseOperator):
278
292
  deployed_model_id = hook.extract_deployed_model_id(deploy_model)
279
293
  self.log.info("Model was deployed. Deployed Model ID: %s", deployed_model_id)
280
294
 
281
- self.xcom_push(context, key="deployed_model_id", value=deployed_model_id)
282
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=deployed_model_id)
295
+ context["ti"].xcom_push(key="deployed_model_id", value=deployed_model_id)
296
+ VertexAIModelLink.persist(context=context, model_id=deployed_model_id)
283
297
  return deploy_model
284
298
 
285
299
 
@@ -330,6 +344,13 @@ class GetEndpointOperator(GoogleCloudBaseOperator):
330
344
  self.gcp_conn_id = gcp_conn_id
331
345
  self.impersonation_chain = impersonation_chain
332
346
 
347
+ @property
348
+ def extra_links_params(self) -> dict[str, Any]:
349
+ return {
350
+ "region": self.region,
351
+ "project_id": self.project_id,
352
+ }
353
+
333
354
  def execute(self, context: Context):
334
355
  hook = EndpointServiceHook(
335
356
  gcp_conn_id=self.gcp_conn_id,
@@ -346,7 +367,7 @@ class GetEndpointOperator(GoogleCloudBaseOperator):
346
367
  timeout=self.timeout,
347
368
  metadata=self.metadata,
348
369
  )
349
- VertexAIEndpointLink.persist(context=context, task_instance=self, endpoint_id=self.endpoint_id)
370
+ VertexAIEndpointLink.persist(context=context, endpoint_id=self.endpoint_id)
350
371
  self.log.info("Endpoint was gotten.")
351
372
  return Endpoint.to_dict(endpoint_obj)
352
373
  except NotFound:
@@ -429,6 +450,12 @@ class ListEndpointsOperator(GoogleCloudBaseOperator):
429
450
  self.gcp_conn_id = gcp_conn_id
430
451
  self.impersonation_chain = impersonation_chain
431
452
 
453
+ @property
454
+ def extra_links_params(self) -> dict[str, Any]:
455
+ return {
456
+ "project_id": self.project_id,
457
+ }
458
+
432
459
  def execute(self, context: Context):
433
460
  hook = EndpointServiceHook(
434
461
  gcp_conn_id=self.gcp_conn_id,
@@ -446,7 +473,7 @@ class ListEndpointsOperator(GoogleCloudBaseOperator):
446
473
  timeout=self.timeout,
447
474
  metadata=self.metadata,
448
475
  )
449
- VertexAIEndpointListLink.persist(context=context, task_instance=self)
476
+ VertexAIEndpointListLink.persist(context=context)
450
477
  return [Endpoint.to_dict(result) for result in results]
451
478
 
452
479
 
@@ -582,6 +609,13 @@ class UpdateEndpointOperator(GoogleCloudBaseOperator):
582
609
  self.gcp_conn_id = gcp_conn_id
583
610
  self.impersonation_chain = impersonation_chain
584
611
 
612
+ @property
613
+ def extra_links_params(self) -> dict[str, Any]:
614
+ return {
615
+ "region": self.region,
616
+ "project_id": self.project_id,
617
+ }
618
+
585
619
  def execute(self, context: Context):
586
620
  hook = EndpointServiceHook(
587
621
  gcp_conn_id=self.gcp_conn_id,
@@ -599,5 +633,5 @@ class UpdateEndpointOperator(GoogleCloudBaseOperator):
599
633
  metadata=self.metadata,
600
634
  )
601
635
  self.log.info("Endpoint was updated")
602
- VertexAIEndpointLink.persist(context=context, task_instance=self, endpoint_id=self.endpoint_id)
636
+ VertexAIEndpointLink.persist(context=context, endpoint_id=self.endpoint_id)
603
637
  return Endpoint.to_dict(result)