apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. airflow/providers/google/__init__.py +3 -3
  2. airflow/providers/google/cloud/hooks/automl.py +1 -1
  3. airflow/providers/google/cloud/hooks/bigquery.py +64 -33
  4. airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
  5. airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
  6. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
  7. airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +246 -32
  9. airflow/providers/google/cloud/hooks/dataplex.py +6 -2
  10. airflow/providers/google/cloud/hooks/dlp.py +14 -14
  11. airflow/providers/google/cloud/hooks/gcs.py +6 -2
  12. airflow/providers/google/cloud/hooks/gdm.py +2 -2
  13. airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
  14. airflow/providers/google/cloud/hooks/mlengine.py +8 -4
  15. airflow/providers/google/cloud/hooks/pubsub.py +1 -1
  16. airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
  17. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
  18. airflow/providers/google/cloud/links/vertex_ai.py +2 -1
  19. airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
  20. airflow/providers/google/cloud/operators/automl.py +13 -12
  21. airflow/providers/google/cloud/operators/bigquery.py +36 -22
  22. airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
  23. airflow/providers/google/cloud/operators/bigtable.py +7 -6
  24. airflow/providers/google/cloud/operators/cloud_build.py +12 -11
  25. airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
  26. airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
  27. airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
  28. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
  29. airflow/providers/google/cloud/operators/compute.py +12 -11
  30. airflow/providers/google/cloud/operators/datacatalog.py +21 -20
  31. airflow/providers/google/cloud/operators/dataflow.py +59 -42
  32. airflow/providers/google/cloud/operators/datafusion.py +11 -10
  33. airflow/providers/google/cloud/operators/datapipeline.py +3 -2
  34. airflow/providers/google/cloud/operators/dataprep.py +5 -4
  35. airflow/providers/google/cloud/operators/dataproc.py +19 -16
  36. airflow/providers/google/cloud/operators/datastore.py +8 -7
  37. airflow/providers/google/cloud/operators/dlp.py +31 -30
  38. airflow/providers/google/cloud/operators/functions.py +4 -3
  39. airflow/providers/google/cloud/operators/gcs.py +66 -41
  40. airflow/providers/google/cloud/operators/kubernetes_engine.py +232 -12
  41. airflow/providers/google/cloud/operators/life_sciences.py +2 -1
  42. airflow/providers/google/cloud/operators/mlengine.py +11 -10
  43. airflow/providers/google/cloud/operators/pubsub.py +6 -5
  44. airflow/providers/google/cloud/operators/spanner.py +7 -6
  45. airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
  46. airflow/providers/google/cloud/operators/stackdriver.py +11 -10
  47. airflow/providers/google/cloud/operators/tasks.py +14 -13
  48. airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
  49. airflow/providers/google/cloud/operators/translate_speech.py +2 -1
  50. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
  51. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
  52. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
  53. airflow/providers/google/cloud/operators/vision.py +13 -12
  54. airflow/providers/google/cloud/operators/workflows.py +10 -9
  55. airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
  56. airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
  57. airflow/providers/google/cloud/sensors/bigtable.py +2 -1
  58. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
  59. airflow/providers/google/cloud/sensors/dataflow.py +239 -52
  60. airflow/providers/google/cloud/sensors/datafusion.py +2 -1
  61. airflow/providers/google/cloud/sensors/dataproc.py +3 -2
  62. airflow/providers/google/cloud/sensors/gcs.py +14 -12
  63. airflow/providers/google/cloud/sensors/tasks.py +2 -1
  64. airflow/providers/google/cloud/sensors/workflows.py +2 -1
  65. airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
  66. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
  67. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
  68. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
  69. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
  70. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
  71. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
  72. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
  73. airflow/providers/google/cloud/triggers/bigquery.py +14 -3
  74. airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
  75. airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
  76. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
  77. airflow/providers/google/cloud/triggers/dataflow.py +504 -4
  78. airflow/providers/google/cloud/triggers/dataproc.py +110 -26
  79. airflow/providers/google/cloud/triggers/mlengine.py +2 -1
  80. airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
  81. airflow/providers/google/common/hooks/base_google.py +45 -7
  82. airflow/providers/google/firebase/hooks/firestore.py +2 -2
  83. airflow/providers/google/firebase/operators/firestore.py +2 -1
  84. airflow/providers/google/get_provider_info.py +3 -2
  85. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/METADATA +8 -8
  86. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/RECORD +88 -89
  87. airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
  88. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/WHEEL +0 -0
  89. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -18,13 +18,24 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import asyncio
21
- from typing import Any, Sequence
21
+ from functools import cached_property
22
+ from typing import TYPE_CHECKING, Any, Sequence
22
23
 
23
24
  from google.cloud.dataflow_v1beta3 import JobState
25
+ from google.cloud.dataflow_v1beta3.types import (
26
+ AutoscalingEvent,
27
+ JobMessage,
28
+ JobMetrics,
29
+ MetricUpdate,
30
+ )
24
31
 
25
- from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook
32
+ from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook, DataflowJobStatus
26
33
  from airflow.triggers.base import BaseTrigger, TriggerEvent
27
34
 
35
+ if TYPE_CHECKING:
36
+ from google.cloud.dataflow_v1beta3.services.messages_v1_beta3.pagers import ListJobMessagesAsyncPager
37
+
38
+
28
39
  DEFAULT_DATAFLOW_LOCATION = "us-central1"
29
40
 
30
41
 
@@ -59,7 +70,6 @@ class TemplateJobStartTrigger(BaseTrigger):
59
70
  cancel_timeout: int | None = 5 * 60,
60
71
  ):
61
72
  super().__init__()
62
-
63
73
  self.project_id = project_id
64
74
  self.job_id = job_id
65
75
  self.location = location
@@ -128,7 +138,7 @@ class TemplateJobStartTrigger(BaseTrigger):
128
138
  return
129
139
  else:
130
140
  self.log.info("Job is still running...")
131
- self.log.info("Current job status is: %s", status)
141
+ self.log.info("Current job status is: %s", status.name)
132
142
  self.log.info("Sleeping for %s seconds.", self.poll_sleep)
133
143
  await asyncio.sleep(self.poll_sleep)
134
144
  except Exception as e:
@@ -142,3 +152,493 @@ class TemplateJobStartTrigger(BaseTrigger):
142
152
  impersonation_chain=self.impersonation_chain,
143
153
  cancel_timeout=self.cancel_timeout,
144
154
  )
155
+
156
+
157
+ class DataflowJobStatusTrigger(BaseTrigger):
158
+ """
159
+ Trigger that checks for metrics associated with a Dataflow job.
160
+
161
+ :param job_id: Required. ID of the job.
162
+ :param expected_statuses: The expected state(s) of the operation.
163
+ See: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs#Job.JobState
164
+ :param project_id: Required. The Google Cloud project ID in which the job was started.
165
+ :param location: Optional. The location where the job is executed. If set to None then
166
+ the value of DEFAULT_DATAFLOW_LOCATION will be used.
167
+ :param gcp_conn_id: The connection ID to use for connecting to Google Cloud.
168
+ :param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job.
169
+ :param impersonation_chain: Optional. Service account to impersonate using short-term
170
+ credentials, or chained list of accounts required to get the access_token
171
+ of the last account in the list, which will be impersonated in the request.
172
+ If set as a string, the account must grant the originating account
173
+ the Service Account Token Creator IAM role.
174
+ If set as a sequence, the identities from the list must grant
175
+ Service Account Token Creator IAM role to the directly preceding identity, with first
176
+ account from the list granting this role to the originating account (templated).
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ job_id: str,
182
+ expected_statuses: set[str],
183
+ project_id: str | None,
184
+ location: str = DEFAULT_DATAFLOW_LOCATION,
185
+ gcp_conn_id: str = "google_cloud_default",
186
+ poll_sleep: int = 10,
187
+ impersonation_chain: str | Sequence[str] | None = None,
188
+ ):
189
+ super().__init__()
190
+ self.job_id = job_id
191
+ self.expected_statuses = expected_statuses
192
+ self.project_id = project_id
193
+ self.location = location
194
+ self.gcp_conn_id = gcp_conn_id
195
+ self.poll_sleep = poll_sleep
196
+ self.impersonation_chain = impersonation_chain
197
+
198
+ def serialize(self) -> tuple[str, dict[str, Any]]:
199
+ """Serialize class arguments and classpath."""
200
+ return (
201
+ "airflow.providers.google.cloud.triggers.dataflow.DataflowJobStatusTrigger",
202
+ {
203
+ "job_id": self.job_id,
204
+ "expected_statuses": self.expected_statuses,
205
+ "project_id": self.project_id,
206
+ "location": self.location,
207
+ "gcp_conn_id": self.gcp_conn_id,
208
+ "poll_sleep": self.poll_sleep,
209
+ "impersonation_chain": self.impersonation_chain,
210
+ },
211
+ )
212
+
213
+ async def run(self):
214
+ """
215
+ Loop until the job reaches an expected or terminal state.
216
+
217
+ Yields a TriggerEvent with success status, if the client returns an expected job status.
218
+
219
+ Yields a TriggerEvent with error status, if the client returns an unexpected terminal
220
+ job status or any exception is raised while looping.
221
+
222
+ In any other case the Trigger will wait for a specified amount of time
223
+ stored in self.poll_sleep variable.
224
+ """
225
+ try:
226
+ while True:
227
+ job_status = await self.async_hook.get_job_status(
228
+ job_id=self.job_id,
229
+ project_id=self.project_id,
230
+ location=self.location,
231
+ )
232
+ if job_status.name in self.expected_statuses:
233
+ yield TriggerEvent(
234
+ {
235
+ "status": "success",
236
+ "message": f"Job with id '{self.job_id}' has reached an expected state: {job_status.name}",
237
+ }
238
+ )
239
+ return
240
+ elif job_status.name in DataflowJobStatus.TERMINAL_STATES:
241
+ yield TriggerEvent(
242
+ {
243
+ "status": "error",
244
+ "message": f"Job with id '{self.job_id}' is already in terminal state: {job_status.name}",
245
+ }
246
+ )
247
+ return
248
+ self.log.info("Sleeping for %s seconds.", self.poll_sleep)
249
+ await asyncio.sleep(self.poll_sleep)
250
+ except Exception as e:
251
+ self.log.error("Exception occurred while checking for job status!")
252
+ yield TriggerEvent(
253
+ {
254
+ "status": "error",
255
+ "message": str(e),
256
+ }
257
+ )
258
+
259
+ @cached_property
260
+ def async_hook(self) -> AsyncDataflowHook:
261
+ return AsyncDataflowHook(
262
+ gcp_conn_id=self.gcp_conn_id,
263
+ poll_sleep=self.poll_sleep,
264
+ impersonation_chain=self.impersonation_chain,
265
+ )
266
+
267
+
268
+ class DataflowJobMetricsTrigger(BaseTrigger):
269
+ """
270
+ Trigger that checks for metrics associated with a Dataflow job.
271
+
272
+ :param job_id: Required. ID of the job.
273
+ :param project_id: Required. The Google Cloud project ID in which the job was started.
274
+ :param location: Optional. The location where the job is executed. If set to None then
275
+ the value of DEFAULT_DATAFLOW_LOCATION will be used.
276
+ :param gcp_conn_id: The connection ID to use for connecting to Google Cloud.
277
+ :param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job.
278
+ :param impersonation_chain: Optional. Service account to impersonate using short-term
279
+ credentials, or chained list of accounts required to get the access_token
280
+ of the last account in the list, which will be impersonated in the request.
281
+ If set as a string, the account must grant the originating account
282
+ the Service Account Token Creator IAM role.
283
+ If set as a sequence, the identities from the list must grant
284
+ Service Account Token Creator IAM role to the directly preceding identity, with first
285
+ account from the list granting this role to the originating account (templated).
286
+ :param fail_on_terminal_state: If set to True the trigger will yield a TriggerEvent with
287
+ error status if the job reaches a terminal state.
288
+ """
289
+
290
+ def __init__(
291
+ self,
292
+ job_id: str,
293
+ project_id: str | None,
294
+ location: str = DEFAULT_DATAFLOW_LOCATION,
295
+ gcp_conn_id: str = "google_cloud_default",
296
+ poll_sleep: int = 10,
297
+ impersonation_chain: str | Sequence[str] | None = None,
298
+ fail_on_terminal_state: bool = True,
299
+ ):
300
+ super().__init__()
301
+ self.project_id = project_id
302
+ self.job_id = job_id
303
+ self.location = location
304
+ self.gcp_conn_id = gcp_conn_id
305
+ self.poll_sleep = poll_sleep
306
+ self.impersonation_chain = impersonation_chain
307
+ self.fail_on_terminal_state = fail_on_terminal_state
308
+
309
+ def serialize(self) -> tuple[str, dict[str, Any]]:
310
+ """Serialize class arguments and classpath."""
311
+ return (
312
+ "airflow.providers.google.cloud.triggers.dataflow.DataflowJobMetricsTrigger",
313
+ {
314
+ "project_id": self.project_id,
315
+ "job_id": self.job_id,
316
+ "location": self.location,
317
+ "gcp_conn_id": self.gcp_conn_id,
318
+ "poll_sleep": self.poll_sleep,
319
+ "impersonation_chain": self.impersonation_chain,
320
+ "fail_on_terminal_state": self.fail_on_terminal_state,
321
+ },
322
+ )
323
+
324
+ async def run(self):
325
+ """
326
+ Loop until a terminal job status or any job metrics are returned.
327
+
328
+ Yields a TriggerEvent with success status, if the client returns any job metrics
329
+ and fail_on_terminal_state attribute is False.
330
+
331
+ Yields a TriggerEvent with error status, if the client returns a job status with
332
+ a terminal state value and fail_on_terminal_state attribute is True.
333
+
334
+ Yields a TriggerEvent with error status, if any exception is raised while looping.
335
+
336
+ In any other case the Trigger will wait for a specified amount of time
337
+ stored in self.poll_sleep variable.
338
+ """
339
+ try:
340
+ while True:
341
+ job_status = await self.async_hook.get_job_status(
342
+ job_id=self.job_id,
343
+ project_id=self.project_id,
344
+ location=self.location,
345
+ )
346
+ job_metrics = await self.get_job_metrics()
347
+ if self.fail_on_terminal_state and job_status.name in DataflowJobStatus.TERMINAL_STATES:
348
+ yield TriggerEvent(
349
+ {
350
+ "status": "error",
351
+ "message": f"Job with id '{self.job_id}' is already in terminal state: {job_status.name}",
352
+ "result": None,
353
+ }
354
+ )
355
+ return
356
+ if job_metrics:
357
+ yield TriggerEvent(
358
+ {
359
+ "status": "success",
360
+ "message": f"Detected {len(job_metrics)} metrics for job '{self.job_id}'",
361
+ "result": job_metrics,
362
+ }
363
+ )
364
+ return
365
+ self.log.info("Sleeping for %s seconds.", self.poll_sleep)
366
+ await asyncio.sleep(self.poll_sleep)
367
+ except Exception as e:
368
+ self.log.error("Exception occurred while checking for job's metrics!")
369
+ yield TriggerEvent({"status": "error", "message": str(e), "result": None})
370
+
371
+ async def get_job_metrics(self) -> list[dict[str, Any]]:
372
+ """Wait for the Dataflow client response and then return it in a serialized list."""
373
+ job_response: JobMetrics = await self.async_hook.get_job_metrics(
374
+ job_id=self.job_id,
375
+ project_id=self.project_id,
376
+ location=self.location,
377
+ )
378
+ return self._get_metrics_from_job_response(job_response)
379
+
380
+ def _get_metrics_from_job_response(self, job_response: JobMetrics) -> list[dict[str, Any]]:
381
+ """Return a list of serialized MetricUpdate objects."""
382
+ return [MetricUpdate.to_dict(metric) for metric in job_response.metrics]
383
+
384
+ @cached_property
385
+ def async_hook(self) -> AsyncDataflowHook:
386
+ return AsyncDataflowHook(
387
+ gcp_conn_id=self.gcp_conn_id,
388
+ poll_sleep=self.poll_sleep,
389
+ impersonation_chain=self.impersonation_chain,
390
+ )
391
+
392
+
393
+ class DataflowJobAutoScalingEventTrigger(BaseTrigger):
394
+ """
395
+ Trigger that checks for autoscaling events associated with a Dataflow job.
396
+
397
+ :param job_id: Required. ID of the job.
398
+ :param project_id: Required. The Google Cloud project ID in which the job was started.
399
+ :param location: Optional. The location where the job is executed. If set to None then
400
+ the value of DEFAULT_DATAFLOW_LOCATION will be used.
401
+ :param gcp_conn_id: The connection ID to use for connecting to Google Cloud.
402
+ :param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job.
403
+ :param impersonation_chain: Optional. Service account to impersonate using short-term
404
+ credentials, or chained list of accounts required to get the access_token
405
+ of the last account in the list, which will be impersonated in the request.
406
+ If set as a string, the account must grant the originating account
407
+ the Service Account Token Creator IAM role.
408
+ If set as a sequence, the identities from the list must grant
409
+ Service Account Token Creator IAM role to the directly preceding identity, with first
410
+ account from the list granting this role to the originating account (templated).
411
+ :param fail_on_terminal_state: If set to True the trigger will yield a TriggerEvent with
412
+ error status if the job reaches a terminal state.
413
+ """
414
+
415
+ def __init__(
416
+ self,
417
+ job_id: str,
418
+ project_id: str | None,
419
+ location: str = DEFAULT_DATAFLOW_LOCATION,
420
+ gcp_conn_id: str = "google_cloud_default",
421
+ poll_sleep: int = 10,
422
+ impersonation_chain: str | Sequence[str] | None = None,
423
+ fail_on_terminal_state: bool = True,
424
+ ):
425
+ super().__init__()
426
+ self.project_id = project_id
427
+ self.job_id = job_id
428
+ self.location = location
429
+ self.gcp_conn_id = gcp_conn_id
430
+ self.poll_sleep = poll_sleep
431
+ self.impersonation_chain = impersonation_chain
432
+ self.fail_on_terminal_state = fail_on_terminal_state
433
+
434
+ def serialize(self) -> tuple[str, dict[str, Any]]:
435
+ """Serialize class arguments and classpath."""
436
+ return (
437
+ "airflow.providers.google.cloud.triggers.dataflow.DataflowJobAutoScalingEventTrigger",
438
+ {
439
+ "project_id": self.project_id,
440
+ "job_id": self.job_id,
441
+ "location": self.location,
442
+ "gcp_conn_id": self.gcp_conn_id,
443
+ "poll_sleep": self.poll_sleep,
444
+ "impersonation_chain": self.impersonation_chain,
445
+ "fail_on_terminal_state": self.fail_on_terminal_state,
446
+ },
447
+ )
448
+
449
+ async def run(self):
450
+ """
451
+ Loop until a terminal job status or any autoscaling events are returned.
452
+
453
+ Yields a TriggerEvent with success status, if the client returns any autoscaling events
454
+ and fail_on_terminal_state attribute is False.
455
+
456
+ Yields a TriggerEvent with error status, if the client returns a job status with
457
+ a terminal state value and fail_on_terminal_state attribute is True.
458
+
459
+ Yields a TriggerEvent with error status, if any exception is raised while looping.
460
+
461
+ In any other case the Trigger will wait for a specified amount of time
462
+ stored in self.poll_sleep variable.
463
+ """
464
+ try:
465
+ while True:
466
+ job_status = await self.async_hook.get_job_status(
467
+ job_id=self.job_id,
468
+ project_id=self.project_id,
469
+ location=self.location,
470
+ )
471
+ autoscaling_events = await self.list_job_autoscaling_events()
472
+ if self.fail_on_terminal_state and job_status.name in DataflowJobStatus.TERMINAL_STATES:
473
+ yield TriggerEvent(
474
+ {
475
+ "status": "error",
476
+ "message": f"Job with id '{self.job_id}' is already in terminal state: {job_status.name}",
477
+ "result": None,
478
+ }
479
+ )
480
+ return
481
+ if autoscaling_events:
482
+ yield TriggerEvent(
483
+ {
484
+ "status": "success",
485
+ "message": f"Detected {len(autoscaling_events)} autoscaling events for job '{self.job_id}'",
486
+ "result": autoscaling_events,
487
+ }
488
+ )
489
+ return
490
+ self.log.info("Sleeping for %s seconds.", self.poll_sleep)
491
+ await asyncio.sleep(self.poll_sleep)
492
+ except Exception as e:
493
+ self.log.error("Exception occurred while checking for job's autoscaling events!")
494
+ yield TriggerEvent({"status": "error", "message": str(e), "result": None})
495
+
496
+ async def list_job_autoscaling_events(self) -> list[dict[str, str | dict]]:
497
+ """Wait for the Dataflow client response and then return it in a serialized list."""
498
+ job_response: ListJobMessagesAsyncPager = await self.async_hook.list_job_messages(
499
+ job_id=self.job_id,
500
+ project_id=self.project_id,
501
+ location=self.location,
502
+ )
503
+ return self._get_autoscaling_events_from_job_response(job_response)
504
+
505
+ def _get_autoscaling_events_from_job_response(
506
+ self, job_response: ListJobMessagesAsyncPager
507
+ ) -> list[dict[str, str | dict]]:
508
+ """Return a list of serialized AutoscalingEvent objects."""
509
+ return [AutoscalingEvent.to_dict(event) for event in job_response.autoscaling_events]
510
+
511
+ @cached_property
512
+ def async_hook(self) -> AsyncDataflowHook:
513
+ return AsyncDataflowHook(
514
+ gcp_conn_id=self.gcp_conn_id,
515
+ poll_sleep=self.poll_sleep,
516
+ impersonation_chain=self.impersonation_chain,
517
+ )
518
+
519
+
520
+ class DataflowJobMessagesTrigger(BaseTrigger):
521
+ """
522
+ Trigger that checks for job messages associated with a Dataflow job.
523
+
524
+ :param job_id: Required. ID of the job.
525
+ :param project_id: Required. The Google Cloud project ID in which the job was started.
526
+ :param location: Optional. The location where the job is executed. If set to None then
527
+ the value of DEFAULT_DATAFLOW_LOCATION will be used.
528
+ :param gcp_conn_id: The connection ID to use for connecting to Google Cloud.
529
+ :param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job.
530
+ :param impersonation_chain: Optional. Service account to impersonate using short-term
531
+ credentials, or chained list of accounts required to get the access_token
532
+ of the last account in the list, which will be impersonated in the request.
533
+ If set as a string, the account must grant the originating account
534
+ the Service Account Token Creator IAM role.
535
+ If set as a sequence, the identities from the list must grant
536
+ Service Account Token Creator IAM role to the directly preceding identity, with first
537
+ account from the list granting this role to the originating account (templated).
538
+ :param fail_on_terminal_state: If set to True the trigger will yield a TriggerEvent with
539
+ error status if the job reaches a terminal state.
540
+ """
541
+
542
+ def __init__(
543
+ self,
544
+ job_id: str,
545
+ project_id: str | None,
546
+ location: str = DEFAULT_DATAFLOW_LOCATION,
547
+ gcp_conn_id: str = "google_cloud_default",
548
+ poll_sleep: int = 10,
549
+ impersonation_chain: str | Sequence[str] | None = None,
550
+ fail_on_terminal_state: bool = True,
551
+ ):
552
+ super().__init__()
553
+ self.project_id = project_id
554
+ self.job_id = job_id
555
+ self.location = location
556
+ self.gcp_conn_id = gcp_conn_id
557
+ self.poll_sleep = poll_sleep
558
+ self.impersonation_chain = impersonation_chain
559
+ self.fail_on_terminal_state = fail_on_terminal_state
560
+
561
+ def serialize(self) -> tuple[str, dict[str, Any]]:
562
+ """Serialize class arguments and classpath."""
563
+ return (
564
+ "airflow.providers.google.cloud.triggers.dataflow.DataflowJobMessagesTrigger",
565
+ {
566
+ "project_id": self.project_id,
567
+ "job_id": self.job_id,
568
+ "location": self.location,
569
+ "gcp_conn_id": self.gcp_conn_id,
570
+ "poll_sleep": self.poll_sleep,
571
+ "impersonation_chain": self.impersonation_chain,
572
+ "fail_on_terminal_state": self.fail_on_terminal_state,
573
+ },
574
+ )
575
+
576
+ async def run(self):
577
+ """
578
+ Loop until a terminal job status or any job messages are returned.
579
+
580
+ Yields a TriggerEvent with success status, if the client returns any job messages
581
+ and fail_on_terminal_state attribute is False.
582
+
583
+ Yields a TriggerEvent with error status, if the client returns a job status with
584
+ a terminal state value and fail_on_terminal_state attribute is True.
585
+
586
+ Yields a TriggerEvent with error status, if any exception is raised while looping.
587
+
588
+ In any other case the Trigger will wait for a specified amount of time
589
+ stored in self.poll_sleep variable.
590
+ """
591
+ try:
592
+ while True:
593
+ job_status = await self.async_hook.get_job_status(
594
+ job_id=self.job_id,
595
+ project_id=self.project_id,
596
+ location=self.location,
597
+ )
598
+ job_messages = await self.list_job_messages()
599
+ if self.fail_on_terminal_state and job_status.name in DataflowJobStatus.TERMINAL_STATES:
600
+ yield TriggerEvent(
601
+ {
602
+ "status": "error",
603
+ "message": f"Job with id '{self.job_id}' is already in terminal state: {job_status.name}",
604
+ "result": None,
605
+ }
606
+ )
607
+ return
608
+ if job_messages:
609
+ yield TriggerEvent(
610
+ {
611
+ "status": "success",
612
+ "message": f"Detected {len(job_messages)} job messages for job '{self.job_id}'",
613
+ "result": job_messages,
614
+ }
615
+ )
616
+ return
617
+ self.log.info("Sleeping for %s seconds.", self.poll_sleep)
618
+ await asyncio.sleep(self.poll_sleep)
619
+ except Exception as e:
620
+ self.log.error("Exception occurred while checking for job's messages!")
621
+ yield TriggerEvent({"status": "error", "message": str(e), "result": None})
622
+
623
+ async def list_job_messages(self) -> list[dict[str, str | dict]]:
624
+ """Wait for the Dataflow client response and then return it in a serialized list."""
625
+ job_response: ListJobMessagesAsyncPager = await self.async_hook.list_job_messages(
626
+ job_id=self.job_id,
627
+ project_id=self.project_id,
628
+ location=self.location,
629
+ )
630
+ return self._get_job_messages_from_job_response(job_response)
631
+
632
+ def _get_job_messages_from_job_response(
633
+ self, job_response: ListJobMessagesAsyncPager
634
+ ) -> list[dict[str, str | dict]]:
635
+ """Return a list of serialized JobMessage objects."""
636
+ return [JobMessage.to_dict(message) for message in job_response.job_messages]
637
+
638
+ @cached_property
639
+ def async_hook(self) -> AsyncDataflowHook:
640
+ return AsyncDataflowHook(
641
+ gcp_conn_id=self.gcp_conn_id,
642
+ poll_sleep=self.poll_sleep,
643
+ impersonation_chain=self.impersonation_chain,
644
+ )