apache-airflow-providers-google 14.0.0__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/LICENSES.txt +14 -0
- airflow/providers/google/3rd-party-licenses/NOTICE +5 -0
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/_vendor/__init__.py +0 -0
- airflow/providers/google/_vendor/json_merge_patch.py +91 -0
- airflow/providers/google/ads/hooks/ads.py +52 -43
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +3 -19
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/_internal_client/secret_manager_client.py +3 -2
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/alloy_db.py +2 -3
- airflow/providers/google/cloud/hooks/bigquery.py +195 -318
- airflow/providers/google/cloud/hooks/bigquery_dts.py +8 -8
- airflow/providers/google/cloud/hooks/bigtable.py +3 -2
- airflow/providers/google/cloud/hooks/cloud_batch.py +8 -9
- airflow/providers/google/cloud/hooks/cloud_build.py +6 -65
- airflow/providers/google/cloud/hooks/cloud_composer.py +292 -24
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +4 -3
- airflow/providers/google/cloud/hooks/cloud_run.py +20 -11
- airflow/providers/google/cloud/hooks/cloud_sql.py +136 -64
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +35 -15
- airflow/providers/google/cloud/hooks/compute.py +7 -6
- airflow/providers/google/cloud/hooks/compute_ssh.py +7 -4
- airflow/providers/google/cloud/hooks/datacatalog.py +12 -3
- airflow/providers/google/cloud/hooks/dataflow.py +87 -242
- airflow/providers/google/cloud/hooks/dataform.py +9 -14
- airflow/providers/google/cloud/hooks/datafusion.py +7 -9
- airflow/providers/google/cloud/hooks/dataplex.py +13 -12
- airflow/providers/google/cloud/hooks/dataprep.py +2 -2
- airflow/providers/google/cloud/hooks/dataproc.py +76 -74
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +4 -3
- airflow/providers/google/cloud/hooks/dlp.py +5 -4
- airflow/providers/google/cloud/hooks/gcs.py +144 -33
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kms.py +3 -2
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +22 -17
- airflow/providers/google/cloud/hooks/looker.py +6 -1
- airflow/providers/google/cloud/hooks/managed_kafka.py +227 -3
- airflow/providers/google/cloud/hooks/mlengine.py +7 -8
- airflow/providers/google/cloud/hooks/natural_language.py +3 -2
- airflow/providers/google/cloud/hooks/os_login.py +3 -2
- airflow/providers/google/cloud/hooks/pubsub.py +6 -6
- airflow/providers/google/cloud/hooks/secret_manager.py +105 -12
- airflow/providers/google/cloud/hooks/spanner.py +75 -10
- airflow/providers/google/cloud/hooks/speech_to_text.py +3 -2
- airflow/providers/google/cloud/hooks/stackdriver.py +18 -18
- airflow/providers/google/cloud/hooks/tasks.py +4 -3
- airflow/providers/google/cloud/hooks/text_to_speech.py +3 -2
- airflow/providers/google/cloud/hooks/translate.py +8 -17
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -222
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +9 -15
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +33 -283
- airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +5 -12
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +6 -12
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +311 -10
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +7 -13
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +8 -12
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +6 -12
- airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +3 -2
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/video_intelligence.py +3 -2
- airflow/providers/google/cloud/hooks/vision.py +7 -7
- airflow/providers/google/cloud/hooks/workflows.py +4 -3
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -7
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -46
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -90
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -89
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +11 -61
- airflow/providers/google/cloud/links/managed_kafka.py +11 -51
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +166 -118
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +14 -9
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +141 -40
- airflow/providers/google/cloud/openlineage/mixins.py +14 -13
- airflow/providers/google/cloud/openlineage/utils.py +19 -3
- airflow/providers/google/cloud/operators/alloy_db.py +76 -61
- airflow/providers/google/cloud/operators/bigquery.py +104 -667
- airflow/providers/google/cloud/operators/bigquery_dts.py +12 -12
- airflow/providers/google/cloud/operators/bigtable.py +38 -7
- airflow/providers/google/cloud/operators/cloud_base.py +22 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +18 -18
- airflow/providers/google/cloud/operators/cloud_build.py +80 -36
- airflow/providers/google/cloud/operators/cloud_composer.py +157 -71
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +74 -46
- airflow/providers/google/cloud/operators/cloud_run.py +39 -20
- airflow/providers/google/cloud/operators/cloud_sql.py +46 -61
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -14
- airflow/providers/google/cloud/operators/compute.py +18 -50
- airflow/providers/google/cloud/operators/datacatalog.py +167 -29
- airflow/providers/google/cloud/operators/dataflow.py +38 -15
- airflow/providers/google/cloud/operators/dataform.py +19 -7
- airflow/providers/google/cloud/operators/datafusion.py +43 -43
- airflow/providers/google/cloud/operators/dataplex.py +212 -126
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +134 -207
- airflow/providers/google/cloud/operators/dataproc_metastore.py +102 -84
- airflow/providers/google/cloud/operators/datastore.py +22 -6
- airflow/providers/google/cloud/operators/dlp.py +24 -45
- airflow/providers/google/cloud/operators/functions.py +21 -14
- airflow/providers/google/cloud/operators/gcs.py +15 -12
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +115 -106
- airflow/providers/google/cloud/operators/looker.py +1 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +362 -40
- airflow/providers/google/cloud/operators/natural_language.py +5 -3
- airflow/providers/google/cloud/operators/pubsub.py +69 -21
- airflow/providers/google/cloud/operators/spanner.py +53 -45
- airflow/providers/google/cloud/operators/speech_to_text.py +5 -4
- airflow/providers/google/cloud/operators/stackdriver.py +5 -11
- airflow/providers/google/cloud/operators/tasks.py +6 -15
- airflow/providers/google/cloud/operators/text_to_speech.py +4 -3
- airflow/providers/google/cloud/operators/translate.py +46 -20
- airflow/providers/google/cloud/operators/translate_speech.py +4 -3
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +44 -34
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +34 -12
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +62 -53
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +75 -11
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +48 -12
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +16 -12
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +62 -14
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +35 -10
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +5 -3
- airflow/providers/google/cloud/operators/vision.py +7 -5
- airflow/providers/google/cloud/operators/workflows.py +24 -19
- airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
- airflow/providers/google/cloud/sensors/bigquery.py +2 -2
- airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -4
- airflow/providers/google/cloud/sensors/bigtable.py +14 -6
- airflow/providers/google/cloud/sensors/cloud_composer.py +535 -33
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -5
- airflow/providers/google/cloud/sensors/dataflow.py +27 -10
- airflow/providers/google/cloud/sensors/dataform.py +2 -2
- airflow/providers/google/cloud/sensors/datafusion.py +4 -4
- airflow/providers/google/cloud/sensors/dataplex.py +7 -5
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +10 -9
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +4 -3
- airflow/providers/google/cloud/sensors/gcs.py +22 -21
- airflow/providers/google/cloud/sensors/looker.py +5 -5
- airflow/providers/google/cloud/sensors/pubsub.py +20 -20
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
- airflow/providers/google/cloud/sensors/workflows.py +6 -4
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +14 -13
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +18 -22
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -5
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +45 -38
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +44 -12
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +36 -14
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +75 -34
- airflow/providers/google/cloud/triggers/bigquery_dts.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_batch.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_build.py +3 -2
- airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +96 -5
- airflow/providers/google/cloud/triggers/dataflow.py +125 -2
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +16 -3
- airflow/providers/google/cloud/triggers/dataproc.py +124 -53
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +46 -28
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +17 -20
- airflow/providers/google/cloud/triggers/vertex_ai.py +8 -7
- airflow/providers/google/cloud/utils/bigquery.py +5 -7
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +4 -3
- airflow/providers/google/cloud/utils/dataform.py +1 -1
- airflow/providers/google/cloud/utils/external_token_supplier.py +0 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/cloud/utils/validators.py +43 -0
- airflow/providers/google/common/auth_backend/google_openid.py +26 -9
- airflow/providers/google/common/consts.py +2 -1
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +40 -43
- airflow/providers/google/common/hooks/operation_helpers.py +78 -0
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +4 -5
- airflow/providers/google/firebase/operators/firestore.py +2 -2
- airflow/providers/google/get_provider_info.py +61 -216
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +30 -6
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/analytics_admin.py +3 -2
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +4 -5
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +7 -6
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +4 -64
- airflow/providers/google/suite/hooks/calendar.py +1 -1
- airflow/providers/google/suite/hooks/drive.py +2 -2
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +117 -72
- apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +1 -1
- apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/example_dags/example_cloud_task.py +0 -54
- airflow/providers/google/cloud/hooks/automl.py +0 -679
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1360
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -1515
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +0 -273
- apache_airflow_providers_google-14.0.0.dist-info/RECORD +0 -318
- /airflow/providers/google/cloud/{example_dags → bundles}/__init__.py +0 -0
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -22,8 +22,6 @@ from collections.abc import Sequence
|
|
|
22
22
|
from functools import cached_property
|
|
23
23
|
from typing import TYPE_CHECKING, Any
|
|
24
24
|
|
|
25
|
-
from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook, DataflowJobStatus
|
|
26
|
-
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
|
27
25
|
from google.cloud.dataflow_v1beta3 import JobState
|
|
28
26
|
from google.cloud.dataflow_v1beta3.types import (
|
|
29
27
|
AutoscalingEvent,
|
|
@@ -34,6 +32,9 @@ from google.cloud.dataflow_v1beta3.types import (
|
|
|
34
32
|
MetricUpdate,
|
|
35
33
|
)
|
|
36
34
|
|
|
35
|
+
from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook, DataflowJobStatus
|
|
36
|
+
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
|
37
|
+
|
|
37
38
|
if TYPE_CHECKING:
|
|
38
39
|
from google.cloud.dataflow_v1beta3.services.messages_v1_beta3.pagers import ListJobMessagesAsyncPager
|
|
39
40
|
|
|
@@ -787,3 +788,125 @@ class DataflowJobMessagesTrigger(BaseTrigger):
|
|
|
787
788
|
poll_sleep=self.poll_sleep,
|
|
788
789
|
impersonation_chain=self.impersonation_chain,
|
|
789
790
|
)
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
class DataflowJobStateCompleteTrigger(BaseTrigger):
|
|
794
|
+
"""
|
|
795
|
+
Trigger that monitors if a Dataflow job has reached any of successful terminal state meant for that job.
|
|
796
|
+
|
|
797
|
+
:param job_id: Required. ID of the job.
|
|
798
|
+
:param project_id: Required. The Google Cloud project ID in which the job was started.
|
|
799
|
+
:param location: Optional. The location where the job is executed. If set to None then
|
|
800
|
+
the value of DEFAULT_DATAFLOW_LOCATION will be used.
|
|
801
|
+
:param wait_until_finished: Optional. Dataflow option to block pipeline until completion.
|
|
802
|
+
:param gcp_conn_id: The connection ID to use for connecting to Google Cloud.
|
|
803
|
+
:param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job.
|
|
804
|
+
:param impersonation_chain: Optional. Service account to impersonate using short-term
|
|
805
|
+
credentials, or chained list of accounts required to get the access_token
|
|
806
|
+
of the last account in the list, which will be impersonated in the request.
|
|
807
|
+
If set as a string, the account must grant the originating account
|
|
808
|
+
the Service Account Token Creator IAM role.
|
|
809
|
+
If set as a sequence, the identities from the list must grant
|
|
810
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
811
|
+
account from the list granting this role to the originating account (templated).
|
|
812
|
+
"""
|
|
813
|
+
|
|
814
|
+
def __init__(
|
|
815
|
+
self,
|
|
816
|
+
job_id: str,
|
|
817
|
+
project_id: str | None,
|
|
818
|
+
location: str = DEFAULT_DATAFLOW_LOCATION,
|
|
819
|
+
wait_until_finished: bool | None = None,
|
|
820
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
821
|
+
poll_sleep: int = 10,
|
|
822
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
823
|
+
):
|
|
824
|
+
super().__init__()
|
|
825
|
+
self.job_id = job_id
|
|
826
|
+
self.project_id = project_id
|
|
827
|
+
self.location = location
|
|
828
|
+
self.wait_until_finished = wait_until_finished
|
|
829
|
+
self.gcp_conn_id = gcp_conn_id
|
|
830
|
+
self.poll_sleep = poll_sleep
|
|
831
|
+
self.impersonation_chain = impersonation_chain
|
|
832
|
+
|
|
833
|
+
def serialize(self) -> tuple[str, dict[str, Any]]:
|
|
834
|
+
"""Serialize class arguments and classpath."""
|
|
835
|
+
return (
|
|
836
|
+
"airflow.providers.google.cloud.triggers.dataflow.DataflowJobStateCompleteTrigger",
|
|
837
|
+
{
|
|
838
|
+
"job_id": self.job_id,
|
|
839
|
+
"project_id": self.project_id,
|
|
840
|
+
"location": self.location,
|
|
841
|
+
"wait_until_finished": self.wait_until_finished,
|
|
842
|
+
"gcp_conn_id": self.gcp_conn_id,
|
|
843
|
+
"poll_sleep": self.poll_sleep,
|
|
844
|
+
"impersonation_chain": self.impersonation_chain,
|
|
845
|
+
},
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
async def run(self):
|
|
849
|
+
"""
|
|
850
|
+
Loop until the job reaches successful final or error state.
|
|
851
|
+
|
|
852
|
+
Yields a TriggerEvent with success status, if the job reaches successful state for own type.
|
|
853
|
+
|
|
854
|
+
Yields a TriggerEvent with error status, if the client returns an unexpected terminal
|
|
855
|
+
job status or any exception is raised while looping.
|
|
856
|
+
|
|
857
|
+
In any other case the Trigger will wait for a specified amount of time
|
|
858
|
+
stored in self.poll_sleep variable.
|
|
859
|
+
"""
|
|
860
|
+
try:
|
|
861
|
+
while True:
|
|
862
|
+
job = await self.async_hook.get_job(
|
|
863
|
+
project_id=self.project_id,
|
|
864
|
+
job_id=self.job_id,
|
|
865
|
+
location=self.location,
|
|
866
|
+
)
|
|
867
|
+
job_state = job.current_state.name
|
|
868
|
+
job_type_name = job.type_.name
|
|
869
|
+
|
|
870
|
+
FAILED_STATES = DataflowJobStatus.FAILED_END_STATES | {DataflowJobStatus.JOB_STATE_DRAINED}
|
|
871
|
+
if job_state in FAILED_STATES:
|
|
872
|
+
yield TriggerEvent(
|
|
873
|
+
{
|
|
874
|
+
"status": "error",
|
|
875
|
+
"message": (
|
|
876
|
+
f"Job with id '{self.job_id}' is in failed terminal state: {job_state}"
|
|
877
|
+
),
|
|
878
|
+
}
|
|
879
|
+
)
|
|
880
|
+
return
|
|
881
|
+
|
|
882
|
+
if self.async_hook.job_reached_terminal_state(
|
|
883
|
+
job={"id": self.job_id, "currentState": job_state, "type": job_type_name},
|
|
884
|
+
wait_until_finished=self.wait_until_finished,
|
|
885
|
+
):
|
|
886
|
+
yield TriggerEvent(
|
|
887
|
+
{
|
|
888
|
+
"status": "success",
|
|
889
|
+
"message": (
|
|
890
|
+
f"Job with id '{self.job_id}' has reached successful final state: {job_state}"
|
|
891
|
+
),
|
|
892
|
+
}
|
|
893
|
+
)
|
|
894
|
+
return
|
|
895
|
+
self.log.info("Sleeping for %s seconds.", self.poll_sleep)
|
|
896
|
+
await asyncio.sleep(self.poll_sleep)
|
|
897
|
+
except Exception as e:
|
|
898
|
+
self.log.error("Exception occurred while checking for job state!")
|
|
899
|
+
yield TriggerEvent(
|
|
900
|
+
{
|
|
901
|
+
"status": "error",
|
|
902
|
+
"message": str(e),
|
|
903
|
+
}
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
@cached_property
|
|
907
|
+
def async_hook(self) -> AsyncDataflowHook:
|
|
908
|
+
return AsyncDataflowHook(
|
|
909
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
910
|
+
poll_sleep=self.poll_sleep,
|
|
911
|
+
impersonation_chain=self.impersonation_chain,
|
|
912
|
+
)
|
|
@@ -86,7 +86,7 @@ class DataFusionStartPipelineTrigger(BaseTrigger):
|
|
|
86
86
|
},
|
|
87
87
|
)
|
|
88
88
|
|
|
89
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
89
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
90
90
|
"""Get current pipeline status and yields a TriggerEvent."""
|
|
91
91
|
hook = self._get_async_hook()
|
|
92
92
|
try:
|
|
@@ -22,9 +22,10 @@ from __future__ import annotations
|
|
|
22
22
|
import asyncio
|
|
23
23
|
from collections.abc import AsyncIterator, Sequence
|
|
24
24
|
|
|
25
|
+
from google.cloud.dataplex_v1.types import DataScanJob
|
|
26
|
+
|
|
25
27
|
from airflow.providers.google.cloud.hooks.dataplex import DataplexAsyncHook
|
|
26
28
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
|
27
|
-
from google.cloud.dataplex_v1.types import DataScanJob
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
class DataplexDataQualityJobTrigger(BaseTrigger):
|
|
@@ -102,7 +103,13 @@ class DataplexDataQualityJobTrigger(BaseTrigger):
|
|
|
102
103
|
self.polling_interval_seconds,
|
|
103
104
|
)
|
|
104
105
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
105
|
-
yield TriggerEvent(
|
|
106
|
+
yield TriggerEvent(
|
|
107
|
+
{
|
|
108
|
+
"job_id": self.job_id,
|
|
109
|
+
"job_state": DataScanJob.State(state).name,
|
|
110
|
+
"job": self._convert_to_dict(job),
|
|
111
|
+
}
|
|
112
|
+
)
|
|
106
113
|
|
|
107
114
|
def _convert_to_dict(self, job: DataScanJob) -> dict:
|
|
108
115
|
"""Return a representation of a DataScanJob instance as a dict."""
|
|
@@ -184,7 +191,13 @@ class DataplexDataProfileJobTrigger(BaseTrigger):
|
|
|
184
191
|
self.polling_interval_seconds,
|
|
185
192
|
)
|
|
186
193
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
187
|
-
yield TriggerEvent(
|
|
194
|
+
yield TriggerEvent(
|
|
195
|
+
{
|
|
196
|
+
"job_id": self.job_id,
|
|
197
|
+
"job_state": DataScanJob.State(state).name,
|
|
198
|
+
"job": self._convert_to_dict(job),
|
|
199
|
+
}
|
|
200
|
+
)
|
|
188
201
|
|
|
189
202
|
def _convert_to_dict(self, job: DataScanJob) -> dict:
|
|
190
203
|
"""Return a representation of a DataScanJob instance as a dict."""
|
|
@@ -25,20 +25,25 @@ import time
|
|
|
25
25
|
from collections.abc import AsyncIterator, Sequence
|
|
26
26
|
from typing import TYPE_CHECKING, Any
|
|
27
27
|
|
|
28
|
+
from asgiref.sync import sync_to_async
|
|
29
|
+
from google.api_core.exceptions import NotFound
|
|
30
|
+
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, Job, JobStatus
|
|
31
|
+
|
|
28
32
|
from airflow.exceptions import AirflowException
|
|
29
|
-
from airflow.models.taskinstance import TaskInstance
|
|
30
33
|
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
|
|
31
34
|
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
|
|
32
35
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
|
36
|
+
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
|
|
33
37
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
|
34
|
-
from airflow.utils.session import provide_session
|
|
35
38
|
from airflow.utils.state import TaskInstanceState
|
|
36
|
-
from google.api_core.exceptions import NotFound
|
|
37
|
-
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
|
|
38
39
|
|
|
39
40
|
if TYPE_CHECKING:
|
|
40
41
|
from sqlalchemy.orm.session import Session
|
|
41
42
|
|
|
43
|
+
if not AIRFLOW_V_3_0_PLUS:
|
|
44
|
+
from airflow.models.taskinstance import TaskInstance
|
|
45
|
+
from airflow.utils.session import provide_session
|
|
46
|
+
|
|
42
47
|
|
|
43
48
|
class DataprocBaseTrigger(BaseTrigger):
|
|
44
49
|
"""Base class for Dataproc triggers."""
|
|
@@ -116,40 +121,67 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
|
|
|
116
121
|
},
|
|
117
122
|
)
|
|
118
123
|
|
|
119
|
-
|
|
120
|
-
def get_task_instance(self, session: Session) -> TaskInstance:
|
|
121
|
-
"""
|
|
122
|
-
Get the task instance for the current task.
|
|
124
|
+
if not AIRFLOW_V_3_0_PLUS:
|
|
123
125
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
126
|
+
@provide_session
|
|
127
|
+
def get_task_instance(self, session: Session) -> TaskInstance:
|
|
128
|
+
"""
|
|
129
|
+
Get the task instance for the current task.
|
|
130
|
+
|
|
131
|
+
:param session: Sqlalchemy session
|
|
132
|
+
"""
|
|
133
|
+
query = session.query(TaskInstance).filter(
|
|
134
|
+
TaskInstance.dag_id == self.task_instance.dag_id,
|
|
135
|
+
TaskInstance.task_id == self.task_instance.task_id,
|
|
136
|
+
TaskInstance.run_id == self.task_instance.run_id,
|
|
137
|
+
TaskInstance.map_index == self.task_instance.map_index,
|
|
138
|
+
)
|
|
139
|
+
task_instance = query.one_or_none()
|
|
140
|
+
if task_instance is None:
|
|
141
|
+
raise AirflowException(
|
|
142
|
+
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
|
|
143
|
+
self.task_instance.dag_id,
|
|
144
|
+
self.task_instance.task_id,
|
|
145
|
+
self.task_instance.run_id,
|
|
146
|
+
self.task_instance.map_index,
|
|
147
|
+
)
|
|
148
|
+
return task_instance
|
|
149
|
+
|
|
150
|
+
async def get_task_state(self):
|
|
151
|
+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
|
|
152
|
+
|
|
153
|
+
task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
|
|
154
|
+
dag_id=self.task_instance.dag_id,
|
|
155
|
+
task_ids=[self.task_instance.task_id],
|
|
156
|
+
run_ids=[self.task_instance.run_id],
|
|
157
|
+
map_index=self.task_instance.map_index,
|
|
131
158
|
)
|
|
132
|
-
|
|
133
|
-
|
|
159
|
+
try:
|
|
160
|
+
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
|
|
161
|
+
except Exception:
|
|
134
162
|
raise AirflowException(
|
|
135
|
-
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
|
|
163
|
+
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
|
|
136
164
|
self.task_instance.dag_id,
|
|
137
165
|
self.task_instance.task_id,
|
|
138
166
|
self.task_instance.run_id,
|
|
139
167
|
self.task_instance.map_index,
|
|
140
168
|
)
|
|
141
|
-
return
|
|
169
|
+
return task_state
|
|
142
170
|
|
|
143
|
-
def safe_to_cancel(self) -> bool:
|
|
171
|
+
async def safe_to_cancel(self) -> bool:
|
|
144
172
|
"""
|
|
145
173
|
Whether it is safe to cancel the external job which is being executed by this trigger.
|
|
146
174
|
|
|
147
175
|
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
|
|
148
176
|
Because in those cases, we should NOT cancel the external job.
|
|
149
177
|
"""
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
178
|
+
if AIRFLOW_V_3_0_PLUS:
|
|
179
|
+
task_state = await self.get_task_state()
|
|
180
|
+
else:
|
|
181
|
+
# Database query is needed to get the latest state of the task instance.
|
|
182
|
+
task_instance = self.get_task_instance() # type: ignore[call-arg]
|
|
183
|
+
task_state = task_instance.state
|
|
184
|
+
return task_state != TaskInstanceState.DEFERRED
|
|
153
185
|
|
|
154
186
|
async def run(self):
|
|
155
187
|
try:
|
|
@@ -162,11 +194,13 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
|
|
|
162
194
|
if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
|
|
163
195
|
break
|
|
164
196
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
165
|
-
yield TriggerEvent(
|
|
197
|
+
yield TriggerEvent(
|
|
198
|
+
{"job_id": self.job_id, "job_state": JobStatus.State(state).name, "job": Job.to_dict(job)}
|
|
199
|
+
)
|
|
166
200
|
except asyncio.CancelledError:
|
|
167
201
|
self.log.info("Task got cancelled.")
|
|
168
202
|
try:
|
|
169
|
-
if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
|
|
203
|
+
if self.job_id and self.cancel_on_kill and await self.safe_to_cancel():
|
|
170
204
|
self.log.info(
|
|
171
205
|
"Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
|
|
172
206
|
" in deferred state."
|
|
@@ -180,7 +214,12 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
|
|
|
180
214
|
job_id=self.job_id, project_id=self.project_id, region=self.region
|
|
181
215
|
)
|
|
182
216
|
self.log.info("Job: %s is cancelled", self.job_id)
|
|
183
|
-
yield TriggerEvent(
|
|
217
|
+
yield TriggerEvent(
|
|
218
|
+
{
|
|
219
|
+
"job_id": self.job_id,
|
|
220
|
+
"job_state": ClusterStatus.State.DELETING.name,
|
|
221
|
+
}
|
|
222
|
+
)
|
|
184
223
|
except Exception as e:
|
|
185
224
|
self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e))
|
|
186
225
|
raise e
|
|
@@ -223,35 +262,62 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
|
|
|
223
262
|
},
|
|
224
263
|
)
|
|
225
264
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
265
|
+
if not AIRFLOW_V_3_0_PLUS:
|
|
266
|
+
|
|
267
|
+
@provide_session
|
|
268
|
+
def get_task_instance(self, session: Session) -> TaskInstance:
|
|
269
|
+
query = session.query(TaskInstance).filter(
|
|
270
|
+
TaskInstance.dag_id == self.task_instance.dag_id,
|
|
271
|
+
TaskInstance.task_id == self.task_instance.task_id,
|
|
272
|
+
TaskInstance.run_id == self.task_instance.run_id,
|
|
273
|
+
TaskInstance.map_index == self.task_instance.map_index,
|
|
274
|
+
)
|
|
275
|
+
task_instance = query.one_or_none()
|
|
276
|
+
if task_instance is None:
|
|
277
|
+
raise AirflowException(
|
|
278
|
+
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
|
|
279
|
+
self.task_instance.dag_id,
|
|
280
|
+
self.task_instance.task_id,
|
|
281
|
+
self.task_instance.run_id,
|
|
282
|
+
self.task_instance.map_index,
|
|
283
|
+
)
|
|
284
|
+
return task_instance
|
|
285
|
+
|
|
286
|
+
async def get_task_state(self):
|
|
287
|
+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
|
|
288
|
+
|
|
289
|
+
task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
|
|
290
|
+
dag_id=self.task_instance.dag_id,
|
|
291
|
+
task_ids=[self.task_instance.task_id],
|
|
292
|
+
run_ids=[self.task_instance.run_id],
|
|
293
|
+
map_index=self.task_instance.map_index,
|
|
233
294
|
)
|
|
234
|
-
|
|
235
|
-
|
|
295
|
+
try:
|
|
296
|
+
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
|
|
297
|
+
except Exception:
|
|
236
298
|
raise AirflowException(
|
|
237
|
-
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found
|
|
299
|
+
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
|
|
238
300
|
self.task_instance.dag_id,
|
|
239
301
|
self.task_instance.task_id,
|
|
240
302
|
self.task_instance.run_id,
|
|
241
303
|
self.task_instance.map_index,
|
|
242
304
|
)
|
|
243
|
-
return
|
|
305
|
+
return task_state
|
|
244
306
|
|
|
245
|
-
def safe_to_cancel(self) -> bool:
|
|
307
|
+
async def safe_to_cancel(self) -> bool:
|
|
246
308
|
"""
|
|
247
309
|
Whether it is safe to cancel the external job which is being executed by this trigger.
|
|
248
310
|
|
|
249
311
|
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
|
|
250
312
|
Because in those cases, we should NOT cancel the external job.
|
|
251
313
|
"""
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
314
|
+
if AIRFLOW_V_3_0_PLUS:
|
|
315
|
+
task_state = await self.get_task_state()
|
|
316
|
+
else:
|
|
317
|
+
# Database query is needed to get the latest state of the task instance.
|
|
318
|
+
task_instance = self.get_task_instance() # type: ignore[call-arg]
|
|
319
|
+
task_state = task_instance.state
|
|
320
|
+
return task_state != TaskInstanceState.DEFERRED
|
|
255
321
|
|
|
256
322
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
257
323
|
try:
|
|
@@ -263,8 +329,8 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
|
|
|
263
329
|
yield TriggerEvent(
|
|
264
330
|
{
|
|
265
331
|
"cluster_name": self.cluster_name,
|
|
266
|
-
"cluster_state": ClusterStatus.State.DELETING,
|
|
267
|
-
"cluster": cluster,
|
|
332
|
+
"cluster_state": ClusterStatus.State.DELETING.name, # type: ignore
|
|
333
|
+
"cluster": Cluster.to_dict(cluster),
|
|
268
334
|
}
|
|
269
335
|
)
|
|
270
336
|
return
|
|
@@ -272,17 +338,18 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
|
|
|
272
338
|
yield TriggerEvent(
|
|
273
339
|
{
|
|
274
340
|
"cluster_name": self.cluster_name,
|
|
275
|
-
"cluster_state": state,
|
|
276
|
-
"cluster": cluster,
|
|
341
|
+
"cluster_state": ClusterStatus.State(state).name,
|
|
342
|
+
"cluster": Cluster.to_dict(cluster),
|
|
277
343
|
}
|
|
278
344
|
)
|
|
279
345
|
return
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
346
|
+
else:
|
|
347
|
+
self.log.info("Current state is %s", state)
|
|
348
|
+
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
|
|
349
|
+
await asyncio.sleep(self.polling_interval_seconds)
|
|
283
350
|
except asyncio.CancelledError:
|
|
284
351
|
try:
|
|
285
|
-
if self.delete_on_error and self.safe_to_cancel():
|
|
352
|
+
if self.delete_on_error and await self.safe_to_cancel():
|
|
286
353
|
self.log.info(
|
|
287
354
|
"Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in "
|
|
288
355
|
"deferred state."
|
|
@@ -368,12 +435,16 @@ class DataprocBatchTrigger(DataprocBaseTrigger):
|
|
|
368
435
|
|
|
369
436
|
if state in (Batch.State.FAILED, Batch.State.SUCCEEDED, Batch.State.CANCELLED):
|
|
370
437
|
break
|
|
371
|
-
self.log.info("Current state is %s", state)
|
|
438
|
+
self.log.info("Current state is %s", Batch.State(state).name)
|
|
372
439
|
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
|
|
373
440
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
374
441
|
|
|
375
442
|
yield TriggerEvent(
|
|
376
|
-
{
|
|
443
|
+
{
|
|
444
|
+
"batch_id": self.batch_id,
|
|
445
|
+
"batch_state": Batch.State(state).name,
|
|
446
|
+
"batch_state_message": batch.state_message,
|
|
447
|
+
}
|
|
377
448
|
)
|
|
378
449
|
|
|
379
450
|
|
|
@@ -431,9 +502,9 @@ class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
|
|
|
431
502
|
try:
|
|
432
503
|
while self.end_time > time.time():
|
|
433
504
|
cluster = await self.get_async_hook().get_cluster(
|
|
434
|
-
region=self.region,
|
|
505
|
+
region=self.region,
|
|
435
506
|
cluster_name=self.cluster_name,
|
|
436
|
-
project_id=self.project_id,
|
|
507
|
+
project_id=self.project_id,
|
|
437
508
|
metadata=self.metadata,
|
|
438
509
|
)
|
|
439
510
|
self.log.info(
|
|
@@ -23,6 +23,7 @@ from collections.abc import AsyncIterator, Sequence
|
|
|
23
23
|
from functools import cached_property
|
|
24
24
|
from typing import TYPE_CHECKING, Any
|
|
25
25
|
|
|
26
|
+
from google.cloud.container_v1.types import Operation
|
|
26
27
|
from packaging.version import parse as parse_version
|
|
27
28
|
|
|
28
29
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
|
@@ -36,7 +37,6 @@ from airflow.providers.google.cloud.hooks.kubernetes_engine import (
|
|
|
36
37
|
)
|
|
37
38
|
from airflow.providers_manager import ProvidersManager
|
|
38
39
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
|
39
|
-
from google.cloud.container_v1.types import Operation
|
|
40
40
|
|
|
41
41
|
if TYPE_CHECKING:
|
|
42
42
|
from datetime import datetime
|
|
@@ -153,7 +153,7 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
|
|
|
153
153
|
)
|
|
154
154
|
|
|
155
155
|
@cached_property
|
|
156
|
-
def hook(self) -> GKEKubernetesAsyncHook:
|
|
156
|
+
def hook(self) -> GKEKubernetesAsyncHook:
|
|
157
157
|
return GKEKubernetesAsyncHook(
|
|
158
158
|
cluster_url=self._cluster_url,
|
|
159
159
|
ssl_ca_cert=self._ssl_ca_cert,
|
|
@@ -200,7 +200,7 @@ class GKEOperationTrigger(BaseTrigger):
|
|
|
200
200
|
},
|
|
201
201
|
)
|
|
202
202
|
|
|
203
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
203
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
204
204
|
"""Get operation status and yields corresponding event."""
|
|
205
205
|
hook = self._get_hook()
|
|
206
206
|
try:
|
|
@@ -260,9 +260,10 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
260
260
|
ssl_ca_cert: str,
|
|
261
261
|
job_name: str,
|
|
262
262
|
job_namespace: str,
|
|
263
|
-
|
|
263
|
+
pod_names: list[str],
|
|
264
264
|
pod_namespace: str,
|
|
265
265
|
base_container_name: str,
|
|
266
|
+
pod_name: str | None = None,
|
|
266
267
|
gcp_conn_id: str = "google_cloud_default",
|
|
267
268
|
poll_interval: float = 2,
|
|
268
269
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
@@ -274,7 +275,13 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
274
275
|
self.ssl_ca_cert = ssl_ca_cert
|
|
275
276
|
self.job_name = job_name
|
|
276
277
|
self.job_namespace = job_namespace
|
|
277
|
-
|
|
278
|
+
if pod_name is not None:
|
|
279
|
+
self._pod_name = pod_name
|
|
280
|
+
self.pod_names = [
|
|
281
|
+
self.pod_name,
|
|
282
|
+
]
|
|
283
|
+
else:
|
|
284
|
+
self.pod_names = pod_names
|
|
278
285
|
self.pod_namespace = pod_namespace
|
|
279
286
|
self.base_container_name = base_container_name
|
|
280
287
|
self.gcp_conn_id = gcp_conn_id
|
|
@@ -283,6 +290,15 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
283
290
|
self.get_logs = get_logs
|
|
284
291
|
self.do_xcom_push = do_xcom_push
|
|
285
292
|
|
|
293
|
+
@property
|
|
294
|
+
def pod_name(self):
|
|
295
|
+
warnings.warn(
|
|
296
|
+
"`pod_name` parameter is deprecated, please use `pod_names`",
|
|
297
|
+
AirflowProviderDeprecationWarning,
|
|
298
|
+
stacklevel=2,
|
|
299
|
+
)
|
|
300
|
+
return self._pod_name
|
|
301
|
+
|
|
286
302
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
|
287
303
|
"""Serialize KubernetesCreateJobTrigger arguments and classpath."""
|
|
288
304
|
return (
|
|
@@ -292,7 +308,7 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
292
308
|
"ssl_ca_cert": self.ssl_ca_cert,
|
|
293
309
|
"job_name": self.job_name,
|
|
294
310
|
"job_namespace": self.job_namespace,
|
|
295
|
-
"
|
|
311
|
+
"pod_names": self.pod_names,
|
|
296
312
|
"pod_namespace": self.pod_namespace,
|
|
297
313
|
"base_container_name": self.base_container_name,
|
|
298
314
|
"gcp_conn_id": self.gcp_conn_id,
|
|
@@ -303,10 +319,8 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
303
319
|
},
|
|
304
320
|
)
|
|
305
321
|
|
|
306
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
322
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
307
323
|
"""Get current job status and yield a TriggerEvent."""
|
|
308
|
-
if self.get_logs or self.do_xcom_push:
|
|
309
|
-
pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace)
|
|
310
324
|
if self.do_xcom_push:
|
|
311
325
|
kubernetes_provider = ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
|
|
312
326
|
kubernetes_provider_name = kubernetes_provider.data["package-name"]
|
|
@@ -318,22 +332,26 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
318
332
|
f"package {kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
|
|
319
333
|
f"support this feature. Please upgrade it to version higher than or equal to {min_version}."
|
|
320
334
|
)
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
namespace=self.pod_namespace
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
335
|
+
xcom_results = []
|
|
336
|
+
for pod_name in self.pod_names:
|
|
337
|
+
pod = await self.hook.get_pod(name=pod_name, namespace=self.pod_namespace)
|
|
338
|
+
await self.hook.wait_until_container_complete(
|
|
339
|
+
name=pod_name,
|
|
340
|
+
namespace=self.pod_namespace,
|
|
341
|
+
container_name=self.base_container_name,
|
|
342
|
+
poll_interval=self.poll_interval,
|
|
343
|
+
)
|
|
344
|
+
self.log.info("Checking if xcom sidecar container is started.")
|
|
345
|
+
await self.hook.wait_until_container_started(
|
|
346
|
+
name=pod_name,
|
|
347
|
+
namespace=self.pod_namespace,
|
|
348
|
+
container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
|
|
349
|
+
poll_interval=self.poll_interval,
|
|
350
|
+
)
|
|
351
|
+
self.log.info("Extracting result from xcom sidecar container.")
|
|
352
|
+
loop = asyncio.get_running_loop()
|
|
353
|
+
xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
|
|
354
|
+
xcom_results.append(xcom_result)
|
|
337
355
|
job: V1Job = await self.hook.wait_until_job_complete(
|
|
338
356
|
name=self.job_name, namespace=self.job_namespace, poll_interval=self.poll_interval
|
|
339
357
|
)
|
|
@@ -345,12 +363,12 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
345
363
|
{
|
|
346
364
|
"name": job.metadata.name,
|
|
347
365
|
"namespace": job.metadata.namespace,
|
|
348
|
-
"
|
|
349
|
-
"pod_namespace":
|
|
366
|
+
"pod_names": [pod_name for pod_name in self.pod_names] if self.get_logs else None,
|
|
367
|
+
"pod_namespace": self.pod_namespace if self.get_logs else None,
|
|
350
368
|
"status": status,
|
|
351
369
|
"message": message,
|
|
352
370
|
"job": job_dict,
|
|
353
|
-
"xcom_result":
|
|
371
|
+
"xcom_result": xcom_results if self.do_xcom_push else None,
|
|
354
372
|
}
|
|
355
373
|
)
|
|
356
374
|
|
|
@@ -90,7 +90,7 @@ class MLEngineStartTrainingJobTrigger(BaseTrigger):
|
|
|
90
90
|
},
|
|
91
91
|
)
|
|
92
92
|
|
|
93
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
93
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
94
94
|
"""Get current job execution status and yields a TriggerEvent."""
|
|
95
95
|
hook = self._get_async_hook()
|
|
96
96
|
try:
|