apache-airflow-providers-google 14.0.0__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/LICENSES.txt +14 -0
- airflow/providers/google/3rd-party-licenses/NOTICE +5 -0
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/_vendor/__init__.py +0 -0
- airflow/providers/google/_vendor/json_merge_patch.py +91 -0
- airflow/providers/google/ads/hooks/ads.py +52 -43
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +3 -19
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/_internal_client/secret_manager_client.py +3 -2
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/alloy_db.py +2 -3
- airflow/providers/google/cloud/hooks/bigquery.py +195 -318
- airflow/providers/google/cloud/hooks/bigquery_dts.py +8 -8
- airflow/providers/google/cloud/hooks/bigtable.py +3 -2
- airflow/providers/google/cloud/hooks/cloud_batch.py +8 -9
- airflow/providers/google/cloud/hooks/cloud_build.py +6 -65
- airflow/providers/google/cloud/hooks/cloud_composer.py +292 -24
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +4 -3
- airflow/providers/google/cloud/hooks/cloud_run.py +20 -11
- airflow/providers/google/cloud/hooks/cloud_sql.py +136 -64
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +35 -15
- airflow/providers/google/cloud/hooks/compute.py +7 -6
- airflow/providers/google/cloud/hooks/compute_ssh.py +7 -4
- airflow/providers/google/cloud/hooks/datacatalog.py +12 -3
- airflow/providers/google/cloud/hooks/dataflow.py +87 -242
- airflow/providers/google/cloud/hooks/dataform.py +9 -14
- airflow/providers/google/cloud/hooks/datafusion.py +7 -9
- airflow/providers/google/cloud/hooks/dataplex.py +13 -12
- airflow/providers/google/cloud/hooks/dataprep.py +2 -2
- airflow/providers/google/cloud/hooks/dataproc.py +76 -74
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +4 -3
- airflow/providers/google/cloud/hooks/dlp.py +5 -4
- airflow/providers/google/cloud/hooks/gcs.py +144 -33
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kms.py +3 -2
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +22 -17
- airflow/providers/google/cloud/hooks/looker.py +6 -1
- airflow/providers/google/cloud/hooks/managed_kafka.py +227 -3
- airflow/providers/google/cloud/hooks/mlengine.py +7 -8
- airflow/providers/google/cloud/hooks/natural_language.py +3 -2
- airflow/providers/google/cloud/hooks/os_login.py +3 -2
- airflow/providers/google/cloud/hooks/pubsub.py +6 -6
- airflow/providers/google/cloud/hooks/secret_manager.py +105 -12
- airflow/providers/google/cloud/hooks/spanner.py +75 -10
- airflow/providers/google/cloud/hooks/speech_to_text.py +3 -2
- airflow/providers/google/cloud/hooks/stackdriver.py +18 -18
- airflow/providers/google/cloud/hooks/tasks.py +4 -3
- airflow/providers/google/cloud/hooks/text_to_speech.py +3 -2
- airflow/providers/google/cloud/hooks/translate.py +8 -17
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -222
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +9 -15
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +33 -283
- airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +5 -12
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +6 -12
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +311 -10
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +7 -13
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +8 -12
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +6 -12
- airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +3 -2
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/video_intelligence.py +3 -2
- airflow/providers/google/cloud/hooks/vision.py +7 -7
- airflow/providers/google/cloud/hooks/workflows.py +4 -3
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -7
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -46
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -90
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -89
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +11 -61
- airflow/providers/google/cloud/links/managed_kafka.py +11 -51
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +166 -118
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +14 -9
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +141 -40
- airflow/providers/google/cloud/openlineage/mixins.py +14 -13
- airflow/providers/google/cloud/openlineage/utils.py +19 -3
- airflow/providers/google/cloud/operators/alloy_db.py +76 -61
- airflow/providers/google/cloud/operators/bigquery.py +104 -667
- airflow/providers/google/cloud/operators/bigquery_dts.py +12 -12
- airflow/providers/google/cloud/operators/bigtable.py +38 -7
- airflow/providers/google/cloud/operators/cloud_base.py +22 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +18 -18
- airflow/providers/google/cloud/operators/cloud_build.py +80 -36
- airflow/providers/google/cloud/operators/cloud_composer.py +157 -71
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +74 -46
- airflow/providers/google/cloud/operators/cloud_run.py +39 -20
- airflow/providers/google/cloud/operators/cloud_sql.py +46 -61
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -14
- airflow/providers/google/cloud/operators/compute.py +18 -50
- airflow/providers/google/cloud/operators/datacatalog.py +167 -29
- airflow/providers/google/cloud/operators/dataflow.py +38 -15
- airflow/providers/google/cloud/operators/dataform.py +19 -7
- airflow/providers/google/cloud/operators/datafusion.py +43 -43
- airflow/providers/google/cloud/operators/dataplex.py +212 -126
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +134 -207
- airflow/providers/google/cloud/operators/dataproc_metastore.py +102 -84
- airflow/providers/google/cloud/operators/datastore.py +22 -6
- airflow/providers/google/cloud/operators/dlp.py +24 -45
- airflow/providers/google/cloud/operators/functions.py +21 -14
- airflow/providers/google/cloud/operators/gcs.py +15 -12
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +115 -106
- airflow/providers/google/cloud/operators/looker.py +1 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +362 -40
- airflow/providers/google/cloud/operators/natural_language.py +5 -3
- airflow/providers/google/cloud/operators/pubsub.py +69 -21
- airflow/providers/google/cloud/operators/spanner.py +53 -45
- airflow/providers/google/cloud/operators/speech_to_text.py +5 -4
- airflow/providers/google/cloud/operators/stackdriver.py +5 -11
- airflow/providers/google/cloud/operators/tasks.py +6 -15
- airflow/providers/google/cloud/operators/text_to_speech.py +4 -3
- airflow/providers/google/cloud/operators/translate.py +46 -20
- airflow/providers/google/cloud/operators/translate_speech.py +4 -3
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +44 -34
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +34 -12
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +62 -53
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +75 -11
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +48 -12
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +16 -12
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +62 -14
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +35 -10
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +5 -3
- airflow/providers/google/cloud/operators/vision.py +7 -5
- airflow/providers/google/cloud/operators/workflows.py +24 -19
- airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
- airflow/providers/google/cloud/sensors/bigquery.py +2 -2
- airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -4
- airflow/providers/google/cloud/sensors/bigtable.py +14 -6
- airflow/providers/google/cloud/sensors/cloud_composer.py +535 -33
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -5
- airflow/providers/google/cloud/sensors/dataflow.py +27 -10
- airflow/providers/google/cloud/sensors/dataform.py +2 -2
- airflow/providers/google/cloud/sensors/datafusion.py +4 -4
- airflow/providers/google/cloud/sensors/dataplex.py +7 -5
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +10 -9
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +4 -3
- airflow/providers/google/cloud/sensors/gcs.py +22 -21
- airflow/providers/google/cloud/sensors/looker.py +5 -5
- airflow/providers/google/cloud/sensors/pubsub.py +20 -20
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
- airflow/providers/google/cloud/sensors/workflows.py +6 -4
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +14 -13
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +18 -22
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -5
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +45 -38
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +44 -12
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +36 -14
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +75 -34
- airflow/providers/google/cloud/triggers/bigquery_dts.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_batch.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_build.py +3 -2
- airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +96 -5
- airflow/providers/google/cloud/triggers/dataflow.py +125 -2
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +16 -3
- airflow/providers/google/cloud/triggers/dataproc.py +124 -53
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +46 -28
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +17 -20
- airflow/providers/google/cloud/triggers/vertex_ai.py +8 -7
- airflow/providers/google/cloud/utils/bigquery.py +5 -7
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +4 -3
- airflow/providers/google/cloud/utils/dataform.py +1 -1
- airflow/providers/google/cloud/utils/external_token_supplier.py +0 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/cloud/utils/validators.py +43 -0
- airflow/providers/google/common/auth_backend/google_openid.py +26 -9
- airflow/providers/google/common/consts.py +2 -1
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +40 -43
- airflow/providers/google/common/hooks/operation_helpers.py +78 -0
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +4 -5
- airflow/providers/google/firebase/operators/firestore.py +2 -2
- airflow/providers/google/get_provider_info.py +61 -216
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +30 -6
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/analytics_admin.py +3 -2
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +4 -5
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +7 -6
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +4 -64
- airflow/providers/google/suite/hooks/calendar.py +1 -1
- airflow/providers/google/suite/hooks/drive.py +2 -2
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +117 -72
- apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +1 -1
- apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/example_dags/example_cloud_task.py +0 -54
- airflow/providers/google/cloud/hooks/automl.py +0 -679
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1360
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -1515
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +0 -273
- apache_airflow_providers_google-14.0.0.dist-info/RECORD +0 -318
- /airflow/providers/google/cloud/{example_dags → bundles}/__init__.py +0 -0
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -22,18 +22,20 @@ from __future__ import annotations
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
23
|
from typing import TYPE_CHECKING
|
|
24
24
|
|
|
25
|
+
from google.protobuf.json_format import MessageToDict
|
|
26
|
+
|
|
25
27
|
from airflow.exceptions import AirflowException
|
|
26
28
|
from airflow.providers.google.cloud.hooks.speech_to_text import CloudSpeechToTextHook
|
|
27
29
|
from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook
|
|
28
30
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
29
31
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
|
30
32
|
from airflow.providers.google.common.links.storage import FileDetailsLink
|
|
31
|
-
from google.protobuf.json_format import MessageToDict
|
|
32
33
|
|
|
33
34
|
if TYPE_CHECKING:
|
|
34
|
-
from airflow.utils.context import Context
|
|
35
35
|
from google.cloud.speech_v1.types import RecognitionAudio, RecognitionConfig
|
|
36
36
|
|
|
37
|
+
from airflow.providers.common.compat.sdk import Context
|
|
38
|
+
|
|
37
39
|
|
|
38
40
|
class CloudTranslateSpeechOperator(GoogleCloudBaseOperator):
|
|
39
41
|
"""
|
|
@@ -171,7 +173,6 @@ class CloudTranslateSpeechOperator(GoogleCloudBaseOperator):
|
|
|
171
173
|
if self.audio.uri:
|
|
172
174
|
FileDetailsLink.persist(
|
|
173
175
|
context=context,
|
|
174
|
-
task_instance=self,
|
|
175
176
|
# Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
|
|
176
177
|
uri=self.audio.uri[5:],
|
|
177
178
|
project_id=self.project_id or translate_hook.project_id,
|
|
@@ -21,7 +21,13 @@
|
|
|
21
21
|
from __future__ import annotations
|
|
22
22
|
|
|
23
23
|
from collections.abc import Sequence
|
|
24
|
-
from typing import TYPE_CHECKING
|
|
24
|
+
from typing import TYPE_CHECKING, Any
|
|
25
|
+
|
|
26
|
+
from google.api_core.exceptions import NotFound
|
|
27
|
+
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
28
|
+
from google.cloud.aiplatform import datasets
|
|
29
|
+
from google.cloud.aiplatform.models import Model
|
|
30
|
+
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
|
|
25
31
|
|
|
26
32
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
27
33
|
from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
|
|
@@ -32,16 +38,12 @@ from airflow.providers.google.cloud.links.vertex_ai import (
|
|
|
32
38
|
)
|
|
33
39
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
34
40
|
from airflow.providers.google.common.deprecated import deprecated
|
|
35
|
-
from google.api_core.exceptions import NotFound
|
|
36
|
-
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
37
|
-
from google.cloud.aiplatform import datasets
|
|
38
|
-
from google.cloud.aiplatform.models import Model
|
|
39
|
-
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
|
|
40
41
|
|
|
41
42
|
if TYPE_CHECKING:
|
|
42
|
-
from airflow.utils.context import Context
|
|
43
43
|
from google.api_core.retry import Retry
|
|
44
44
|
|
|
45
|
+
from airflow.providers.common.compat.sdk import Context
|
|
46
|
+
|
|
45
47
|
|
|
46
48
|
class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
47
49
|
"""The base class for operators that launch AutoML jobs on VertexAI."""
|
|
@@ -91,6 +93,13 @@ class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
|
91
93
|
self.impersonation_chain = impersonation_chain
|
|
92
94
|
self.hook: AutoMLHook | None = None
|
|
93
95
|
|
|
96
|
+
@property
|
|
97
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
98
|
+
return {
|
|
99
|
+
"region": self.region,
|
|
100
|
+
"project_id": self.project_id,
|
|
101
|
+
}
|
|
102
|
+
|
|
94
103
|
def on_kill(self) -> None:
|
|
95
104
|
"""Act as a callback called when the operator is killed; cancel any running job."""
|
|
96
105
|
if self.hook:
|
|
@@ -242,12 +251,12 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
|
242
251
|
if model:
|
|
243
252
|
result = Model.to_dict(model)
|
|
244
253
|
model_id = self.hook.extract_model_id(result)
|
|
245
|
-
|
|
246
|
-
VertexAIModelLink.persist(context=context,
|
|
254
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
255
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
247
256
|
else:
|
|
248
257
|
result = model # type: ignore
|
|
249
|
-
|
|
250
|
-
VertexAITrainingLink.persist(context=context,
|
|
258
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
259
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
251
260
|
return result
|
|
252
261
|
|
|
253
262
|
|
|
@@ -334,12 +343,12 @@ class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
|
334
343
|
if model:
|
|
335
344
|
result = Model.to_dict(model)
|
|
336
345
|
model_id = self.hook.extract_model_id(result)
|
|
337
|
-
|
|
338
|
-
VertexAIModelLink.persist(context=context,
|
|
346
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
347
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
339
348
|
else:
|
|
340
349
|
result = model # type: ignore
|
|
341
|
-
|
|
342
|
-
VertexAITrainingLink.persist(context=context,
|
|
350
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
351
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
343
352
|
return result
|
|
344
353
|
|
|
345
354
|
|
|
@@ -457,15 +466,20 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
|
457
466
|
if model:
|
|
458
467
|
result = Model.to_dict(model)
|
|
459
468
|
model_id = self.hook.extract_model_id(result)
|
|
460
|
-
|
|
461
|
-
VertexAIModelLink.persist(context=context,
|
|
469
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
470
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
462
471
|
else:
|
|
463
472
|
result = model # type: ignore
|
|
464
|
-
|
|
465
|
-
VertexAITrainingLink.persist(context=context,
|
|
473
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
474
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
466
475
|
return result
|
|
467
476
|
|
|
468
477
|
|
|
478
|
+
@deprecated(
|
|
479
|
+
planned_removal_date="March 24, 2026",
|
|
480
|
+
use_instead="airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator",
|
|
481
|
+
category=AirflowProviderDeprecationWarning,
|
|
482
|
+
)
|
|
469
483
|
class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
470
484
|
"""Create Auto ML Video Training job."""
|
|
471
485
|
|
|
@@ -531,12 +545,12 @@ class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
|
531
545
|
if model:
|
|
532
546
|
result = Model.to_dict(model)
|
|
533
547
|
model_id = self.hook.extract_model_id(result)
|
|
534
|
-
|
|
535
|
-
VertexAIModelLink.persist(context=context,
|
|
548
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
549
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
536
550
|
else:
|
|
537
551
|
result = model # type: ignore
|
|
538
|
-
|
|
539
|
-
VertexAITrainingLink.persist(context=context,
|
|
552
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
553
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
540
554
|
return result
|
|
541
555
|
|
|
542
556
|
|
|
@@ -573,16 +587,6 @@ class DeleteAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
|
|
|
573
587
|
self.gcp_conn_id = gcp_conn_id
|
|
574
588
|
self.impersonation_chain = impersonation_chain
|
|
575
589
|
|
|
576
|
-
@property
|
|
577
|
-
@deprecated(
|
|
578
|
-
planned_removal_date="March 01, 2025",
|
|
579
|
-
use_instead="training_pipeline_id",
|
|
580
|
-
category=AirflowProviderDeprecationWarning,
|
|
581
|
-
)
|
|
582
|
-
def training_pipeline(self):
|
|
583
|
-
"""Alias for ``training_pipeline_id``, used for compatibility (deprecated)."""
|
|
584
|
-
return self.training_pipeline_id
|
|
585
|
-
|
|
586
590
|
def execute(self, context: Context):
|
|
587
591
|
hook = AutoMLHook(
|
|
588
592
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -650,6 +654,12 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
|
|
|
650
654
|
self.gcp_conn_id = gcp_conn_id
|
|
651
655
|
self.impersonation_chain = impersonation_chain
|
|
652
656
|
|
|
657
|
+
@property
|
|
658
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
659
|
+
return {
|
|
660
|
+
"project_id": self.project_id,
|
|
661
|
+
}
|
|
662
|
+
|
|
653
663
|
def execute(self, context: Context):
|
|
654
664
|
hook = AutoMLHook(
|
|
655
665
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -666,5 +676,5 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
|
|
|
666
676
|
timeout=self.timeout,
|
|
667
677
|
metadata=self.metadata,
|
|
668
678
|
)
|
|
669
|
-
VertexAITrainingPipelinesLink.persist(context=context
|
|
679
|
+
VertexAITrainingPipelinesLink.persist(context=context)
|
|
670
680
|
return [TrainingPipeline.to_dict(result) for result in results]
|
|
@@ -24,6 +24,10 @@ from collections.abc import Sequence
|
|
|
24
24
|
from functools import cached_property
|
|
25
25
|
from typing import TYPE_CHECKING, Any
|
|
26
26
|
|
|
27
|
+
from google.api_core.exceptions import NotFound
|
|
28
|
+
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
29
|
+
from google.cloud.aiplatform_v1.types import BatchPredictionJob
|
|
30
|
+
|
|
27
31
|
from airflow.configuration import conf
|
|
28
32
|
from airflow.exceptions import AirflowException
|
|
29
33
|
from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobHook
|
|
@@ -33,15 +37,13 @@ from airflow.providers.google.cloud.links.vertex_ai import (
|
|
|
33
37
|
)
|
|
34
38
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
35
39
|
from airflow.providers.google.cloud.triggers.vertex_ai import CreateBatchPredictionJobTrigger
|
|
36
|
-
from google.api_core.exceptions import NotFound
|
|
37
|
-
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
38
|
-
from google.cloud.aiplatform_v1.types import BatchPredictionJob
|
|
39
40
|
|
|
40
41
|
if TYPE_CHECKING:
|
|
41
|
-
from airflow.utils.context import Context
|
|
42
42
|
from google.api_core.retry import Retry
|
|
43
43
|
from google.cloud.aiplatform import BatchPredictionJob as BatchPredictionJobObject, Model, explain
|
|
44
44
|
|
|
45
|
+
from airflow.providers.common.compat.sdk import Context
|
|
46
|
+
|
|
45
47
|
|
|
46
48
|
class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
47
49
|
"""
|
|
@@ -229,6 +231,13 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
|
229
231
|
impersonation_chain=self.impersonation_chain,
|
|
230
232
|
)
|
|
231
233
|
|
|
234
|
+
@property
|
|
235
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
236
|
+
return {
|
|
237
|
+
"region": self.region,
|
|
238
|
+
"project_id": self.project_id,
|
|
239
|
+
}
|
|
240
|
+
|
|
232
241
|
def execute(self, context: Context):
|
|
233
242
|
self.log.info("Creating Batch prediction job")
|
|
234
243
|
batch_prediction_job: BatchPredictionJobObject = self.hook.submit_batch_prediction_job(
|
|
@@ -260,9 +269,10 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
|
260
269
|
batch_prediction_job_id = batch_prediction_job.name
|
|
261
270
|
self.log.info("Batch prediction job was created. Job id: %s", batch_prediction_job_id)
|
|
262
271
|
|
|
263
|
-
|
|
272
|
+
context["ti"].xcom_push(key="batch_prediction_job_id", value=batch_prediction_job_id)
|
|
264
273
|
VertexAIBatchPredictionJobLink.persist(
|
|
265
|
-
context=context,
|
|
274
|
+
context=context,
|
|
275
|
+
batch_prediction_job_id=batch_prediction_job_id,
|
|
266
276
|
)
|
|
267
277
|
|
|
268
278
|
if self.deferrable:
|
|
@@ -293,13 +303,11 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
|
293
303
|
job: dict[str, Any] = event["job"]
|
|
294
304
|
self.log.info("Batch prediction job %s created and completed successfully.", job["name"])
|
|
295
305
|
job_id = self.hook.extract_batch_prediction_job_id(job)
|
|
296
|
-
|
|
297
|
-
context,
|
|
306
|
+
context["ti"].xcom_push(
|
|
298
307
|
key="batch_prediction_job_id",
|
|
299
308
|
value=job_id,
|
|
300
309
|
)
|
|
301
|
-
|
|
302
|
-
context,
|
|
310
|
+
context["ti"].xcom_push(
|
|
303
311
|
key="training_conf",
|
|
304
312
|
value={
|
|
305
313
|
"training_conf_id": job_id,
|
|
@@ -425,6 +433,13 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
|
425
433
|
self.gcp_conn_id = gcp_conn_id
|
|
426
434
|
self.impersonation_chain = impersonation_chain
|
|
427
435
|
|
|
436
|
+
@property
|
|
437
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
438
|
+
return {
|
|
439
|
+
"region": self.region,
|
|
440
|
+
"project_id": self.project_id,
|
|
441
|
+
}
|
|
442
|
+
|
|
428
443
|
def execute(self, context: Context):
|
|
429
444
|
hook = BatchPredictionJobHook(
|
|
430
445
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -443,7 +458,8 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
|
443
458
|
)
|
|
444
459
|
self.log.info("Batch prediction job was gotten.")
|
|
445
460
|
VertexAIBatchPredictionJobLink.persist(
|
|
446
|
-
context=context,
|
|
461
|
+
context=context,
|
|
462
|
+
batch_prediction_job_id=self.batch_prediction_job,
|
|
447
463
|
)
|
|
448
464
|
return BatchPredictionJob.to_dict(result)
|
|
449
465
|
except NotFound:
|
|
@@ -515,6 +531,12 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
|
|
|
515
531
|
self.gcp_conn_id = gcp_conn_id
|
|
516
532
|
self.impersonation_chain = impersonation_chain
|
|
517
533
|
|
|
534
|
+
@property
|
|
535
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
536
|
+
return {
|
|
537
|
+
"project_id": self.project_id,
|
|
538
|
+
}
|
|
539
|
+
|
|
518
540
|
def execute(self, context: Context):
|
|
519
541
|
hook = BatchPredictionJobHook(
|
|
520
542
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -531,5 +553,5 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
|
|
|
531
553
|
timeout=self.timeout,
|
|
532
554
|
metadata=self.metadata,
|
|
533
555
|
)
|
|
534
|
-
VertexAIBatchPredictionJobListLink.persist(context=context
|
|
556
|
+
VertexAIBatchPredictionJobListLink.persist(context=context)
|
|
535
557
|
return [BatchPredictionJob.to_dict(result) for result in results]
|
|
@@ -23,8 +23,14 @@ from collections.abc import Sequence
|
|
|
23
23
|
from functools import cached_property
|
|
24
24
|
from typing import TYPE_CHECKING, Any
|
|
25
25
|
|
|
26
|
+
from google.api_core.exceptions import NotFound
|
|
27
|
+
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
28
|
+
from google.cloud.aiplatform.models import Model
|
|
29
|
+
from google.cloud.aiplatform_v1.types.dataset import Dataset
|
|
30
|
+
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
|
|
31
|
+
|
|
26
32
|
from airflow.configuration import conf
|
|
27
|
-
from airflow.exceptions import AirflowException
|
|
33
|
+
from airflow.exceptions import AirflowException
|
|
28
34
|
from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
|
|
29
35
|
from airflow.providers.google.cloud.links.vertex_ai import (
|
|
30
36
|
VertexAIModelLink,
|
|
@@ -37,21 +43,17 @@ from airflow.providers.google.cloud.triggers.vertex_ai import (
|
|
|
37
43
|
CustomPythonPackageTrainingJobTrigger,
|
|
38
44
|
CustomTrainingJobTrigger,
|
|
39
45
|
)
|
|
40
|
-
from airflow.providers.google.common.deprecated import deprecated
|
|
41
|
-
from google.api_core.exceptions import NotFound
|
|
42
|
-
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
43
|
-
from google.cloud.aiplatform.models import Model
|
|
44
|
-
from google.cloud.aiplatform_v1.types.dataset import Dataset
|
|
45
|
-
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
|
|
46
46
|
|
|
47
47
|
if TYPE_CHECKING:
|
|
48
|
-
from airflow.utils.context import Context
|
|
49
48
|
from google.api_core.retry import Retry
|
|
50
49
|
from google.cloud.aiplatform import (
|
|
51
50
|
CustomContainerTrainingJob,
|
|
52
51
|
CustomPythonPackageTrainingJob,
|
|
53
52
|
CustomTrainingJob,
|
|
54
53
|
)
|
|
54
|
+
from google.cloud.aiplatform_v1.types import PscInterfaceConfig
|
|
55
|
+
|
|
56
|
+
from airflow.providers.common.compat.sdk import Context
|
|
55
57
|
|
|
56
58
|
|
|
57
59
|
class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
@@ -109,6 +111,7 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
|
109
111
|
predefined_split_column_name: str | None = None,
|
|
110
112
|
timestamp_split_column_name: str | None = None,
|
|
111
113
|
tensorboard: str | None = None,
|
|
114
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
112
115
|
gcp_conn_id: str = "google_cloud_default",
|
|
113
116
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
114
117
|
**kwargs,
|
|
@@ -165,21 +168,29 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
|
165
168
|
self.predefined_split_column_name = predefined_split_column_name
|
|
166
169
|
self.timestamp_split_column_name = timestamp_split_column_name
|
|
167
170
|
self.tensorboard = tensorboard
|
|
171
|
+
self.psc_interface_config = psc_interface_config
|
|
168
172
|
# END Run param
|
|
169
173
|
self.gcp_conn_id = gcp_conn_id
|
|
170
174
|
self.impersonation_chain = impersonation_chain
|
|
171
175
|
|
|
176
|
+
@property
|
|
177
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
178
|
+
return {
|
|
179
|
+
"region": self.region,
|
|
180
|
+
"project_id": self.project_id,
|
|
181
|
+
}
|
|
182
|
+
|
|
172
183
|
def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
|
|
173
184
|
if event["status"] == "error":
|
|
174
185
|
raise AirflowException(event["message"])
|
|
175
186
|
training_pipeline = event["job"]
|
|
176
187
|
custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(training_pipeline)
|
|
177
|
-
|
|
188
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
|
178
189
|
try:
|
|
179
190
|
model = training_pipeline["model_to_upload"]
|
|
180
191
|
model_id = self.hook.extract_model_id(model)
|
|
181
|
-
|
|
182
|
-
VertexAIModelLink.persist(context=context,
|
|
192
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
193
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
183
194
|
return model
|
|
184
195
|
except KeyError:
|
|
185
196
|
self.log.warning(
|
|
@@ -465,6 +476,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
465
476
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
466
477
|
For more information on configuring your service account please visit:
|
|
467
478
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
479
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
480
|
+
training.
|
|
468
481
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
469
482
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
470
483
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -578,18 +591,19 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
578
591
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
579
592
|
tensorboard=self.tensorboard,
|
|
580
593
|
sync=True,
|
|
594
|
+
psc_interface_config=self.psc_interface_config,
|
|
581
595
|
)
|
|
582
596
|
|
|
583
597
|
if model:
|
|
584
598
|
result = Model.to_dict(model)
|
|
585
599
|
model_id = self.hook.extract_model_id(result)
|
|
586
|
-
|
|
587
|
-
VertexAIModelLink.persist(context=context,
|
|
600
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
601
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
588
602
|
else:
|
|
589
603
|
result = model # type: ignore
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
VertexAITrainingLink.persist(context=context,
|
|
604
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
605
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
|
606
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
593
607
|
return result
|
|
594
608
|
|
|
595
609
|
def invoke_defer(self, context: Context) -> None:
|
|
@@ -644,11 +658,12 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
644
658
|
predefined_split_column_name=self.predefined_split_column_name,
|
|
645
659
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
646
660
|
tensorboard=self.tensorboard,
|
|
661
|
+
psc_interface_config=self.psc_interface_config,
|
|
647
662
|
)
|
|
648
663
|
custom_container_training_job_obj.wait_for_resource_creation()
|
|
649
664
|
training_pipeline_id: str = custom_container_training_job_obj.name
|
|
650
|
-
|
|
651
|
-
VertexAITrainingLink.persist(context=context,
|
|
665
|
+
context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
|
|
666
|
+
VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
|
|
652
667
|
self.defer(
|
|
653
668
|
trigger=CustomContainerTrainingJobTrigger(
|
|
654
669
|
conn_id=self.gcp_conn_id,
|
|
@@ -923,6 +938,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
|
923
938
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
924
939
|
For more information on configuring your service account please visit:
|
|
925
940
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
941
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
942
|
+
training.
|
|
926
943
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
927
944
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
928
945
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -1035,18 +1052,19 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
|
1035
1052
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
1036
1053
|
tensorboard=self.tensorboard,
|
|
1037
1054
|
sync=True,
|
|
1055
|
+
psc_interface_config=self.psc_interface_config,
|
|
1038
1056
|
)
|
|
1039
1057
|
|
|
1040
1058
|
if model:
|
|
1041
1059
|
result = Model.to_dict(model)
|
|
1042
1060
|
model_id = self.hook.extract_model_id(result)
|
|
1043
|
-
|
|
1044
|
-
VertexAIModelLink.persist(context=context,
|
|
1061
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
1062
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
1045
1063
|
else:
|
|
1046
1064
|
result = model # type: ignore
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
VertexAITrainingLink.persist(context=context,
|
|
1065
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
1066
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
|
1067
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
1050
1068
|
return result
|
|
1051
1069
|
|
|
1052
1070
|
def invoke_defer(self, context: Context) -> None:
|
|
@@ -1102,11 +1120,12 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
|
1102
1120
|
predefined_split_column_name=self.predefined_split_column_name,
|
|
1103
1121
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
1104
1122
|
tensorboard=self.tensorboard,
|
|
1123
|
+
psc_interface_config=self.psc_interface_config,
|
|
1105
1124
|
)
|
|
1106
1125
|
custom_python_training_job_obj.wait_for_resource_creation()
|
|
1107
1126
|
training_pipeline_id: str = custom_python_training_job_obj.name
|
|
1108
|
-
|
|
1109
|
-
VertexAITrainingLink.persist(context=context,
|
|
1127
|
+
context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
|
|
1128
|
+
VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
|
|
1110
1129
|
self.defer(
|
|
1111
1130
|
trigger=CustomPythonPackageTrainingJobTrigger(
|
|
1112
1131
|
conn_id=self.gcp_conn_id,
|
|
@@ -1381,6 +1400,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
1381
1400
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
1382
1401
|
For more information on configuring your service account please visit:
|
|
1383
1402
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
1403
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
1404
|
+
training.
|
|
1384
1405
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
1385
1406
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
1386
1407
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -1498,18 +1519,19 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
1498
1519
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
1499
1520
|
tensorboard=self.tensorboard,
|
|
1500
1521
|
sync=True,
|
|
1522
|
+
psc_interface_config=None,
|
|
1501
1523
|
)
|
|
1502
1524
|
|
|
1503
1525
|
if model:
|
|
1504
1526
|
result = Model.to_dict(model)
|
|
1505
1527
|
model_id = self.hook.extract_model_id(result)
|
|
1506
|
-
|
|
1507
|
-
VertexAIModelLink.persist(context=context,
|
|
1528
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
1529
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
1508
1530
|
else:
|
|
1509
1531
|
result = model # type: ignore
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
VertexAITrainingLink.persist(context=context,
|
|
1532
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
|
1533
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
|
1534
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
|
1513
1535
|
return result
|
|
1514
1536
|
|
|
1515
1537
|
def invoke_defer(self, context: Context) -> None:
|
|
@@ -1565,11 +1587,12 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
|
1565
1587
|
predefined_split_column_name=self.predefined_split_column_name,
|
|
1566
1588
|
timestamp_split_column_name=self.timestamp_split_column_name,
|
|
1567
1589
|
tensorboard=self.tensorboard,
|
|
1590
|
+
psc_interface_config=self.psc_interface_config,
|
|
1568
1591
|
)
|
|
1569
1592
|
custom_training_job_obj.wait_for_resource_creation()
|
|
1570
1593
|
training_pipeline_id: str = custom_training_job_obj.name
|
|
1571
|
-
|
|
1572
|
-
VertexAITrainingLink.persist(context=context,
|
|
1594
|
+
context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
|
|
1595
|
+
VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
|
|
1573
1596
|
self.defer(
|
|
1574
1597
|
trigger=CustomTrainingJobTrigger(
|
|
1575
1598
|
conn_id=self.gcp_conn_id,
|
|
@@ -1632,26 +1655,6 @@ class DeleteCustomTrainingJobOperator(GoogleCloudBaseOperator):
|
|
|
1632
1655
|
self.gcp_conn_id = gcp_conn_id
|
|
1633
1656
|
self.impersonation_chain = impersonation_chain
|
|
1634
1657
|
|
|
1635
|
-
@property
|
|
1636
|
-
@deprecated(
|
|
1637
|
-
planned_removal_date="March 01, 2025",
|
|
1638
|
-
use_instead="training_pipeline_id",
|
|
1639
|
-
category=AirflowProviderDeprecationWarning,
|
|
1640
|
-
)
|
|
1641
|
-
def training_pipeline(self):
|
|
1642
|
-
"""Alias for ``training_pipeline_id``, used for compatibility (deprecated)."""
|
|
1643
|
-
return self.training_pipeline_id
|
|
1644
|
-
|
|
1645
|
-
@property
|
|
1646
|
-
@deprecated(
|
|
1647
|
-
planned_removal_date="March 01, 2025",
|
|
1648
|
-
use_instead="custom_job_id",
|
|
1649
|
-
category=AirflowProviderDeprecationWarning,
|
|
1650
|
-
)
|
|
1651
|
-
def custom_job(self):
|
|
1652
|
-
"""Alias for ``custom_job_id``, used for compatibility (deprecated)."""
|
|
1653
|
-
return self.custom_job_id
|
|
1654
|
-
|
|
1655
1658
|
def execute(self, context: Context):
|
|
1656
1659
|
hook = CustomJobHook(
|
|
1657
1660
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -1767,6 +1770,12 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
|
|
|
1767
1770
|
self.gcp_conn_id = gcp_conn_id
|
|
1768
1771
|
self.impersonation_chain = impersonation_chain
|
|
1769
1772
|
|
|
1773
|
+
@property
|
|
1774
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
1775
|
+
return {
|
|
1776
|
+
"project_id": self.project_id,
|
|
1777
|
+
}
|
|
1778
|
+
|
|
1770
1779
|
def execute(self, context: Context):
|
|
1771
1780
|
hook = CustomJobHook(
|
|
1772
1781
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -1783,5 +1792,5 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
|
|
|
1783
1792
|
timeout=self.timeout,
|
|
1784
1793
|
metadata=self.metadata,
|
|
1785
1794
|
)
|
|
1786
|
-
VertexAITrainingPipelinesLink.persist(context=context
|
|
1795
|
+
VertexAITrainingPipelinesLink.persist(context=context)
|
|
1787
1796
|
return [TrainingPipeline.to_dict(result) for result in results]
|