apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
  2. airflow/providers/google/__init__.py +3 -3
  3. airflow/providers/google/ads/hooks/ads.py +39 -6
  4. airflow/providers/google/ads/operators/ads.py +2 -2
  5. airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
  6. airflow/providers/google/assets/gcs.py +1 -11
  7. airflow/providers/google/cloud/bundles/__init__.py +16 -0
  8. airflow/providers/google/cloud/bundles/gcs.py +161 -0
  9. airflow/providers/google/cloud/hooks/alloy_db.py +1 -1
  10. airflow/providers/google/cloud/hooks/bigquery.py +176 -293
  11. airflow/providers/google/cloud/hooks/cloud_batch.py +1 -1
  12. airflow/providers/google/cloud/hooks/cloud_build.py +1 -1
  13. airflow/providers/google/cloud/hooks/cloud_composer.py +288 -15
  14. airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
  15. airflow/providers/google/cloud/hooks/cloud_memorystore.py +1 -1
  16. airflow/providers/google/cloud/hooks/cloud_run.py +18 -10
  17. airflow/providers/google/cloud/hooks/cloud_sql.py +102 -23
  18. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +29 -7
  19. airflow/providers/google/cloud/hooks/compute.py +1 -1
  20. airflow/providers/google/cloud/hooks/compute_ssh.py +6 -2
  21. airflow/providers/google/cloud/hooks/datacatalog.py +10 -1
  22. airflow/providers/google/cloud/hooks/dataflow.py +72 -95
  23. airflow/providers/google/cloud/hooks/dataform.py +1 -1
  24. airflow/providers/google/cloud/hooks/datafusion.py +21 -19
  25. airflow/providers/google/cloud/hooks/dataplex.py +2 -2
  26. airflow/providers/google/cloud/hooks/dataprep.py +1 -1
  27. airflow/providers/google/cloud/hooks/dataproc.py +73 -72
  28. airflow/providers/google/cloud/hooks/dataproc_metastore.py +1 -1
  29. airflow/providers/google/cloud/hooks/dlp.py +1 -1
  30. airflow/providers/google/cloud/hooks/functions.py +1 -1
  31. airflow/providers/google/cloud/hooks/gcs.py +112 -15
  32. airflow/providers/google/cloud/hooks/gdm.py +1 -1
  33. airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
  34. airflow/providers/google/cloud/hooks/kubernetes_engine.py +3 -3
  35. airflow/providers/google/cloud/hooks/looker.py +6 -2
  36. airflow/providers/google/cloud/hooks/managed_kafka.py +1 -1
  37. airflow/providers/google/cloud/hooks/mlengine.py +4 -3
  38. airflow/providers/google/cloud/hooks/pubsub.py +3 -0
  39. airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
  40. airflow/providers/google/cloud/hooks/spanner.py +74 -9
  41. airflow/providers/google/cloud/hooks/stackdriver.py +11 -9
  42. airflow/providers/google/cloud/hooks/tasks.py +1 -1
  43. airflow/providers/google/cloud/hooks/translate.py +2 -2
  44. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +2 -210
  45. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +3 -3
  46. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +28 -2
  47. airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
  48. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +308 -8
  49. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
  50. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
  51. airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +1 -1
  52. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +1 -1
  53. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
  54. airflow/providers/google/cloud/hooks/vision.py +3 -3
  55. airflow/providers/google/cloud/hooks/workflows.py +1 -1
  56. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  57. airflow/providers/google/cloud/links/base.py +77 -13
  58. airflow/providers/google/cloud/links/bigquery.py +0 -47
  59. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  60. airflow/providers/google/cloud/links/bigtable.py +0 -48
  61. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  62. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  63. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  64. airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
  65. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  66. airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
  67. airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
  68. airflow/providers/google/cloud/links/compute.py +0 -58
  69. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  70. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  71. airflow/providers/google/cloud/links/dataflow.py +0 -34
  72. airflow/providers/google/cloud/links/dataform.py +0 -64
  73. airflow/providers/google/cloud/links/datafusion.py +1 -96
  74. airflow/providers/google/cloud/links/dataplex.py +0 -154
  75. airflow/providers/google/cloud/links/dataprep.py +0 -24
  76. airflow/providers/google/cloud/links/dataproc.py +11 -95
  77. airflow/providers/google/cloud/links/datastore.py +0 -31
  78. airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
  79. airflow/providers/google/cloud/links/managed_kafka.py +0 -70
  80. airflow/providers/google/cloud/links/mlengine.py +0 -70
  81. airflow/providers/google/cloud/links/pubsub.py +0 -32
  82. airflow/providers/google/cloud/links/spanner.py +0 -33
  83. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  84. airflow/providers/google/cloud/links/translate.py +17 -187
  85. airflow/providers/google/cloud/links/vertex_ai.py +28 -195
  86. airflow/providers/google/cloud/links/workflows.py +0 -52
  87. airflow/providers/google/cloud/log/gcs_task_handler.py +58 -22
  88. airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
  89. airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
  90. airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
  91. airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
  92. airflow/providers/google/cloud/openlineage/facets.py +102 -1
  93. airflow/providers/google/cloud/openlineage/mixins.py +10 -8
  94. airflow/providers/google/cloud/openlineage/utils.py +15 -1
  95. airflow/providers/google/cloud/operators/alloy_db.py +71 -56
  96. airflow/providers/google/cloud/operators/bigquery.py +73 -636
  97. airflow/providers/google/cloud/operators/bigquery_dts.py +4 -6
  98. airflow/providers/google/cloud/operators/bigtable.py +37 -8
  99. airflow/providers/google/cloud/operators/cloud_base.py +21 -1
  100. airflow/providers/google/cloud/operators/cloud_batch.py +3 -3
  101. airflow/providers/google/cloud/operators/cloud_build.py +76 -33
  102. airflow/providers/google/cloud/operators/cloud_composer.py +129 -41
  103. airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
  104. airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
  105. airflow/providers/google/cloud/operators/cloud_run.py +24 -6
  106. airflow/providers/google/cloud/operators/cloud_sql.py +8 -17
  107. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +93 -12
  108. airflow/providers/google/cloud/operators/compute.py +9 -41
  109. airflow/providers/google/cloud/operators/datacatalog.py +157 -21
  110. airflow/providers/google/cloud/operators/dataflow.py +40 -16
  111. airflow/providers/google/cloud/operators/dataform.py +15 -5
  112. airflow/providers/google/cloud/operators/datafusion.py +42 -21
  113. airflow/providers/google/cloud/operators/dataplex.py +194 -110
  114. airflow/providers/google/cloud/operators/dataprep.py +1 -5
  115. airflow/providers/google/cloud/operators/dataproc.py +80 -36
  116. airflow/providers/google/cloud/operators/dataproc_metastore.py +97 -89
  117. airflow/providers/google/cloud/operators/datastore.py +23 -7
  118. airflow/providers/google/cloud/operators/dlp.py +6 -29
  119. airflow/providers/google/cloud/operators/functions.py +17 -8
  120. airflow/providers/google/cloud/operators/gcs.py +12 -9
  121. airflow/providers/google/cloud/operators/gen_ai.py +389 -0
  122. airflow/providers/google/cloud/operators/kubernetes_engine.py +62 -100
  123. airflow/providers/google/cloud/operators/looker.py +2 -2
  124. airflow/providers/google/cloud/operators/managed_kafka.py +108 -53
  125. airflow/providers/google/cloud/operators/natural_language.py +1 -1
  126. airflow/providers/google/cloud/operators/pubsub.py +68 -15
  127. airflow/providers/google/cloud/operators/spanner.py +26 -13
  128. airflow/providers/google/cloud/operators/speech_to_text.py +2 -3
  129. airflow/providers/google/cloud/operators/stackdriver.py +1 -9
  130. airflow/providers/google/cloud/operators/tasks.py +1 -12
  131. airflow/providers/google/cloud/operators/text_to_speech.py +2 -3
  132. airflow/providers/google/cloud/operators/translate.py +41 -17
  133. airflow/providers/google/cloud/operators/translate_speech.py +2 -3
  134. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
  135. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +30 -10
  136. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +55 -27
  137. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
  138. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
  139. airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
  140. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
  141. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -115
  142. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +12 -10
  143. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
  144. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +31 -8
  145. airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
  146. airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
  147. airflow/providers/google/cloud/operators/vision.py +2 -2
  148. airflow/providers/google/cloud/operators/workflows.py +18 -15
  149. airflow/providers/google/cloud/secrets/secret_manager.py +3 -2
  150. airflow/providers/google/cloud/sensors/bigquery.py +3 -3
  151. airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -3
  152. airflow/providers/google/cloud/sensors/bigtable.py +11 -4
  153. airflow/providers/google/cloud/sensors/cloud_composer.py +533 -30
  154. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -3
  155. airflow/providers/google/cloud/sensors/dataflow.py +26 -10
  156. airflow/providers/google/cloud/sensors/dataform.py +2 -3
  157. airflow/providers/google/cloud/sensors/datafusion.py +4 -5
  158. airflow/providers/google/cloud/sensors/dataplex.py +2 -3
  159. airflow/providers/google/cloud/sensors/dataprep.py +2 -2
  160. airflow/providers/google/cloud/sensors/dataproc.py +2 -3
  161. airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -3
  162. airflow/providers/google/cloud/sensors/gcs.py +4 -5
  163. airflow/providers/google/cloud/sensors/looker.py +2 -3
  164. airflow/providers/google/cloud/sensors/pubsub.py +4 -5
  165. airflow/providers/google/cloud/sensors/tasks.py +2 -2
  166. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -3
  167. airflow/providers/google/cloud/sensors/workflows.py +2 -3
  168. airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
  169. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
  170. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +4 -3
  171. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
  172. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +10 -5
  173. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
  174. airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
  175. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
  176. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
  177. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  178. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +3 -3
  179. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -4
  180. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +21 -13
  181. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +4 -3
  182. airflow/providers/google/cloud/transfers/gcs_to_local.py +6 -4
  183. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +11 -5
  184. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
  185. airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
  186. airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
  187. airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
  188. airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
  189. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
  190. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
  191. airflow/providers/google/cloud/transfers/s3_to_gcs.py +13 -7
  192. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
  193. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +14 -5
  194. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
  195. airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
  196. airflow/providers/google/cloud/triggers/bigquery.py +76 -35
  197. airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
  198. airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
  199. airflow/providers/google/cloud/triggers/cloud_run.py +3 -3
  200. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +92 -2
  201. airflow/providers/google/cloud/triggers/dataflow.py +122 -0
  202. airflow/providers/google/cloud/triggers/datafusion.py +1 -1
  203. airflow/providers/google/cloud/triggers/dataplex.py +14 -2
  204. airflow/providers/google/cloud/triggers/dataproc.py +123 -53
  205. airflow/providers/google/cloud/triggers/kubernetes_engine.py +47 -28
  206. airflow/providers/google/cloud/triggers/mlengine.py +1 -1
  207. airflow/providers/google/cloud/triggers/pubsub.py +15 -19
  208. airflow/providers/google/cloud/triggers/vertex_ai.py +1 -1
  209. airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
  210. airflow/providers/google/cloud/utils/credentials_provider.py +2 -2
  211. airflow/providers/google/cloud/utils/field_sanitizer.py +1 -1
  212. airflow/providers/google/cloud/utils/field_validator.py +2 -3
  213. airflow/providers/google/common/auth_backend/google_openid.py +4 -4
  214. airflow/providers/google/common/deprecated.py +2 -1
  215. airflow/providers/google/common/hooks/base_google.py +27 -9
  216. airflow/providers/google/common/hooks/operation_helpers.py +1 -1
  217. airflow/providers/google/common/links/storage.py +0 -22
  218. airflow/providers/google/common/utils/get_secret.py +31 -0
  219. airflow/providers/google/common/utils/id_token_credentials.py +3 -4
  220. airflow/providers/google/firebase/hooks/firestore.py +1 -1
  221. airflow/providers/google/firebase/operators/firestore.py +3 -3
  222. airflow/providers/google/get_provider_info.py +56 -52
  223. airflow/providers/google/go_module_utils.py +35 -3
  224. airflow/providers/google/leveldb/hooks/leveldb.py +27 -2
  225. airflow/providers/google/leveldb/operators/leveldb.py +2 -2
  226. airflow/providers/google/marketing_platform/hooks/campaign_manager.py +1 -1
  227. airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
  228. airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
  229. airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
  230. airflow/providers/google/marketing_platform/operators/analytics_admin.py +2 -3
  231. airflow/providers/google/marketing_platform/operators/campaign_manager.py +6 -6
  232. airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
  233. airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
  234. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
  235. airflow/providers/google/marketing_platform/sensors/display_video.py +3 -64
  236. airflow/providers/google/suite/hooks/calendar.py +2 -2
  237. airflow/providers/google/suite/hooks/sheets.py +16 -2
  238. airflow/providers/google/suite/operators/sheets.py +8 -3
  239. airflow/providers/google/suite/sensors/drive.py +2 -2
  240. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +3 -3
  241. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  242. airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
  243. airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
  244. airflow/providers/google/version_compat.py +15 -1
  245. {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/METADATA +90 -46
  246. apache_airflow_providers_google-19.3.0.dist-info/RECORD +331 -0
  247. apache_airflow_providers_google-19.3.0.dist-info/licenses/NOTICE +5 -0
  248. airflow/providers/google/cloud/hooks/automl.py +0 -673
  249. airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
  250. airflow/providers/google/cloud/links/automl.py +0 -193
  251. airflow/providers/google/cloud/operators/automl.py +0 -1362
  252. airflow/providers/google/cloud/operators/life_sciences.py +0 -119
  253. airflow/providers/google/cloud/operators/mlengine.py +0 -112
  254. apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
  255. {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/WHEEL +0 -0
  256. {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/entry_points.txt +0 -0
  257. {airflow/providers/google → apache_airflow_providers_google-19.3.0.dist-info/licenses}/LICENSE +0 -0
@@ -788,3 +788,125 @@ class DataflowJobMessagesTrigger(BaseTrigger):
788
788
  poll_sleep=self.poll_sleep,
789
789
  impersonation_chain=self.impersonation_chain,
790
790
  )
791
+
792
+
793
+ class DataflowJobStateCompleteTrigger(BaseTrigger):
794
+ """
795
+ Trigger that monitors if a Dataflow job has reached any of successful terminal state meant for that job.
796
+
797
+ :param job_id: Required. ID of the job.
798
+ :param project_id: Required. The Google Cloud project ID in which the job was started.
799
+ :param location: Optional. The location where the job is executed. If set to None then
800
+ the value of DEFAULT_DATAFLOW_LOCATION will be used.
801
+ :param wait_until_finished: Optional. Dataflow option to block pipeline until completion.
802
+ :param gcp_conn_id: The connection ID to use for connecting to Google Cloud.
803
+ :param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job.
804
+ :param impersonation_chain: Optional. Service account to impersonate using short-term
805
+ credentials, or chained list of accounts required to get the access_token
806
+ of the last account in the list, which will be impersonated in the request.
807
+ If set as a string, the account must grant the originating account
808
+ the Service Account Token Creator IAM role.
809
+ If set as a sequence, the identities from the list must grant
810
+ Service Account Token Creator IAM role to the directly preceding identity, with first
811
+ account from the list granting this role to the originating account (templated).
812
+ """
813
+
814
+ def __init__(
815
+ self,
816
+ job_id: str,
817
+ project_id: str | None,
818
+ location: str = DEFAULT_DATAFLOW_LOCATION,
819
+ wait_until_finished: bool | None = None,
820
+ gcp_conn_id: str = "google_cloud_default",
821
+ poll_sleep: int = 10,
822
+ impersonation_chain: str | Sequence[str] | None = None,
823
+ ):
824
+ super().__init__()
825
+ self.job_id = job_id
826
+ self.project_id = project_id
827
+ self.location = location
828
+ self.wait_until_finished = wait_until_finished
829
+ self.gcp_conn_id = gcp_conn_id
830
+ self.poll_sleep = poll_sleep
831
+ self.impersonation_chain = impersonation_chain
832
+
833
+ def serialize(self) -> tuple[str, dict[str, Any]]:
834
+ """Serialize class arguments and classpath."""
835
+ return (
836
+ "airflow.providers.google.cloud.triggers.dataflow.DataflowJobStateCompleteTrigger",
837
+ {
838
+ "job_id": self.job_id,
839
+ "project_id": self.project_id,
840
+ "location": self.location,
841
+ "wait_until_finished": self.wait_until_finished,
842
+ "gcp_conn_id": self.gcp_conn_id,
843
+ "poll_sleep": self.poll_sleep,
844
+ "impersonation_chain": self.impersonation_chain,
845
+ },
846
+ )
847
+
848
+ async def run(self):
849
+ """
850
+ Loop until the job reaches successful final or error state.
851
+
852
+ Yields a TriggerEvent with success status, if the job reaches successful state for own type.
853
+
854
+ Yields a TriggerEvent with error status, if the client returns an unexpected terminal
855
+ job status or any exception is raised while looping.
856
+
857
+ In any other case the Trigger will wait for a specified amount of time
858
+ stored in self.poll_sleep variable.
859
+ """
860
+ try:
861
+ while True:
862
+ job = await self.async_hook.get_job(
863
+ project_id=self.project_id,
864
+ job_id=self.job_id,
865
+ location=self.location,
866
+ )
867
+ job_state = job.current_state.name
868
+ job_type_name = job.type_.name
869
+
870
+ FAILED_STATES = DataflowJobStatus.FAILED_END_STATES | {DataflowJobStatus.JOB_STATE_DRAINED}
871
+ if job_state in FAILED_STATES:
872
+ yield TriggerEvent(
873
+ {
874
+ "status": "error",
875
+ "message": (
876
+ f"Job with id '{self.job_id}' is in failed terminal state: {job_state}"
877
+ ),
878
+ }
879
+ )
880
+ return
881
+
882
+ if self.async_hook.job_reached_terminal_state(
883
+ job={"id": self.job_id, "currentState": job_state, "type": job_type_name},
884
+ wait_until_finished=self.wait_until_finished,
885
+ ):
886
+ yield TriggerEvent(
887
+ {
888
+ "status": "success",
889
+ "message": (
890
+ f"Job with id '{self.job_id}' has reached successful final state: {job_state}"
891
+ ),
892
+ }
893
+ )
894
+ return
895
+ self.log.info("Sleeping for %s seconds.", self.poll_sleep)
896
+ await asyncio.sleep(self.poll_sleep)
897
+ except Exception as e:
898
+ self.log.error("Exception occurred while checking for job state!")
899
+ yield TriggerEvent(
900
+ {
901
+ "status": "error",
902
+ "message": str(e),
903
+ }
904
+ )
905
+
906
+ @cached_property
907
+ def async_hook(self) -> AsyncDataflowHook:
908
+ return AsyncDataflowHook(
909
+ gcp_conn_id=self.gcp_conn_id,
910
+ poll_sleep=self.poll_sleep,
911
+ impersonation_chain=self.impersonation_chain,
912
+ )
@@ -86,7 +86,7 @@ class DataFusionStartPipelineTrigger(BaseTrigger):
86
86
  },
87
87
  )
88
88
 
89
- async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
89
+ async def run(self) -> AsyncIterator[TriggerEvent]:
90
90
  """Get current pipeline status and yields a TriggerEvent."""
91
91
  hook = self._get_async_hook()
92
92
  try:
@@ -103,7 +103,13 @@ class DataplexDataQualityJobTrigger(BaseTrigger):
103
103
  self.polling_interval_seconds,
104
104
  )
105
105
  await asyncio.sleep(self.polling_interval_seconds)
106
- yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": self._convert_to_dict(job)})
106
+ yield TriggerEvent(
107
+ {
108
+ "job_id": self.job_id,
109
+ "job_state": DataScanJob.State(state).name,
110
+ "job": self._convert_to_dict(job),
111
+ }
112
+ )
107
113
 
108
114
  def _convert_to_dict(self, job: DataScanJob) -> dict:
109
115
  """Return a representation of a DataScanJob instance as a dict."""
@@ -185,7 +191,13 @@ class DataplexDataProfileJobTrigger(BaseTrigger):
185
191
  self.polling_interval_seconds,
186
192
  )
187
193
  await asyncio.sleep(self.polling_interval_seconds)
188
- yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": self._convert_to_dict(job)})
194
+ yield TriggerEvent(
195
+ {
196
+ "job_id": self.job_id,
197
+ "job_state": DataScanJob.State(state).name,
198
+ "job": self._convert_to_dict(job),
199
+ }
200
+ )
189
201
 
190
202
  def _convert_to_dict(self, job: DataScanJob) -> dict:
191
203
  """Return a representation of a DataScanJob instance as a dict."""
@@ -25,21 +25,25 @@ import time
25
25
  from collections.abc import AsyncIterator, Sequence
26
26
  from typing import TYPE_CHECKING, Any
27
27
 
28
+ from asgiref.sync import sync_to_async
28
29
  from google.api_core.exceptions import NotFound
29
- from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
30
+ from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, Job, JobStatus
30
31
 
31
- from airflow.exceptions import AirflowException
32
- from airflow.models.taskinstance import TaskInstance
32
+ from airflow.providers.common.compat.sdk import AirflowException
33
33
  from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
34
34
  from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
35
35
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
36
+ from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
36
37
  from airflow.triggers.base import BaseTrigger, TriggerEvent
37
- from airflow.utils.session import provide_session
38
38
  from airflow.utils.state import TaskInstanceState
39
39
 
40
40
  if TYPE_CHECKING:
41
41
  from sqlalchemy.orm.session import Session
42
42
 
43
+ if not AIRFLOW_V_3_0_PLUS:
44
+ from airflow.models.taskinstance import TaskInstance
45
+ from airflow.utils.session import provide_session
46
+
43
47
 
44
48
  class DataprocBaseTrigger(BaseTrigger):
45
49
  """Base class for Dataproc triggers."""
@@ -117,40 +121,67 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
117
121
  },
118
122
  )
119
123
 
120
- @provide_session
121
- def get_task_instance(self, session: Session) -> TaskInstance:
122
- """
123
- Get the task instance for the current task.
124
+ if not AIRFLOW_V_3_0_PLUS:
124
125
 
125
- :param session: Sqlalchemy session
126
- """
127
- query = session.query(TaskInstance).filter(
128
- TaskInstance.dag_id == self.task_instance.dag_id,
129
- TaskInstance.task_id == self.task_instance.task_id,
130
- TaskInstance.run_id == self.task_instance.run_id,
131
- TaskInstance.map_index == self.task_instance.map_index,
126
+ @provide_session
127
+ def get_task_instance(self, session: Session) -> TaskInstance:
128
+ """
129
+ Get the task instance for the current task.
130
+
131
+ :param session: Sqlalchemy session
132
+ """
133
+ query = session.query(TaskInstance).filter(
134
+ TaskInstance.dag_id == self.task_instance.dag_id,
135
+ TaskInstance.task_id == self.task_instance.task_id,
136
+ TaskInstance.run_id == self.task_instance.run_id,
137
+ TaskInstance.map_index == self.task_instance.map_index,
138
+ )
139
+ task_instance = query.one_or_none()
140
+ if task_instance is None:
141
+ raise AirflowException(
142
+ "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
143
+ self.task_instance.dag_id,
144
+ self.task_instance.task_id,
145
+ self.task_instance.run_id,
146
+ self.task_instance.map_index,
147
+ )
148
+ return task_instance
149
+
150
+ async def get_task_state(self):
151
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
152
+
153
+ task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
154
+ dag_id=self.task_instance.dag_id,
155
+ task_ids=[self.task_instance.task_id],
156
+ run_ids=[self.task_instance.run_id],
157
+ map_index=self.task_instance.map_index,
132
158
  )
133
- task_instance = query.one_or_none()
134
- if task_instance is None:
159
+ try:
160
+ task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
161
+ except Exception:
135
162
  raise AirflowException(
136
- "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
163
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
137
164
  self.task_instance.dag_id,
138
165
  self.task_instance.task_id,
139
166
  self.task_instance.run_id,
140
167
  self.task_instance.map_index,
141
168
  )
142
- return task_instance
169
+ return task_state
143
170
 
144
- def safe_to_cancel(self) -> bool:
171
+ async def safe_to_cancel(self) -> bool:
145
172
  """
146
173
  Whether it is safe to cancel the external job which is being executed by this trigger.
147
174
 
148
175
  This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
149
176
  Because in those cases, we should NOT cancel the external job.
150
177
  """
151
- # Database query is needed to get the latest state of the task instance.
152
- task_instance = self.get_task_instance() # type: ignore[call-arg]
153
- return task_instance.state != TaskInstanceState.DEFERRED
178
+ if AIRFLOW_V_3_0_PLUS:
179
+ task_state = await self.get_task_state()
180
+ else:
181
+ # Database query is needed to get the latest state of the task instance.
182
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
183
+ task_state = task_instance.state
184
+ return task_state != TaskInstanceState.DEFERRED
154
185
 
155
186
  async def run(self):
156
187
  try:
@@ -163,11 +194,13 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
163
194
  if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
164
195
  break
165
196
  await asyncio.sleep(self.polling_interval_seconds)
166
- yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job})
197
+ yield TriggerEvent(
198
+ {"job_id": self.job_id, "job_state": JobStatus.State(state).name, "job": Job.to_dict(job)}
199
+ )
167
200
  except asyncio.CancelledError:
168
201
  self.log.info("Task got cancelled.")
169
202
  try:
170
- if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
203
+ if self.job_id and self.cancel_on_kill and await self.safe_to_cancel():
171
204
  self.log.info(
172
205
  "Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
173
206
  " in deferred state."
@@ -181,7 +214,12 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
181
214
  job_id=self.job_id, project_id=self.project_id, region=self.region
182
215
  )
183
216
  self.log.info("Job: %s is cancelled", self.job_id)
184
- yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING})
217
+ yield TriggerEvent(
218
+ {
219
+ "job_id": self.job_id,
220
+ "job_state": ClusterStatus.State.DELETING.name,
221
+ }
222
+ )
185
223
  except Exception as e:
186
224
  self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e))
187
225
  raise e
@@ -224,35 +262,62 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
224
262
  },
225
263
  )
226
264
 
227
- @provide_session
228
- def get_task_instance(self, session: Session) -> TaskInstance:
229
- query = session.query(TaskInstance).filter(
230
- TaskInstance.dag_id == self.task_instance.dag_id,
231
- TaskInstance.task_id == self.task_instance.task_id,
232
- TaskInstance.run_id == self.task_instance.run_id,
233
- TaskInstance.map_index == self.task_instance.map_index,
265
+ if not AIRFLOW_V_3_0_PLUS:
266
+
267
+ @provide_session
268
+ def get_task_instance(self, session: Session) -> TaskInstance:
269
+ query = session.query(TaskInstance).filter(
270
+ TaskInstance.dag_id == self.task_instance.dag_id,
271
+ TaskInstance.task_id == self.task_instance.task_id,
272
+ TaskInstance.run_id == self.task_instance.run_id,
273
+ TaskInstance.map_index == self.task_instance.map_index,
274
+ )
275
+ task_instance = query.one_or_none()
276
+ if task_instance is None:
277
+ raise AirflowException(
278
+ "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
279
+ self.task_instance.dag_id,
280
+ self.task_instance.task_id,
281
+ self.task_instance.run_id,
282
+ self.task_instance.map_index,
283
+ )
284
+ return task_instance
285
+
286
+ async def get_task_state(self):
287
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
288
+
289
+ task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
290
+ dag_id=self.task_instance.dag_id,
291
+ task_ids=[self.task_instance.task_id],
292
+ run_ids=[self.task_instance.run_id],
293
+ map_index=self.task_instance.map_index,
234
294
  )
235
- task_instance = query.one_or_none()
236
- if task_instance is None:
295
+ try:
296
+ task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
297
+ except Exception:
237
298
  raise AirflowException(
238
- "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
299
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
239
300
  self.task_instance.dag_id,
240
301
  self.task_instance.task_id,
241
302
  self.task_instance.run_id,
242
303
  self.task_instance.map_index,
243
304
  )
244
- return task_instance
305
+ return task_state
245
306
 
246
- def safe_to_cancel(self) -> bool:
307
+ async def safe_to_cancel(self) -> bool:
247
308
  """
248
309
  Whether it is safe to cancel the external job which is being executed by this trigger.
249
310
 
250
311
  This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
251
312
  Because in those cases, we should NOT cancel the external job.
252
313
  """
253
- # Database query is needed to get the latest state of the task instance.
254
- task_instance = self.get_task_instance() # type: ignore[call-arg]
255
- return task_instance.state != TaskInstanceState.DEFERRED
314
+ if AIRFLOW_V_3_0_PLUS:
315
+ task_state = await self.get_task_state()
316
+ else:
317
+ # Database query is needed to get the latest state of the task instance.
318
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
319
+ task_state = task_instance.state
320
+ return task_state != TaskInstanceState.DEFERRED
256
321
 
257
322
  async def run(self) -> AsyncIterator[TriggerEvent]:
258
323
  try:
@@ -264,8 +329,8 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
264
329
  yield TriggerEvent(
265
330
  {
266
331
  "cluster_name": self.cluster_name,
267
- "cluster_state": ClusterStatus.State.DELETING,
268
- "cluster": cluster,
332
+ "cluster_state": ClusterStatus.State.DELETING.name, # type: ignore
333
+ "cluster": Cluster.to_dict(cluster),
269
334
  }
270
335
  )
271
336
  return
@@ -273,17 +338,18 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
273
338
  yield TriggerEvent(
274
339
  {
275
340
  "cluster_name": self.cluster_name,
276
- "cluster_state": state,
277
- "cluster": cluster,
341
+ "cluster_state": ClusterStatus.State(state).name,
342
+ "cluster": Cluster.to_dict(cluster),
278
343
  }
279
344
  )
280
345
  return
281
- self.log.info("Current state is %s", state)
282
- self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
283
- await asyncio.sleep(self.polling_interval_seconds)
346
+ else:
347
+ self.log.info("Current state is %s", state)
348
+ self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
349
+ await asyncio.sleep(self.polling_interval_seconds)
284
350
  except asyncio.CancelledError:
285
351
  try:
286
- if self.delete_on_error and self.safe_to_cancel():
352
+ if self.delete_on_error and await self.safe_to_cancel():
287
353
  self.log.info(
288
354
  "Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in "
289
355
  "deferred state."
@@ -369,12 +435,16 @@ class DataprocBatchTrigger(DataprocBaseTrigger):
369
435
 
370
436
  if state in (Batch.State.FAILED, Batch.State.SUCCEEDED, Batch.State.CANCELLED):
371
437
  break
372
- self.log.info("Current state is %s", state)
438
+ self.log.info("Current state is %s", Batch.State(state).name)
373
439
  self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
374
440
  await asyncio.sleep(self.polling_interval_seconds)
375
441
 
376
442
  yield TriggerEvent(
377
- {"batch_id": self.batch_id, "batch_state": state, "batch_state_message": batch.state_message}
443
+ {
444
+ "batch_id": self.batch_id,
445
+ "batch_state": Batch.State(state).name,
446
+ "batch_state_message": batch.state_message,
447
+ }
378
448
  )
379
449
 
380
450
 
@@ -432,9 +502,9 @@ class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
432
502
  try:
433
503
  while self.end_time > time.time():
434
504
  cluster = await self.get_async_hook().get_cluster(
435
- region=self.region, # type: ignore[arg-type]
505
+ region=self.region,
436
506
  cluster_name=self.cluster_name,
437
- project_id=self.project_id, # type: ignore[arg-type]
507
+ project_id=self.project_id,
438
508
  metadata=self.metadata,
439
509
  )
440
510
  self.log.info(
@@ -26,10 +26,11 @@ from typing import TYPE_CHECKING, Any
26
26
  from google.cloud.container_v1.types import Operation
27
27
  from packaging.version import parse as parse_version
28
28
 
29
- from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
29
+ from airflow.exceptions import AirflowProviderDeprecationWarning
30
30
  from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
31
31
  from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodManager
32
32
  from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
33
+ from airflow.providers.common.compat.sdk import AirflowException
33
34
  from airflow.providers.google.cloud.hooks.kubernetes_engine import (
34
35
  GKEAsyncHook,
35
36
  GKEKubernetesAsyncHook,
@@ -153,7 +154,7 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
153
154
  )
154
155
 
155
156
  @cached_property
156
- def hook(self) -> GKEKubernetesAsyncHook: # type: ignore[override]
157
+ def hook(self) -> GKEKubernetesAsyncHook:
157
158
  return GKEKubernetesAsyncHook(
158
159
  cluster_url=self._cluster_url,
159
160
  ssl_ca_cert=self._ssl_ca_cert,
@@ -200,7 +201,7 @@ class GKEOperationTrigger(BaseTrigger):
200
201
  },
201
202
  )
202
203
 
203
- async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
204
+ async def run(self) -> AsyncIterator[TriggerEvent]:
204
205
  """Get operation status and yields corresponding event."""
205
206
  hook = self._get_hook()
206
207
  try:
@@ -260,9 +261,10 @@ class GKEJobTrigger(BaseTrigger):
260
261
  ssl_ca_cert: str,
261
262
  job_name: str,
262
263
  job_namespace: str,
263
- pod_name: str,
264
+ pod_names: list[str],
264
265
  pod_namespace: str,
265
266
  base_container_name: str,
267
+ pod_name: str | None = None,
266
268
  gcp_conn_id: str = "google_cloud_default",
267
269
  poll_interval: float = 2,
268
270
  impersonation_chain: str | Sequence[str] | None = None,
@@ -274,7 +276,13 @@ class GKEJobTrigger(BaseTrigger):
274
276
  self.ssl_ca_cert = ssl_ca_cert
275
277
  self.job_name = job_name
276
278
  self.job_namespace = job_namespace
277
- self.pod_name = pod_name
279
+ if pod_name is not None:
280
+ self._pod_name = pod_name
281
+ self.pod_names = [
282
+ self.pod_name,
283
+ ]
284
+ else:
285
+ self.pod_names = pod_names
278
286
  self.pod_namespace = pod_namespace
279
287
  self.base_container_name = base_container_name
280
288
  self.gcp_conn_id = gcp_conn_id
@@ -283,6 +291,15 @@ class GKEJobTrigger(BaseTrigger):
283
291
  self.get_logs = get_logs
284
292
  self.do_xcom_push = do_xcom_push
285
293
 
294
+ @property
295
+ def pod_name(self):
296
+ warnings.warn(
297
+ "`pod_name` parameter is deprecated, please use `pod_names`",
298
+ AirflowProviderDeprecationWarning,
299
+ stacklevel=2,
300
+ )
301
+ return self._pod_name
302
+
286
303
  def serialize(self) -> tuple[str, dict[str, Any]]:
287
304
  """Serialize KubernetesCreateJobTrigger arguments and classpath."""
288
305
  return (
@@ -292,7 +309,7 @@ class GKEJobTrigger(BaseTrigger):
292
309
  "ssl_ca_cert": self.ssl_ca_cert,
293
310
  "job_name": self.job_name,
294
311
  "job_namespace": self.job_namespace,
295
- "pod_name": self.pod_name,
312
+ "pod_names": self.pod_names,
296
313
  "pod_namespace": self.pod_namespace,
297
314
  "base_container_name": self.base_container_name,
298
315
  "gcp_conn_id": self.gcp_conn_id,
@@ -303,10 +320,8 @@ class GKEJobTrigger(BaseTrigger):
303
320
  },
304
321
  )
305
322
 
306
- async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
323
+ async def run(self) -> AsyncIterator[TriggerEvent]:
307
324
  """Get current job status and yield a TriggerEvent."""
308
- if self.get_logs or self.do_xcom_push:
309
- pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace)
310
325
  if self.do_xcom_push:
311
326
  kubernetes_provider = ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
312
327
  kubernetes_provider_name = kubernetes_provider.data["package-name"]
@@ -318,22 +333,26 @@ class GKEJobTrigger(BaseTrigger):
318
333
  f"package {kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
319
334
  f"support this feature. Please upgrade it to version higher than or equal to {min_version}."
320
335
  )
321
- await self.hook.wait_until_container_complete(
322
- name=self.pod_name,
323
- namespace=self.pod_namespace,
324
- container_name=self.base_container_name,
325
- poll_interval=self.poll_interval,
326
- )
327
- self.log.info("Checking if xcom sidecar container is started.")
328
- await self.hook.wait_until_container_started(
329
- name=self.pod_name,
330
- namespace=self.pod_namespace,
331
- container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
332
- poll_interval=self.poll_interval,
333
- )
334
- self.log.info("Extracting result from xcom sidecar container.")
335
- loop = asyncio.get_running_loop()
336
- xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
336
+ xcom_results = []
337
+ for pod_name in self.pod_names:
338
+ pod = await self.hook.get_pod(name=pod_name, namespace=self.pod_namespace)
339
+ await self.hook.wait_until_container_complete(
340
+ name=pod_name,
341
+ namespace=self.pod_namespace,
342
+ container_name=self.base_container_name,
343
+ poll_interval=self.poll_interval,
344
+ )
345
+ self.log.info("Checking if xcom sidecar container is started.")
346
+ await self.hook.wait_until_container_started(
347
+ name=pod_name,
348
+ namespace=self.pod_namespace,
349
+ container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
350
+ poll_interval=self.poll_interval,
351
+ )
352
+ self.log.info("Extracting result from xcom sidecar container.")
353
+ loop = asyncio.get_running_loop()
354
+ xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
355
+ xcom_results.append(xcom_result)
337
356
  job: V1Job = await self.hook.wait_until_job_complete(
338
357
  name=self.job_name, namespace=self.job_namespace, poll_interval=self.poll_interval
339
358
  )
@@ -345,12 +364,12 @@ class GKEJobTrigger(BaseTrigger):
345
364
  {
346
365
  "name": job.metadata.name,
347
366
  "namespace": job.metadata.namespace,
348
- "pod_name": pod.metadata.name if self.get_logs else None,
349
- "pod_namespace": pod.metadata.namespace if self.get_logs else None,
367
+ "pod_names": [pod_name for pod_name in self.pod_names] if self.get_logs else None,
368
+ "pod_namespace": self.pod_namespace if self.get_logs else None,
350
369
  "status": status,
351
370
  "message": message,
352
371
  "job": job_dict,
353
- "xcom_result": xcom_result if self.do_xcom_push else None,
372
+ "xcom_result": xcom_results if self.do_xcom_push else None,
354
373
  }
355
374
  )
356
375
 
@@ -90,7 +90,7 @@ class MLEngineStartTrainingJobTrigger(BaseTrigger):
90
90
  },
91
91
  )
92
92
 
93
- async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
93
+ async def run(self) -> AsyncIterator[TriggerEvent]:
94
94
  """Get current job execution status and yields a TriggerEvent."""
95
95
  hook = self._get_async_hook()
96
96
  try:
@@ -85,27 +85,23 @@ class PubsubPullTrigger(BaseTrigger):
85
85
  },
86
86
  )
87
87
 
88
- async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
89
- try:
90
- while True:
91
- if pulled_messages := await self.hook.pull(
92
- project_id=self.project_id,
93
- subscription=self.subscription,
94
- max_messages=self.max_messages,
95
- return_immediately=True,
96
- ):
97
- if self.ack_messages:
98
- await self.message_acknowledgement(pulled_messages)
88
+ async def run(self) -> AsyncIterator[TriggerEvent]:
89
+ while True:
90
+ if pulled_messages := await self.hook.pull(
91
+ project_id=self.project_id,
92
+ subscription=self.subscription,
93
+ max_messages=self.max_messages,
94
+ return_immediately=True,
95
+ ):
96
+ if self.ack_messages:
97
+ await self.message_acknowledgement(pulled_messages)
99
98
 
100
- messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]
99
+ messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]
101
100
 
102
- yield TriggerEvent({"status": "success", "message": messages_json})
103
- return
104
- self.log.info("Sleeping for %s seconds.", self.poke_interval)
105
- await asyncio.sleep(self.poke_interval)
106
- except Exception as e:
107
- yield TriggerEvent({"status": "error", "message": str(e)})
108
- return
101
+ yield TriggerEvent({"status": "success", "message": messages_json})
102
+ return
103
+ self.log.info("Sleeping for %s seconds.", self.poke_interval)
104
+ await asyncio.sleep(self.poke_interval)
109
105
 
110
106
  async def message_acknowledgement(self, pulled_messages):
111
107
  await self.hook.acknowledge(