apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/ads/hooks/ads.py +39 -5
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/bundles/__init__.py +16 -0
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/bigquery.py +166 -281
- airflow/providers/google/cloud/hooks/cloud_composer.py +287 -14
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_run.py +17 -9
- airflow/providers/google/cloud/hooks/cloud_sql.py +101 -22
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +27 -6
- airflow/providers/google/cloud/hooks/compute_ssh.py +5 -1
- airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
- airflow/providers/google/cloud/hooks/dataflow.py +71 -94
- airflow/providers/google/cloud/hooks/datafusion.py +1 -1
- airflow/providers/google/cloud/hooks/dataplex.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dataproc.py +72 -71
- airflow/providers/google/cloud/hooks/gcs.py +111 -14
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/looker.py +6 -1
- airflow/providers/google/cloud/hooks/mlengine.py +3 -2
- airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
- airflow/providers/google/cloud/hooks/spanner.py +73 -8
- airflow/providers/google/cloud/hooks/stackdriver.py +10 -8
- airflow/providers/google/cloud/hooks/translate.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -209
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +2 -2
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +27 -1
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/vision.py +2 -2
- airflow/providers/google/cloud/hooks/workflows.py +1 -1
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -13
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -96
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -95
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
- airflow/providers/google/cloud/links/managed_kafka.py +0 -70
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +17 -9
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +102 -1
- airflow/providers/google/cloud/openlineage/mixins.py +10 -8
- airflow/providers/google/cloud/openlineage/utils.py +15 -1
- airflow/providers/google/cloud/operators/alloy_db.py +70 -55
- airflow/providers/google/cloud/operators/bigquery.py +73 -636
- airflow/providers/google/cloud/operators/bigquery_dts.py +3 -5
- airflow/providers/google/cloud/operators/bigtable.py +36 -7
- airflow/providers/google/cloud/operators/cloud_base.py +21 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +2 -2
- airflow/providers/google/cloud/operators/cloud_build.py +75 -32
- airflow/providers/google/cloud/operators/cloud_composer.py +128 -40
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
- airflow/providers/google/cloud/operators/cloud_run.py +23 -5
- airflow/providers/google/cloud/operators/cloud_sql.py +8 -16
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -11
- airflow/providers/google/cloud/operators/compute.py +8 -40
- airflow/providers/google/cloud/operators/datacatalog.py +157 -21
- airflow/providers/google/cloud/operators/dataflow.py +38 -15
- airflow/providers/google/cloud/operators/dataform.py +15 -5
- airflow/providers/google/cloud/operators/datafusion.py +41 -20
- airflow/providers/google/cloud/operators/dataplex.py +193 -109
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +78 -35
- airflow/providers/google/cloud/operators/dataproc_metastore.py +96 -88
- airflow/providers/google/cloud/operators/datastore.py +22 -6
- airflow/providers/google/cloud/operators/dlp.py +6 -29
- airflow/providers/google/cloud/operators/functions.py +16 -7
- airflow/providers/google/cloud/operators/gcs.py +10 -8
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +60 -99
- airflow/providers/google/cloud/operators/looker.py +1 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +107 -52
- airflow/providers/google/cloud/operators/natural_language.py +1 -1
- airflow/providers/google/cloud/operators/pubsub.py +60 -14
- airflow/providers/google/cloud/operators/spanner.py +25 -12
- airflow/providers/google/cloud/operators/speech_to_text.py +1 -2
- airflow/providers/google/cloud/operators/stackdriver.py +1 -9
- airflow/providers/google/cloud/operators/tasks.py +1 -12
- airflow/providers/google/cloud/operators/text_to_speech.py +1 -2
- airflow/providers/google/cloud/operators/translate.py +40 -16
- airflow/providers/google/cloud/operators/translate_speech.py +1 -2
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +29 -9
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +54 -26
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +11 -9
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +30 -7
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
- airflow/providers/google/cloud/operators/vision.py +2 -2
- airflow/providers/google/cloud/operators/workflows.py +18 -15
- airflow/providers/google/cloud/sensors/bigquery.py +2 -2
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -2
- airflow/providers/google/cloud/sensors/bigtable.py +11 -4
- airflow/providers/google/cloud/sensors/cloud_composer.py +533 -29
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -2
- airflow/providers/google/cloud/sensors/dataflow.py +26 -9
- airflow/providers/google/cloud/sensors/dataform.py +2 -2
- airflow/providers/google/cloud/sensors/datafusion.py +4 -4
- airflow/providers/google/cloud/sensors/dataplex.py +2 -2
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -2
- airflow/providers/google/cloud/sensors/gcs.py +4 -4
- airflow/providers/google/cloud/sensors/looker.py +2 -2
- airflow/providers/google/cloud/sensors/pubsub.py +4 -4
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
- airflow/providers/google/cloud/sensors/workflows.py +2 -2
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +4 -4
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +20 -12
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +13 -4
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +75 -34
- airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_composer.py +302 -46
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +91 -1
- airflow/providers/google/cloud/triggers/dataflow.py +122 -0
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +14 -2
- airflow/providers/google/cloud/triggers/dataproc.py +122 -52
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +45 -27
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +15 -19
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +1 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/common/auth_backend/google_openid.py +4 -4
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +27 -8
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +3 -4
- airflow/providers/google/firebase/operators/firestore.py +2 -2
- airflow/providers/google/get_provider_info.py +56 -52
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +26 -1
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +1 -2
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +3 -63
- airflow/providers/google/suite/hooks/calendar.py +1 -1
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +92 -48
- apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
- apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/hooks/automl.py +0 -673
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1362
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -112
- apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -21,7 +21,7 @@
|
|
|
21
21
|
from __future__ import annotations
|
|
22
22
|
|
|
23
23
|
from collections.abc import Sequence
|
|
24
|
-
from typing import TYPE_CHECKING
|
|
24
|
+
from typing import TYPE_CHECKING, Any
|
|
25
25
|
|
|
26
26
|
from google.api_core.exceptions import NotFound
|
|
27
27
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
@@ -29,6 +29,7 @@ from google.cloud.aiplatform import datasets
|
|
|
29
29
|
from google.cloud.aiplatform.models import Model
|
|
30
30
|
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
|
|
31
31
|
|
|
32
|
+
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
32
33
|
from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
|
|
33
34
|
from airflow.providers.google.cloud.links.vertex_ai import (
|
|
34
35
|
VertexAIModelLink,
|
|
@@ -36,11 +37,12 @@ from airflow.providers.google.cloud.links.vertex_ai import (
|
|
|
36
37
|
VertexAITrainingPipelinesLink,
|
|
37
38
|
)
|
|
38
39
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
40
|
+
from airflow.providers.google.common.deprecated import deprecated
|
|
39
41
|
|
|
40
42
|
if TYPE_CHECKING:
|
|
41
43
|
from google.api_core.retry import Retry
|
|
42
44
|
|
|
43
|
-
from airflow.
|
|
45
|
+
from airflow.providers.common.compat.sdk import Context
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
@@ -91,6 +93,13 @@ class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
|
91
93
|
self.impersonation_chain = impersonation_chain
|
|
92
94
|
self.hook: AutoMLHook | None = None
|
|
93
95
|
|
|
96
|
+
@property
|
|
97
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
98
|
+
return {
|
|
99
|
+
"region": self.region,
|
|
100
|
+
"project_id": self.project_id,
|
|
101
|
+
}
|
|
102
|
+
|
|
94
103
|
def on_kill(self) -> None:
|
|
95
104
|
"""Act as a callback called when the operator is killed; cancel any running job."""
|
|
96
105
|
if self.hook:
|
|
@@ -242,12 +251,12 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
|
242
251
|
if model:
|
|
243
252
|
result = Model.to_dict(model)
|
|
244
253
|
model_id = self.hook.extract_model_id(result)
|
|
245
|
-
|
|
246
|
-
VertexAIModelLink.persist(context=context,
|
|
254
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
255
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
247
256
|
else:
|
|
248
257
|
result = model # type: ignore
|
|
249
|
-
|
|
250
|
-
VertexAITrainingLink.persist(context=context,
|
|
258
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
259
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
251
260
|
return result
|
|
252
261
|
|
|
253
262
|
|
|
@@ -334,12 +343,12 @@ class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
|
334
343
|
if model:
|
|
335
344
|
result = Model.to_dict(model)
|
|
336
345
|
model_id = self.hook.extract_model_id(result)
|
|
337
|
-
|
|
338
|
-
VertexAIModelLink.persist(context=context,
|
|
346
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
347
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
339
348
|
else:
|
|
340
349
|
result = model # type: ignore
|
|
341
|
-
|
|
342
|
-
VertexAITrainingLink.persist(context=context,
|
|
350
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
351
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
343
352
|
return result
|
|
344
353
|
|
|
345
354
|
|
|
@@ -457,15 +466,20 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
|
457
466
|
if model:
|
|
458
467
|
result = Model.to_dict(model)
|
|
459
468
|
model_id = self.hook.extract_model_id(result)
|
|
460
|
-
|
|
461
|
-
VertexAIModelLink.persist(context=context,
|
|
469
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
470
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
462
471
|
else:
|
|
463
472
|
result = model # type: ignore
|
|
464
|
-
|
|
465
|
-
VertexAITrainingLink.persist(context=context,
|
|
473
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
474
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
466
475
|
return result
|
|
467
476
|
|
|
468
477
|
|
|
478
|
+
@deprecated(
|
|
479
|
+
planned_removal_date="March 24, 2026",
|
|
480
|
+
use_instead="airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator",
|
|
481
|
+
category=AirflowProviderDeprecationWarning,
|
|
482
|
+
)
|
|
469
483
|
class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
470
484
|
"""Create Auto ML Video Training job."""
|
|
471
485
|
|
|
@@ -531,12 +545,12 @@ class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
|
531
545
|
if model:
|
|
532
546
|
result = Model.to_dict(model)
|
|
533
547
|
model_id = self.hook.extract_model_id(result)
|
|
534
|
-
|
|
535
|
-
VertexAIModelLink.persist(context=context,
|
|
548
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
549
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
536
550
|
else:
|
|
537
551
|
result = model # type: ignore
|
|
538
|
-
|
|
539
|
-
VertexAITrainingLink.persist(context=context,
|
|
552
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
553
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
540
554
|
return result
|
|
541
555
|
|
|
542
556
|
|
|
@@ -640,6 +654,12 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
|
|
|
640
654
|
self.gcp_conn_id = gcp_conn_id
|
|
641
655
|
self.impersonation_chain = impersonation_chain
|
|
642
656
|
|
|
657
|
+
@property
|
|
658
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
659
|
+
return {
|
|
660
|
+
"project_id": self.project_id,
|
|
661
|
+
}
|
|
662
|
+
|
|
643
663
|
def execute(self, context: Context):
|
|
644
664
|
hook = AutoMLHook(
|
|
645
665
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -656,5 +676,5 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
|
|
|
656
676
|
timeout=self.timeout,
|
|
657
677
|
metadata=self.metadata,
|
|
658
678
|
)
|
|
659
|
-
VertexAITrainingPipelinesLink.persist(context=context
|
|
679
|
+
VertexAITrainingPipelinesLink.persist(context=context)
|
|
660
680
|
return [TrainingPipeline.to_dict(result) for result in results]
|
|
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
|
|
|
42
42
|
from google.api_core.retry import Retry
|
|
43
43
|
from google.cloud.aiplatform import BatchPredictionJob as BatchPredictionJobObject, Model, explain
|
|
44
44
|
|
|
45
|
-
from airflow.
|
|
45
|
+
from airflow.providers.common.compat.sdk import Context
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
@@ -231,6 +231,13 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
|
231
231
|
impersonation_chain=self.impersonation_chain,
|
|
232
232
|
)
|
|
233
233
|
|
|
234
|
+
@property
|
|
235
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
236
|
+
return {
|
|
237
|
+
"region": self.region,
|
|
238
|
+
"project_id": self.project_id,
|
|
239
|
+
}
|
|
240
|
+
|
|
234
241
|
def execute(self, context: Context):
|
|
235
242
|
self.log.info("Creating Batch prediction job")
|
|
236
243
|
batch_prediction_job: BatchPredictionJobObject = self.hook.submit_batch_prediction_job(
|
|
@@ -262,9 +269,10 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
|
262
269
|
batch_prediction_job_id = batch_prediction_job.name
|
|
263
270
|
self.log.info("Batch prediction job was created. Job id: %s", batch_prediction_job_id)
|
|
264
271
|
|
|
265
|
-
|
|
272
|
+
context["ti"].xcom_push(key="batch_prediction_job_id", value=batch_prediction_job_id)
|
|
266
273
|
VertexAIBatchPredictionJobLink.persist(
|
|
267
|
-
context=context,
|
|
274
|
+
context=context,
|
|
275
|
+
batch_prediction_job_id=batch_prediction_job_id,
|
|
268
276
|
)
|
|
269
277
|
|
|
270
278
|
if self.deferrable:
|
|
@@ -295,13 +303,11 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
|
295
303
|
job: dict[str, Any] = event["job"]
|
|
296
304
|
self.log.info("Batch prediction job %s created and completed successfully.", job["name"])
|
|
297
305
|
job_id = self.hook.extract_batch_prediction_job_id(job)
|
|
298
|
-
|
|
299
|
-
context,
|
|
306
|
+
context["ti"].xcom_push(
|
|
300
307
|
key="batch_prediction_job_id",
|
|
301
308
|
value=job_id,
|
|
302
309
|
)
|
|
303
|
-
|
|
304
|
-
context,
|
|
310
|
+
context["ti"].xcom_push(
|
|
305
311
|
key="training_conf",
|
|
306
312
|
value={
|
|
307
313
|
"training_conf_id": job_id,
|
|
@@ -427,6 +433,13 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
|
427
433
|
self.gcp_conn_id = gcp_conn_id
|
|
428
434
|
self.impersonation_chain = impersonation_chain
|
|
429
435
|
|
|
436
|
+
@property
|
|
437
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
438
|
+
return {
|
|
439
|
+
"region": self.region,
|
|
440
|
+
"project_id": self.project_id,
|
|
441
|
+
}
|
|
442
|
+
|
|
430
443
|
def execute(self, context: Context):
|
|
431
444
|
hook = BatchPredictionJobHook(
|
|
432
445
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -445,7 +458,8 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
|
445
458
|
)
|
|
446
459
|
self.log.info("Batch prediction job was gotten.")
|
|
447
460
|
VertexAIBatchPredictionJobLink.persist(
|
|
448
|
-
context=context,
|
|
461
|
+
context=context,
|
|
462
|
+
batch_prediction_job_id=self.batch_prediction_job,
|
|
449
463
|
)
|
|
450
464
|
return BatchPredictionJob.to_dict(result)
|
|
451
465
|
except NotFound:
|
|
@@ -517,6 +531,12 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
|
|
|
517
531
|
self.gcp_conn_id = gcp_conn_id
|
|
518
532
|
self.impersonation_chain = impersonation_chain
|
|
519
533
|
|
|
534
|
+
@property
|
|
535
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
536
|
+
return {
|
|
537
|
+
"project_id": self.project_id,
|
|
538
|
+
}
|
|
539
|
+
|
|
520
540
|
def execute(self, context: Context):
|
|
521
541
|
hook = BatchPredictionJobHook(
|
|
522
542
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -533,5 +553,5 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
|
|
|
533
553
|
timeout=self.timeout,
|
|
534
554
|
metadata=self.metadata,
|
|
535
555
|
)
|
|
536
|
-
VertexAIBatchPredictionJobListLink.persist(context=context
|
|
556
|
+
VertexAIBatchPredictionJobListLink.persist(context=context)
|
|
537
557
|
return [BatchPredictionJob.to_dict(result) for result in results]
|
|
@@ -51,8 +51,9 @@ if TYPE_CHECKING:
|
|
|
51
51
|
CustomPythonPackageTrainingJob,
|
|
52
52
|
CustomTrainingJob,
|
|
53
53
|
)
|
|
54
|
+
from google.cloud.aiplatform_v1.types import PscInterfaceConfig
|
|
54
55
|
|
|
55
|
-
from airflow.
|
|
56
|
+
from airflow.providers.common.compat.sdk import Context
|
|
56
57
|
|
|
57
58
|
|
|
58
59
|
class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
@@ -110,6 +111,7 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
|
110
111
|
predefined_split_column_name: str | None = None,
|
|
111
112
|
timestamp_split_column_name: str | None = None,
|
|
112
113
|
tensorboard: str | None = None,
|
|
114
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
113
115
|
gcp_conn_id: str = "google_cloud_default",
|
|
114
116
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
115
117
|
**kwargs,
|
|
@@ -166,21 +168,29 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
|
166
168
|
self.predefined_split_column_name = predefined_split_column_name
|
|
167
169
|
self.timestamp_split_column_name = timestamp_split_column_name
|
|
168
170
|
self.tensorboard = tensorboard
|
|
171
|
+
self.psc_interface_config = psc_interface_config
|
|
169
172
|
# END Run param
|
|
170
173
|
self.gcp_conn_id = gcp_conn_id
|
|
171
174
|
self.impersonation_chain = impersonation_chain
|
|
172
175
|
|
|
176
|
+
@property
|
|
177
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
178
|
+
return {
|
|
179
|
+
"region": self.region,
|
|
180
|
+
"project_id": self.project_id,
|
|
181
|
+
}
|
|
182
|
+
|
|
173
183
|
def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
|
|
174
184
|
if event["status"] == "error":
|
|
175
185
|
raise AirflowException(event["message"])
|
|
176
186
|
training_pipeline = event["job"]
|
|
177
187
|
custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(training_pipeline)
|
|
178
|
-
|
|
188
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
|
179
189
|
try:
|
|
180
190
|
model = training_pipeline["model_to_upload"]
|
|
181
191
|
model_id = self.hook.extract_model_id(model)
|
|
182
|
-
|
|
183
|
-
VertexAIModelLink.persist(context=context,
|
|
192
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
193
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
184
194
|
return model
|
|
185
195
|
except KeyError:
|
|
186
196
|
self.log.warning(
|
|
@@ -466,6 +476,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
466
476
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
467
477
|
For more information on configuring your service account please visit:
|
|
468
478
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
479
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
480
|
+
training.
|
|
469
481
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
470
482
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
471
483
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -579,18 +591,19 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
579
591
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
580
592
|
tensorboard=self.tensorboard,
|
|
581
593
|
sync=True,
|
|
594
|
+
psc_interface_config=self.psc_interface_config,
|
|
582
595
|
)
|
|
583
596
|
|
|
584
597
|
if model:
|
|
585
598
|
result = Model.to_dict(model)
|
|
586
599
|
model_id = self.hook.extract_model_id(result)
|
|
587
|
-
|
|
588
|
-
VertexAIModelLink.persist(context=context,
|
|
600
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
601
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
589
602
|
else:
|
|
590
603
|
result = model # type: ignore
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
VertexAITrainingLink.persist(context=context,
|
|
604
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
605
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
|
606
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
594
607
|
return result
|
|
595
608
|
|
|
596
609
|
def invoke_defer(self, context: Context) -> None:
|
|
@@ -645,11 +658,12 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
645
658
|
predefined_split_column_name=self.predefined_split_column_name,
|
|
646
659
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
647
660
|
tensorboard=self.tensorboard,
|
|
661
|
+
psc_interface_config=self.psc_interface_config,
|
|
648
662
|
)
|
|
649
663
|
custom_container_training_job_obj.wait_for_resource_creation()
|
|
650
664
|
training_pipeline_id: str = custom_container_training_job_obj.name
|
|
651
|
-
|
|
652
|
-
VertexAITrainingLink.persist(context=context,
|
|
665
|
+
context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
|
|
666
|
+
VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
|
|
653
667
|
self.defer(
|
|
654
668
|
trigger=CustomContainerTrainingJobTrigger(
|
|
655
669
|
conn_id=self.gcp_conn_id,
|
|
@@ -924,6 +938,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
|
924
938
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
925
939
|
For more information on configuring your service account please visit:
|
|
926
940
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
941
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
942
|
+
training.
|
|
927
943
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
928
944
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
929
945
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -1036,18 +1052,19 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
|
1036
1052
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
1037
1053
|
tensorboard=self.tensorboard,
|
|
1038
1054
|
sync=True,
|
|
1055
|
+
psc_interface_config=self.psc_interface_config,
|
|
1039
1056
|
)
|
|
1040
1057
|
|
|
1041
1058
|
if model:
|
|
1042
1059
|
result = Model.to_dict(model)
|
|
1043
1060
|
model_id = self.hook.extract_model_id(result)
|
|
1044
|
-
|
|
1045
|
-
VertexAIModelLink.persist(context=context,
|
|
1061
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
1062
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
1046
1063
|
else:
|
|
1047
1064
|
result = model # type: ignore
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
VertexAITrainingLink.persist(context=context,
|
|
1065
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
1066
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
|
1067
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
1051
1068
|
return result
|
|
1052
1069
|
|
|
1053
1070
|
def invoke_defer(self, context: Context) -> None:
|
|
@@ -1103,11 +1120,12 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
|
1103
1120
|
predefined_split_column_name=self.predefined_split_column_name,
|
|
1104
1121
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
1105
1122
|
tensorboard=self.tensorboard,
|
|
1123
|
+
psc_interface_config=self.psc_interface_config,
|
|
1106
1124
|
)
|
|
1107
1125
|
custom_python_training_job_obj.wait_for_resource_creation()
|
|
1108
1126
|
training_pipeline_id: str = custom_python_training_job_obj.name
|
|
1109
|
-
|
|
1110
|
-
VertexAITrainingLink.persist(context=context,
|
|
1127
|
+
context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
|
|
1128
|
+
VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
|
|
1111
1129
|
self.defer(
|
|
1112
1130
|
trigger=CustomPythonPackageTrainingJobTrigger(
|
|
1113
1131
|
conn_id=self.gcp_conn_id,
|
|
@@ -1382,6 +1400,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
1382
1400
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
1383
1401
|
For more information on configuring your service account please visit:
|
|
1384
1402
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
1403
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
1404
|
+
training.
|
|
1385
1405
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
1386
1406
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
1387
1407
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -1499,18 +1519,19 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
1499
1519
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
1500
1520
|
tensorboard=self.tensorboard,
|
|
1501
1521
|
sync=True,
|
|
1522
|
+
psc_interface_config=None,
|
|
1502
1523
|
)
|
|
1503
1524
|
|
|
1504
1525
|
if model:
|
|
1505
1526
|
result = Model.to_dict(model)
|
|
1506
1527
|
model_id = self.hook.extract_model_id(result)
|
|
1507
|
-
|
|
1508
|
-
VertexAIModelLink.persist(context=context,
|
|
1528
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
1529
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
1509
1530
|
else:
|
|
1510
1531
|
result = model # type: ignore
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
VertexAITrainingLink.persist(context=context,
|
|
1532
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
1533
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
|
1534
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
1514
1535
|
return result
|
|
1515
1536
|
|
|
1516
1537
|
def invoke_defer(self, context: Context) -> None:
|
|
@@ -1566,11 +1587,12 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
1566
1587
|
predefined_split_column_name=self.predefined_split_column_name,
|
|
1567
1588
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
1568
1589
|
tensorboard=self.tensorboard,
|
|
1590
|
+
psc_interface_config=self.psc_interface_config,
|
|
1569
1591
|
)
|
|
1570
1592
|
custom_training_job_obj.wait_for_resource_creation()
|
|
1571
1593
|
training_pipeline_id: str = custom_training_job_obj.name
|
|
1572
|
-
|
|
1573
|
-
VertexAITrainingLink.persist(context=context,
|
|
1594
|
+
context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
|
|
1595
|
+
VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
|
|
1574
1596
|
self.defer(
|
|
1575
1597
|
trigger=CustomTrainingJobTrigger(
|
|
1576
1598
|
conn_id=self.gcp_conn_id,
|
|
@@ -1748,6 +1770,12 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
|
|
|
1748
1770
|
self.gcp_conn_id = gcp_conn_id
|
|
1749
1771
|
self.impersonation_chain = impersonation_chain
|
|
1750
1772
|
|
|
1773
|
+
@property
|
|
1774
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
1775
|
+
return {
|
|
1776
|
+
"project_id": self.project_id,
|
|
1777
|
+
}
|
|
1778
|
+
|
|
1751
1779
|
def execute(self, context: Context):
|
|
1752
1780
|
hook = CustomJobHook(
|
|
1753
1781
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -1764,5 +1792,5 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
|
|
|
1764
1792
|
timeout=self.timeout,
|
|
1765
1793
|
metadata=self.metadata,
|
|
1766
1794
|
)
|
|
1767
|
-
VertexAITrainingPipelinesLink.persist(context=context
|
|
1795
|
+
VertexAITrainingPipelinesLink.persist(context=context)
|
|
1768
1796
|
return [TrainingPipeline.to_dict(result) for result in results]
|
|
@@ -20,12 +20,13 @@
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
|
-
from typing import TYPE_CHECKING
|
|
23
|
+
from typing import TYPE_CHECKING, Any
|
|
24
24
|
|
|
25
25
|
from google.api_core.exceptions import NotFound
|
|
26
26
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
27
27
|
from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, ImportDataConfig
|
|
28
28
|
|
|
29
|
+
from airflow.exceptions import AirflowException
|
|
29
30
|
from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook
|
|
30
31
|
from airflow.providers.google.cloud.links.vertex_ai import VertexAIDatasetLink, VertexAIDatasetListLink
|
|
31
32
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
@@ -34,7 +35,7 @@ if TYPE_CHECKING:
|
|
|
34
35
|
from google.api_core.retry import Retry
|
|
35
36
|
from google.protobuf.field_mask_pb2 import FieldMask
|
|
36
37
|
|
|
37
|
-
from airflow.
|
|
38
|
+
from airflow.providers.common.compat.sdk import Context
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
class CreateDatasetOperator(GoogleCloudBaseOperator):
|
|
@@ -85,6 +86,13 @@ class CreateDatasetOperator(GoogleCloudBaseOperator):
|
|
|
85
86
|
self.gcp_conn_id = gcp_conn_id
|
|
86
87
|
self.impersonation_chain = impersonation_chain
|
|
87
88
|
|
|
89
|
+
@property
|
|
90
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
91
|
+
return {
|
|
92
|
+
"region": self.region,
|
|
93
|
+
"project_id": self.project_id,
|
|
94
|
+
}
|
|
95
|
+
|
|
88
96
|
def execute(self, context: Context):
|
|
89
97
|
hook = DatasetHook(
|
|
90
98
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -106,8 +114,8 @@ class CreateDatasetOperator(GoogleCloudBaseOperator):
|
|
|
106
114
|
dataset_id = hook.extract_dataset_id(dataset)
|
|
107
115
|
self.log.info("Dataset was created. Dataset id: %s", dataset_id)
|
|
108
116
|
|
|
109
|
-
|
|
110
|
-
VertexAIDatasetLink.persist(context=context,
|
|
117
|
+
context["ti"].xcom_push(key="dataset_id", value=dataset_id)
|
|
118
|
+
VertexAIDatasetLink.persist(context=context, dataset_id=dataset_id)
|
|
111
119
|
return dataset
|
|
112
120
|
|
|
113
121
|
|
|
@@ -160,6 +168,13 @@ class GetDatasetOperator(GoogleCloudBaseOperator):
|
|
|
160
168
|
self.gcp_conn_id = gcp_conn_id
|
|
161
169
|
self.impersonation_chain = impersonation_chain
|
|
162
170
|
|
|
171
|
+
@property
|
|
172
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
173
|
+
return {
|
|
174
|
+
"region": self.region,
|
|
175
|
+
"project_id": self.project_id,
|
|
176
|
+
}
|
|
177
|
+
|
|
163
178
|
def execute(self, context: Context):
|
|
164
179
|
hook = DatasetHook(
|
|
165
180
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -177,7 +192,7 @@ class GetDatasetOperator(GoogleCloudBaseOperator):
|
|
|
177
192
|
timeout=self.timeout,
|
|
178
193
|
metadata=self.metadata,
|
|
179
194
|
)
|
|
180
|
-
VertexAIDatasetLink.persist(context=context,
|
|
195
|
+
VertexAIDatasetLink.persist(context=context, dataset_id=self.dataset_id)
|
|
181
196
|
self.log.info("Dataset was gotten.")
|
|
182
197
|
return Dataset.to_dict(dataset_obj)
|
|
183
198
|
except NotFound:
|
|
@@ -321,7 +336,21 @@ class ExportDataOperator(GoogleCloudBaseOperator):
|
|
|
321
336
|
self.log.info("Export was done successfully")
|
|
322
337
|
|
|
323
338
|
|
|
324
|
-
class
|
|
339
|
+
class DatasetImportDataResultsCheckHelper:
|
|
340
|
+
"""Helper utils to verify import dataset data results."""
|
|
341
|
+
|
|
342
|
+
@staticmethod
|
|
343
|
+
def _get_number_of_ds_items(dataset, total_key_name):
|
|
344
|
+
number_of_items = type(dataset).to_dict(dataset).get(total_key_name, 0)
|
|
345
|
+
return number_of_items
|
|
346
|
+
|
|
347
|
+
@staticmethod
|
|
348
|
+
def _raise_for_empty_import_result(dataset_id, initial_size, size_after_import):
|
|
349
|
+
if int(size_after_import) - int(initial_size) <= 0:
|
|
350
|
+
raise AirflowException(f"Empty results of data import for the dataset_id {dataset_id}.")
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class ImportDataOperator(GoogleCloudBaseOperator, DatasetImportDataResultsCheckHelper):
|
|
325
354
|
"""
|
|
326
355
|
Imports data into a Dataset.
|
|
327
356
|
|
|
@@ -342,6 +371,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
|
|
|
342
371
|
If set as a sequence, the identities from the list must grant
|
|
343
372
|
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
344
373
|
account from the list granting this role to the originating account (templated).
|
|
374
|
+
:param raise_for_empty_result: Raise an error if no additional data has been populated after the import.
|
|
345
375
|
"""
|
|
346
376
|
|
|
347
377
|
template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")
|
|
@@ -358,6 +388,7 @@ class ImportDataOperator(GoogleCloudBaseOperator):
|
|
|
358
388
|
metadata: Sequence[tuple[str, str]] = (),
|
|
359
389
|
gcp_conn_id: str = "google_cloud_default",
|
|
360
390
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
391
|
+
raise_for_empty_result: bool = False,
|
|
361
392
|
**kwargs,
|
|
362
393
|
) -> None:
|
|
363
394
|
super().__init__(**kwargs)
|
|
@@ -370,13 +401,24 @@ class ImportDataOperator(GoogleCloudBaseOperator):
|
|
|
370
401
|
self.metadata = metadata
|
|
371
402
|
self.gcp_conn_id = gcp_conn_id
|
|
372
403
|
self.impersonation_chain = impersonation_chain
|
|
404
|
+
self.raise_for_empty_result = raise_for_empty_result
|
|
373
405
|
|
|
374
406
|
def execute(self, context: Context):
|
|
375
407
|
hook = DatasetHook(
|
|
376
408
|
gcp_conn_id=self.gcp_conn_id,
|
|
377
409
|
impersonation_chain=self.impersonation_chain,
|
|
378
410
|
)
|
|
379
|
-
|
|
411
|
+
initial_dataset_size = self._get_number_of_ds_items(
|
|
412
|
+
dataset=hook.get_dataset(
|
|
413
|
+
dataset=self.dataset_id,
|
|
414
|
+
project_id=self.project_id,
|
|
415
|
+
region=self.region,
|
|
416
|
+
retry=self.retry,
|
|
417
|
+
timeout=self.timeout,
|
|
418
|
+
metadata=self.metadata,
|
|
419
|
+
),
|
|
420
|
+
total_key_name="data_item_count",
|
|
421
|
+
)
|
|
380
422
|
self.log.info("Importing data: %s", self.dataset_id)
|
|
381
423
|
operation = hook.import_data(
|
|
382
424
|
project_id=self.project_id,
|
|
@@ -388,7 +430,21 @@ class ImportDataOperator(GoogleCloudBaseOperator):
|
|
|
388
430
|
metadata=self.metadata,
|
|
389
431
|
)
|
|
390
432
|
hook.wait_for_operation(timeout=self.timeout, operation=operation)
|
|
433
|
+
result_dataset_size = self._get_number_of_ds_items(
|
|
434
|
+
dataset=hook.get_dataset(
|
|
435
|
+
dataset=self.dataset_id,
|
|
436
|
+
project_id=self.project_id,
|
|
437
|
+
region=self.region,
|
|
438
|
+
retry=self.retry,
|
|
439
|
+
timeout=self.timeout,
|
|
440
|
+
metadata=self.metadata,
|
|
441
|
+
),
|
|
442
|
+
total_key_name="data_item_count",
|
|
443
|
+
)
|
|
444
|
+
if self.raise_for_empty_result:
|
|
445
|
+
self._raise_for_empty_import_result(self.dataset_id, initial_dataset_size, result_dataset_size)
|
|
391
446
|
self.log.info("Import was done successfully")
|
|
447
|
+
return {"total_data_items_imported": int(result_dataset_size) - int(initial_dataset_size)}
|
|
392
448
|
|
|
393
449
|
|
|
394
450
|
class ListDatasetsOperator(GoogleCloudBaseOperator):
|
|
@@ -451,6 +507,12 @@ class ListDatasetsOperator(GoogleCloudBaseOperator):
|
|
|
451
507
|
self.gcp_conn_id = gcp_conn_id
|
|
452
508
|
self.impersonation_chain = impersonation_chain
|
|
453
509
|
|
|
510
|
+
@property
|
|
511
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
512
|
+
return {
|
|
513
|
+
"project_id": self.project_id,
|
|
514
|
+
}
|
|
515
|
+
|
|
454
516
|
def execute(self, context: Context):
|
|
455
517
|
hook = DatasetHook(
|
|
456
518
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -468,7 +530,7 @@ class ListDatasetsOperator(GoogleCloudBaseOperator):
|
|
|
468
530
|
timeout=self.timeout,
|
|
469
531
|
metadata=self.metadata,
|
|
470
532
|
)
|
|
471
|
-
VertexAIDatasetListLink.persist(context=context
|
|
533
|
+
VertexAIDatasetListLink.persist(context=context)
|
|
472
534
|
return [Dataset.to_dict(result) for result in results]
|
|
473
535
|
|
|
474
536
|
|