apache-airflow-providers-google 16.0.0a1__py3-none-any.whl → 16.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +43 -5
- airflow/providers/google/ads/operators/ads.py +1 -1
- airflow/providers/google/ads/transfers/ads_to_gcs.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +63 -77
- airflow/providers/google/cloud/hooks/cloud_sql.py +8 -4
- airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
- airflow/providers/google/cloud/hooks/dataflow.py +2 -2
- airflow/providers/google/cloud/hooks/dataplex.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +4 -1
- airflow/providers/google/cloud/hooks/gcs.py +5 -5
- airflow/providers/google/cloud/hooks/looker.py +10 -1
- airflow/providers/google/cloud/hooks/mlengine.py +2 -1
- airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
- airflow/providers/google/cloud/hooks/spanner.py +2 -2
- airflow/providers/google/cloud/hooks/translate.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -36
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +44 -80
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +11 -2
- airflow/providers/google/cloud/hooks/vision.py +2 -2
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +75 -11
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/cloud_run.py +27 -0
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +16 -43
- airflow/providers/google/cloud/links/cloud_tasks.py +6 -25
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -96
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +14 -90
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +5 -59
- airflow/providers/google/cloud/links/life_sciences.py +0 -19
- airflow/providers/google/cloud/links/managed_kafka.py +0 -70
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +16 -186
- airflow/providers/google/cloud/links/vertex_ai.py +8 -224
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +4 -4
- airflow/providers/google/cloud/operators/alloy_db.py +69 -54
- airflow/providers/google/cloud/operators/automl.py +16 -14
- airflow/providers/google/cloud/operators/bigquery.py +49 -25
- airflow/providers/google/cloud/operators/bigquery_dts.py +2 -4
- airflow/providers/google/cloud/operators/bigtable.py +35 -6
- airflow/providers/google/cloud/operators/cloud_base.py +21 -1
- airflow/providers/google/cloud/operators/cloud_build.py +74 -31
- airflow/providers/google/cloud/operators/cloud_composer.py +34 -35
- airflow/providers/google/cloud/operators/cloud_memorystore.py +68 -42
- airflow/providers/google/cloud/operators/cloud_run.py +9 -1
- airflow/providers/google/cloud/operators/cloud_sql.py +11 -15
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +0 -2
- airflow/providers/google/cloud/operators/compute.py +7 -39
- airflow/providers/google/cloud/operators/datacatalog.py +156 -20
- airflow/providers/google/cloud/operators/dataflow.py +37 -14
- airflow/providers/google/cloud/operators/dataform.py +14 -4
- airflow/providers/google/cloud/operators/datafusion.py +4 -12
- airflow/providers/google/cloud/operators/dataplex.py +180 -96
- airflow/providers/google/cloud/operators/dataprep.py +0 -4
- airflow/providers/google/cloud/operators/dataproc.py +10 -16
- airflow/providers/google/cloud/operators/dataproc_metastore.py +95 -87
- airflow/providers/google/cloud/operators/datastore.py +21 -5
- airflow/providers/google/cloud/operators/dlp.py +3 -26
- airflow/providers/google/cloud/operators/functions.py +15 -6
- airflow/providers/google/cloud/operators/gcs.py +1 -7
- airflow/providers/google/cloud/operators/kubernetes_engine.py +53 -92
- airflow/providers/google/cloud/operators/life_sciences.py +0 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +106 -51
- airflow/providers/google/cloud/operators/mlengine.py +0 -1
- airflow/providers/google/cloud/operators/pubsub.py +4 -5
- airflow/providers/google/cloud/operators/spanner.py +0 -4
- airflow/providers/google/cloud/operators/speech_to_text.py +0 -1
- airflow/providers/google/cloud/operators/stackdriver.py +0 -8
- airflow/providers/google/cloud/operators/tasks.py +0 -11
- airflow/providers/google/cloud/operators/text_to_speech.py +0 -1
- airflow/providers/google/cloud/operators/translate.py +37 -13
- airflow/providers/google/cloud/operators/translate_speech.py +0 -1
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +31 -18
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +28 -8
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +38 -25
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +69 -7
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +42 -8
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +531 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +93 -117
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +10 -8
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +56 -10
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +29 -6
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +9 -6
- airflow/providers/google/cloud/operators/workflows.py +1 -9
- airflow/providers/google/cloud/sensors/bigquery.py +1 -1
- airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -1
- airflow/providers/google/cloud/sensors/bigtable.py +15 -3
- airflow/providers/google/cloud/sensors/cloud_composer.py +6 -1
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -1
- airflow/providers/google/cloud/sensors/dataflow.py +3 -3
- airflow/providers/google/cloud/sensors/dataform.py +6 -1
- airflow/providers/google/cloud/sensors/datafusion.py +6 -1
- airflow/providers/google/cloud/sensors/dataplex.py +6 -1
- airflow/providers/google/cloud/sensors/dataprep.py +6 -1
- airflow/providers/google/cloud/sensors/dataproc.py +6 -1
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +6 -1
- airflow/providers/google/cloud/sensors/gcs.py +9 -3
- airflow/providers/google/cloud/sensors/looker.py +6 -1
- airflow/providers/google/cloud/sensors/pubsub.py +8 -3
- airflow/providers/google/cloud/sensors/tasks.py +6 -1
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +6 -1
- airflow/providers/google/cloud/sensors/workflows.py +6 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +10 -7
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -2
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +0 -1
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -1
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/gcs_to_local.py +1 -1
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +5 -1
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +1 -1
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +11 -5
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +1 -1
- airflow/providers/google/cloud/triggers/bigquery.py +32 -5
- airflow/providers/google/cloud/triggers/dataproc.py +62 -10
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/common/auth_backend/google_openid.py +2 -1
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +7 -3
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/firebase/operators/firestore.py +1 -1
- airflow/providers/google/get_provider_info.py +14 -16
- airflow/providers/google/leveldb/hooks/leveldb.py +30 -1
- airflow/providers/google/leveldb/operators/leveldb.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +3 -6
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +0 -1
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +4 -4
- airflow/providers/google/marketing_platform/operators/display_video.py +6 -6
- airflow/providers/google/marketing_platform/operators/search_ads.py +1 -1
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +6 -1
- airflow/providers/google/marketing_platform/sensors/display_video.py +6 -1
- airflow/providers/google/suite/operators/sheets.py +3 -3
- airflow/providers/google/suite/sensors/drive.py +6 -1
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +1 -1
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +1 -1
- airflow/providers/google/version_compat.py +28 -0
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0.dist-info}/METADATA +35 -35
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0.dist-info}/RECORD +171 -170
- airflow/providers/google/cloud/links/automl.py +0 -193
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0.dist-info}/entry_points.txt +0 -0
@@ -37,6 +37,7 @@ from airflow.providers.google.cloud.links.translate import (
|
|
37
37
|
TranslationNativeDatasetLink,
|
38
38
|
)
|
39
39
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
40
|
+
from airflow.providers.google.cloud.operators.vertex_ai.dataset import DatasetImportDataResultsCheckHelper
|
40
41
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
41
42
|
|
42
43
|
if TYPE_CHECKING:
|
@@ -394,7 +395,6 @@ class TranslateTextBatchOperator(GoogleCloudBaseOperator):
|
|
394
395
|
self.log.info("Translate text batch job started.")
|
395
396
|
TranslateTextBatchLink.persist(
|
396
397
|
context=context,
|
397
|
-
task_instance=self,
|
398
398
|
project_id=self.project_id or hook.project_id,
|
399
399
|
output_config=self.output_config,
|
400
400
|
)
|
@@ -480,15 +480,15 @@ class TranslateCreateDatasetOperator(GoogleCloudBaseOperator):
|
|
480
480
|
result = hook.wait_for_operation_result(result_operation)
|
481
481
|
result = type(result).to_dict(result)
|
482
482
|
dataset_id = hook.extract_object_id(result)
|
483
|
-
|
483
|
+
context["ti"].xcom_push(key="dataset_id", value=dataset_id)
|
484
484
|
self.log.info("Dataset creation complete. The dataset_id: %s.", dataset_id)
|
485
485
|
|
486
486
|
project_id = self.project_id or hook.project_id
|
487
487
|
TranslationNativeDatasetLink.persist(
|
488
488
|
context=context,
|
489
|
-
task_instance=self,
|
490
489
|
dataset_id=dataset_id,
|
491
490
|
project_id=project_id,
|
491
|
+
location=self.location,
|
492
492
|
)
|
493
493
|
return result
|
494
494
|
|
@@ -556,7 +556,6 @@ class TranslateDatasetsListOperator(GoogleCloudBaseOperator):
|
|
556
556
|
project_id = self.project_id or hook.project_id
|
557
557
|
TranslationDatasetsListLink.persist(
|
558
558
|
context=context,
|
559
|
-
task_instance=self,
|
560
559
|
project_id=project_id,
|
561
560
|
)
|
562
561
|
self.log.info("Requesting datasets list")
|
@@ -577,7 +576,7 @@ class TranslateDatasetsListOperator(GoogleCloudBaseOperator):
|
|
577
576
|
return result_ids
|
578
577
|
|
579
578
|
|
580
|
-
class TranslateImportDataOperator(GoogleCloudBaseOperator):
|
579
|
+
class TranslateImportDataOperator(GoogleCloudBaseOperator, DatasetImportDataResultsCheckHelper):
|
581
580
|
"""
|
582
581
|
Import data to the translation dataset.
|
583
582
|
|
@@ -604,6 +603,7 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
|
|
604
603
|
If set as a sequence, the identities from the list must grant
|
605
604
|
Service Account Token Creator IAM role to the directly preceding identity, with first
|
606
605
|
account from the list granting this role to the originating account (templated).
|
606
|
+
:param raise_for_empty_result: Raise an error if no additional data has been populated after the import.
|
607
607
|
"""
|
608
608
|
|
609
609
|
template_fields: Sequence[str] = (
|
@@ -629,6 +629,7 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
|
|
629
629
|
retry: Retry | _MethodDefault = DEFAULT,
|
630
630
|
gcp_conn_id: str = "google_cloud_default",
|
631
631
|
impersonation_chain: str | Sequence[str] | None = None,
|
632
|
+
raise_for_empty_result: bool = False,
|
632
633
|
**kwargs,
|
633
634
|
) -> None:
|
634
635
|
super().__init__(**kwargs)
|
@@ -641,9 +642,21 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
|
|
641
642
|
self.retry = retry
|
642
643
|
self.gcp_conn_id = gcp_conn_id
|
643
644
|
self.impersonation_chain = impersonation_chain
|
645
|
+
self.raise_for_empty_result = raise_for_empty_result
|
644
646
|
|
645
647
|
def execute(self, context: Context):
|
646
648
|
hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
|
649
|
+
initial_dataset_size = self._get_number_of_ds_items(
|
650
|
+
dataset=hook.get_dataset(
|
651
|
+
dataset_id=self.dataset_id,
|
652
|
+
project_id=self.project_id,
|
653
|
+
location=self.location,
|
654
|
+
retry=self.retry,
|
655
|
+
timeout=self.timeout,
|
656
|
+
metadata=self.metadata,
|
657
|
+
),
|
658
|
+
total_key_name="example_count",
|
659
|
+
)
|
647
660
|
self.log.info("Importing data to dataset...")
|
648
661
|
operation = hook.import_dataset_data(
|
649
662
|
dataset_id=self.dataset_id,
|
@@ -657,12 +670,27 @@ class TranslateImportDataOperator(GoogleCloudBaseOperator):
|
|
657
670
|
project_id = self.project_id or hook.project_id
|
658
671
|
TranslationNativeDatasetLink.persist(
|
659
672
|
context=context,
|
660
|
-
task_instance=self,
|
661
673
|
dataset_id=self.dataset_id,
|
662
674
|
project_id=project_id,
|
675
|
+
location=self.location,
|
663
676
|
)
|
664
677
|
hook.wait_for_operation_done(operation=operation, timeout=self.timeout)
|
678
|
+
|
679
|
+
result_dataset_size = self._get_number_of_ds_items(
|
680
|
+
dataset=hook.get_dataset(
|
681
|
+
dataset_id=self.dataset_id,
|
682
|
+
project_id=self.project_id,
|
683
|
+
location=self.location,
|
684
|
+
retry=self.retry,
|
685
|
+
timeout=self.timeout,
|
686
|
+
metadata=self.metadata,
|
687
|
+
),
|
688
|
+
total_key_name="example_count",
|
689
|
+
)
|
690
|
+
if self.raise_for_empty_result:
|
691
|
+
self._raise_for_empty_import_result(self.dataset_id, initial_dataset_size, result_dataset_size)
|
665
692
|
self.log.info("Importing data finished!")
|
693
|
+
return {"total_imported": int(result_dataset_size) - int(initial_dataset_size)}
|
666
694
|
|
667
695
|
|
668
696
|
class TranslateDeleteDatasetOperator(GoogleCloudBaseOperator):
|
@@ -821,16 +849,16 @@ class TranslateCreateModelOperator(GoogleCloudBaseOperator):
|
|
821
849
|
result = hook.wait_for_operation_result(operation=result_operation)
|
822
850
|
result = type(result).to_dict(result)
|
823
851
|
model_id = hook.extract_object_id(result)
|
824
|
-
|
852
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
825
853
|
self.log.info("Model creation complete. The model_id: %s.", model_id)
|
826
854
|
|
827
855
|
project_id = self.project_id or hook.project_id
|
828
856
|
TranslationModelLink.persist(
|
829
857
|
context=context,
|
830
|
-
task_instance=self,
|
831
858
|
dataset_id=self.dataset_id,
|
832
859
|
model_id=model_id,
|
833
860
|
project_id=project_id,
|
861
|
+
location=self.location,
|
834
862
|
)
|
835
863
|
return result
|
836
864
|
|
@@ -898,7 +926,6 @@ class TranslateModelsListOperator(GoogleCloudBaseOperator):
|
|
898
926
|
project_id = self.project_id or hook.project_id
|
899
927
|
TranslationModelsListLink.persist(
|
900
928
|
context=context,
|
901
|
-
task_instance=self,
|
902
929
|
project_id=project_id,
|
903
930
|
)
|
904
931
|
self.log.info("Requesting models list")
|
@@ -1141,7 +1168,6 @@ class TranslateDocumentOperator(GoogleCloudBaseOperator):
|
|
1141
1168
|
if self.document_output_config:
|
1142
1169
|
TranslateResultByOutputConfigLink.persist(
|
1143
1170
|
context=context,
|
1144
|
-
task_instance=self,
|
1145
1171
|
project_id=self.project_id or hook.project_id,
|
1146
1172
|
output_config=self.document_output_config,
|
1147
1173
|
)
|
@@ -1304,7 +1330,6 @@ class TranslateDocumentBatchOperator(GoogleCloudBaseOperator):
|
|
1304
1330
|
self.log.info("Batch document translation job started.")
|
1305
1331
|
TranslateResultByOutputConfigLink.persist(
|
1306
1332
|
context=context,
|
1307
|
-
task_instance=self,
|
1308
1333
|
project_id=self.project_id or hook.project_id,
|
1309
1334
|
output_config=self.output_config,
|
1310
1335
|
)
|
@@ -1411,7 +1436,7 @@ class TranslateCreateGlossaryOperator(GoogleCloudBaseOperator):
|
|
1411
1436
|
result = type(result).to_dict(result)
|
1412
1437
|
|
1413
1438
|
glossary_id = hook.extract_object_id(result)
|
1414
|
-
|
1439
|
+
context["ti"].xcom_push(key="glossary_id", value=glossary_id)
|
1415
1440
|
self.log.info("Glossary creation complete. The glossary_id: %s.", glossary_id)
|
1416
1441
|
return result
|
1417
1442
|
|
@@ -1610,7 +1635,6 @@ class TranslateListGlossariesOperator(GoogleCloudBaseOperator):
|
|
1610
1635
|
project_id = self.project_id or hook.project_id
|
1611
1636
|
TranslationGlossariesListLink.persist(
|
1612
1637
|
context=context,
|
1613
|
-
task_instance=self,
|
1614
1638
|
project_id=project_id,
|
1615
1639
|
)
|
1616
1640
|
self.log.info("Requesting glossaries list")
|
@@ -173,7 +173,6 @@ class CloudTranslateSpeechOperator(GoogleCloudBaseOperator):
|
|
173
173
|
if self.audio.uri:
|
174
174
|
FileDetailsLink.persist(
|
175
175
|
context=context,
|
176
|
-
task_instance=self,
|
177
176
|
# Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
|
178
177
|
uri=self.audio.uri[5:],
|
179
178
|
project_id=self.project_id or translate_hook.project_id,
|
@@ -21,7 +21,7 @@
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
23
|
from collections.abc import Sequence
|
24
|
-
from typing import TYPE_CHECKING
|
24
|
+
from typing import TYPE_CHECKING, Any
|
25
25
|
|
26
26
|
from google.api_core.exceptions import NotFound
|
27
27
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
@@ -91,6 +91,13 @@ class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
91
91
|
self.impersonation_chain = impersonation_chain
|
92
92
|
self.hook: AutoMLHook | None = None
|
93
93
|
|
94
|
+
@property
|
95
|
+
def extra_links_params(self) -> dict[str, Any]:
|
96
|
+
return {
|
97
|
+
"region": self.region,
|
98
|
+
"project_id": self.project_id,
|
99
|
+
}
|
100
|
+
|
94
101
|
def on_kill(self) -> None:
|
95
102
|
"""Act as a callback called when the operator is killed; cancel any running job."""
|
96
103
|
if self.hook:
|
@@ -242,12 +249,12 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
242
249
|
if model:
|
243
250
|
result = Model.to_dict(model)
|
244
251
|
model_id = self.hook.extract_model_id(result)
|
245
|
-
|
246
|
-
VertexAIModelLink.persist(context=context,
|
252
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
253
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
247
254
|
else:
|
248
255
|
result = model # type: ignore
|
249
|
-
|
250
|
-
VertexAITrainingLink.persist(context=context,
|
256
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
257
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
251
258
|
return result
|
252
259
|
|
253
260
|
|
@@ -334,12 +341,12 @@ class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
334
341
|
if model:
|
335
342
|
result = Model.to_dict(model)
|
336
343
|
model_id = self.hook.extract_model_id(result)
|
337
|
-
|
338
|
-
VertexAIModelLink.persist(context=context,
|
344
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
345
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
339
346
|
else:
|
340
347
|
result = model # type: ignore
|
341
|
-
|
342
|
-
VertexAITrainingLink.persist(context=context,
|
348
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
349
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
343
350
|
return result
|
344
351
|
|
345
352
|
|
@@ -457,12 +464,12 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
457
464
|
if model:
|
458
465
|
result = Model.to_dict(model)
|
459
466
|
model_id = self.hook.extract_model_id(result)
|
460
|
-
|
461
|
-
VertexAIModelLink.persist(context=context,
|
467
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
468
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
462
469
|
else:
|
463
470
|
result = model # type: ignore
|
464
|
-
|
465
|
-
VertexAITrainingLink.persist(context=context,
|
471
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
472
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
466
473
|
return result
|
467
474
|
|
468
475
|
|
@@ -531,12 +538,12 @@ class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
|
|
531
538
|
if model:
|
532
539
|
result = Model.to_dict(model)
|
533
540
|
model_id = self.hook.extract_model_id(result)
|
534
|
-
|
535
|
-
VertexAIModelLink.persist(context=context,
|
541
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
542
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
536
543
|
else:
|
537
544
|
result = model # type: ignore
|
538
|
-
|
539
|
-
VertexAITrainingLink.persist(context=context,
|
545
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
546
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
540
547
|
return result
|
541
548
|
|
542
549
|
|
@@ -640,6 +647,12 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
|
|
640
647
|
self.gcp_conn_id = gcp_conn_id
|
641
648
|
self.impersonation_chain = impersonation_chain
|
642
649
|
|
650
|
+
@property
|
651
|
+
def extra_links_params(self) -> dict[str, Any]:
|
652
|
+
return {
|
653
|
+
"project_id": self.project_id,
|
654
|
+
}
|
655
|
+
|
643
656
|
def execute(self, context: Context):
|
644
657
|
hook = AutoMLHook(
|
645
658
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -656,5 +669,5 @@ class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
|
|
656
669
|
timeout=self.timeout,
|
657
670
|
metadata=self.metadata,
|
658
671
|
)
|
659
|
-
VertexAITrainingPipelinesLink.persist(context=context
|
672
|
+
VertexAITrainingPipelinesLink.persist(context=context)
|
660
673
|
return [TrainingPipeline.to_dict(result) for result in results]
|
@@ -231,6 +231,13 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
231
231
|
impersonation_chain=self.impersonation_chain,
|
232
232
|
)
|
233
233
|
|
234
|
+
@property
|
235
|
+
def extra_links_params(self) -> dict[str, Any]:
|
236
|
+
return {
|
237
|
+
"region": self.region,
|
238
|
+
"project_id": self.project_id,
|
239
|
+
}
|
240
|
+
|
234
241
|
def execute(self, context: Context):
|
235
242
|
self.log.info("Creating Batch prediction job")
|
236
243
|
batch_prediction_job: BatchPredictionJobObject = self.hook.submit_batch_prediction_job(
|
@@ -262,9 +269,10 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
262
269
|
batch_prediction_job_id = batch_prediction_job.name
|
263
270
|
self.log.info("Batch prediction job was created. Job id: %s", batch_prediction_job_id)
|
264
271
|
|
265
|
-
|
272
|
+
context["ti"].xcom_push(key="batch_prediction_job_id", value=batch_prediction_job_id)
|
266
273
|
VertexAIBatchPredictionJobLink.persist(
|
267
|
-
context=context,
|
274
|
+
context=context,
|
275
|
+
batch_prediction_job_id=batch_prediction_job_id,
|
268
276
|
)
|
269
277
|
|
270
278
|
if self.deferrable:
|
@@ -295,13 +303,11 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
295
303
|
job: dict[str, Any] = event["job"]
|
296
304
|
self.log.info("Batch prediction job %s created and completed successfully.", job["name"])
|
297
305
|
job_id = self.hook.extract_batch_prediction_job_id(job)
|
298
|
-
|
299
|
-
context,
|
306
|
+
context["ti"].xcom_push(
|
300
307
|
key="batch_prediction_job_id",
|
301
308
|
value=job_id,
|
302
309
|
)
|
303
|
-
|
304
|
-
context,
|
310
|
+
context["ti"].xcom_push(
|
305
311
|
key="training_conf",
|
306
312
|
value={
|
307
313
|
"training_conf_id": job_id,
|
@@ -427,6 +433,13 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
427
433
|
self.gcp_conn_id = gcp_conn_id
|
428
434
|
self.impersonation_chain = impersonation_chain
|
429
435
|
|
436
|
+
@property
|
437
|
+
def extra_links_params(self) -> dict[str, Any]:
|
438
|
+
return {
|
439
|
+
"region": self.region,
|
440
|
+
"project_id": self.project_id,
|
441
|
+
}
|
442
|
+
|
430
443
|
def execute(self, context: Context):
|
431
444
|
hook = BatchPredictionJobHook(
|
432
445
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -445,7 +458,8 @@ class GetBatchPredictionJobOperator(GoogleCloudBaseOperator):
|
|
445
458
|
)
|
446
459
|
self.log.info("Batch prediction job was gotten.")
|
447
460
|
VertexAIBatchPredictionJobLink.persist(
|
448
|
-
context=context,
|
461
|
+
context=context,
|
462
|
+
batch_prediction_job_id=self.batch_prediction_job,
|
449
463
|
)
|
450
464
|
return BatchPredictionJob.to_dict(result)
|
451
465
|
except NotFound:
|
@@ -517,6 +531,12 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
|
|
517
531
|
self.gcp_conn_id = gcp_conn_id
|
518
532
|
self.impersonation_chain = impersonation_chain
|
519
533
|
|
534
|
+
@property
|
535
|
+
def extra_links_params(self) -> dict[str, Any]:
|
536
|
+
return {
|
537
|
+
"project_id": self.project_id,
|
538
|
+
}
|
539
|
+
|
520
540
|
def execute(self, context: Context):
|
521
541
|
hook = BatchPredictionJobHook(
|
522
542
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -533,5 +553,5 @@ class ListBatchPredictionJobsOperator(GoogleCloudBaseOperator):
|
|
533
553
|
timeout=self.timeout,
|
534
554
|
metadata=self.metadata,
|
535
555
|
)
|
536
|
-
VertexAIBatchPredictionJobListLink.persist(context=context
|
556
|
+
VertexAIBatchPredictionJobListLink.persist(context=context)
|
537
557
|
return [BatchPredictionJob.to_dict(result) for result in results]
|
@@ -170,17 +170,24 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
170
170
|
self.gcp_conn_id = gcp_conn_id
|
171
171
|
self.impersonation_chain = impersonation_chain
|
172
172
|
|
173
|
+
@property
|
174
|
+
def extra_links_params(self) -> dict[str, Any]:
|
175
|
+
return {
|
176
|
+
"region": self.region,
|
177
|
+
"project_id": self.project_id,
|
178
|
+
}
|
179
|
+
|
173
180
|
def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
|
174
181
|
if event["status"] == "error":
|
175
182
|
raise AirflowException(event["message"])
|
176
183
|
training_pipeline = event["job"]
|
177
184
|
custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(training_pipeline)
|
178
|
-
|
185
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
179
186
|
try:
|
180
187
|
model = training_pipeline["model_to_upload"]
|
181
188
|
model_id = self.hook.extract_model_id(model)
|
182
|
-
|
183
|
-
VertexAIModelLink.persist(context=context,
|
189
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
190
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
184
191
|
return model
|
185
192
|
except KeyError:
|
186
193
|
self.log.warning(
|
@@ -584,13 +591,13 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
584
591
|
if model:
|
585
592
|
result = Model.to_dict(model)
|
586
593
|
model_id = self.hook.extract_model_id(result)
|
587
|
-
|
588
|
-
VertexAIModelLink.persist(context=context,
|
594
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
595
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
589
596
|
else:
|
590
597
|
result = model # type: ignore
|
591
|
-
|
592
|
-
|
593
|
-
VertexAITrainingLink.persist(context=context,
|
598
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
599
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
600
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
594
601
|
return result
|
595
602
|
|
596
603
|
def invoke_defer(self, context: Context) -> None:
|
@@ -648,8 +655,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
648
655
|
)
|
649
656
|
custom_container_training_job_obj.wait_for_resource_creation()
|
650
657
|
training_pipeline_id: str = custom_container_training_job_obj.name
|
651
|
-
|
652
|
-
VertexAITrainingLink.persist(context=context,
|
658
|
+
context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
|
659
|
+
VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
|
653
660
|
self.defer(
|
654
661
|
trigger=CustomContainerTrainingJobTrigger(
|
655
662
|
conn_id=self.gcp_conn_id,
|
@@ -1041,13 +1048,13 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
1041
1048
|
if model:
|
1042
1049
|
result = Model.to_dict(model)
|
1043
1050
|
model_id = self.hook.extract_model_id(result)
|
1044
|
-
|
1045
|
-
VertexAIModelLink.persist(context=context,
|
1051
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
1052
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
1046
1053
|
else:
|
1047
1054
|
result = model # type: ignore
|
1048
|
-
|
1049
|
-
|
1050
|
-
VertexAITrainingLink.persist(context=context,
|
1055
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
1056
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
1057
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
1051
1058
|
return result
|
1052
1059
|
|
1053
1060
|
def invoke_defer(self, context: Context) -> None:
|
@@ -1106,8 +1113,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
1106
1113
|
)
|
1107
1114
|
custom_python_training_job_obj.wait_for_resource_creation()
|
1108
1115
|
training_pipeline_id: str = custom_python_training_job_obj.name
|
1109
|
-
|
1110
|
-
VertexAITrainingLink.persist(context=context,
|
1116
|
+
context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
|
1117
|
+
VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
|
1111
1118
|
self.defer(
|
1112
1119
|
trigger=CustomPythonPackageTrainingJobTrigger(
|
1113
1120
|
conn_id=self.gcp_conn_id,
|
@@ -1504,13 +1511,13 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
1504
1511
|
if model:
|
1505
1512
|
result = Model.to_dict(model)
|
1506
1513
|
model_id = self.hook.extract_model_id(result)
|
1507
|
-
|
1508
|
-
VertexAIModelLink.persist(context=context,
|
1514
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
1515
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
1509
1516
|
else:
|
1510
1517
|
result = model # type: ignore
|
1511
|
-
|
1512
|
-
|
1513
|
-
VertexAITrainingLink.persist(context=context,
|
1518
|
+
context["ti"].xcom_push(key="training_id", value=training_id)
|
1519
|
+
context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
|
1520
|
+
VertexAITrainingLink.persist(context=context, training_id=training_id)
|
1514
1521
|
return result
|
1515
1522
|
|
1516
1523
|
def invoke_defer(self, context: Context) -> None:
|
@@ -1569,8 +1576,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
1569
1576
|
)
|
1570
1577
|
custom_training_job_obj.wait_for_resource_creation()
|
1571
1578
|
training_pipeline_id: str = custom_training_job_obj.name
|
1572
|
-
|
1573
|
-
VertexAITrainingLink.persist(context=context,
|
1579
|
+
context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
|
1580
|
+
VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id)
|
1574
1581
|
self.defer(
|
1575
1582
|
trigger=CustomTrainingJobTrigger(
|
1576
1583
|
conn_id=self.gcp_conn_id,
|
@@ -1748,6 +1755,12 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
|
|
1748
1755
|
self.gcp_conn_id = gcp_conn_id
|
1749
1756
|
self.impersonation_chain = impersonation_chain
|
1750
1757
|
|
1758
|
+
@property
|
1759
|
+
def extra_links_params(self) -> dict[str, Any]:
|
1760
|
+
return {
|
1761
|
+
"project_id": self.project_id,
|
1762
|
+
}
|
1763
|
+
|
1751
1764
|
def execute(self, context: Context):
|
1752
1765
|
hook = CustomJobHook(
|
1753
1766
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -1764,5 +1777,5 @@ class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
|
|
1764
1777
|
timeout=self.timeout,
|
1765
1778
|
metadata=self.metadata,
|
1766
1779
|
)
|
1767
|
-
VertexAITrainingPipelinesLink.persist(context=context
|
1780
|
+
VertexAITrainingPipelinesLink.persist(context=context)
|
1768
1781
|
return [TrainingPipeline.to_dict(result) for result in results]
|