apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
  2. airflow/providers/google/__init__.py +3 -3
  3. airflow/providers/google/ads/hooks/ads.py +39 -5
  4. airflow/providers/google/ads/operators/ads.py +2 -2
  5. airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
  6. airflow/providers/google/assets/gcs.py +1 -11
  7. airflow/providers/google/cloud/bundles/__init__.py +16 -0
  8. airflow/providers/google/cloud/bundles/gcs.py +161 -0
  9. airflow/providers/google/cloud/hooks/bigquery.py +166 -281
  10. airflow/providers/google/cloud/hooks/cloud_composer.py +287 -14
  11. airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
  12. airflow/providers/google/cloud/hooks/cloud_run.py +17 -9
  13. airflow/providers/google/cloud/hooks/cloud_sql.py +101 -22
  14. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +27 -6
  15. airflow/providers/google/cloud/hooks/compute_ssh.py +5 -1
  16. airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
  17. airflow/providers/google/cloud/hooks/dataflow.py +71 -94
  18. airflow/providers/google/cloud/hooks/datafusion.py +1 -1
  19. airflow/providers/google/cloud/hooks/dataplex.py +1 -1
  20. airflow/providers/google/cloud/hooks/dataprep.py +1 -1
  21. airflow/providers/google/cloud/hooks/dataproc.py +72 -71
  22. airflow/providers/google/cloud/hooks/gcs.py +111 -14
  23. airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
  24. airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
  25. airflow/providers/google/cloud/hooks/looker.py +6 -1
  26. airflow/providers/google/cloud/hooks/mlengine.py +3 -2
  27. airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
  28. airflow/providers/google/cloud/hooks/spanner.py +73 -8
  29. airflow/providers/google/cloud/hooks/stackdriver.py +10 -8
  30. airflow/providers/google/cloud/hooks/translate.py +1 -1
  31. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -209
  32. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +2 -2
  33. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +27 -1
  34. airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
  35. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
  36. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
  37. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
  38. airflow/providers/google/cloud/hooks/vision.py +2 -2
  39. airflow/providers/google/cloud/hooks/workflows.py +1 -1
  40. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  41. airflow/providers/google/cloud/links/base.py +77 -13
  42. airflow/providers/google/cloud/links/bigquery.py +0 -47
  43. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  44. airflow/providers/google/cloud/links/bigtable.py +0 -48
  45. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  46. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  47. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  48. airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
  49. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  50. airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
  51. airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
  52. airflow/providers/google/cloud/links/compute.py +0 -58
  53. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  54. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  55. airflow/providers/google/cloud/links/dataflow.py +0 -34
  56. airflow/providers/google/cloud/links/dataform.py +0 -64
  57. airflow/providers/google/cloud/links/datafusion.py +1 -96
  58. airflow/providers/google/cloud/links/dataplex.py +0 -154
  59. airflow/providers/google/cloud/links/dataprep.py +0 -24
  60. airflow/providers/google/cloud/links/dataproc.py +11 -95
  61. airflow/providers/google/cloud/links/datastore.py +0 -31
  62. airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
  63. airflow/providers/google/cloud/links/managed_kafka.py +0 -70
  64. airflow/providers/google/cloud/links/mlengine.py +0 -70
  65. airflow/providers/google/cloud/links/pubsub.py +0 -32
  66. airflow/providers/google/cloud/links/spanner.py +0 -33
  67. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  68. airflow/providers/google/cloud/links/translate.py +17 -187
  69. airflow/providers/google/cloud/links/vertex_ai.py +28 -195
  70. airflow/providers/google/cloud/links/workflows.py +0 -52
  71. airflow/providers/google/cloud/log/gcs_task_handler.py +17 -9
  72. airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
  73. airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
  74. airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
  75. airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
  76. airflow/providers/google/cloud/openlineage/facets.py +102 -1
  77. airflow/providers/google/cloud/openlineage/mixins.py +10 -8
  78. airflow/providers/google/cloud/openlineage/utils.py +15 -1
  79. airflow/providers/google/cloud/operators/alloy_db.py +70 -55
  80. airflow/providers/google/cloud/operators/bigquery.py +73 -636
  81. airflow/providers/google/cloud/operators/bigquery_dts.py +3 -5
  82. airflow/providers/google/cloud/operators/bigtable.py +36 -7
  83. airflow/providers/google/cloud/operators/cloud_base.py +21 -1
  84. airflow/providers/google/cloud/operators/cloud_batch.py +2 -2
  85. airflow/providers/google/cloud/operators/cloud_build.py +75 -32
  86. airflow/providers/google/cloud/operators/cloud_composer.py +128 -40
  87. airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
  88. airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
  89. airflow/providers/google/cloud/operators/cloud_run.py +23 -5
  90. airflow/providers/google/cloud/operators/cloud_sql.py +8 -16
  91. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -11
  92. airflow/providers/google/cloud/operators/compute.py +8 -40
  93. airflow/providers/google/cloud/operators/datacatalog.py +157 -21
  94. airflow/providers/google/cloud/operators/dataflow.py +38 -15
  95. airflow/providers/google/cloud/operators/dataform.py +15 -5
  96. airflow/providers/google/cloud/operators/datafusion.py +41 -20
  97. airflow/providers/google/cloud/operators/dataplex.py +193 -109
  98. airflow/providers/google/cloud/operators/dataprep.py +1 -5
  99. airflow/providers/google/cloud/operators/dataproc.py +78 -35
  100. airflow/providers/google/cloud/operators/dataproc_metastore.py +96 -88
  101. airflow/providers/google/cloud/operators/datastore.py +22 -6
  102. airflow/providers/google/cloud/operators/dlp.py +6 -29
  103. airflow/providers/google/cloud/operators/functions.py +16 -7
  104. airflow/providers/google/cloud/operators/gcs.py +10 -8
  105. airflow/providers/google/cloud/operators/gen_ai.py +389 -0
  106. airflow/providers/google/cloud/operators/kubernetes_engine.py +60 -99
  107. airflow/providers/google/cloud/operators/looker.py +1 -1
  108. airflow/providers/google/cloud/operators/managed_kafka.py +107 -52
  109. airflow/providers/google/cloud/operators/natural_language.py +1 -1
  110. airflow/providers/google/cloud/operators/pubsub.py +60 -14
  111. airflow/providers/google/cloud/operators/spanner.py +25 -12
  112. airflow/providers/google/cloud/operators/speech_to_text.py +1 -2
  113. airflow/providers/google/cloud/operators/stackdriver.py +1 -9
  114. airflow/providers/google/cloud/operators/tasks.py +1 -12
  115. airflow/providers/google/cloud/operators/text_to_speech.py +1 -2
  116. airflow/providers/google/cloud/operators/translate.py +40 -16
  117. airflow/providers/google/cloud/operators/translate_speech.py +1 -2
  118. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
  119. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +29 -9
  120. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +54 -26
  121. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
  122. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
  123. airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
  124. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
  125. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
  126. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +11 -9
  127. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
  128. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +30 -7
  129. airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
  130. airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
  131. airflow/providers/google/cloud/operators/vision.py +2 -2
  132. airflow/providers/google/cloud/operators/workflows.py +18 -15
  133. airflow/providers/google/cloud/sensors/bigquery.py +2 -2
  134. airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -2
  135. airflow/providers/google/cloud/sensors/bigtable.py +11 -4
  136. airflow/providers/google/cloud/sensors/cloud_composer.py +533 -29
  137. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -2
  138. airflow/providers/google/cloud/sensors/dataflow.py +26 -9
  139. airflow/providers/google/cloud/sensors/dataform.py +2 -2
  140. airflow/providers/google/cloud/sensors/datafusion.py +4 -4
  141. airflow/providers/google/cloud/sensors/dataplex.py +2 -2
  142. airflow/providers/google/cloud/sensors/dataprep.py +2 -2
  143. airflow/providers/google/cloud/sensors/dataproc.py +2 -2
  144. airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -2
  145. airflow/providers/google/cloud/sensors/gcs.py +4 -4
  146. airflow/providers/google/cloud/sensors/looker.py +2 -2
  147. airflow/providers/google/cloud/sensors/pubsub.py +4 -4
  148. airflow/providers/google/cloud/sensors/tasks.py +2 -2
  149. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
  150. airflow/providers/google/cloud/sensors/workflows.py +2 -2
  151. airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
  152. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
  153. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
  154. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
  155. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +4 -4
  156. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
  157. airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
  158. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
  159. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
  160. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  161. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +2 -2
  162. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +3 -3
  163. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +20 -12
  164. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
  165. airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
  166. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
  167. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
  168. airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
  169. airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
  170. airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
  171. airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
  172. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
  173. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
  174. airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
  175. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
  176. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +13 -4
  177. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
  178. airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
  179. airflow/providers/google/cloud/triggers/bigquery.py +75 -34
  180. airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
  181. airflow/providers/google/cloud/triggers/cloud_composer.py +302 -46
  182. airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
  183. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +91 -1
  184. airflow/providers/google/cloud/triggers/dataflow.py +122 -0
  185. airflow/providers/google/cloud/triggers/datafusion.py +1 -1
  186. airflow/providers/google/cloud/triggers/dataplex.py +14 -2
  187. airflow/providers/google/cloud/triggers/dataproc.py +122 -52
  188. airflow/providers/google/cloud/triggers/kubernetes_engine.py +45 -27
  189. airflow/providers/google/cloud/triggers/mlengine.py +1 -1
  190. airflow/providers/google/cloud/triggers/pubsub.py +15 -19
  191. airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
  192. airflow/providers/google/cloud/utils/credentials_provider.py +1 -1
  193. airflow/providers/google/cloud/utils/field_validator.py +1 -2
  194. airflow/providers/google/common/auth_backend/google_openid.py +4 -4
  195. airflow/providers/google/common/deprecated.py +2 -1
  196. airflow/providers/google/common/hooks/base_google.py +27 -8
  197. airflow/providers/google/common/links/storage.py +0 -22
  198. airflow/providers/google/common/utils/get_secret.py +31 -0
  199. airflow/providers/google/common/utils/id_token_credentials.py +3 -4
  200. airflow/providers/google/firebase/operators/firestore.py +2 -2
  201. airflow/providers/google/get_provider_info.py +56 -52
  202. airflow/providers/google/go_module_utils.py +35 -3
  203. airflow/providers/google/leveldb/hooks/leveldb.py +26 -1
  204. airflow/providers/google/leveldb/operators/leveldb.py +2 -2
  205. airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
  206. airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
  207. airflow/providers/google/marketing_platform/operators/analytics_admin.py +1 -2
  208. airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
  209. airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
  210. airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
  211. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
  212. airflow/providers/google/marketing_platform/sensors/display_video.py +3 -63
  213. airflow/providers/google/suite/hooks/calendar.py +1 -1
  214. airflow/providers/google/suite/hooks/sheets.py +15 -1
  215. airflow/providers/google/suite/operators/sheets.py +8 -3
  216. airflow/providers/google/suite/sensors/drive.py +2 -2
  217. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
  218. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  219. airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
  220. airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
  221. airflow/providers/google/version_compat.py +15 -1
  222. {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +92 -48
  223. apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
  224. apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
  225. airflow/providers/google/cloud/hooks/automl.py +0 -673
  226. airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
  227. airflow/providers/google/cloud/links/automl.py +0 -193
  228. airflow/providers/google/cloud/operators/automl.py +0 -1362
  229. airflow/providers/google/cloud/operators/life_sciences.py +0 -119
  230. airflow/providers/google/cloud/operators/mlengine.py +0 -112
  231. apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
  232. {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +0 -0
  233. {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
  234. {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
@@ -21,7 +21,7 @@
21
21
  from __future__ import annotations
22
22
 
23
23
  from collections.abc import Sequence
24
- from typing import TYPE_CHECKING
24
+ from typing import TYPE_CHECKING, Any
25
25
 
26
26
  from google.api_core.exceptions import NotFound
27
27
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -29,6 +29,7 @@ from google.cloud.aiplatform import datasets
29
29
  from google.cloud.aiplatform.models import Model
30
30
  from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
31
31
 
32
+ from airflow.exceptions import AirflowProviderDeprecationWarning
32
33
  from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
33
34
  from airflow.providers.google.cloud.links.vertex_ai import (
34
35
  VertexAIModelLink,
@@ -36,11 +37,12 @@ from airflow.providers.google.cloud.links.vertex_ai import (
36
37
  VertexAITrainingPipelinesLink,
37
38
  )
38
39
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
40
+ from airflow.providers.google.common.deprecated import deprecated
39
41
 
40
42
  if TYPE_CHECKING:
41
43
  from google.api_core.retry import Retry
42
44
 
43
- from airflow.utils.context import Context
45
+ from airflow.providers.common.compat.sdk import Context
44
46
 
45
47
 
46
48
  class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
@@ -91,6 +93,13 @@ class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
91
93
  self.impersonation_chain = impersonation_chain
92
94
  self.hook: AutoMLHook | None = None
93
95
 
96
+ @property
97
+ def extra_links_params(self) -> dict[str, Any]:
98
+ return {
99
+ "region": self.region,
100
+ "project_id": self.project_id,
101
+ }
102
+
94
103
  def on_kill(self) -> None:
95
104
  """Act as a callback called when the operator is killed; cancel any running job."""
96
105
  if self.hook:
@@ -242,12 +251,12 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
242
251
  if model:
243
252
  result = Model.to_dict(model)
244
253
  model_id = self.hook.extract_model_id(result)
245
- self.xcom_push(context, key="model_id", value=model_id)
246
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
254
+ context["ti"].xcom_push(key="model_id", value=model_id)
255
+ VertexAIModelLink.persist(context=context, model_id=model_id)
247
256
  else:
248
257
  result = model # type: ignore
249
- self.xcom_push(context, key="training_id", value=training_id)
250
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
258
+ context["ti"].xcom_push(key="training_id", value=training_id)
259
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
251
260
  return result
252
261
 
253
262
 
@@ -334,12 +343,12 @@ class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
334
343
  if model:
335
344
  result = Model.to_dict(model)
336
345
  model_id = self.hook.extract_model_id(result)
337
- self.xcom_push(context, key="model_id", value=model_id)
338
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
346
+ context["ti"].xcom_push(key="model_id", value=model_id)
347
+ VertexAIModelLink.persist(context=context, model_id=model_id)
339
348
  else:
340
349
  result = model # type: ignore
341
- self.xcom_push(context, key="training_id", value=training_id)
342
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
350
+ context["ti"].xcom_push(key="training_id", value=training_id)
351
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
343
352
  return result
344
353
 
345
354
 
@@ -457,15 +466,20 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
457
466
  if model:
458
467
  result = Model.to_dict(model)
459
468
  model_id = self.hook.extract_model_id(result)
460
- self.xcom_push(context, key="model_id", value=model_id)
461
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
469
+ context["ti"].xcom_push(key="model_id", value=model_id)
470
+ VertexAIModelLink.persist(context=context, model_id=model_id)
462
471
  else:
463
472
  result = model # type: ignore
464
- self.xcom_push(context, key="training_id", value=training_id)
465
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
473
+ context["ti"].xcom_push(key="training_id", value=training_id)
474
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
466
475
  return result
467
476
 
468
477
 
478
+ @deprecated(
479
+ planned_removal_date="March 24, 2026",
480
+ use_instead="airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator",
481
+ category=AirflowProviderDeprecationWarning,
482
+ )
469
483
  class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
470
484
  """Create Auto ML Video Training job."""
471
485
 
@@ -531,12 +545,12 @@ class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
531
545
  if model:
532
546
  result = Model.to_dict(model)
533
547
  model_id = self.hook.extract_model_id(result)
534
- self.xcom_push(context, key="model_id", value=model_id)
535
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
548
+ context["ti"].xcom_push(key="model_id", value=model_id)
549
+ VertexAIModelLink.persist(context=context, model_id=model_id)
536
550
  else:
537
551
  result = model # type: ignore
538
- self.xcom_push(context, key="training_id", value=training_id)
539
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
552
+ context["ti"].xcom_push(key="training_id", value=training_id)
553
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
540
554
  return result
541
555
 
542
556
 
@@ -640,6 +654,12 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
640
654
  self.gcp_conn_id = gcp_conn_id
641
655
  self.impersonation_chain = impersonation_chain
642
656
 
657
+ @property
658
+ def extra_links_params(self) -> dict[str, Any]:
659
+ return {
660
+ "project_id": self.project_id,
661
+ }
662
+
643
663
  def execute(self, context: Context):
644
664
  hook = AutoMLHook(
645
665
  gcp_conn_id=self.gcp_conn_id,
@@ -656,5 +676,5 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
656
676
  timeout=self.timeout,
657
677
  metadata=self.metadata,
658
678
  )
659
- VertexAITrainingPipelinesLink.persist(context=context, task_instance=self)
679
+ VertexAITrainingPipelinesLink.persist(context=context)
660
680
  return [TrainingPipeline.to_dict(result) for result in results]
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
42
42
  from google.api_core.retry import Retry
43
43
  from google.cloud.aiplatform import BatchPredictionJob as BatchPredictionJobObject, Model, explain
44
44
 
45
- from airflow.utils.context import Context
45
+ from airflow.providers.common.compat.sdk import Context
46
46
 
47
47
 
48
48
  class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
@@ -231,6 +231,13 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
231
231
  impersonation_chain=self.impersonation_chain,
232
232
  )
233
233
 
234
+ @property
235
+ def extra_links_params(self) -> dict[str, Any]:
236
+ return {
237
+ "region": self.region,
238
+ "project_id": self.project_id,
239
+ }
240
+
234
241
  def execute(self, context: Context):
235
242
  self.log.info("Creating Batch prediction job")
236
243
  batch_prediction_job: BatchPredictionJobObject = self.hook.submit_batch_prediction_job(
@@ -262,9 +269,10 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
262
269
  batch_prediction_job_id = batch_prediction_job.name
263
270
  self.log.info("Batch prediction job was created. Job id: %s", batch_prediction_job_id)
264
271
 
265
- self.xcom_push(context, key="batch_prediction_job_id", value=batch_prediction_job_id)
272
+ context["ti"].xcom_push(key="batch_prediction_job_id", value=batch_prediction_job_id)
266
273
  VertexAIBatchPredictionJobLink.persist(
267
- context=context, task_instance=self, batch_prediction_job_id=batch_prediction_job_id
274
+ context=context,
275
+ batch_prediction_job_id=batch_prediction_job_id,
268
276
  )
269
277
 
270
278
  if self.deferrable:
@@ -295,13 +303,11 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
295
303
  job: dict[str, Any] = event["job"]
296
304
  self.log.info("Batch prediction job %s created and completed successfully.", job["name"])
297
305
  job_id = self.hook.extract_batch_prediction_job_id(job)
298
- self.xcom_push(
299
- context,
306
+ context["ti"].xcom_push(
300
307
  key="batch_prediction_job_id",
301
308
  value=job_id,
302
309
  )
303
- self.xcom_push(
304
- context,
310
+ context["ti"].xcom_push(
305
311
  key="training_conf",
306
312
  value={
307
313
  "training_conf_id": job_id,
@@ -427,6 +433,13 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
427
433
  self.gcp_conn_id = gcp_conn_id
428
434
  self.impersonation_chain = impersonation_chain
429
435
 
436
+ @property
437
+ def extra_links_params(self) -> dict[str, Any]:
438
+ return {
439
+ "region": self.region,
440
+ "project_id": self.project_id,
441
+ }
442
+
430
443
  def execute(self, context: Context):
431
444
  hook = BatchPredictionJobHook(
432
445
  gcp_conn_id=self.gcp_conn_id,
@@ -445,7 +458,8 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
445
458
  )
446
459
  self.log.info("Batch prediction job was gotten.")
447
460
  VertexAIBatchPredictionJobLink.persist(
448
- context=context, task_instance=self, batch_prediction_job_id=self.batch_prediction_job
461
+ context=context,
462
+ batch_prediction_job_id=self.batch_prediction_job,
449
463
  )
450
464
  return BatchPredictionJob.to_dict(result)
451
465
  except NotFound:
@@ -517,6 +531,12 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
517
531
  self.gcp_conn_id = gcp_conn_id
518
532
  self.impersonation_chain = impersonation_chain
519
533
 
534
+ @property
535
+ def extra_links_params(self) -> dict[str, Any]:
536
+ return {
537
+ "project_id": self.project_id,
538
+ }
539
+
520
540
  def execute(self, context: Context):
521
541
  hook = BatchPredictionJobHook(
522
542
  gcp_conn_id=self.gcp_conn_id,
@@ -533,5 +553,5 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
533
553
  timeout=self.timeout,
534
554
  metadata=self.metadata,
535
555
  )
536
- VertexAIBatchPredictionJobListLink.persist(context=context, task_instance=self)
556
+ VertexAIBatchPredictionJobListLink.persist(context=context)
537
557
  return [BatchPredictionJob.to_dict(result) for result in results]
@@ -51,8 +51,9 @@ if TYPE_CHECKING:
51
51
  CustomPythonPackageTrainingJob,
52
52
  CustomTrainingJob,
53
53
  )
54
+ from google.cloud.aiplatform_v1.types import PscInterfaceConfig
54
55
 
55
- from airflow.utils.context import Context
56
+ from airflow.providers.common.compat.sdk import Context
56
57
 
57
58
 
58
59
  class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
@@ -110,6 +111,7 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
110
111
  predefined_split_column_name: str | None = None,
111
112
  timestamp_split_column_name: str | None = None,
112
113
  tensorboard: str | None = None,
114
+ psc_interface_config: PscInterfaceConfig | None = None,
113
115
  gcp_conn_id: str = "google_cloud_default",
114
116
  impersonation_chain: str | Sequence[str] | None = None,
115
117
  **kwargs,
@@ -166,21 +168,29 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
166
168
  self.predefined_split_column_name = predefined_split_column_name
167
169
  self.timestamp_split_column_name = timestamp_split_column_name
168
170
  self.tensorboard = tensorboard
171
+ self.psc_interface_config = psc_interface_config
169
172
  # END Run param
170
173
  self.gcp_conn_id = gcp_conn_id
171
174
  self.impersonation_chain = impersonation_chain
172
175
 
176
+ @property
177
+ def extra_links_params(self) -> dict[str, Any]:
178
+ return {
179
+ "region": self.region,
180
+ "project_id": self.project_id,
181
+ }
182
+
173
183
  def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
174
184
  if event["status"] == "error":
175
185
  raise AirflowException(event["message"])
176
186
  training_pipeline = event["job"]
177
187
  custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(training_pipeline)
178
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
188
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
179
189
  try:
180
190
  model = training_pipeline["model_to_upload"]
181
191
  model_id = self.hook.extract_model_id(model)
182
- self.xcom_push(context, key="model_id", value=model_id)
183
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
192
+ context["ti"].xcom_push(key="model_id", value=model_id)
193
+ VertexAIModelLink.persist(context=context, model_id=model_id)
184
194
  return model
185
195
  except KeyError:
186
196
  self.log.warning(
@@ -466,6 +476,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
466
476
  ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
467
477
  For more information on configuring your service account please visit:
468
478
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
479
+ :param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
480
+ training.
469
481
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
470
482
  :param impersonation_chain: Optional service account to impersonate using short-term
471
483
  credentials, or chained list of accounts required to get the access_token
@@ -579,18 +591,19 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
579
591
  timestamp_split_column_name=self.timestamp_split_column_name,
580
592
  tensorboard=self.tensorboard,
581
593
  sync=True,
594
+ psc_interface_config=self.psc_interface_config,
582
595
  )
583
596
 
584
597
  if model:
585
598
  result = Model.to_dict(model)
586
599
  model_id = self.hook.extract_model_id(result)
587
- self.xcom_push(context, key="model_id", value=model_id)
588
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
600
+ context["ti"].xcom_push(key="model_id", value=model_id)
601
+ VertexAIModelLink.persist(context=context, model_id=model_id)
589
602
  else:
590
603
  result = model # type: ignore
591
- self.xcom_push(context, key="training_id", value=training_id)
592
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
593
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
604
+ context["ti"].xcom_push(key="training_id", value=training_id)
605
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
606
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
594
607
  return result
595
608
 
596
609
  def invoke_defer(self, context: Context) -> None:
@@ -645,11 +658,12 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
645
658
  predefined_split_column_name=self.predefined_split_column_name,
646
659
  timestamp_split_column_name=self.timestamp_split_column_name,
647
660
  tensorboard=self.tensorboard,
661
+ psc_interface_config=self.psc_interface_config,
648
662
  )
649
663
  custom_container_training_job_obj.wait_for_resource_creation()
650
664
  training_pipeline_id: str = custom_container_training_job_obj.name
651
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
652
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
665
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
666
+ VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
653
667
  self.defer(
654
668
  trigger=CustomContainerTrainingJobTrigger(
655
669
  conn_id=self.gcp_conn_id,
@@ -924,6 +938,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
924
938
  ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
925
939
  For more information on configuring your service account please visit:
926
940
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
941
+ :param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
942
+ training.
927
943
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
928
944
  :param impersonation_chain: Optional service account to impersonate using short-term
929
945
  credentials, or chained list of accounts required to get the access_token
@@ -1036,18 +1052,19 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
1036
1052
  timestamp_split_column_name=self.timestamp_split_column_name,
1037
1053
  tensorboard=self.tensorboard,
1038
1054
  sync=True,
1055
+ psc_interface_config=self.psc_interface_config,
1039
1056
  )
1040
1057
 
1041
1058
  if model:
1042
1059
  result = Model.to_dict(model)
1043
1060
  model_id = self.hook.extract_model_id(result)
1044
- self.xcom_push(context, key="model_id", value=model_id)
1045
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1061
+ context["ti"].xcom_push(key="model_id", value=model_id)
1062
+ VertexAIModelLink.persist(context=context, model_id=model_id)
1046
1063
  else:
1047
1064
  result = model # type: ignore
1048
- self.xcom_push(context, key="training_id", value=training_id)
1049
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
1050
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
1065
+ context["ti"].xcom_push(key="training_id", value=training_id)
1066
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
1067
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
1051
1068
  return result
1052
1069
 
1053
1070
  def invoke_defer(self, context: Context) -> None:
@@ -1103,11 +1120,12 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
1103
1120
  predefined_split_column_name=self.predefined_split_column_name,
1104
1121
  timestamp_split_column_name=self.timestamp_split_column_name,
1105
1122
  tensorboard=self.tensorboard,
1123
+ psc_interface_config=self.psc_interface_config,
1106
1124
  )
1107
1125
  custom_python_training_job_obj.wait_for_resource_creation()
1108
1126
  training_pipeline_id: str = custom_python_training_job_obj.name
1109
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
1110
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
1127
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
1128
+ VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
1111
1129
  self.defer(
1112
1130
  trigger=CustomPythonPackageTrainingJobTrigger(
1113
1131
  conn_id=self.gcp_conn_id,
@@ -1382,6 +1400,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1382
1400
  ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
1383
1401
  For more information on configuring your service account please visit:
1384
1402
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
1403
+ :param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
1404
+ training.
1385
1405
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
1386
1406
  :param impersonation_chain: Optional service account to impersonate using short-term
1387
1407
  credentials, or chained list of accounts required to get the access_token
@@ -1499,18 +1519,19 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1499
1519
  timestamp_split_column_name=self.timestamp_split_column_name,
1500
1520
  tensorboard=self.tensorboard,
1501
1521
  sync=True,
1522
+ psc_interface_config=None,
1502
1523
  )
1503
1524
 
1504
1525
  if model:
1505
1526
  result = Model.to_dict(model)
1506
1527
  model_id = self.hook.extract_model_id(result)
1507
- self.xcom_push(context, key="model_id", value=model_id)
1508
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1528
+ context["ti"].xcom_push(key="model_id", value=model_id)
1529
+ VertexAIModelLink.persist(context=context, model_id=model_id)
1509
1530
  else:
1510
1531
  result = model # type: ignore
1511
- self.xcom_push(context, key="training_id", value=training_id)
1512
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
1513
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
1532
+ context["ti"].xcom_push(key="training_id", value=training_id)
1533
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
1534
+ VertexAITrainingLink.persist(context=context, training_id=training_id)
1514
1535
  return result
1515
1536
 
1516
1537
  def invoke_defer(self, context: Context) -> None:
@@ -1566,11 +1587,12 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1566
1587
  predefined_split_column_name=self.predefined_split_column_name,
1567
1588
  timestamp_split_column_name=self.timestamp_split_column_name,
1568
1589
  tensorboard=self.tensorboard,
1590
+ psc_interface_config=self.psc_interface_config,
1569
1591
  )
1570
1592
  custom_training_job_obj.wait_for_resource_creation()
1571
1593
  training_pipeline_id: str = custom_training_job_obj.name
1572
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
1573
- VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
1594
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
1595
+ VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
1574
1596
  self.defer(
1575
1597
  trigger=CustomTrainingJobTrigger(
1576
1598
  conn_id=self.gcp_conn_id,
@@ -1748,6 +1770,12 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
1748
1770
  self.gcp_conn_id = gcp_conn_id
1749
1771
  self.impersonation_chain = impersonation_chain
1750
1772
 
1773
+ @property
1774
+ def extra_links_params(self) -> dict[str, Any]:
1775
+ return {
1776
+ "project_id": self.project_id,
1777
+ }
1778
+
1751
1779
  def execute(self, context: Context):
1752
1780
  hook = CustomJobHook(
1753
1781
  gcp_conn_id=self.gcp_conn_id,
@@ -1764,5 +1792,5 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
1764
1792
  timeout=self.timeout,
1765
1793
  metadata=self.metadata,
1766
1794
  )
1767
- VertexAITrainingPipelinesLink.persist(context=context, task_instance=self)
1795
+ VertexAITrainingPipelinesLink.persist(context=context)
1768
1796
  return [TrainingPipeline.to_dict(result) for result in results]
@@ -20,12 +20,13 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  from collections.abc import Sequence
23
- from typing import TYPE_CHECKING
23
+ from typing import TYPE_CHECKING, Any
24
24
 
25
25
  from google.api_core.exceptions import NotFound
26
26
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
27
27
  from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, ImportDataConfig
28
28
 
29
+ from airflow.exceptions import AirflowException
29
30
  from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook
30
31
  from airflow.providers.google.cloud.links.vertex_ai import VertexAIDatasetLink, VertexAIDatasetListLink
31
32
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
@@ -34,7 +35,7 @@ if TYPE_CHECKING:
34
35
  from google.api_core.retry import Retry
35
36
  from google.protobuf.field_mask_pb2 import FieldMask
36
37
 
37
- from airflow.utils.context import Context
38
+ from airflow.providers.common.compat.sdk import Context
38
39
 
39
40
 
40
41
  class CreateDatasetOperator(GoogleCloudBaseOperator):
@@ -85,6 +86,13 @@ class CreateDatasetOperator(GoogleCloudBaseOperator):
85
86
  self.gcp_conn_id = gcp_conn_id
86
87
  self.impersonation_chain = impersonation_chain
87
88
 
89
+ @property
90
+ def extra_links_params(self) -> dict[str, Any]:
91
+ return {
92
+ "region": self.region,
93
+ "project_id": self.project_id,
94
+ }
95
+
88
96
  def execute(self, context: Context):
89
97
  hook = DatasetHook(
90
98
  gcp_conn_id=self.gcp_conn_id,
@@ -106,8 +114,8 @@ class CreateDatasetOperator(GoogleCloudBaseOperator):
106
114
  dataset_id = hook.extract_dataset_id(dataset)
107
115
  self.log.info("Dataset was created. Dataset id: %s", dataset_id)
108
116
 
109
- self.xcom_push(context, key="dataset_id", value=dataset_id)
110
- VertexAIDatasetLink.persist(context=context, task_instance=self, dataset_id=dataset_id)
117
+ context["ti"].xcom_push(key="dataset_id", value=dataset_id)
118
+ VertexAIDatasetLink.persist(context=context, dataset_id=dataset_id)
111
119
  return dataset
112
120
 
113
121
 
@@ -160,6 +168,13 @@ class GetDatasetOperator(GoogleCloudBaseOperator):
160
168
  self.gcp_conn_id = gcp_conn_id
161
169
  self.impersonation_chain = impersonation_chain
162
170
 
171
+ @property
172
+ def extra_links_params(self) -> dict[str, Any]:
173
+ return {
174
+ "region": self.region,
175
+ "project_id": self.project_id,
176
+ }
177
+
163
178
  def execute(self, context: Context):
164
179
  hook = DatasetHook(
165
180
  gcp_conn_id=self.gcp_conn_id,
@@ -177,7 +192,7 @@ class GetDatasetOperator(GoogleCloudBaseOperator):
177
192
  timeout=self.timeout,
178
193
  metadata=self.metadata,
179
194
  )
180
- VertexAIDatasetLink.persist(context=context, task_instance=self, dataset_id=self.dataset_id)
195
+ VertexAIDatasetLink.persist(context=context, dataset_id=self.dataset_id)
181
196
  self.log.info("Dataset was gotten.")
182
197
  return Dataset.to_dict(dataset_obj)
183
198
  except NotFound:
@@ -321,7 +336,21 @@ class ExportDataOperator(GoogleCloudBaseOperator):
321
336
  self.log.info("Export was done successfully")
322
337
 
323
338
 
324
- class ImportDataOperator(GoogleCloudBaseOperator):
339
+ class DatasetImportDataResultsCheckHelper:
340
+ """Helper utils to verify import dataset data results."""
341
+
342
+ @staticmethod
343
+ def _get_number_of_ds_items(dataset, total_key_name):
344
+ number_of_items = type(dataset).to_dict(dataset).get(total_key_name, 0)
345
+ return number_of_items
346
+
347
+ @staticmethod
348
+ def _raise_for_empty_import_result(dataset_id, initial_size, size_after_import):
349
+ if int(size_after_import) - int(initial_size) <= 0:
350
+ raise AirflowException(f"Empty results of data import for the dataset_id {dataset_id}.")
351
+
352
+
353
+ class ImportDataOperator(GoogleCloudBaseOperator, DatasetImportDataResultsCheckHelper):
325
354
  """
326
355
  Imports data into a Dataset.
327
356
 
@@ -342,6 +371,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
342
371
  If set as a sequence, the identities from the list must grant
343
372
  Service Account Token Creator IAM role to the directly preceding identity, with first
344
373
  account from the list granting this role to the originating account (templated).
374
+ :param raise_for_empty_result: Raise an error if no additional data has been populated after the import.
345
375
  """
346
376
 
347
377
  template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")
@@ -358,6 +388,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
358
388
  metadata: Sequence[tuple[str, str]] = (),
359
389
  gcp_conn_id: str = "google_cloud_default",
360
390
  impersonation_chain: str | Sequence[str] | None = None,
391
+ raise_for_empty_result: bool = False,
361
392
  **kwargs,
362
393
  ) -> None:
363
394
  super().__init__(**kwargs)
@@ -370,13 +401,24 @@ class ImportDataOperator(GoogleCloudBaseOperator):
370
401
  self.metadata = metadata
371
402
  self.gcp_conn_id = gcp_conn_id
372
403
  self.impersonation_chain = impersonation_chain
404
+ self.raise_for_empty_result = raise_for_empty_result
373
405
 
374
406
  def execute(self, context: Context):
375
407
  hook = DatasetHook(
376
408
  gcp_conn_id=self.gcp_conn_id,
377
409
  impersonation_chain=self.impersonation_chain,
378
410
  )
379
-
411
+ initial_dataset_size = self._get_number_of_ds_items(
412
+ dataset=hook.get_dataset(
413
+ dataset=self.dataset_id,
414
+ project_id=self.project_id,
415
+ region=self.region,
416
+ retry=self.retry,
417
+ timeout=self.timeout,
418
+ metadata=self.metadata,
419
+ ),
420
+ total_key_name="data_item_count",
421
+ )
380
422
  self.log.info("Importing data: %s", self.dataset_id)
381
423
  operation = hook.import_data(
382
424
  project_id=self.project_id,
@@ -388,7 +430,21 @@ class ImportDataOperator(GoogleCloudBaseOperator):
388
430
  metadata=self.metadata,
389
431
  )
390
432
  hook.wait_for_operation(timeout=self.timeout, operation=operation)
433
+ result_dataset_size = self._get_number_of_ds_items(
434
+ dataset=hook.get_dataset(
435
+ dataset=self.dataset_id,
436
+ project_id=self.project_id,
437
+ region=self.region,
438
+ retry=self.retry,
439
+ timeout=self.timeout,
440
+ metadata=self.metadata,
441
+ ),
442
+ total_key_name="data_item_count",
443
+ )
444
+ if self.raise_for_empty_result:
445
+ self._raise_for_empty_import_result(self.dataset_id, initial_dataset_size, result_dataset_size)
391
446
  self.log.info("Import was done successfully")
447
+ return {"total_data_items_imported": int(result_dataset_size) - int(initial_dataset_size)}
392
448
 
393
449
 
394
450
  class ListDatasetsOperator(GoogleCloudBaseOperator):
@@ -451,6 +507,12 @@ class ListDatasetsOperator(GoogleCloudBaseOperator):
451
507
  self.gcp_conn_id = gcp_conn_id
452
508
  self.impersonation_chain = impersonation_chain
453
509
 
510
+ @property
511
+ def extra_links_params(self) -> dict[str, Any]:
512
+ return {
513
+ "project_id": self.project_id,
514
+ }
515
+
454
516
  def execute(self, context: Context):
455
517
  hook = DatasetHook(
456
518
  gcp_conn_id=self.gcp_conn_id,
@@ -468,7 +530,7 @@ class ListDatasetsOperator(GoogleCloudBaseOperator):
468
530
  timeout=self.timeout,
469
531
  metadata=self.metadata,
470
532
  )
471
- VertexAIDatasetListLink.persist(context=context, task_instance=self)
533
+ VertexAIDatasetListLink.persist(context=context)
472
534
  return [Dataset.to_dict(result) for result in results]
473
535
 
474
536