apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. airflow/providers/google/__init__.py +3 -3
  2. airflow/providers/google/cloud/hooks/automl.py +1 -1
  3. airflow/providers/google/cloud/hooks/bigquery.py +64 -33
  4. airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
  5. airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
  6. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
  7. airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +246 -32
  9. airflow/providers/google/cloud/hooks/dataplex.py +6 -2
  10. airflow/providers/google/cloud/hooks/dlp.py +14 -14
  11. airflow/providers/google/cloud/hooks/gcs.py +6 -2
  12. airflow/providers/google/cloud/hooks/gdm.py +2 -2
  13. airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
  14. airflow/providers/google/cloud/hooks/mlengine.py +8 -4
  15. airflow/providers/google/cloud/hooks/pubsub.py +1 -1
  16. airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
  17. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
  18. airflow/providers/google/cloud/links/vertex_ai.py +2 -1
  19. airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
  20. airflow/providers/google/cloud/operators/automl.py +13 -12
  21. airflow/providers/google/cloud/operators/bigquery.py +36 -22
  22. airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
  23. airflow/providers/google/cloud/operators/bigtable.py +7 -6
  24. airflow/providers/google/cloud/operators/cloud_build.py +12 -11
  25. airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
  26. airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
  27. airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
  28. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
  29. airflow/providers/google/cloud/operators/compute.py +12 -11
  30. airflow/providers/google/cloud/operators/datacatalog.py +21 -20
  31. airflow/providers/google/cloud/operators/dataflow.py +59 -42
  32. airflow/providers/google/cloud/operators/datafusion.py +11 -10
  33. airflow/providers/google/cloud/operators/datapipeline.py +3 -2
  34. airflow/providers/google/cloud/operators/dataprep.py +5 -4
  35. airflow/providers/google/cloud/operators/dataproc.py +19 -16
  36. airflow/providers/google/cloud/operators/datastore.py +8 -7
  37. airflow/providers/google/cloud/operators/dlp.py +31 -30
  38. airflow/providers/google/cloud/operators/functions.py +4 -3
  39. airflow/providers/google/cloud/operators/gcs.py +66 -41
  40. airflow/providers/google/cloud/operators/kubernetes_engine.py +232 -12
  41. airflow/providers/google/cloud/operators/life_sciences.py +2 -1
  42. airflow/providers/google/cloud/operators/mlengine.py +11 -10
  43. airflow/providers/google/cloud/operators/pubsub.py +6 -5
  44. airflow/providers/google/cloud/operators/spanner.py +7 -6
  45. airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
  46. airflow/providers/google/cloud/operators/stackdriver.py +11 -10
  47. airflow/providers/google/cloud/operators/tasks.py +14 -13
  48. airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
  49. airflow/providers/google/cloud/operators/translate_speech.py +2 -1
  50. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
  51. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
  52. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
  53. airflow/providers/google/cloud/operators/vision.py +13 -12
  54. airflow/providers/google/cloud/operators/workflows.py +10 -9
  55. airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
  56. airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
  57. airflow/providers/google/cloud/sensors/bigtable.py +2 -1
  58. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
  59. airflow/providers/google/cloud/sensors/dataflow.py +239 -52
  60. airflow/providers/google/cloud/sensors/datafusion.py +2 -1
  61. airflow/providers/google/cloud/sensors/dataproc.py +3 -2
  62. airflow/providers/google/cloud/sensors/gcs.py +14 -12
  63. airflow/providers/google/cloud/sensors/tasks.py +2 -1
  64. airflow/providers/google/cloud/sensors/workflows.py +2 -1
  65. airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
  66. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
  67. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
  68. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
  69. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
  70. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
  71. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
  72. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
  73. airflow/providers/google/cloud/triggers/bigquery.py +14 -3
  74. airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
  75. airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
  76. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
  77. airflow/providers/google/cloud/triggers/dataflow.py +504 -4
  78. airflow/providers/google/cloud/triggers/dataproc.py +110 -26
  79. airflow/providers/google/cloud/triggers/mlengine.py +2 -1
  80. airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
  81. airflow/providers/google/common/hooks/base_google.py +45 -7
  82. airflow/providers/google/firebase/hooks/firestore.py +2 -2
  83. airflow/providers/google/firebase/operators/firestore.py +2 -1
  84. airflow/providers/google/get_provider_info.py +3 -2
  85. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/METADATA +8 -8
  86. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/RECORD +88 -89
  87. airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
  88. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/WHEEL +0 -0
  89. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,6 +19,7 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
+ import base64
22
23
  import errno
23
24
  import json
24
25
  import os
@@ -34,7 +35,7 @@ import uuid
34
35
  from inspect import signature
35
36
  from pathlib import Path
36
37
  from subprocess import PIPE, Popen
37
- from tempfile import gettempdir
38
+ from tempfile import NamedTemporaryFile, _TemporaryFileWrapper, gettempdir
38
39
  from typing import TYPE_CHECKING, Any, Sequence
39
40
  from urllib.parse import quote_plus
40
41
 
@@ -49,12 +50,21 @@ from googleapiclient.errors import HttpError
49
50
  from airflow.exceptions import AirflowException
50
51
  from airflow.hooks.base import BaseHook
51
52
  from airflow.models import Connection
52
- from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field
53
+ from airflow.providers.google.cloud.hooks.secret_manager import (
54
+ GoogleCloudSecretManagerHook,
55
+ )
56
+ from airflow.providers.google.common.hooks.base_google import (
57
+ PROVIDE_PROJECT_ID,
58
+ GoogleBaseAsyncHook,
59
+ GoogleBaseHook,
60
+ get_field,
61
+ )
53
62
  from airflow.providers.mysql.hooks.mysql import MySqlHook
54
63
  from airflow.providers.postgres.hooks.postgres import PostgresHook
55
64
  from airflow.utils.log.logging_mixin import LoggingMixin
56
65
 
57
66
  if TYPE_CHECKING:
67
+ from google.cloud.secretmanager_v1 import AccessSecretVersionResponse
58
68
  from requests import Session
59
69
 
60
70
  UNIX_PATH_MAX = 108
@@ -377,6 +387,29 @@ class CloudSQLHook(GoogleBaseHook):
377
387
  except HttpError as ex:
378
388
  raise AirflowException(f"Cloning of instance {instance} failed: {ex.content}")
379
389
 
390
+ @GoogleBaseHook.fallback_to_default_project_id
391
+ def create_ssl_certificate(self, instance: str, body: dict, project_id: str):
392
+ """
393
+ Create SSL certificate for a Cloud SQL instance.
394
+
395
+ :param instance: Cloud SQL instance ID. This does not include the project ID.
396
+ :param body: The request body, as described in
397
+ https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/sslCerts/insert#SslCertsInsertRequest
398
+ :param project_id: Project ID of the project that contains the instance. If set
399
+ to None or missing, the default project_id from the Google Cloud connection is used.
400
+ :return: SslCert insert response. For more details see:
401
+ https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/sslCerts/insert#response-body
402
+ """
403
+ response = (
404
+ self.get_conn()
405
+ .sslCerts()
406
+ .insert(project=project_id, instance=instance, body=body)
407
+ .execute(num_retries=self.num_retries)
408
+ )
409
+ operation_name = response.get("operation", {}).get("name", {})
410
+ self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name)
411
+ return response
412
+
380
413
  @GoogleBaseHook.fallback_to_default_project_id
381
414
  def _wait_for_operation_to_complete(
382
415
  self, project_id: str, operation_name: str, time_to_sleep: int = TIME_TO_SLEEP_IN_SECONDS
@@ -482,7 +515,7 @@ class CloudSqlProxyRunner(LoggingMixin):
482
515
  path_prefix: str,
483
516
  instance_specification: str,
484
517
  gcp_conn_id: str = "google_cloud_default",
485
- project_id: str | None = None,
518
+ project_id: str = PROVIDE_PROJECT_ID,
486
519
  sql_proxy_version: str | None = None,
487
520
  sql_proxy_binary_path: str | None = None,
488
521
  ) -> None:
@@ -758,7 +791,24 @@ class CloudSQLDatabaseHook(BaseHook):
758
791
  :param gcp_conn_id: The connection ID used to connect to Google Cloud for
759
792
  cloud-sql-proxy authentication.
760
793
  :param default_gcp_project_id: Default project id used if project_id not specified
761
- in the connection URL
794
+ in the connection URL
795
+ :param ssl_cert: Optional. Path to client certificate to authenticate when SSL is used. Overrides the
796
+ connection field ``sslcert``.
797
+ :param ssl_key: Optional. Path to client private key to authenticate when SSL is used. Overrides the
798
+ connection field ``sslkey``.
799
+ :param ssl_root_cert: Optional. Path to server's certificate to authenticate when SSL is used. Overrides
800
+ the connection field ``sslrootcert``.
801
+ :param ssl_secret_id: Optional. ID of the secret in Google Cloud Secret Manager that stores SSL
802
+ certificate in the format below:
803
+
804
+ {'sslcert': '',
805
+ 'sslkey': '',
806
+ 'sslrootcert': ''}
807
+
808
+ Overrides the connection fields ``sslcert``, ``sslkey``, ``sslrootcert``.
809
+ Note that according to the Secret Manager requirements, the mentioned dict should be saved as a
810
+ string, and encoded with base64.
811
+ Note that this parameter is incompatible with parameters ``ssl_cert``, ``ssl_key``, ``ssl_root_cert``.
762
812
  """
763
813
 
764
814
  conn_name_attr = "gcp_cloudsql_conn_id"
@@ -770,12 +820,18 @@ class CloudSQLDatabaseHook(BaseHook):
770
820
  self,
771
821
  gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
772
822
  gcp_conn_id: str = "google_cloud_default",
823
+ impersonation_chain: str | Sequence[str] | None = None,
773
824
  default_gcp_project_id: str | None = None,
774
825
  sql_proxy_binary_path: str | None = None,
826
+ ssl_cert: str | None = None,
827
+ ssl_key: str | None = None,
828
+ ssl_root_cert: str | None = None,
829
+ ssl_secret_id: str | None = None,
775
830
  ) -> None:
776
831
  super().__init__()
777
832
  self.gcp_conn_id = gcp_conn_id
778
833
  self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id
834
+ self.impersonation_chain = impersonation_chain
779
835
  self.cloudsql_connection = self.get_connection(self.gcp_cloudsql_conn_id)
780
836
  self.extras = self.cloudsql_connection.extra_dejson
781
837
  self.project_id = self.extras.get("project_id", default_gcp_project_id)
@@ -792,9 +848,11 @@ class CloudSQLDatabaseHook(BaseHook):
792
848
  self.password = self.cloudsql_connection.password
793
849
  self.public_ip = self.cloudsql_connection.host
794
850
  self.public_port = self.cloudsql_connection.port
795
- self.sslcert = self.extras.get("sslcert")
796
- self.sslkey = self.extras.get("sslkey")
797
- self.sslrootcert = self.extras.get("sslrootcert")
851
+ self.ssl_cert = ssl_cert
852
+ self.ssl_key = ssl_key
853
+ self.ssl_root_cert = ssl_root_cert
854
+ self.ssl_secret_id = ssl_secret_id
855
+ self._ssl_cert_temp_files: dict[str, _TemporaryFileWrapper] = {}
798
856
  # Port and socket path and db_hook are automatically generated
799
857
  self.sql_proxy_tcp_port = None
800
858
  self.sql_proxy_unique_path: str | None = None
@@ -805,6 +863,84 @@ class CloudSQLDatabaseHook(BaseHook):
805
863
  self.db_conn_id = str(uuid.uuid1())
806
864
  self._validate_inputs()
807
865
 
866
+ @property
867
+ def sslcert(self) -> str | None:
868
+ return self._get_ssl_temporary_file_path(cert_name="sslcert", cert_path=self.ssl_cert)
869
+
870
+ @property
871
+ def sslkey(self) -> str | None:
872
+ return self._get_ssl_temporary_file_path(cert_name="sslkey", cert_path=self.ssl_key)
873
+
874
+ @property
875
+ def sslrootcert(self) -> str | None:
876
+ return self._get_ssl_temporary_file_path(cert_name="sslrootcert", cert_path=self.ssl_root_cert)
877
+
878
+ def _get_ssl_temporary_file_path(self, cert_name: str, cert_path: str | None) -> str | None:
879
+ cert_value = self._get_cert_from_secret(cert_name)
880
+ original_cert_path = cert_path or self.extras.get(cert_name)
881
+ if cert_value or original_cert_path:
882
+ if cert_name not in self._ssl_cert_temp_files:
883
+ return self._set_temporary_ssl_file(
884
+ cert_name=cert_name, cert_path=original_cert_path, cert_value=cert_value
885
+ )
886
+ return self._ssl_cert_temp_files[cert_name].name
887
+ return None
888
+
889
+ def _get_cert_from_secret(self, cert_name: str) -> str | None:
890
+ if not self.ssl_secret_id:
891
+ return None
892
+
893
+ secret_hook = GoogleCloudSecretManagerHook(
894
+ gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
895
+ )
896
+ secret: AccessSecretVersionResponse = secret_hook.access_secret(
897
+ project_id=self.project_id,
898
+ secret_id=self.ssl_secret_id,
899
+ )
900
+ secret_data = json.loads(base64.b64decode(secret.payload.data))
901
+ if cert_name in secret_data:
902
+ return secret_data[cert_name]
903
+ else:
904
+ raise AirflowException(
905
+ "Invalid secret format. Expected dictionary with keys: `sslcert`, `sslkey`, `sslrootcert`"
906
+ )
907
+
908
+ def _set_temporary_ssl_file(
909
+ self, cert_name: str, cert_path: str | None = None, cert_value: str | None = None
910
+ ) -> str | None:
911
+ """Save the certificate as a temporary file.
912
+
913
+ This method was implemented in order to overcome psql connection error caused by excessive file
914
+ permissions: "private key file "..." has group or world access; file must have permissions
915
+ u=rw (0600) or less if owned by the current user, or permissions u=rw,g=r (0640) or less if owned
916
+ by root". NamedTemporaryFile enforces using exactly one of create/read/write/append mode so the
917
+ created file obtains least required permissions "-rw-------" that satisfies the rules.
918
+
919
+ :param cert_name: Required. Name of the certificate (one of sslcert, sslkey, sslrootcert).
920
+ :param cert_path: Optional. Path to the certificate.
921
+ :param cert_value: Optional. The certificate content.
922
+
923
+ :returns: The path to the temporary certificate file.
924
+ """
925
+ if all([cert_path, cert_value]):
926
+ raise AirflowException(
927
+ "Both parameters were specified: `cert_path`, `cert_value`. Please use only one of them."
928
+ )
929
+ if not any([cert_path, cert_value]):
930
+ self.log.info("Neither cert path and cert value provided. Nothing to save.")
931
+ return None
932
+
933
+ _temp_file = NamedTemporaryFile(mode="w+b", prefix="/tmp/certs/")
934
+ if cert_path:
935
+ with open(cert_path, "rb") as cert_file:
936
+ _temp_file.write(cert_file.read())
937
+ elif cert_value:
938
+ _temp_file.write(cert_value.encode("ascii"))
939
+ _temp_file.flush()
940
+ self._ssl_cert_temp_files[cert_name] = _temp_file
941
+ self.log.info("Copied the certificate '%s' into a temporary file '%s'", cert_name, _temp_file.name)
942
+ return _temp_file.name
943
+
808
944
  @staticmethod
809
945
  def _get_bool(val: Any) -> bool:
810
946
  if val == "False" or val is False:
@@ -836,6 +972,17 @@ class CloudSQLDatabaseHook(BaseHook):
836
972
  " SSL is not needed as Cloud SQL Proxy "
837
973
  "provides encryption on its own"
838
974
  )
975
+ if any([self.ssl_key, self.ssl_cert, self.ssl_root_cert]) and self.ssl_secret_id:
976
+ raise AirflowException(
977
+ "Invalid SSL settings. Please use either all of parameters ['ssl_cert', 'ssl_cert', "
978
+ "'ssl_root_cert'] or a single parameter 'ssl_secret_id'."
979
+ )
980
+ if any([self.ssl_key, self.ssl_cert, self.ssl_root_cert]):
981
+ field_names = ["ssl_key", "ssl_cert", "ssl_root_cert"]
982
+ if missed_values := [field for field in field_names if not getattr(self, field)]:
983
+ s = "s are" if len(missed_values) > 1 else "is"
984
+ missed_values_str = ", ".join(f for f in missed_values)
985
+ raise AirflowException(f"Invalid SSL settings. Parameter{s} missing: {missed_values_str}")
839
986
 
840
987
  def validate_ssl_certs(self) -> None:
841
988
  """
@@ -46,7 +46,11 @@ from googleapiclient.errors import HttpError
46
46
 
47
47
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
48
48
  from airflow.providers.google.common.consts import CLIENT_INFO
49
- from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
49
+ from airflow.providers.google.common.hooks.base_google import (
50
+ PROVIDE_PROJECT_ID,
51
+ GoogleBaseAsyncHook,
52
+ GoogleBaseHook,
53
+ )
50
54
 
51
55
  if TYPE_CHECKING:
52
56
  from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
@@ -84,6 +88,7 @@ ALREADY_EXISTING_IN_SINK = "overwriteObjectsAlreadyExistingInSink"
84
88
  AWS_ACCESS_KEY = "awsAccessKey"
85
89
  AWS_SECRET_ACCESS_KEY = "secretAccessKey"
86
90
  AWS_S3_DATA_SOURCE = "awsS3DataSource"
91
+ AWS_ROLE_ARN = "roleArn"
87
92
  BODY = "body"
88
93
  BUCKET_NAME = "bucketName"
89
94
  COUNTERS = "counters"
@@ -504,7 +509,7 @@ class CloudDataTransferServiceHook(GoogleBaseHook):
504
509
  class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
505
510
  """Asynchronous hook for Google Storage Transfer Service."""
506
511
 
507
- def __init__(self, project_id: str | None = None, **kwargs: Any) -> None:
512
+ def __init__(self, project_id: str = PROVIDE_PROJECT_ID, **kwargs: Any) -> None:
508
513
  super().__init__(**kwargs)
509
514
  self.project_id = project_id
510
515
  self._client: StorageTransferServiceAsyncClient | None = None
@@ -29,6 +29,7 @@ from paramiko.ssh_exception import SSHException
29
29
  from airflow.exceptions import AirflowException
30
30
  from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook
31
31
  from airflow.providers.google.cloud.hooks.os_login import OSLoginHook
32
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
32
33
  from airflow.providers.ssh.hooks.ssh import SSHHook
33
34
  from airflow.utils.types import NOTSET, ArgNotSet
34
35
 
@@ -109,7 +110,7 @@ class ComputeEngineSSHHook(SSHHook):
109
110
  instance_name: str | None = None,
110
111
  zone: str | None = None,
111
112
  user: str | None = "root",
112
- project_id: str | None = None,
113
+ project_id: str = PROVIDE_PROJECT_ID,
113
114
  hostname: str | None = None,
114
115
  use_internal_ip: bool = False,
115
116
  use_iap_tunnel: bool = False,
@@ -31,9 +31,23 @@ from copy import deepcopy
31
31
  from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence, TypeVar, cast
32
32
 
33
33
  from deprecated import deprecated
34
- from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobView
34
+ from google.cloud.dataflow_v1beta3 import (
35
+ GetJobRequest,
36
+ Job,
37
+ JobState,
38
+ JobsV1Beta3AsyncClient,
39
+ JobView,
40
+ ListJobMessagesRequest,
41
+ MessagesV1Beta3AsyncClient,
42
+ MetricsV1Beta3AsyncClient,
43
+ )
44
+ from google.cloud.dataflow_v1beta3.types import (
45
+ GetJobMetricsRequest,
46
+ JobMessageImportance,
47
+ JobMetrics,
48
+ )
35
49
  from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest
36
- from googleapiclient.discovery import build
50
+ from googleapiclient.discovery import Resource, build
37
51
 
38
52
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
39
53
  from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
@@ -47,6 +61,8 @@ from airflow.utils.timeout import timeout
47
61
 
48
62
  if TYPE_CHECKING:
49
63
  from google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.pagers import ListJobsAsyncPager
64
+ from google.cloud.dataflow_v1beta3.services.messages_v1_beta3.pagers import ListJobMessagesAsyncPager
65
+ from google.protobuf.timestamp_pb2 import Timestamp
50
66
 
51
67
 
52
68
  # This is the default location
@@ -561,7 +577,7 @@ class DataflowHook(GoogleBaseHook):
561
577
  impersonation_chain=impersonation_chain,
562
578
  )
563
579
 
564
- def get_conn(self) -> build:
580
+ def get_conn(self) -> Resource:
565
581
  """Return a Google Cloud Dataflow service object."""
566
582
  http_authorized = self._authorize()
567
583
  return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False)
@@ -641,9 +657,9 @@ class DataflowHook(GoogleBaseHook):
641
657
  on_new_job_callback: Callable[[dict], None] | None = None,
642
658
  location: str = DEFAULT_DATAFLOW_LOCATION,
643
659
  environment: dict | None = None,
644
- ) -> dict:
660
+ ) -> dict[str, str]:
645
661
  """
646
- Start Dataflow template job.
662
+ Launch a Dataflow job with a Classic Template and wait for its completion.
647
663
 
648
664
  :param job_name: The name of the job.
649
665
  :param variables: Map of job runtime environment options.
@@ -676,26 +692,14 @@ class DataflowHook(GoogleBaseHook):
676
692
  environment=environment,
677
693
  )
678
694
 
679
- service = self.get_conn()
680
-
681
- request = (
682
- service.projects()
683
- .locations()
684
- .templates()
685
- .launch(
686
- projectId=project_id,
687
- location=location,
688
- gcsPath=dataflow_template,
689
- body={
690
- "jobName": name,
691
- "parameters": parameters,
692
- "environment": environment,
693
- },
694
- )
695
+ job: dict[str, str] = self.send_launch_template_request(
696
+ project_id=project_id,
697
+ location=location,
698
+ gcs_path=dataflow_template,
699
+ job_name=name,
700
+ parameters=parameters,
701
+ environment=environment,
695
702
  )
696
- response = request.execute(num_retries=self.num_retries)
697
-
698
- job = response["job"]
699
703
 
700
704
  if on_new_job_id_callback:
701
705
  warnings.warn(
@@ -703,7 +707,7 @@ class DataflowHook(GoogleBaseHook):
703
707
  AirflowProviderDeprecationWarning,
704
708
  stacklevel=3,
705
709
  )
706
- on_new_job_id_callback(job.get("id"))
710
+ on_new_job_id_callback(job["id"])
707
711
 
708
712
  if on_new_job_callback:
709
713
  on_new_job_callback(job)
@@ -722,7 +726,62 @@ class DataflowHook(GoogleBaseHook):
722
726
  expected_terminal_state=self.expected_terminal_state,
723
727
  )
724
728
  jobs_controller.wait_for_done()
725
- return response["job"]
729
+ return job
730
+
731
+ @_fallback_to_location_from_variables
732
+ @_fallback_to_project_id_from_variables
733
+ @GoogleBaseHook.fallback_to_default_project_id
734
+ def launch_job_with_template(
735
+ self,
736
+ *,
737
+ job_name: str,
738
+ variables: dict,
739
+ parameters: dict,
740
+ dataflow_template: str,
741
+ project_id: str,
742
+ append_job_name: bool = True,
743
+ location: str = DEFAULT_DATAFLOW_LOCATION,
744
+ environment: dict | None = None,
745
+ ) -> dict[str, str]:
746
+ """
747
+ Launch a Dataflow job with a Classic Template and exit without waiting for its completion.
748
+
749
+ :param job_name: The name of the job.
750
+ :param variables: Map of job runtime environment options.
751
+ It will update environment argument if passed.
752
+
753
+ .. seealso::
754
+ For more information on possible configurations, look at the API documentation
755
+ `https://cloud.google.com/dataflow/pipelines/specifying-exec-params
756
+ <https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
757
+
758
+ :param parameters: Parameters for the template
759
+ :param dataflow_template: GCS path to the template.
760
+ :param project_id: Optional, the Google Cloud project ID in which to start a job.
761
+ If set to None or missing, the default project_id from the Google Cloud connection is used.
762
+ :param append_job_name: True if unique suffix has to be appended to job name.
763
+ :param location: Job location.
764
+
765
+ .. seealso::
766
+ For more information on possible configurations, look at the API documentation
767
+ `https://cloud.google.com/dataflow/pipelines/specifying-exec-params
768
+ <https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
769
+ :return: the Dataflow job response
770
+ """
771
+ name = self.build_dataflow_job_name(job_name, append_job_name)
772
+ environment = self._update_environment(
773
+ variables=variables,
774
+ environment=environment,
775
+ )
776
+ job: dict[str, str] = self.send_launch_template_request(
777
+ project_id=project_id,
778
+ location=location,
779
+ gcs_path=dataflow_template,
780
+ job_name=name,
781
+ parameters=parameters,
782
+ environment=environment,
783
+ )
784
+ return job
726
785
 
727
786
  def _update_environment(self, variables: dict, environment: dict | None = None) -> dict:
728
787
  environment = environment or {}
@@ -758,6 +817,35 @@ class DataflowHook(GoogleBaseHook):
758
817
 
759
818
  return environment
760
819
 
820
+ def send_launch_template_request(
821
+ self,
822
+ *,
823
+ project_id: str,
824
+ location: str,
825
+ gcs_path: str,
826
+ job_name: str,
827
+ parameters: dict,
828
+ environment: dict,
829
+ ) -> dict[str, str]:
830
+ service: Resource = self.get_conn()
831
+ request = (
832
+ service.projects()
833
+ .locations()
834
+ .templates()
835
+ .launch(
836
+ projectId=project_id,
837
+ location=location,
838
+ gcsPath=gcs_path,
839
+ body={
840
+ "jobName": job_name,
841
+ "parameters": parameters,
842
+ "environment": environment,
843
+ },
844
+ )
845
+ )
846
+ response: dict = request.execute(num_retries=self.num_retries)
847
+ return response["job"]
848
+
761
849
  @GoogleBaseHook.fallback_to_default_project_id
762
850
  def start_flex_template(
763
851
  self,
@@ -766,9 +854,9 @@ class DataflowHook(GoogleBaseHook):
766
854
  project_id: str,
767
855
  on_new_job_id_callback: Callable[[str], None] | None = None,
768
856
  on_new_job_callback: Callable[[dict], None] | None = None,
769
- ) -> dict:
857
+ ) -> dict[str, str]:
770
858
  """
771
- Start flex templates with the Dataflow pipeline.
859
+ Launch a Dataflow job with a Flex Template and wait for its completion.
772
860
 
773
861
  :param body: The request body. See:
774
862
  https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
@@ -779,15 +867,16 @@ class DataflowHook(GoogleBaseHook):
779
867
  :param on_new_job_callback: A callback that is called when a Job is detected.
780
868
  :return: the Job
781
869
  """
782
- service = self.get_conn()
870
+ service: Resource = self.get_conn()
783
871
  request = (
784
872
  service.projects()
785
873
  .locations()
786
874
  .flexTemplates()
787
875
  .launch(projectId=project_id, body=body, location=location)
788
876
  )
789
- response = request.execute(num_retries=self.num_retries)
877
+ response: dict = request.execute(num_retries=self.num_retries)
790
878
  job = response["job"]
879
+ job_id: str = job["id"]
791
880
 
792
881
  if on_new_job_id_callback:
793
882
  warnings.warn(
@@ -795,7 +884,7 @@ class DataflowHook(GoogleBaseHook):
795
884
  AirflowProviderDeprecationWarning,
796
885
  stacklevel=3,
797
886
  )
798
- on_new_job_id_callback(job.get("id"))
887
+ on_new_job_id_callback(job_id)
799
888
 
800
889
  if on_new_job_callback:
801
890
  on_new_job_callback(job)
@@ -803,7 +892,7 @@ class DataflowHook(GoogleBaseHook):
803
892
  jobs_controller = _DataflowJobsController(
804
893
  dataflow=self.get_conn(),
805
894
  project_number=project_id,
806
- job_id=job.get("id"),
895
+ job_id=job_id,
807
896
  location=location,
808
897
  poll_sleep=self.poll_sleep,
809
898
  num_retries=self.num_retries,
@@ -814,6 +903,42 @@ class DataflowHook(GoogleBaseHook):
814
903
 
815
904
  return jobs_controller.get_jobs(refresh=True)[0]
816
905
 
906
+ @GoogleBaseHook.fallback_to_default_project_id
907
+ def launch_job_with_flex_template(
908
+ self,
909
+ body: dict,
910
+ location: str,
911
+ project_id: str,
912
+ ) -> dict[str, str]:
913
+ """
914
+ Launch a Dataflow Job with a Flex Template and exit without waiting for the job completion.
915
+
916
+ :param body: The request body. See:
917
+ https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
918
+ :param location: The location of the Dataflow job (for example europe-west1)
919
+ :param project_id: The ID of the GCP project that owns the job.
920
+ If set to ``None`` or missing, the default project_id from the GCP connection is used.
921
+ :return: a Dataflow job response
922
+ """
923
+ service: Resource = self.get_conn()
924
+ request = (
925
+ service.projects()
926
+ .locations()
927
+ .flexTemplates()
928
+ .launch(projectId=project_id, body=body, location=location)
929
+ )
930
+ response: dict = request.execute(num_retries=self.num_retries)
931
+ return response["job"]
932
+
933
+ @staticmethod
934
+ def extract_job_id(job: dict) -> str:
935
+ try:
936
+ return job["id"]
937
+ except KeyError:
938
+ raise AirflowException(
939
+ "While reading job object after template execution error occurred. Job object has no id."
940
+ )
941
+
817
942
  @_fallback_to_location_from_variables
818
943
  @_fallback_to_project_id_from_variables
819
944
  @GoogleBaseHook.fallback_to_default_project_id
@@ -1353,3 +1478,92 @@ class AsyncDataflowHook(GoogleBaseAsyncHook):
1353
1478
  )
1354
1479
  page_result: ListJobsAsyncPager = await client.list_jobs(request=request)
1355
1480
  return page_result
1481
+
1482
+ async def list_job_messages(
1483
+ self,
1484
+ job_id: str,
1485
+ project_id: str | None = PROVIDE_PROJECT_ID,
1486
+ minimum_importance: int = JobMessageImportance.JOB_MESSAGE_BASIC,
1487
+ page_size: int | None = None,
1488
+ page_token: str | None = None,
1489
+ start_time: Timestamp | None = None,
1490
+ end_time: Timestamp | None = None,
1491
+ location: str | None = DEFAULT_DATAFLOW_LOCATION,
1492
+ ) -> ListJobMessagesAsyncPager:
1493
+ """
1494
+ Return ListJobMessagesAsyncPager object from MessagesV1Beta3AsyncClient.
1495
+
1496
+ This method wraps around a similar method of MessagesV1Beta3AsyncClient. ListJobMessagesAsyncPager can be iterated
1497
+ over to extract messages associated with a specific Job ID.
1498
+
1499
+ For more details see the MessagesV1Beta3AsyncClient method description at:
1500
+ https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.messages_v1_beta3.MessagesV1Beta3AsyncClient
1501
+
1502
+ :param job_id: ID of the Dataflow job to get messages about.
1503
+ :param project_id: Optional. The Google Cloud project ID in which to start a job.
1504
+ If set to None or missing, the default project_id from the Google Cloud connection is used.
1505
+ :param minimum_importance: Optional. Filter to only get messages with importance >= level.
1506
+ For more details see the description at:
1507
+ https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.types.JobMessageImportance
1508
+ :param page_size: Optional. If specified, determines the maximum number of messages to return.
1509
+ If unspecified, the service may choose an appropriate default, or may return an arbitrarily large number of results.
1510
+ :param page_token: Optional. If supplied, this should be the value of next_page_token returned by an earlier call.
1511
+ This will cause the next page of results to be returned.
1512
+ :param start_time: Optional. If specified, return only messages with timestamps >= start_time.
1513
+ The default is the job creation time (i.e. beginning of messages).
1514
+ :param end_time: Optional. If specified, return only messages with timestamps < end_time. The default is the current time.
1515
+ :param location: Optional. The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains
1516
+ the job specified by job_id.
1517
+ """
1518
+ project_id = project_id or (await self.get_project_id())
1519
+ client = await self.initialize_client(MessagesV1Beta3AsyncClient)
1520
+ request = ListJobMessagesRequest(
1521
+ {
1522
+ "project_id": project_id,
1523
+ "job_id": job_id,
1524
+ "minimum_importance": minimum_importance,
1525
+ "page_size": page_size,
1526
+ "page_token": page_token,
1527
+ "start_time": start_time,
1528
+ "end_time": end_time,
1529
+ "location": location,
1530
+ }
1531
+ )
1532
+ page_results: ListJobMessagesAsyncPager = await client.list_job_messages(request=request)
1533
+ return page_results
1534
+
1535
+ async def get_job_metrics(
1536
+ self,
1537
+ job_id: str,
1538
+ project_id: str | None = PROVIDE_PROJECT_ID,
1539
+ start_time: Timestamp | None = None,
1540
+ location: str | None = DEFAULT_DATAFLOW_LOCATION,
1541
+ ) -> JobMetrics:
1542
+ """
1543
+ Return JobMetrics object from MetricsV1Beta3AsyncClient.
1544
+
1545
+ This method wraps around a similar method of MetricsV1Beta3AsyncClient.
1546
+
1547
+ For more details see the MetricsV1Beta3AsyncClient method description at:
1548
+ https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.metrics_v1_beta3.MetricsV1Beta3AsyncClient
1549
+
1550
+ :param job_id: ID of the Dataflow job to get metrics for.
1551
+ :param project_id: Optional. The Google Cloud project ID in which to start a job.
1552
+ If set to None or missing, the default project_id from the Google Cloud connection is used.
1553
+ :param start_time: Optional. Return only metric data that has changed since this time.
1554
+ Default is to return all information about all metrics for the job.
1555
+ :param location: Optional. The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains
1556
+ the job specified by job_id.
1557
+ """
1558
+ project_id = project_id or (await self.get_project_id())
1559
+ client: MetricsV1Beta3AsyncClient = await self.initialize_client(MetricsV1Beta3AsyncClient)
1560
+ request = GetJobMetricsRequest(
1561
+ {
1562
+ "project_id": project_id,
1563
+ "job_id": job_id,
1564
+ "start_time": start_time,
1565
+ "location": location,
1566
+ }
1567
+ )
1568
+ job_metrics: JobMetrics = await client.get_job_metrics(request=request)
1569
+ return job_metrics