apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/ads/hooks/ads.py +39 -5
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/bundles/__init__.py +16 -0
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/bigquery.py +166 -281
- airflow/providers/google/cloud/hooks/cloud_composer.py +287 -14
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_run.py +17 -9
- airflow/providers/google/cloud/hooks/cloud_sql.py +101 -22
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +27 -6
- airflow/providers/google/cloud/hooks/compute_ssh.py +5 -1
- airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
- airflow/providers/google/cloud/hooks/dataflow.py +71 -94
- airflow/providers/google/cloud/hooks/datafusion.py +1 -1
- airflow/providers/google/cloud/hooks/dataplex.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dataproc.py +72 -71
- airflow/providers/google/cloud/hooks/gcs.py +111 -14
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/looker.py +6 -1
- airflow/providers/google/cloud/hooks/mlengine.py +3 -2
- airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
- airflow/providers/google/cloud/hooks/spanner.py +73 -8
- airflow/providers/google/cloud/hooks/stackdriver.py +10 -8
- airflow/providers/google/cloud/hooks/translate.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -209
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +2 -2
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +27 -1
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/vision.py +2 -2
- airflow/providers/google/cloud/hooks/workflows.py +1 -1
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -13
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -96
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -95
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
- airflow/providers/google/cloud/links/managed_kafka.py +0 -70
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +17 -9
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +102 -1
- airflow/providers/google/cloud/openlineage/mixins.py +10 -8
- airflow/providers/google/cloud/openlineage/utils.py +15 -1
- airflow/providers/google/cloud/operators/alloy_db.py +70 -55
- airflow/providers/google/cloud/operators/bigquery.py +73 -636
- airflow/providers/google/cloud/operators/bigquery_dts.py +3 -5
- airflow/providers/google/cloud/operators/bigtable.py +36 -7
- airflow/providers/google/cloud/operators/cloud_base.py +21 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +2 -2
- airflow/providers/google/cloud/operators/cloud_build.py +75 -32
- airflow/providers/google/cloud/operators/cloud_composer.py +128 -40
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
- airflow/providers/google/cloud/operators/cloud_run.py +23 -5
- airflow/providers/google/cloud/operators/cloud_sql.py +8 -16
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -11
- airflow/providers/google/cloud/operators/compute.py +8 -40
- airflow/providers/google/cloud/operators/datacatalog.py +157 -21
- airflow/providers/google/cloud/operators/dataflow.py +38 -15
- airflow/providers/google/cloud/operators/dataform.py +15 -5
- airflow/providers/google/cloud/operators/datafusion.py +41 -20
- airflow/providers/google/cloud/operators/dataplex.py +193 -109
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +78 -35
- airflow/providers/google/cloud/operators/dataproc_metastore.py +96 -88
- airflow/providers/google/cloud/operators/datastore.py +22 -6
- airflow/providers/google/cloud/operators/dlp.py +6 -29
- airflow/providers/google/cloud/operators/functions.py +16 -7
- airflow/providers/google/cloud/operators/gcs.py +10 -8
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +60 -99
- airflow/providers/google/cloud/operators/looker.py +1 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +107 -52
- airflow/providers/google/cloud/operators/natural_language.py +1 -1
- airflow/providers/google/cloud/operators/pubsub.py +60 -14
- airflow/providers/google/cloud/operators/spanner.py +25 -12
- airflow/providers/google/cloud/operators/speech_to_text.py +1 -2
- airflow/providers/google/cloud/operators/stackdriver.py +1 -9
- airflow/providers/google/cloud/operators/tasks.py +1 -12
- airflow/providers/google/cloud/operators/text_to_speech.py +1 -2
- airflow/providers/google/cloud/operators/translate.py +40 -16
- airflow/providers/google/cloud/operators/translate_speech.py +1 -2
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +29 -9
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +54 -26
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +11 -9
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +30 -7
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
- airflow/providers/google/cloud/operators/vision.py +2 -2
- airflow/providers/google/cloud/operators/workflows.py +18 -15
- airflow/providers/google/cloud/sensors/bigquery.py +2 -2
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -2
- airflow/providers/google/cloud/sensors/bigtable.py +11 -4
- airflow/providers/google/cloud/sensors/cloud_composer.py +533 -29
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -2
- airflow/providers/google/cloud/sensors/dataflow.py +26 -9
- airflow/providers/google/cloud/sensors/dataform.py +2 -2
- airflow/providers/google/cloud/sensors/datafusion.py +4 -4
- airflow/providers/google/cloud/sensors/dataplex.py +2 -2
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -2
- airflow/providers/google/cloud/sensors/gcs.py +4 -4
- airflow/providers/google/cloud/sensors/looker.py +2 -2
- airflow/providers/google/cloud/sensors/pubsub.py +4 -4
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
- airflow/providers/google/cloud/sensors/workflows.py +2 -2
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +4 -4
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +20 -12
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +13 -4
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +75 -34
- airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_composer.py +302 -46
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +91 -1
- airflow/providers/google/cloud/triggers/dataflow.py +122 -0
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +14 -2
- airflow/providers/google/cloud/triggers/dataproc.py +122 -52
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +45 -27
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +15 -19
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +1 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/common/auth_backend/google_openid.py +4 -4
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +27 -8
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +3 -4
- airflow/providers/google/firebase/operators/firestore.py +2 -2
- airflow/providers/google/get_provider_info.py +56 -52
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +26 -1
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +1 -2
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +3 -63
- airflow/providers/google/suite/hooks/calendar.py +1 -1
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +92 -48
- apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
- apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/hooks/automl.py +0 -673
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1362
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -112
- apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -788,3 +788,125 @@ class DataflowJobMessagesTrigger(BaseTrigger):
|
|
|
788
788
|
poll_sleep=self.poll_sleep,
|
|
789
789
|
impersonation_chain=self.impersonation_chain,
|
|
790
790
|
)
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
class DataflowJobStateCompleteTrigger(BaseTrigger):
|
|
794
|
+
"""
|
|
795
|
+
Trigger that monitors if a Dataflow job has reached any of successful terminal state meant for that job.
|
|
796
|
+
|
|
797
|
+
:param job_id: Required. ID of the job.
|
|
798
|
+
:param project_id: Required. The Google Cloud project ID in which the job was started.
|
|
799
|
+
:param location: Optional. The location where the job is executed. If set to None then
|
|
800
|
+
the value of DEFAULT_DATAFLOW_LOCATION will be used.
|
|
801
|
+
:param wait_until_finished: Optional. Dataflow option to block pipeline until completion.
|
|
802
|
+
:param gcp_conn_id: The connection ID to use for connecting to Google Cloud.
|
|
803
|
+
:param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job.
|
|
804
|
+
:param impersonation_chain: Optional. Service account to impersonate using short-term
|
|
805
|
+
credentials, or chained list of accounts required to get the access_token
|
|
806
|
+
of the last account in the list, which will be impersonated in the request.
|
|
807
|
+
If set as a string, the account must grant the originating account
|
|
808
|
+
the Service Account Token Creator IAM role.
|
|
809
|
+
If set as a sequence, the identities from the list must grant
|
|
810
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
811
|
+
account from the list granting this role to the originating account (templated).
|
|
812
|
+
"""
|
|
813
|
+
|
|
814
|
+
def __init__(
|
|
815
|
+
self,
|
|
816
|
+
job_id: str,
|
|
817
|
+
project_id: str | None,
|
|
818
|
+
location: str = DEFAULT_DATAFLOW_LOCATION,
|
|
819
|
+
wait_until_finished: bool | None = None,
|
|
820
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
821
|
+
poll_sleep: int = 10,
|
|
822
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
823
|
+
):
|
|
824
|
+
super().__init__()
|
|
825
|
+
self.job_id = job_id
|
|
826
|
+
self.project_id = project_id
|
|
827
|
+
self.location = location
|
|
828
|
+
self.wait_until_finished = wait_until_finished
|
|
829
|
+
self.gcp_conn_id = gcp_conn_id
|
|
830
|
+
self.poll_sleep = poll_sleep
|
|
831
|
+
self.impersonation_chain = impersonation_chain
|
|
832
|
+
|
|
833
|
+
def serialize(self) -> tuple[str, dict[str, Any]]:
|
|
834
|
+
"""Serialize class arguments and classpath."""
|
|
835
|
+
return (
|
|
836
|
+
"airflow.providers.google.cloud.triggers.dataflow.DataflowJobStateCompleteTrigger",
|
|
837
|
+
{
|
|
838
|
+
"job_id": self.job_id,
|
|
839
|
+
"project_id": self.project_id,
|
|
840
|
+
"location": self.location,
|
|
841
|
+
"wait_until_finished": self.wait_until_finished,
|
|
842
|
+
"gcp_conn_id": self.gcp_conn_id,
|
|
843
|
+
"poll_sleep": self.poll_sleep,
|
|
844
|
+
"impersonation_chain": self.impersonation_chain,
|
|
845
|
+
},
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
async def run(self):
|
|
849
|
+
"""
|
|
850
|
+
Loop until the job reaches successful final or error state.
|
|
851
|
+
|
|
852
|
+
Yields a TriggerEvent with success status, if the job reaches successful state for own type.
|
|
853
|
+
|
|
854
|
+
Yields a TriggerEvent with error status, if the client returns an unexpected terminal
|
|
855
|
+
job status or any exception is raised while looping.
|
|
856
|
+
|
|
857
|
+
In any other case the Trigger will wait for a specified amount of time
|
|
858
|
+
stored in self.poll_sleep variable.
|
|
859
|
+
"""
|
|
860
|
+
try:
|
|
861
|
+
while True:
|
|
862
|
+
job = await self.async_hook.get_job(
|
|
863
|
+
project_id=self.project_id,
|
|
864
|
+
job_id=self.job_id,
|
|
865
|
+
location=self.location,
|
|
866
|
+
)
|
|
867
|
+
job_state = job.current_state.name
|
|
868
|
+
job_type_name = job.type_.name
|
|
869
|
+
|
|
870
|
+
FAILED_STATES = DataflowJobStatus.FAILED_END_STATES | {DataflowJobStatus.JOB_STATE_DRAINED}
|
|
871
|
+
if job_state in FAILED_STATES:
|
|
872
|
+
yield TriggerEvent(
|
|
873
|
+
{
|
|
874
|
+
"status": "error",
|
|
875
|
+
"message": (
|
|
876
|
+
f"Job with id '{self.job_id}' is in failed terminal state: {job_state}"
|
|
877
|
+
),
|
|
878
|
+
}
|
|
879
|
+
)
|
|
880
|
+
return
|
|
881
|
+
|
|
882
|
+
if self.async_hook.job_reached_terminal_state(
|
|
883
|
+
job={"id": self.job_id, "currentState": job_state, "type": job_type_name},
|
|
884
|
+
wait_until_finished=self.wait_until_finished,
|
|
885
|
+
):
|
|
886
|
+
yield TriggerEvent(
|
|
887
|
+
{
|
|
888
|
+
"status": "success",
|
|
889
|
+
"message": (
|
|
890
|
+
f"Job with id '{self.job_id}' has reached successful final state: {job_state}"
|
|
891
|
+
),
|
|
892
|
+
}
|
|
893
|
+
)
|
|
894
|
+
return
|
|
895
|
+
self.log.info("Sleeping for %s seconds.", self.poll_sleep)
|
|
896
|
+
await asyncio.sleep(self.poll_sleep)
|
|
897
|
+
except Exception as e:
|
|
898
|
+
self.log.error("Exception occurred while checking for job state!")
|
|
899
|
+
yield TriggerEvent(
|
|
900
|
+
{
|
|
901
|
+
"status": "error",
|
|
902
|
+
"message": str(e),
|
|
903
|
+
}
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
@cached_property
|
|
907
|
+
def async_hook(self) -> AsyncDataflowHook:
|
|
908
|
+
return AsyncDataflowHook(
|
|
909
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
910
|
+
poll_sleep=self.poll_sleep,
|
|
911
|
+
impersonation_chain=self.impersonation_chain,
|
|
912
|
+
)
|
|
@@ -86,7 +86,7 @@ class DataFusionStartPipelineTrigger(BaseTrigger):
|
|
|
86
86
|
},
|
|
87
87
|
)
|
|
88
88
|
|
|
89
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
89
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
90
90
|
"""Get current pipeline status and yields a TriggerEvent."""
|
|
91
91
|
hook = self._get_async_hook()
|
|
92
92
|
try:
|
|
@@ -103,7 +103,13 @@ class DataplexDataQualityJobTrigger(BaseTrigger):
|
|
|
103
103
|
self.polling_interval_seconds,
|
|
104
104
|
)
|
|
105
105
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
106
|
-
yield TriggerEvent(
|
|
106
|
+
yield TriggerEvent(
|
|
107
|
+
{
|
|
108
|
+
"job_id": self.job_id,
|
|
109
|
+
"job_state": DataScanJob.State(state).name,
|
|
110
|
+
"job": self._convert_to_dict(job),
|
|
111
|
+
}
|
|
112
|
+
)
|
|
107
113
|
|
|
108
114
|
def _convert_to_dict(self, job: DataScanJob) -> dict:
|
|
109
115
|
"""Return a representation of a DataScanJob instance as a dict."""
|
|
@@ -185,7 +191,13 @@ class DataplexDataProfileJobTrigger(BaseTrigger):
|
|
|
185
191
|
self.polling_interval_seconds,
|
|
186
192
|
)
|
|
187
193
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
188
|
-
yield TriggerEvent(
|
|
194
|
+
yield TriggerEvent(
|
|
195
|
+
{
|
|
196
|
+
"job_id": self.job_id,
|
|
197
|
+
"job_state": DataScanJob.State(state).name,
|
|
198
|
+
"job": self._convert_to_dict(job),
|
|
199
|
+
}
|
|
200
|
+
)
|
|
189
201
|
|
|
190
202
|
def _convert_to_dict(self, job: DataScanJob) -> dict:
|
|
191
203
|
"""Return a representation of a DataScanJob instance as a dict."""
|
|
@@ -25,21 +25,25 @@ import time
|
|
|
25
25
|
from collections.abc import AsyncIterator, Sequence
|
|
26
26
|
from typing import TYPE_CHECKING, Any
|
|
27
27
|
|
|
28
|
+
from asgiref.sync import sync_to_async
|
|
28
29
|
from google.api_core.exceptions import NotFound
|
|
29
|
-
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
|
|
30
|
+
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, Job, JobStatus
|
|
30
31
|
|
|
31
32
|
from airflow.exceptions import AirflowException
|
|
32
|
-
from airflow.models.taskinstance import TaskInstance
|
|
33
33
|
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
|
|
34
34
|
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
|
|
35
35
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
|
36
|
+
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
|
|
36
37
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
|
37
|
-
from airflow.utils.session import provide_session
|
|
38
38
|
from airflow.utils.state import TaskInstanceState
|
|
39
39
|
|
|
40
40
|
if TYPE_CHECKING:
|
|
41
41
|
from sqlalchemy.orm.session import Session
|
|
42
42
|
|
|
43
|
+
if not AIRFLOW_V_3_0_PLUS:
|
|
44
|
+
from airflow.models.taskinstance import TaskInstance
|
|
45
|
+
from airflow.utils.session import provide_session
|
|
46
|
+
|
|
43
47
|
|
|
44
48
|
class DataprocBaseTrigger(BaseTrigger):
|
|
45
49
|
"""Base class for Dataproc triggers."""
|
|
@@ -117,40 +121,67 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
|
|
|
117
121
|
},
|
|
118
122
|
)
|
|
119
123
|
|
|
120
|
-
|
|
121
|
-
def get_task_instance(self, session: Session) -> TaskInstance:
|
|
122
|
-
"""
|
|
123
|
-
Get the task instance for the current task.
|
|
124
|
+
if not AIRFLOW_V_3_0_PLUS:
|
|
124
125
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
126
|
+
@provide_session
|
|
127
|
+
def get_task_instance(self, session: Session) -> TaskInstance:
|
|
128
|
+
"""
|
|
129
|
+
Get the task instance for the current task.
|
|
130
|
+
|
|
131
|
+
:param session: Sqlalchemy session
|
|
132
|
+
"""
|
|
133
|
+
query = session.query(TaskInstance).filter(
|
|
134
|
+
TaskInstance.dag_id == self.task_instance.dag_id,
|
|
135
|
+
TaskInstance.task_id == self.task_instance.task_id,
|
|
136
|
+
TaskInstance.run_id == self.task_instance.run_id,
|
|
137
|
+
TaskInstance.map_index == self.task_instance.map_index,
|
|
138
|
+
)
|
|
139
|
+
task_instance = query.one_or_none()
|
|
140
|
+
if task_instance is None:
|
|
141
|
+
raise AirflowException(
|
|
142
|
+
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
|
|
143
|
+
self.task_instance.dag_id,
|
|
144
|
+
self.task_instance.task_id,
|
|
145
|
+
self.task_instance.run_id,
|
|
146
|
+
self.task_instance.map_index,
|
|
147
|
+
)
|
|
148
|
+
return task_instance
|
|
149
|
+
|
|
150
|
+
async def get_task_state(self):
|
|
151
|
+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
|
|
152
|
+
|
|
153
|
+
task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
|
|
154
|
+
dag_id=self.task_instance.dag_id,
|
|
155
|
+
task_ids=[self.task_instance.task_id],
|
|
156
|
+
run_ids=[self.task_instance.run_id],
|
|
157
|
+
map_index=self.task_instance.map_index,
|
|
132
158
|
)
|
|
133
|
-
|
|
134
|
-
|
|
159
|
+
try:
|
|
160
|
+
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
|
|
161
|
+
except Exception:
|
|
135
162
|
raise AirflowException(
|
|
136
|
-
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
|
|
163
|
+
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
|
|
137
164
|
self.task_instance.dag_id,
|
|
138
165
|
self.task_instance.task_id,
|
|
139
166
|
self.task_instance.run_id,
|
|
140
167
|
self.task_instance.map_index,
|
|
141
168
|
)
|
|
142
|
-
return
|
|
169
|
+
return task_state
|
|
143
170
|
|
|
144
|
-
def safe_to_cancel(self) -> bool:
|
|
171
|
+
async def safe_to_cancel(self) -> bool:
|
|
145
172
|
"""
|
|
146
173
|
Whether it is safe to cancel the external job which is being executed by this trigger.
|
|
147
174
|
|
|
148
175
|
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
|
|
149
176
|
Because in those cases, we should NOT cancel the external job.
|
|
150
177
|
"""
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
178
|
+
if AIRFLOW_V_3_0_PLUS:
|
|
179
|
+
task_state = await self.get_task_state()
|
|
180
|
+
else:
|
|
181
|
+
# Database query is needed to get the latest state of the task instance.
|
|
182
|
+
task_instance = self.get_task_instance() # type: ignore[call-arg]
|
|
183
|
+
task_state = task_instance.state
|
|
184
|
+
return task_state != TaskInstanceState.DEFERRED
|
|
154
185
|
|
|
155
186
|
async def run(self):
|
|
156
187
|
try:
|
|
@@ -163,11 +194,13 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
|
|
|
163
194
|
if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
|
|
164
195
|
break
|
|
165
196
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
166
|
-
yield TriggerEvent(
|
|
197
|
+
yield TriggerEvent(
|
|
198
|
+
{"job_id": self.job_id, "job_state": JobStatus.State(state).name, "job": Job.to_dict(job)}
|
|
199
|
+
)
|
|
167
200
|
except asyncio.CancelledError:
|
|
168
201
|
self.log.info("Task got cancelled.")
|
|
169
202
|
try:
|
|
170
|
-
if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
|
|
203
|
+
if self.job_id and self.cancel_on_kill and await self.safe_to_cancel():
|
|
171
204
|
self.log.info(
|
|
172
205
|
"Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
|
|
173
206
|
" in deferred state."
|
|
@@ -181,7 +214,12 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
|
|
|
181
214
|
job_id=self.job_id, project_id=self.project_id, region=self.region
|
|
182
215
|
)
|
|
183
216
|
self.log.info("Job: %s is cancelled", self.job_id)
|
|
184
|
-
yield TriggerEvent(
|
|
217
|
+
yield TriggerEvent(
|
|
218
|
+
{
|
|
219
|
+
"job_id": self.job_id,
|
|
220
|
+
"job_state": ClusterStatus.State.DELETING.name,
|
|
221
|
+
}
|
|
222
|
+
)
|
|
185
223
|
except Exception as e:
|
|
186
224
|
self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e))
|
|
187
225
|
raise e
|
|
@@ -224,35 +262,62 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
|
|
|
224
262
|
},
|
|
225
263
|
)
|
|
226
264
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
265
|
+
if not AIRFLOW_V_3_0_PLUS:
|
|
266
|
+
|
|
267
|
+
@provide_session
|
|
268
|
+
def get_task_instance(self, session: Session) -> TaskInstance:
|
|
269
|
+
query = session.query(TaskInstance).filter(
|
|
270
|
+
TaskInstance.dag_id == self.task_instance.dag_id,
|
|
271
|
+
TaskInstance.task_id == self.task_instance.task_id,
|
|
272
|
+
TaskInstance.run_id == self.task_instance.run_id,
|
|
273
|
+
TaskInstance.map_index == self.task_instance.map_index,
|
|
274
|
+
)
|
|
275
|
+
task_instance = query.one_or_none()
|
|
276
|
+
if task_instance is None:
|
|
277
|
+
raise AirflowException(
|
|
278
|
+
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
|
|
279
|
+
self.task_instance.dag_id,
|
|
280
|
+
self.task_instance.task_id,
|
|
281
|
+
self.task_instance.run_id,
|
|
282
|
+
self.task_instance.map_index,
|
|
283
|
+
)
|
|
284
|
+
return task_instance
|
|
285
|
+
|
|
286
|
+
async def get_task_state(self):
|
|
287
|
+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
|
|
288
|
+
|
|
289
|
+
task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
|
|
290
|
+
dag_id=self.task_instance.dag_id,
|
|
291
|
+
task_ids=[self.task_instance.task_id],
|
|
292
|
+
run_ids=[self.task_instance.run_id],
|
|
293
|
+
map_index=self.task_instance.map_index,
|
|
234
294
|
)
|
|
235
|
-
|
|
236
|
-
|
|
295
|
+
try:
|
|
296
|
+
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
|
|
297
|
+
except Exception:
|
|
237
298
|
raise AirflowException(
|
|
238
|
-
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found
|
|
299
|
+
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
|
|
239
300
|
self.task_instance.dag_id,
|
|
240
301
|
self.task_instance.task_id,
|
|
241
302
|
self.task_instance.run_id,
|
|
242
303
|
self.task_instance.map_index,
|
|
243
304
|
)
|
|
244
|
-
return
|
|
305
|
+
return task_state
|
|
245
306
|
|
|
246
|
-
def safe_to_cancel(self) -> bool:
|
|
307
|
+
async def safe_to_cancel(self) -> bool:
|
|
247
308
|
"""
|
|
248
309
|
Whether it is safe to cancel the external job which is being executed by this trigger.
|
|
249
310
|
|
|
250
311
|
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
|
|
251
312
|
Because in those cases, we should NOT cancel the external job.
|
|
252
313
|
"""
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
314
|
+
if AIRFLOW_V_3_0_PLUS:
|
|
315
|
+
task_state = await self.get_task_state()
|
|
316
|
+
else:
|
|
317
|
+
# Database query is needed to get the latest state of the task instance.
|
|
318
|
+
task_instance = self.get_task_instance() # type: ignore[call-arg]
|
|
319
|
+
task_state = task_instance.state
|
|
320
|
+
return task_state != TaskInstanceState.DEFERRED
|
|
256
321
|
|
|
257
322
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
258
323
|
try:
|
|
@@ -264,8 +329,8 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
|
|
|
264
329
|
yield TriggerEvent(
|
|
265
330
|
{
|
|
266
331
|
"cluster_name": self.cluster_name,
|
|
267
|
-
"cluster_state": ClusterStatus.State.DELETING,
|
|
268
|
-
"cluster": cluster,
|
|
332
|
+
"cluster_state": ClusterStatus.State.DELETING.name, # type: ignore
|
|
333
|
+
"cluster": Cluster.to_dict(cluster),
|
|
269
334
|
}
|
|
270
335
|
)
|
|
271
336
|
return
|
|
@@ -273,17 +338,18 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
|
|
|
273
338
|
yield TriggerEvent(
|
|
274
339
|
{
|
|
275
340
|
"cluster_name": self.cluster_name,
|
|
276
|
-
"cluster_state": state,
|
|
277
|
-
"cluster": cluster,
|
|
341
|
+
"cluster_state": ClusterStatus.State(state).name,
|
|
342
|
+
"cluster": Cluster.to_dict(cluster),
|
|
278
343
|
}
|
|
279
344
|
)
|
|
280
345
|
return
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
346
|
+
else:
|
|
347
|
+
self.log.info("Current state is %s", state)
|
|
348
|
+
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
|
|
349
|
+
await asyncio.sleep(self.polling_interval_seconds)
|
|
284
350
|
except asyncio.CancelledError:
|
|
285
351
|
try:
|
|
286
|
-
if self.delete_on_error and self.safe_to_cancel():
|
|
352
|
+
if self.delete_on_error and await self.safe_to_cancel():
|
|
287
353
|
self.log.info(
|
|
288
354
|
"Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in "
|
|
289
355
|
"deferred state."
|
|
@@ -369,12 +435,16 @@ class DataprocBatchTrigger(DataprocBaseTrigger):
|
|
|
369
435
|
|
|
370
436
|
if state in (Batch.State.FAILED, Batch.State.SUCCEEDED, Batch.State.CANCELLED):
|
|
371
437
|
break
|
|
372
|
-
self.log.info("Current state is %s", state)
|
|
438
|
+
self.log.info("Current state is %s", Batch.State(state).name)
|
|
373
439
|
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
|
|
374
440
|
await asyncio.sleep(self.polling_interval_seconds)
|
|
375
441
|
|
|
376
442
|
yield TriggerEvent(
|
|
377
|
-
{
|
|
443
|
+
{
|
|
444
|
+
"batch_id": self.batch_id,
|
|
445
|
+
"batch_state": Batch.State(state).name,
|
|
446
|
+
"batch_state_message": batch.state_message,
|
|
447
|
+
}
|
|
378
448
|
)
|
|
379
449
|
|
|
380
450
|
|
|
@@ -432,9 +502,9 @@ class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
|
|
|
432
502
|
try:
|
|
433
503
|
while self.end_time > time.time():
|
|
434
504
|
cluster = await self.get_async_hook().get_cluster(
|
|
435
|
-
region=self.region,
|
|
505
|
+
region=self.region,
|
|
436
506
|
cluster_name=self.cluster_name,
|
|
437
|
-
project_id=self.project_id,
|
|
507
|
+
project_id=self.project_id,
|
|
438
508
|
metadata=self.metadata,
|
|
439
509
|
)
|
|
440
510
|
self.log.info(
|
|
@@ -153,7 +153,7 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
|
|
|
153
153
|
)
|
|
154
154
|
|
|
155
155
|
@cached_property
|
|
156
|
-
def hook(self) -> GKEKubernetesAsyncHook:
|
|
156
|
+
def hook(self) -> GKEKubernetesAsyncHook:
|
|
157
157
|
return GKEKubernetesAsyncHook(
|
|
158
158
|
cluster_url=self._cluster_url,
|
|
159
159
|
ssl_ca_cert=self._ssl_ca_cert,
|
|
@@ -200,7 +200,7 @@ class GKEOperationTrigger(BaseTrigger):
|
|
|
200
200
|
},
|
|
201
201
|
)
|
|
202
202
|
|
|
203
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
203
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
204
204
|
"""Get operation status and yields corresponding event."""
|
|
205
205
|
hook = self._get_hook()
|
|
206
206
|
try:
|
|
@@ -260,9 +260,10 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
260
260
|
ssl_ca_cert: str,
|
|
261
261
|
job_name: str,
|
|
262
262
|
job_namespace: str,
|
|
263
|
-
|
|
263
|
+
pod_names: list[str],
|
|
264
264
|
pod_namespace: str,
|
|
265
265
|
base_container_name: str,
|
|
266
|
+
pod_name: str | None = None,
|
|
266
267
|
gcp_conn_id: str = "google_cloud_default",
|
|
267
268
|
poll_interval: float = 2,
|
|
268
269
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
@@ -274,7 +275,13 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
274
275
|
self.ssl_ca_cert = ssl_ca_cert
|
|
275
276
|
self.job_name = job_name
|
|
276
277
|
self.job_namespace = job_namespace
|
|
277
|
-
|
|
278
|
+
if pod_name is not None:
|
|
279
|
+
self._pod_name = pod_name
|
|
280
|
+
self.pod_names = [
|
|
281
|
+
self.pod_name,
|
|
282
|
+
]
|
|
283
|
+
else:
|
|
284
|
+
self.pod_names = pod_names
|
|
278
285
|
self.pod_namespace = pod_namespace
|
|
279
286
|
self.base_container_name = base_container_name
|
|
280
287
|
self.gcp_conn_id = gcp_conn_id
|
|
@@ -283,6 +290,15 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
283
290
|
self.get_logs = get_logs
|
|
284
291
|
self.do_xcom_push = do_xcom_push
|
|
285
292
|
|
|
293
|
+
@property
|
|
294
|
+
def pod_name(self):
|
|
295
|
+
warnings.warn(
|
|
296
|
+
"`pod_name` parameter is deprecated, please use `pod_names`",
|
|
297
|
+
AirflowProviderDeprecationWarning,
|
|
298
|
+
stacklevel=2,
|
|
299
|
+
)
|
|
300
|
+
return self._pod_name
|
|
301
|
+
|
|
286
302
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
|
287
303
|
"""Serialize KubernetesCreateJobTrigger arguments and classpath."""
|
|
288
304
|
return (
|
|
@@ -292,7 +308,7 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
292
308
|
"ssl_ca_cert": self.ssl_ca_cert,
|
|
293
309
|
"job_name": self.job_name,
|
|
294
310
|
"job_namespace": self.job_namespace,
|
|
295
|
-
"
|
|
311
|
+
"pod_names": self.pod_names,
|
|
296
312
|
"pod_namespace": self.pod_namespace,
|
|
297
313
|
"base_container_name": self.base_container_name,
|
|
298
314
|
"gcp_conn_id": self.gcp_conn_id,
|
|
@@ -303,10 +319,8 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
303
319
|
},
|
|
304
320
|
)
|
|
305
321
|
|
|
306
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
322
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
307
323
|
"""Get current job status and yield a TriggerEvent."""
|
|
308
|
-
if self.get_logs or self.do_xcom_push:
|
|
309
|
-
pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace)
|
|
310
324
|
if self.do_xcom_push:
|
|
311
325
|
kubernetes_provider = ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
|
|
312
326
|
kubernetes_provider_name = kubernetes_provider.data["package-name"]
|
|
@@ -318,22 +332,26 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
318
332
|
f"package {kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
|
|
319
333
|
f"support this feature. Please upgrade it to version higher than or equal to {min_version}."
|
|
320
334
|
)
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
namespace=self.pod_namespace
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
335
|
+
xcom_results = []
|
|
336
|
+
for pod_name in self.pod_names:
|
|
337
|
+
pod = await self.hook.get_pod(name=pod_name, namespace=self.pod_namespace)
|
|
338
|
+
await self.hook.wait_until_container_complete(
|
|
339
|
+
name=pod_name,
|
|
340
|
+
namespace=self.pod_namespace,
|
|
341
|
+
container_name=self.base_container_name,
|
|
342
|
+
poll_interval=self.poll_interval,
|
|
343
|
+
)
|
|
344
|
+
self.log.info("Checking if xcom sidecar container is started.")
|
|
345
|
+
await self.hook.wait_until_container_started(
|
|
346
|
+
name=pod_name,
|
|
347
|
+
namespace=self.pod_namespace,
|
|
348
|
+
container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
|
|
349
|
+
poll_interval=self.poll_interval,
|
|
350
|
+
)
|
|
351
|
+
self.log.info("Extracting result from xcom sidecar container.")
|
|
352
|
+
loop = asyncio.get_running_loop()
|
|
353
|
+
xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
|
|
354
|
+
xcom_results.append(xcom_result)
|
|
337
355
|
job: V1Job = await self.hook.wait_until_job_complete(
|
|
338
356
|
name=self.job_name, namespace=self.job_namespace, poll_interval=self.poll_interval
|
|
339
357
|
)
|
|
@@ -345,12 +363,12 @@ class GKEJobTrigger(BaseTrigger):
|
|
|
345
363
|
{
|
|
346
364
|
"name": job.metadata.name,
|
|
347
365
|
"namespace": job.metadata.namespace,
|
|
348
|
-
"
|
|
349
|
-
"pod_namespace":
|
|
366
|
+
"pod_names": [pod_name for pod_name in self.pod_names] if self.get_logs else None,
|
|
367
|
+
"pod_namespace": self.pod_namespace if self.get_logs else None,
|
|
350
368
|
"status": status,
|
|
351
369
|
"message": message,
|
|
352
370
|
"job": job_dict,
|
|
353
|
-
"xcom_result":
|
|
371
|
+
"xcom_result": xcom_results if self.do_xcom_push else None,
|
|
354
372
|
}
|
|
355
373
|
)
|
|
356
374
|
|
|
@@ -90,7 +90,7 @@ class MLEngineStartTrainingJobTrigger(BaseTrigger):
|
|
|
90
90
|
},
|
|
91
91
|
)
|
|
92
92
|
|
|
93
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
93
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
94
94
|
"""Get current job execution status and yields a TriggerEvent."""
|
|
95
95
|
hook = self._get_async_hook()
|
|
96
96
|
try:
|
|
@@ -85,27 +85,23 @@ class PubsubPullTrigger(BaseTrigger):
|
|
|
85
85
|
},
|
|
86
86
|
)
|
|
87
87
|
|
|
88
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
await self.message_acknowledgement(pulled_messages)
|
|
88
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
89
|
+
while True:
|
|
90
|
+
if pulled_messages := await self.hook.pull(
|
|
91
|
+
project_id=self.project_id,
|
|
92
|
+
subscription=self.subscription,
|
|
93
|
+
max_messages=self.max_messages,
|
|
94
|
+
return_immediately=True,
|
|
95
|
+
):
|
|
96
|
+
if self.ack_messages:
|
|
97
|
+
await self.message_acknowledgement(pulled_messages)
|
|
99
98
|
|
|
100
|
-
|
|
99
|
+
messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]
|
|
101
100
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
except Exception as e:
|
|
107
|
-
yield TriggerEvent({"status": "error", "message": str(e)})
|
|
108
|
-
return
|
|
101
|
+
yield TriggerEvent({"status": "success", "message": messages_json})
|
|
102
|
+
return
|
|
103
|
+
self.log.info("Sleeping for %s seconds.", self.poke_interval)
|
|
104
|
+
await asyncio.sleep(self.poke_interval)
|
|
109
105
|
|
|
110
106
|
async def message_acknowledgement(self, pulled_messages):
|
|
111
107
|
await self.hook.acknowledge(
|
|
@@ -23,9 +23,9 @@ from google.cloud.bigquery.table import Row, RowIterator
|
|
|
23
23
|
|
|
24
24
|
if TYPE_CHECKING:
|
|
25
25
|
from collections.abc import Iterator
|
|
26
|
-
from logging import Logger
|
|
27
26
|
|
|
28
27
|
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
|
|
28
|
+
from airflow.sdk.types import Logger
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
def bigquery_get_data(
|