apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/ads/hooks/ads.py +39 -6
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/bundles/__init__.py +16 -0
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/alloy_db.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +176 -293
- airflow/providers/google/cloud/hooks/cloud_batch.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_build.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_composer.py +288 -15
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_run.py +18 -10
- airflow/providers/google/cloud/hooks/cloud_sql.py +102 -23
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +29 -7
- airflow/providers/google/cloud/hooks/compute.py +1 -1
- airflow/providers/google/cloud/hooks/compute_ssh.py +6 -2
- airflow/providers/google/cloud/hooks/datacatalog.py +10 -1
- airflow/providers/google/cloud/hooks/dataflow.py +72 -95
- airflow/providers/google/cloud/hooks/dataform.py +1 -1
- airflow/providers/google/cloud/hooks/datafusion.py +21 -19
- airflow/providers/google/cloud/hooks/dataplex.py +2 -2
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dataproc.py +73 -72
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +1 -1
- airflow/providers/google/cloud/hooks/dlp.py +1 -1
- airflow/providers/google/cloud/hooks/functions.py +1 -1
- airflow/providers/google/cloud/hooks/gcs.py +112 -15
- airflow/providers/google/cloud/hooks/gdm.py +1 -1
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +3 -3
- airflow/providers/google/cloud/hooks/looker.py +6 -2
- airflow/providers/google/cloud/hooks/managed_kafka.py +1 -1
- airflow/providers/google/cloud/hooks/mlengine.py +4 -3
- airflow/providers/google/cloud/hooks/pubsub.py +3 -0
- airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
- airflow/providers/google/cloud/hooks/spanner.py +74 -9
- airflow/providers/google/cloud/hooks/stackdriver.py +11 -9
- airflow/providers/google/cloud/hooks/tasks.py +1 -1
- airflow/providers/google/cloud/hooks/translate.py +2 -2
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +2 -210
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +3 -3
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +28 -2
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +308 -8
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/vision.py +3 -3
- airflow/providers/google/cloud/hooks/workflows.py +1 -1
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -13
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -96
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -95
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
- airflow/providers/google/cloud/links/managed_kafka.py +0 -70
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +58 -22
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +102 -1
- airflow/providers/google/cloud/openlineage/mixins.py +10 -8
- airflow/providers/google/cloud/openlineage/utils.py +15 -1
- airflow/providers/google/cloud/operators/alloy_db.py +71 -56
- airflow/providers/google/cloud/operators/bigquery.py +73 -636
- airflow/providers/google/cloud/operators/bigquery_dts.py +4 -6
- airflow/providers/google/cloud/operators/bigtable.py +37 -8
- airflow/providers/google/cloud/operators/cloud_base.py +21 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +3 -3
- airflow/providers/google/cloud/operators/cloud_build.py +76 -33
- airflow/providers/google/cloud/operators/cloud_composer.py +129 -41
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
- airflow/providers/google/cloud/operators/cloud_run.py +24 -6
- airflow/providers/google/cloud/operators/cloud_sql.py +8 -17
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +93 -12
- airflow/providers/google/cloud/operators/compute.py +9 -41
- airflow/providers/google/cloud/operators/datacatalog.py +157 -21
- airflow/providers/google/cloud/operators/dataflow.py +40 -16
- airflow/providers/google/cloud/operators/dataform.py +15 -5
- airflow/providers/google/cloud/operators/datafusion.py +42 -21
- airflow/providers/google/cloud/operators/dataplex.py +194 -110
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +80 -36
- airflow/providers/google/cloud/operators/dataproc_metastore.py +97 -89
- airflow/providers/google/cloud/operators/datastore.py +23 -7
- airflow/providers/google/cloud/operators/dlp.py +6 -29
- airflow/providers/google/cloud/operators/functions.py +17 -8
- airflow/providers/google/cloud/operators/gcs.py +12 -9
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +62 -100
- airflow/providers/google/cloud/operators/looker.py +2 -2
- airflow/providers/google/cloud/operators/managed_kafka.py +108 -53
- airflow/providers/google/cloud/operators/natural_language.py +1 -1
- airflow/providers/google/cloud/operators/pubsub.py +68 -15
- airflow/providers/google/cloud/operators/spanner.py +26 -13
- airflow/providers/google/cloud/operators/speech_to_text.py +2 -3
- airflow/providers/google/cloud/operators/stackdriver.py +1 -9
- airflow/providers/google/cloud/operators/tasks.py +1 -12
- airflow/providers/google/cloud/operators/text_to_speech.py +2 -3
- airflow/providers/google/cloud/operators/translate.py +41 -17
- airflow/providers/google/cloud/operators/translate_speech.py +2 -3
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +30 -10
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +55 -27
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -115
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +12 -10
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +31 -8
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
- airflow/providers/google/cloud/operators/vision.py +2 -2
- airflow/providers/google/cloud/operators/workflows.py +18 -15
- airflow/providers/google/cloud/secrets/secret_manager.py +3 -2
- airflow/providers/google/cloud/sensors/bigquery.py +3 -3
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -3
- airflow/providers/google/cloud/sensors/bigtable.py +11 -4
- airflow/providers/google/cloud/sensors/cloud_composer.py +533 -30
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -3
- airflow/providers/google/cloud/sensors/dataflow.py +26 -10
- airflow/providers/google/cloud/sensors/dataform.py +2 -3
- airflow/providers/google/cloud/sensors/datafusion.py +4 -5
- airflow/providers/google/cloud/sensors/dataplex.py +2 -3
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +2 -3
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -3
- airflow/providers/google/cloud/sensors/gcs.py +4 -5
- airflow/providers/google/cloud/sensors/looker.py +2 -3
- airflow/providers/google/cloud/sensors/pubsub.py +4 -5
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -3
- airflow/providers/google/cloud/sensors/workflows.py +2 -3
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +4 -3
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +10 -5
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -4
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +21 -13
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +4 -3
- airflow/providers/google/cloud/transfers/gcs_to_local.py +6 -4
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +11 -5
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +13 -7
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +14 -5
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +76 -35
- airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
- airflow/providers/google/cloud/triggers/cloud_run.py +3 -3
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +92 -2
- airflow/providers/google/cloud/triggers/dataflow.py +122 -0
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +14 -2
- airflow/providers/google/cloud/triggers/dataproc.py +123 -53
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +47 -28
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +15 -19
- airflow/providers/google/cloud/triggers/vertex_ai.py +1 -1
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +2 -2
- airflow/providers/google/cloud/utils/field_sanitizer.py +1 -1
- airflow/providers/google/cloud/utils/field_validator.py +2 -3
- airflow/providers/google/common/auth_backend/google_openid.py +4 -4
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +27 -9
- airflow/providers/google/common/hooks/operation_helpers.py +1 -1
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +3 -4
- airflow/providers/google/firebase/hooks/firestore.py +1 -1
- airflow/providers/google/firebase/operators/firestore.py +3 -3
- airflow/providers/google/get_provider_info.py +56 -52
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +27 -2
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/campaign_manager.py +1 -1
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +2 -3
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +6 -6
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +3 -64
- airflow/providers/google/suite/hooks/calendar.py +2 -2
- airflow/providers/google/suite/hooks/sheets.py +16 -2
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +3 -3
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/METADATA +90 -46
- apache_airflow_providers_google-19.3.0.dist-info/RECORD +331 -0
- apache_airflow_providers_google-19.3.0.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/hooks/automl.py +0 -673
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1362
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -112
- apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.3.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -788,3 +788,125 @@ class DataflowJobMessagesTrigger(BaseTrigger):
|
|
|
788
788
|
poll_sleep=self.poll_sleep,
|
|
789
789
|
impersonation_chain=self.impersonation_chain,
|
|
790
790
|
)
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
class DataflowJobStateCompleteTrigger(BaseTrigger):
|
|
794
|
+
"""
|
|
795
|
+
Trigger that monitors if a Dataflow job has reached any of successful terminal state meant for that job.
|
|
796
|
+
|
|
797
|
+
:param job_id: Required. ID of the job.
|
|
798
|
+
:param project_id: Required. The Google Cloud project ID in which the job was started.
|
|
799
|
+
:param location: Optional. The location where the job is executed. If set to None then
|
|
800
|
+
the value of DEFAULT_DATAFLOW_LOCATION will be used.
|
|
801
|
+
:param wait_until_finished: Optional. Dataflow option to block pipeline until completion.
|
|
802
|
+
:param gcp_conn_id: The connection ID to use for connecting to Google Cloud.
|
|
803
|
+
:param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job.
|
|
804
|
+
:param impersonation_chain: Optional. Service account to impersonate using short-term
|
|
805
|
+
credentials, or chained list of accounts required to get the access_token
|
|
806
|
+
of the last account in the list, which will be impersonated in the request.
|
|
807
|
+
If set as a string, the account must grant the originating account
|
|
808
|
+
the Service Account Token Creator IAM role.
|
|
809
|
+
If set as a sequence, the identities from the list must grant
|
|
810
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
811
|
+
account from the list granting this role to the originating account (templated).
|
|
812
|
+
"""
|
|
813
|
+
|
|
814
|
+
def __init__(
|
|
815
|
+
self,
|
|
816
|
+
job_id: str,
|
|
817
|
+
project_id: str | None,
|
|
818
|
+
location: str = DEFAULT_DATAFLOW_LOCATION,
|
|
819
|
+
wait_until_finished: bool | None = None,
|
|
820
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
821
|
+
poll_sleep: int = 10,
|
|
822
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
823
|
+
):
|
|
824
|
+
super().__init__()
|
|
825
|
+
self.job_id = job_id
|
|
826
|
+
self.project_id = project_id
|
|
827
|
+
self.location = location
|
|
828
|
+
self.wait_until_finished = wait_until_finished
|
|
829
|
+
self.gcp_conn_id = gcp_conn_id
|
|
830
|
+
self.poll_sleep = poll_sleep
|
|
831
|
+
self.impersonation_chain = impersonation_chain
|
|
832
|
+
|
|
833
|
+
def serialize(self) -> tuple[str, dict[str, Any]]:
|
|
834
|
+
"""Serialize class arguments and classpath."""
|
|
835
|
+
return (
|
|
836
|
+
"airflow.providers.google.cloud.triggers.dataflow.DataflowJobStateCompleteTrigger",
|
|
837
|
+
{
|
|
838
|
+
"job_id": self.job_id,
|
|
839
|
+
"project_id": self.project_id,
|
|
840
|
+
"location": self.location,
|
|
841
|
+
"wait_until_finished": self.wait_until_finished,
|
|
842
|
+
"gcp_conn_id": self.gcp_conn_id,
|
|
843
|
+
"poll_sleep": self.poll_sleep,
|
|
844
|
+
"impersonation_chain": self.impersonation_chain,
|
|
845
|
+
},
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
async def run(self):
|
|
849
|
+
"""
|
|
850
|
+
Loop until the job reaches successful final or error state.
|
|
851
|
+
|
|
852
|
+
Yields a TriggerEvent with success status, if the job reaches successful state for own type.
|
|
853
|
+
|
|
854
|
+
Yields a TriggerEvent with error status, if the client returns an unexpected terminal
|
|
855
|
+
job status or any exception is raised while looping.
|
|
856
|
+
|
|
857
|
+
In any other case the Trigger will wait for a specified amount of time
|
|
858
|
+
stored in self.poll_sleep variable.
|
|
859
|
+
"""
|
|
860
|
+
try:
|
|
861
|
+
while True:
|
|
862
|
+
job = await self.async_hook.get_job(
|
|
863
|
+
project_id=self.project_id,
|
|
864
|
+
job_id=self.job_id,
|
|
865
|
+
location=self.location,
|
|
866
|
+
)
|
|
867
|
+
job_state = job.current_state.name
|
|
868
|
+
job_type_name = job.type_.name
|
|
869
|
+
|
|
870
|
+
FAILED_STATES = DataflowJobStatus.FAILED_END_STATES | {DataflowJobStatus.JOB_STATE_DRAINED}
|
|
871
|
+
if job_state in FAILED_STATES:
|
|
872
|
+
yield TriggerEvent(
|
|
873
|
+
{
|
|
874
|
+
"status": "error",
|
|
875
|
+
"message": (
|
|
876
|
+
f"Job with id '{self.job_id}' is in failed terminal state: {job_state}"
|
|
877
|
+
),
|
|
878
|
+
}
|
|
879
|
+
)
|
|
880
|
+
return
|
|
881
|
+
|
|
882
|
+
if self.async_hook.job_reached_terminal_state(
|
|
883
|
+
job={"id": self.job_id, "currentState": job_state, "type": job_type_name},
|
|
884
|
+
wait_until_finished=self.wait_until_finished,
|
|
885
|
+
):
|
|
886
|
+
yield TriggerEvent(
|
|
887
|
+
{
|
|
888
|
+
"status": "success",
|
|
889
|
+
"message": (
|
|
890
|
+
f"Job with id '{self.job_id}' has reached successful final state: {job_state}"
|
|
891
|
+
),
|
|
892
|
+
}
|
|
893
|
+
)
|
|
894
|
+
return
|
|
895
|
+
self.log.info("Sleeping for %s seconds.", self.poll_sleep)
|
|
896
|
+
await asyncio.sleep(self.poll_sleep)
|
|
897
|
+
except Exception as e:
|
|
898
|
+
self.log.error("Exception occurred while checking for job state!")
|
|
899
|
+
yield TriggerEvent(
|
|
900
|
+
{
|
|
901
|
+
"status": "error",
|
|
902
|
+
"message": str(e),
|
|
903
|
+
}
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
@cached_property
|
|
907
|
+
def async_hook(self) -> AsyncDataflowHook:
|
|
908
|
+
return AsyncDataflowHook(
|
|
909
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
910
|
+
poll_sleep=self.poll_sleep,
|
|
911
|
+
impersonation_chain=self.impersonation_chain,
|
|
912
|
+
)
|
|
@@ -86,7 +86,7 @@ class DataFusionStartPipelineTrigger(BaseTrigger):
|
|
|
86
86
|
},
|
|
87
87
|
)
|
|
88
88
|
|
|
89
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
89
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
90
90
|
"""Get current pipeline status and yields a TriggerEvent."""
|
|
91
91
|
hook = self._get_async_hook()
|
|
92
92
|
try:
|
|
@@ -103,7 +103,13 @@ class DataplexDataQualityJobTrigger(BaseTrigger):
|
|
|
103
103
|
self.polling_interval_seconds,
|
|
104
104
|
)
|
|
105
105
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
106
|
-
yield TriggerEvent(
|
|
106
|
+
yield TriggerEvent(
|
|
107
|
+
{
|
|
108
|
+
"job_id": self.job_id,
|
|
109
|
+
"job_state": DataScanJob.State(state).name,
|
|
110
|
+
"job": self._convert_to_dict(job),
|
|
111
|
+
}
|
|
112
|
+
)
|
|
107
113
|
|
|
108
114
|
def _convert_to_dict(self, job: DataScanJob) -> dict:
|
|
109
115
|
"""Return a representation of a DataScanJob instance as a dict."""
|
|
@@ -185,7 +191,13 @@ class DataplexDataProfileJobTrigger(BaseTrigger):
|
|
|
185
191
|
self.polling_interval_seconds,
|
|
186
192
|
)
|
|
187
193
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
188
|
-
yield TriggerEvent(
|
|
194
|
+
yield TriggerEvent(
|
|
195
|
+
{
|
|
196
|
+
"job_id": self.job_id,
|
|
197
|
+
"job_state": DataScanJob.State(state).name,
|
|
198
|
+
"job": self._convert_to_dict(job),
|
|
199
|
+
}
|
|
200
|
+
)
|
|
189
201
|
|
|
190
202
|
def _convert_to_dict(self, job: DataScanJob) -> dict:
|
|
191
203
|
"""Return a representation of a DataScanJob instance as a dict."""
|
|
@@ -25,21 +25,25 @@ import time
|
|
|
25
25
|
from collections.abc import AsyncIterator, Sequence
|
|
26
26
|
from typing import TYPE_CHECKING, Any
|
|
27
27
|
|
|
28
|
+
from asgiref.sync import sync_to_async
|
|
28
29
|
from google.api_core.exceptions import NotFound
|
|
29
|
-
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
|
|
30
|
+
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, Job, JobStatus
|
|
30
31
|
|
|
31
|
-
from airflow.
|
|
32
|
-
from airflow.models.taskinstance import TaskInstance
|
|
32
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
33
33
|
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
|
|
34
34
|
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
|
|
35
35
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
|
36
|
+
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
|
|
36
37
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
|
37
|
-
from airflow.utils.session import provide_session
|
|
38
38
|
from airflow.utils.state import TaskInstanceState
|
|
39
39
|
|
|
40
40
|
if TYPE_CHECKING:
|
|
41
41
|
from sqlalchemy.orm.session import Session
|
|
42
42
|
|
|
43
|
+
if not AIRFLOW_V_3_0_PLUS:
|
|
44
|
+
from airflow.models.taskinstance import TaskInstance
|
|
45
|
+
from airflow.utils.session import provide_session
|
|
46
|
+
|
|
43
47
|
|
|
44
48
|
class DataprocBaseTrigger(BaseTrigger):
|
|
45
49
|
"""Base class for Dataproc triggers."""
|
|
@@ -117,40 +121,67 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
|
|
|
117
121
|
},
|
|
118
122
|
)
|
|
119
123
|
|
|
120
|
-
|
|
121
|
-
def get_task_instance(self, session: Session) -> TaskInstance:
|
|
122
|
-
"""
|
|
123
|
-
Get the task instance for the current task.
|
|
124
|
+
if not AIRFLOW_V_3_0_PLUS:
|
|
124
125
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
126
|
+
@provide_session
|
|
127
|
+
def get_task_instance(self, session: Session) -> TaskInstance:
|
|
128
|
+
"""
|
|
129
|
+
Get the task instance for the current task.
|
|
130
|
+
|
|
131
|
+
:param session: Sqlalchemy session
|
|
132
|
+
"""
|
|
133
|
+
query = session.query(TaskInstance).filter(
|
|
134
|
+
TaskInstance.dag_id == self.task_instance.dag_id,
|
|
135
|
+
TaskInstance.task_id == self.task_instance.task_id,
|
|
136
|
+
TaskInstance.run_id == self.task_instance.run_id,
|
|
137
|
+
TaskInstance.map_index == self.task_instance.map_index,
|
|
138
|
+
)
|
|
139
|
+
task_instance = query.one_or_none()
|
|
140
|
+
if task_instance is None:
|
|
141
|
+
raise AirflowException(
|
|
142
|
+
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
|
|
143
|
+
self.task_instance.dag_id,
|
|
144
|
+
self.task_instance.task_id,
|
|
145
|
+
self.task_instance.run_id,
|
|
146
|
+
self.task_instance.map_index,
|
|
147
|
+
)
|
|
148
|
+
return task_instance
|
|
149
|
+
|
|
150
|
+
async def get_task_state(self):
|
|
151
|
+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
|
|
152
|
+
|
|
153
|
+
task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
|
|
154
|
+
dag_id=self.task_instance.dag_id,
|
|
155
|
+
task_ids=[self.task_instance.task_id],
|
|
156
|
+
run_ids=[self.task_instance.run_id],
|
|
157
|
+
map_index=self.task_instance.map_index,
|
|
132
158
|
)
|
|
133
|
-
|
|
134
|
-
|
|
159
|
+
try:
|
|
160
|
+
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
|
|
161
|
+
except Exception:
|
|
135
162
|
raise AirflowException(
|
|
136
|
-
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
|
|
163
|
+
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
|
|
137
164
|
self.task_instance.dag_id,
|
|
138
165
|
self.task_instance.task_id,
|
|
139
166
|
self.task_instance.run_id,
|
|
140
167
|
self.task_instance.map_index,
|
|
141
168
|
)
|
|
142
|
-
return
|
|
169
|
+
return task_state
|
|
143
170
|
|
|
144
|
-
def safe_to_cancel(self) -> bool:
|
|
171
|
+
async def safe_to_cancel(self) -> bool:
|
|
145
172
|
"""
|
|
146
173
|
Whether it is safe to cancel the external job which is being executed by this trigger.
|
|
147
174
|
|
|
148
175
|
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
|
|
149
176
|
Because in those cases, we should NOT cancel the external job.
|
|
150
177
|
"""
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
178
|
+
if AIRFLOW_V_3_0_PLUS:
|
|
179
|
+
task_state = await self.get_task_state()
|
|
180
|
+
else:
|
|
181
|
+
# Database query is needed to get the latest state of the task instance.
|
|
182
|
+
task_instance = self.get_task_instance() # type: ignore[call-arg]
|
|
183
|
+
task_state = task_instance.state
|
|
184
|
+
return task_state != TaskInstanceState.DEFERRED
|
|
154
185
|
|
|
155
186
|
async def run(self):
|
|
156
187
|
try:
|
|
@@ -163,11 +194,13 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
|
|
|
163
194
|
if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
|
|
164
195
|
break
|
|
165
196
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
166
|
-
yield TriggerEvent(
|
|
197
|
+
yield TriggerEvent(
|
|
198
|
+
{"job_id": self.job_id, "job_state": JobStatus.State(state).name, "job": Job.to_dict(job)}
|
|
199
|
+
)
|
|
167
200
|
except asyncio.CancelledError:
|
|
168
201
|
self.log.info("Task got cancelled.")
|
|
169
202
|
try:
|
|
170
|
-
if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
|
|
203
|
+
if self.job_id and self.cancel_on_kill and await self.safe_to_cancel():
|
|
171
204
|
self.log.info(
|
|
172
205
|
"Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
|
|
173
206
|
" in deferred state."
|
|
@@ -181,7 +214,12 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
|
|
|
181
214
|
job_id=self.job_id, project_id=self.project_id, region=self.region
|
|
182
215
|
)
|
|
183
216
|
self.log.info("Job: %s is cancelled", self.job_id)
|
|
184
|
-
yield TriggerEvent(
|
|
217
|
+
yield TriggerEvent(
|
|
218
|
+
{
|
|
219
|
+
"job_id": self.job_id,
|
|
220
|
+
"job_state": ClusterStatus.State.DELETING.name,
|
|
221
|
+
}
|
|
222
|
+
)
|
|
185
223
|
except Exception as e:
|
|
186
224
|
self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e))
|
|
187
225
|
raise e
|
|
@@ -224,35 +262,62 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
|
|
|
224
262
|
},
|
|
225
263
|
)
|
|
226
264
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
265
|
+
if not AIRFLOW_V_3_0_PLUS:
|
|
266
|
+
|
|
267
|
+
@provide_session
|
|
268
|
+
def get_task_instance(self, session: Session) -> TaskInstance:
|
|
269
|
+
query = session.query(TaskInstance).filter(
|
|
270
|
+
TaskInstance.dag_id == self.task_instance.dag_id,
|
|
271
|
+
TaskInstance.task_id == self.task_instance.task_id,
|
|
272
|
+
TaskInstance.run_id == self.task_instance.run_id,
|
|
273
|
+
TaskInstance.map_index == self.task_instance.map_index,
|
|
274
|
+
)
|
|
275
|
+
task_instance = query.one_or_none()
|
|
276
|
+
if task_instance is None:
|
|
277
|
+
raise AirflowException(
|
|
278
|
+
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
|
|
279
|
+
self.task_instance.dag_id,
|
|
280
|
+
self.task_instance.task_id,
|
|
281
|
+
self.task_instance.run_id,
|
|
282
|
+
self.task_instance.map_index,
|
|
283
|
+
)
|
|
284
|
+
return task_instance
|
|
285
|
+
|
|
286
|
+
async def get_task_state(self):
|
|
287
|
+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
|
|
288
|
+
|
|
289
|
+
task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
|
|
290
|
+
dag_id=self.task_instance.dag_id,
|
|
291
|
+
task_ids=[self.task_instance.task_id],
|
|
292
|
+
run_ids=[self.task_instance.run_id],
|
|
293
|
+
map_index=self.task_instance.map_index,
|
|
234
294
|
)
|
|
235
|
-
|
|
236
|
-
|
|
295
|
+
try:
|
|
296
|
+
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
|
|
297
|
+
except Exception:
|
|
237
298
|
raise AirflowException(
|
|
238
|
-
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found
|
|
299
|
+
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
|
|
239
300
|
self.task_instance.dag_id,
|
|
240
301
|
self.task_instance.task_id,
|
|
241
302
|
self.task_instance.run_id,
|
|
242
303
|
self.task_instance.map_index,
|
|
243
304
|
)
|
|
244
|
-
return
|
|
305
|
+
return task_state
|
|
245
306
|
|
|
246
|
-
def safe_to_cancel(self) -> bool:
|
|
307
|
+
async def safe_to_cancel(self) -> bool:
|
|
247
308
|
"""
|
|
248
309
|
Whether it is safe to cancel the external job which is being executed by this trigger.
|
|
249
310
|
|
|
250
311
|
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
|
|
251
312
|
Because in those cases, we should NOT cancel the external job.
|
|
252
313
|
"""
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
314
|
+
if AIRFLOW_V_3_0_PLUS:
|
|
315
|
+
task_state = await self.get_task_state()
|
|
316
|
+
else:
|
|
317
|
+
# Database query is needed to get the latest state of the task instance.
|
|
318
|
+
task_instance = self.get_task_instance() # type: ignore[call-arg]
|
|
319
|
+
task_state = task_instance.state
|
|
320
|
+
return task_state != TaskInstanceState.DEFERRED
|
|
256
321
|
|
|
257
322
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
258
323
|
try:
|
|
@@ -264,8 +329,8 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
|
|
|
264
329
|
yield TriggerEvent(
|
|
265
330
|
{
|
|
266
331
|
"cluster_name": self.cluster_name,
|
|
267
|
-
"cluster_state": ClusterStatus.State.DELETING,
|
|
268
|
-
"cluster": cluster,
|
|
332
|
+
"cluster_state": ClusterStatus.State.DELETING.name, # type: ignore
|
|
333
|
+
"cluster": Cluster.to_dict(cluster),
|
|
269
334
|
}
|
|
270
335
|
)
|
|
271
336
|
return
|
|
@@ -273,17 +338,18 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
|
|
|
273
338
|
yield TriggerEvent(
|
|
274
339
|
{
|
|
275
340
|
"cluster_name": self.cluster_name,
|
|
276
|
-
"cluster_state": state,
|
|
277
|
-
"cluster": cluster,
|
|
341
|
+
"cluster_state": ClusterStatus.State(state).name,
|
|
342
|
+
"cluster": Cluster.to_dict(cluster),
|
|
278
343
|
}
|
|
279
344
|
)
|
|
280
345
|
return
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
346
|
+
else:
|
|
347
|
+
self.log.info("Current state is %s", state)
|
|
348
|
+
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
|
|
349
|
+
await asyncio.sleep(self.polling_interval_seconds)
|
|
284
350
|
except asyncio.CancelledError:
|
|
285
351
|
try:
|
|
286
|
-
if self.delete_on_error and self.safe_to_cancel():
|
|
352
|
+
if self.delete_on_error and await self.safe_to_cancel():
|
|
287
353
|
self.log.info(
|
|
288
354
|
"Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in "
|
|
289
355
|
"deferred state."
|
|
@@ -369,12 +435,16 @@ class DataprocBatchTrigger(DataprocBaseTrigger):
|
|
|
369
435
|
|
|
370
436
|
if state in (Batch.State.FAILED, Batch.State.SUCCEEDED, Batch.State.CANCELLED):
|
|
371
437
|
break
|
|
372
|
-
self.log.info("Current state is %s", state)
|
|
438
|
+
self.log.info("Current state is %s", Batch.State(state).name)
|
|
373
439
|
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
|
|
374
440
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
375
441
|
|
|
376
442
|
yield TriggerEvent(
|
|
377
|
-
{
|
|
443
|
+
{
|
|
444
|
+
"batch_id": self.batch_id,
|
|
445
|
+
"batch_state": Batch.State(state).name,
|
|
446
|
+
"batch_state_message": batch.state_message,
|
|
447
|
+
}
|
|
378
448
|
)
|
|
379
449
|
|
|
380
450
|
|
|
@@ -432,9 +502,9 @@ class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
|
|
|
432
502
|
try:
|
|
433
503
|
while self.end_time > time.time():
|
|
434
504
|
cluster = await self.get_async_hook().get_cluster(
|
|
435
|
-
region=self.region,
|
|
505
|
+
region=self.region,
|
|
436
506
|
cluster_name=self.cluster_name,
|
|
437
|
-
project_id=self.project_id,
|
|
507
|
+
project_id=self.project_id,
|
|
438
508
|
metadata=self.metadata,
|
|
439
509
|
)
|
|
440
510
|
self.log.info(
|
|
@@ -26,10 +26,11 @@ from typing import TYPE_CHECKING, Any
|
|
|
26
26
|
from google.cloud.container_v1.types import Operation
|
|
27
27
|
from packaging.version import parse as parse_version
|
|
28
28
|
|
|
29
|
-
from airflow.exceptions import
|
|
29
|
+
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
30
30
|
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
|
|
31
31
|
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodManager
|
|
32
32
|
from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
|
|
33
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
33
34
|
from airflow.providers.google.cloud.hooks.kubernetes_engine import (
|
|
34
35
|
GKEAsyncHook,
|
|
35
36
|
GKEKubernetesAsyncHook,
|
|
@@ -153,7 +154,7 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
|
|
|
153
154
|
)
|
|
154
155
|
|
|
155
156
|
@cached_property
|
|
156
|
-
def hook(self) -> GKEKubernetesAsyncHook:
|
|
157
|
+
def hook(self) -> GKEKubernetesAsyncHook:
|
|
157
158
|
return GKEKubernetesAsyncHook(
|
|
158
159
|
cluster_url=self._cluster_url,
|
|
159
160
|
ssl_ca_cert=self._ssl_ca_cert,
|
|
@@ -200,7 +201,7 @@ class GKEOperationTrigger(BaseTrigger):
|
|
|
200
201
|
},
|
|
201
202
|
)
|
|
202
203
|
|
|
203
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
204
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
204
205
|
"""Get operation status and yields corresponding event."""
|
|
205
206
|
hook = self._get_hook()
|
|
206
207
|
try:
|
|
@@ -260,9 +261,10 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
260
261
|
ssl_ca_cert: str,
|
|
261
262
|
job_name: str,
|
|
262
263
|
job_namespace: str,
|
|
263
|
-
|
|
264
|
+
pod_names: list[str],
|
|
264
265
|
pod_namespace: str,
|
|
265
266
|
base_container_name: str,
|
|
267
|
+
pod_name: str | None = None,
|
|
266
268
|
gcp_conn_id: str = "google_cloud_default",
|
|
267
269
|
poll_interval: float = 2,
|
|
268
270
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
@@ -274,7 +276,13 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
274
276
|
self.ssl_ca_cert = ssl_ca_cert
|
|
275
277
|
self.job_name = job_name
|
|
276
278
|
self.job_namespace = job_namespace
|
|
277
|
-
|
|
279
|
+
if pod_name is not None:
|
|
280
|
+
self._pod_name = pod_name
|
|
281
|
+
self.pod_names = [
|
|
282
|
+
self.pod_name,
|
|
283
|
+
]
|
|
284
|
+
else:
|
|
285
|
+
self.pod_names = pod_names
|
|
278
286
|
self.pod_namespace = pod_namespace
|
|
279
287
|
self.base_container_name = base_container_name
|
|
280
288
|
self.gcp_conn_id = gcp_conn_id
|
|
@@ -283,6 +291,15 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
283
291
|
self.get_logs = get_logs
|
|
284
292
|
self.do_xcom_push = do_xcom_push
|
|
285
293
|
|
|
294
|
+
@property
|
|
295
|
+
def pod_name(self):
|
|
296
|
+
warnings.warn(
|
|
297
|
+
"`pod_name` parameter is deprecated, please use `pod_names`",
|
|
298
|
+
AirflowProviderDeprecationWarning,
|
|
299
|
+
stacklevel=2,
|
|
300
|
+
)
|
|
301
|
+
return self._pod_name
|
|
302
|
+
|
|
286
303
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
|
287
304
|
"""Serialize KubernetesCreateJobTrigger arguments and classpath."""
|
|
288
305
|
return (
|
|
@@ -292,7 +309,7 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
292
309
|
"ssl_ca_cert": self.ssl_ca_cert,
|
|
293
310
|
"job_name": self.job_name,
|
|
294
311
|
"job_namespace": self.job_namespace,
|
|
295
|
-
"
|
|
312
|
+
"pod_names": self.pod_names,
|
|
296
313
|
"pod_namespace": self.pod_namespace,
|
|
297
314
|
"base_container_name": self.base_container_name,
|
|
298
315
|
"gcp_conn_id": self.gcp_conn_id,
|
|
@@ -303,10 +320,8 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
303
320
|
},
|
|
304
321
|
)
|
|
305
322
|
|
|
306
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
323
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
307
324
|
"""Get current job status and yield a TriggerEvent."""
|
|
308
|
-
if self.get_logs or self.do_xcom_push:
|
|
309
|
-
pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace)
|
|
310
325
|
if self.do_xcom_push:
|
|
311
326
|
kubernetes_provider = ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
|
|
312
327
|
kubernetes_provider_name = kubernetes_provider.data["package-name"]
|
|
@@ -318,22 +333,26 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
318
333
|
f"package {kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
|
|
319
334
|
f"support this feature. Please upgrade it to version higher than or equal to {min_version}."
|
|
320
335
|
)
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
namespace=self.pod_namespace
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
336
|
+
xcom_results = []
|
|
337
|
+
for pod_name in self.pod_names:
|
|
338
|
+
pod = await self.hook.get_pod(name=pod_name, namespace=self.pod_namespace)
|
|
339
|
+
await self.hook.wait_until_container_complete(
|
|
340
|
+
name=pod_name,
|
|
341
|
+
namespace=self.pod_namespace,
|
|
342
|
+
container_name=self.base_container_name,
|
|
343
|
+
poll_interval=self.poll_interval,
|
|
344
|
+
)
|
|
345
|
+
self.log.info("Checking if xcom sidecar container is started.")
|
|
346
|
+
await self.hook.wait_until_container_started(
|
|
347
|
+
name=pod_name,
|
|
348
|
+
namespace=self.pod_namespace,
|
|
349
|
+
container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
|
|
350
|
+
poll_interval=self.poll_interval,
|
|
351
|
+
)
|
|
352
|
+
self.log.info("Extracting result from xcom sidecar container.")
|
|
353
|
+
loop = asyncio.get_running_loop()
|
|
354
|
+
xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
|
|
355
|
+
xcom_results.append(xcom_result)
|
|
337
356
|
job: V1Job = await self.hook.wait_until_job_complete(
|
|
338
357
|
name=self.job_name, namespace=self.job_namespace, poll_interval=self.poll_interval
|
|
339
358
|
)
|
|
@@ -345,12 +364,12 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
345
364
|
{
|
|
346
365
|
"name": job.metadata.name,
|
|
347
366
|
"namespace": job.metadata.namespace,
|
|
348
|
-
"
|
|
349
|
-
"pod_namespace":
|
|
367
|
+
"pod_names": [pod_name for pod_name in self.pod_names] if self.get_logs else None,
|
|
368
|
+
"pod_namespace": self.pod_namespace if self.get_logs else None,
|
|
350
369
|
"status": status,
|
|
351
370
|
"message": message,
|
|
352
371
|
"job": job_dict,
|
|
353
|
-
"xcom_result":
|
|
372
|
+
"xcom_result": xcom_results if self.do_xcom_push else None,
|
|
354
373
|
}
|
|
355
374
|
)
|
|
356
375
|
|
|
@@ -90,7 +90,7 @@ class MLEngineStartTrainingJobTrigger(BaseTrigger):
|
|
|
90
90
|
},
|
|
91
91
|
)
|
|
92
92
|
|
|
93
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
93
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
94
94
|
"""Get current job execution status and yields a TriggerEvent."""
|
|
95
95
|
hook = self._get_async_hook()
|
|
96
96
|
try:
|
|
@@ -85,27 +85,23 @@ class PubsubPullTrigger(BaseTrigger):
|
|
|
85
85
|
},
|
|
86
86
|
)
|
|
87
87
|
|
|
88
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
await self.message_acknowledgement(pulled_messages)
|
|
88
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
89
|
+
while True:
|
|
90
|
+
if pulled_messages := await self.hook.pull(
|
|
91
|
+
project_id=self.project_id,
|
|
92
|
+
subscription=self.subscription,
|
|
93
|
+
max_messages=self.max_messages,
|
|
94
|
+
return_immediately=True,
|
|
95
|
+
):
|
|
96
|
+
if self.ack_messages:
|
|
97
|
+
await self.message_acknowledgement(pulled_messages)
|
|
99
98
|
|
|
100
|
-
|
|
99
|
+
messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]
|
|
101
100
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
except Exception as e:
|
|
107
|
-
yield TriggerEvent({"status": "error", "message": str(e)})
|
|
108
|
-
return
|
|
101
|
+
yield TriggerEvent({"status": "success", "message": messages_json})
|
|
102
|
+
return
|
|
103
|
+
self.log.info("Sleeping for %s seconds.", self.poke_interval)
|
|
104
|
+
await asyncio.sleep(self.poke_interval)
|
|
109
105
|
|
|
110
106
|
async def message_acknowledgement(self, pulled_messages):
|
|
111
107
|
await self.hook.acknowledge(
|