apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/cloud/hooks/automl.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +64 -33
- airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
- airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
- airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
- airflow/providers/google/cloud/hooks/dataflow.py +246 -32
- airflow/providers/google/cloud/hooks/dataplex.py +6 -2
- airflow/providers/google/cloud/hooks/dlp.py +14 -14
- airflow/providers/google/cloud/hooks/gcs.py +6 -2
- airflow/providers/google/cloud/hooks/gdm.py +2 -2
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/mlengine.py +8 -4
- airflow/providers/google/cloud/hooks/pubsub.py +1 -1
- airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
- airflow/providers/google/cloud/links/vertex_ai.py +2 -1
- airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
- airflow/providers/google/cloud/operators/automl.py +13 -12
- airflow/providers/google/cloud/operators/bigquery.py +36 -22
- airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
- airflow/providers/google/cloud/operators/bigtable.py +7 -6
- airflow/providers/google/cloud/operators/cloud_build.py +12 -11
- airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
- airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
- airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
- airflow/providers/google/cloud/operators/compute.py +12 -11
- airflow/providers/google/cloud/operators/datacatalog.py +21 -20
- airflow/providers/google/cloud/operators/dataflow.py +59 -42
- airflow/providers/google/cloud/operators/datafusion.py +11 -10
- airflow/providers/google/cloud/operators/datapipeline.py +3 -2
- airflow/providers/google/cloud/operators/dataprep.py +5 -4
- airflow/providers/google/cloud/operators/dataproc.py +19 -16
- airflow/providers/google/cloud/operators/datastore.py +8 -7
- airflow/providers/google/cloud/operators/dlp.py +31 -30
- airflow/providers/google/cloud/operators/functions.py +4 -3
- airflow/providers/google/cloud/operators/gcs.py +66 -41
- airflow/providers/google/cloud/operators/kubernetes_engine.py +232 -12
- airflow/providers/google/cloud/operators/life_sciences.py +2 -1
- airflow/providers/google/cloud/operators/mlengine.py +11 -10
- airflow/providers/google/cloud/operators/pubsub.py +6 -5
- airflow/providers/google/cloud/operators/spanner.py +7 -6
- airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
- airflow/providers/google/cloud/operators/stackdriver.py +11 -10
- airflow/providers/google/cloud/operators/tasks.py +14 -13
- airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
- airflow/providers/google/cloud/operators/translate_speech.py +2 -1
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
- airflow/providers/google/cloud/operators/vision.py +13 -12
- airflow/providers/google/cloud/operators/workflows.py +10 -9
- airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
- airflow/providers/google/cloud/sensors/bigtable.py +2 -1
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
- airflow/providers/google/cloud/sensors/dataflow.py +239 -52
- airflow/providers/google/cloud/sensors/datafusion.py +2 -1
- airflow/providers/google/cloud/sensors/dataproc.py +3 -2
- airflow/providers/google/cloud/sensors/gcs.py +14 -12
- airflow/providers/google/cloud/sensors/tasks.py +2 -1
- airflow/providers/google/cloud/sensors/workflows.py +2 -1
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
- airflow/providers/google/cloud/triggers/bigquery.py +14 -3
- airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
- airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
- airflow/providers/google/cloud/triggers/dataflow.py +504 -4
- airflow/providers/google/cloud/triggers/dataproc.py +110 -26
- airflow/providers/google/cloud/triggers/mlengine.py +2 -1
- airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
- airflow/providers/google/common/hooks/base_google.py +45 -7
- airflow/providers/google/firebase/hooks/firestore.py +2 -2
- airflow/providers/google/firebase/operators/firestore.py +2 -1
- airflow/providers/google/get_provider_info.py +3 -2
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/METADATA +8 -8
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/RECORD +88 -89
- airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,6 +19,7 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
+
import base64
|
22
23
|
import errno
|
23
24
|
import json
|
24
25
|
import os
|
@@ -34,7 +35,7 @@ import uuid
|
|
34
35
|
from inspect import signature
|
35
36
|
from pathlib import Path
|
36
37
|
from subprocess import PIPE, Popen
|
37
|
-
from tempfile import gettempdir
|
38
|
+
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper, gettempdir
|
38
39
|
from typing import TYPE_CHECKING, Any, Sequence
|
39
40
|
from urllib.parse import quote_plus
|
40
41
|
|
@@ -49,12 +50,21 @@ from googleapiclient.errors import HttpError
|
|
49
50
|
from airflow.exceptions import AirflowException
|
50
51
|
from airflow.hooks.base import BaseHook
|
51
52
|
from airflow.models import Connection
|
52
|
-
from airflow.providers.google.
|
53
|
+
from airflow.providers.google.cloud.hooks.secret_manager import (
|
54
|
+
GoogleCloudSecretManagerHook,
|
55
|
+
)
|
56
|
+
from airflow.providers.google.common.hooks.base_google import (
|
57
|
+
PROVIDE_PROJECT_ID,
|
58
|
+
GoogleBaseAsyncHook,
|
59
|
+
GoogleBaseHook,
|
60
|
+
get_field,
|
61
|
+
)
|
53
62
|
from airflow.providers.mysql.hooks.mysql import MySqlHook
|
54
63
|
from airflow.providers.postgres.hooks.postgres import PostgresHook
|
55
64
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
56
65
|
|
57
66
|
if TYPE_CHECKING:
|
67
|
+
from google.cloud.secretmanager_v1 import AccessSecretVersionResponse
|
58
68
|
from requests import Session
|
59
69
|
|
60
70
|
UNIX_PATH_MAX = 108
|
@@ -377,6 +387,29 @@ class CloudSQLHook(GoogleBaseHook):
|
|
377
387
|
except HttpError as ex:
|
378
388
|
raise AirflowException(f"Cloning of instance {instance} failed: {ex.content}")
|
379
389
|
|
390
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
391
|
+
def create_ssl_certificate(self, instance: str, body: dict, project_id: str):
|
392
|
+
"""
|
393
|
+
Create SSL certificate for a Cloud SQL instance.
|
394
|
+
|
395
|
+
:param instance: Cloud SQL instance ID. This does not include the project ID.
|
396
|
+
:param body: The request body, as described in
|
397
|
+
https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/sslCerts/insert#SslCertsInsertRequest
|
398
|
+
:param project_id: Project ID of the project that contains the instance. If set
|
399
|
+
to None or missing, the default project_id from the Google Cloud connection is used.
|
400
|
+
:return: SslCert insert response. For more details see:
|
401
|
+
https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/sslCerts/insert#response-body
|
402
|
+
"""
|
403
|
+
response = (
|
404
|
+
self.get_conn()
|
405
|
+
.sslCerts()
|
406
|
+
.insert(project=project_id, instance=instance, body=body)
|
407
|
+
.execute(num_retries=self.num_retries)
|
408
|
+
)
|
409
|
+
operation_name = response.get("operation", {}).get("name", {})
|
410
|
+
self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name)
|
411
|
+
return response
|
412
|
+
|
380
413
|
@GoogleBaseHook.fallback_to_default_project_id
|
381
414
|
def _wait_for_operation_to_complete(
|
382
415
|
self, project_id: str, operation_name: str, time_to_sleep: int = TIME_TO_SLEEP_IN_SECONDS
|
@@ -482,7 +515,7 @@ class CloudSqlProxyRunner(LoggingMixin):
|
|
482
515
|
path_prefix: str,
|
483
516
|
instance_specification: str,
|
484
517
|
gcp_conn_id: str = "google_cloud_default",
|
485
|
-
project_id: str
|
518
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
486
519
|
sql_proxy_version: str | None = None,
|
487
520
|
sql_proxy_binary_path: str | None = None,
|
488
521
|
) -> None:
|
@@ -758,7 +791,24 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
758
791
|
:param gcp_conn_id: The connection ID used to connect to Google Cloud for
|
759
792
|
cloud-sql-proxy authentication.
|
760
793
|
:param default_gcp_project_id: Default project id used if project_id not specified
|
761
|
-
|
794
|
+
in the connection URL
|
795
|
+
:param ssl_cert: Optional. Path to client certificate to authenticate when SSL is used. Overrides the
|
796
|
+
connection field ``sslcert``.
|
797
|
+
:param ssl_key: Optional. Path to client private key to authenticate when SSL is used. Overrides the
|
798
|
+
connection field ``sslkey``.
|
799
|
+
:param ssl_root_cert: Optional. Path to server's certificate to authenticate when SSL is used. Overrides
|
800
|
+
the connection field ``sslrootcert``.
|
801
|
+
:param ssl_secret_id: Optional. ID of the secret in Google Cloud Secret Manager that stores SSL
|
802
|
+
certificate in the format below:
|
803
|
+
|
804
|
+
{'sslcert': '',
|
805
|
+
'sslkey': '',
|
806
|
+
'sslrootcert': ''}
|
807
|
+
|
808
|
+
Overrides the connection fields ``sslcert``, ``sslkey``, ``sslrootcert``.
|
809
|
+
Note that according to the Secret Manager requirements, the mentioned dict should be saved as a
|
810
|
+
string, and encoded with base64.
|
811
|
+
Note that this parameter is incompatible with parameters ``ssl_cert``, ``ssl_key``, ``ssl_root_cert``.
|
762
812
|
"""
|
763
813
|
|
764
814
|
conn_name_attr = "gcp_cloudsql_conn_id"
|
@@ -770,12 +820,18 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
770
820
|
self,
|
771
821
|
gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
|
772
822
|
gcp_conn_id: str = "google_cloud_default",
|
823
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
773
824
|
default_gcp_project_id: str | None = None,
|
774
825
|
sql_proxy_binary_path: str | None = None,
|
826
|
+
ssl_cert: str | None = None,
|
827
|
+
ssl_key: str | None = None,
|
828
|
+
ssl_root_cert: str | None = None,
|
829
|
+
ssl_secret_id: str | None = None,
|
775
830
|
) -> None:
|
776
831
|
super().__init__()
|
777
832
|
self.gcp_conn_id = gcp_conn_id
|
778
833
|
self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id
|
834
|
+
self.impersonation_chain = impersonation_chain
|
779
835
|
self.cloudsql_connection = self.get_connection(self.gcp_cloudsql_conn_id)
|
780
836
|
self.extras = self.cloudsql_connection.extra_dejson
|
781
837
|
self.project_id = self.extras.get("project_id", default_gcp_project_id)
|
@@ -792,9 +848,11 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
792
848
|
self.password = self.cloudsql_connection.password
|
793
849
|
self.public_ip = self.cloudsql_connection.host
|
794
850
|
self.public_port = self.cloudsql_connection.port
|
795
|
-
self.
|
796
|
-
self.
|
797
|
-
self.
|
851
|
+
self.ssl_cert = ssl_cert
|
852
|
+
self.ssl_key = ssl_key
|
853
|
+
self.ssl_root_cert = ssl_root_cert
|
854
|
+
self.ssl_secret_id = ssl_secret_id
|
855
|
+
self._ssl_cert_temp_files: dict[str, _TemporaryFileWrapper] = {}
|
798
856
|
# Port and socket path and db_hook are automatically generated
|
799
857
|
self.sql_proxy_tcp_port = None
|
800
858
|
self.sql_proxy_unique_path: str | None = None
|
@@ -805,6 +863,84 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
805
863
|
self.db_conn_id = str(uuid.uuid1())
|
806
864
|
self._validate_inputs()
|
807
865
|
|
866
|
+
@property
|
867
|
+
def sslcert(self) -> str | None:
|
868
|
+
return self._get_ssl_temporary_file_path(cert_name="sslcert", cert_path=self.ssl_cert)
|
869
|
+
|
870
|
+
@property
|
871
|
+
def sslkey(self) -> str | None:
|
872
|
+
return self._get_ssl_temporary_file_path(cert_name="sslkey", cert_path=self.ssl_key)
|
873
|
+
|
874
|
+
@property
|
875
|
+
def sslrootcert(self) -> str | None:
|
876
|
+
return self._get_ssl_temporary_file_path(cert_name="sslrootcert", cert_path=self.ssl_root_cert)
|
877
|
+
|
878
|
+
def _get_ssl_temporary_file_path(self, cert_name: str, cert_path: str | None) -> str | None:
|
879
|
+
cert_value = self._get_cert_from_secret(cert_name)
|
880
|
+
original_cert_path = cert_path or self.extras.get(cert_name)
|
881
|
+
if cert_value or original_cert_path:
|
882
|
+
if cert_name not in self._ssl_cert_temp_files:
|
883
|
+
return self._set_temporary_ssl_file(
|
884
|
+
cert_name=cert_name, cert_path=original_cert_path, cert_value=cert_value
|
885
|
+
)
|
886
|
+
return self._ssl_cert_temp_files[cert_name].name
|
887
|
+
return None
|
888
|
+
|
889
|
+
def _get_cert_from_secret(self, cert_name: str) -> str | None:
|
890
|
+
if not self.ssl_secret_id:
|
891
|
+
return None
|
892
|
+
|
893
|
+
secret_hook = GoogleCloudSecretManagerHook(
|
894
|
+
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
|
895
|
+
)
|
896
|
+
secret: AccessSecretVersionResponse = secret_hook.access_secret(
|
897
|
+
project_id=self.project_id,
|
898
|
+
secret_id=self.ssl_secret_id,
|
899
|
+
)
|
900
|
+
secret_data = json.loads(base64.b64decode(secret.payload.data))
|
901
|
+
if cert_name in secret_data:
|
902
|
+
return secret_data[cert_name]
|
903
|
+
else:
|
904
|
+
raise AirflowException(
|
905
|
+
"Invalid secret format. Expected dictionary with keys: `sslcert`, `sslkey`, `sslrootcert`"
|
906
|
+
)
|
907
|
+
|
908
|
+
def _set_temporary_ssl_file(
|
909
|
+
self, cert_name: str, cert_path: str | None = None, cert_value: str | None = None
|
910
|
+
) -> str | None:
|
911
|
+
"""Save the certificate as a temporary file.
|
912
|
+
|
913
|
+
This method was implemented in order to overcome psql connection error caused by excessive file
|
914
|
+
permissions: "private key file "..." has group or world access; file must have permissions
|
915
|
+
u=rw (0600) or less if owned by the current user, or permissions u=rw,g=r (0640) or less if owned
|
916
|
+
by root". NamedTemporaryFile enforces using exactly one of create/read/write/append mode so the
|
917
|
+
created file obtains least required permissions "-rw-------" that satisfies the rules.
|
918
|
+
|
919
|
+
:param cert_name: Required. Name of the certificate (one of sslcert, sslkey, sslrootcert).
|
920
|
+
:param cert_path: Optional. Path to the certificate.
|
921
|
+
:param cert_value: Optional. The certificate content.
|
922
|
+
|
923
|
+
:returns: The path to the temporary certificate file.
|
924
|
+
"""
|
925
|
+
if all([cert_path, cert_value]):
|
926
|
+
raise AirflowException(
|
927
|
+
"Both parameters were specified: `cert_path`, `cert_value`. Please use only one of them."
|
928
|
+
)
|
929
|
+
if not any([cert_path, cert_value]):
|
930
|
+
self.log.info("Neither cert path and cert value provided. Nothing to save.")
|
931
|
+
return None
|
932
|
+
|
933
|
+
_temp_file = NamedTemporaryFile(mode="w+b", prefix="/tmp/certs/")
|
934
|
+
if cert_path:
|
935
|
+
with open(cert_path, "rb") as cert_file:
|
936
|
+
_temp_file.write(cert_file.read())
|
937
|
+
elif cert_value:
|
938
|
+
_temp_file.write(cert_value.encode("ascii"))
|
939
|
+
_temp_file.flush()
|
940
|
+
self._ssl_cert_temp_files[cert_name] = _temp_file
|
941
|
+
self.log.info("Copied the certificate '%s' into a temporary file '%s'", cert_name, _temp_file.name)
|
942
|
+
return _temp_file.name
|
943
|
+
|
808
944
|
@staticmethod
|
809
945
|
def _get_bool(val: Any) -> bool:
|
810
946
|
if val == "False" or val is False:
|
@@ -836,6 +972,17 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
836
972
|
" SSL is not needed as Cloud SQL Proxy "
|
837
973
|
"provides encryption on its own"
|
838
974
|
)
|
975
|
+
if any([self.ssl_key, self.ssl_cert, self.ssl_root_cert]) and self.ssl_secret_id:
|
976
|
+
raise AirflowException(
|
977
|
+
"Invalid SSL settings. Please use either all of parameters ['ssl_cert', 'ssl_cert', "
|
978
|
+
"'ssl_root_cert'] or a single parameter 'ssl_secret_id'."
|
979
|
+
)
|
980
|
+
if any([self.ssl_key, self.ssl_cert, self.ssl_root_cert]):
|
981
|
+
field_names = ["ssl_key", "ssl_cert", "ssl_root_cert"]
|
982
|
+
if missed_values := [field for field in field_names if not getattr(self, field)]:
|
983
|
+
s = "s are" if len(missed_values) > 1 else "is"
|
984
|
+
missed_values_str = ", ".join(f for f in missed_values)
|
985
|
+
raise AirflowException(f"Invalid SSL settings. Parameter{s} missing: {missed_values_str}")
|
839
986
|
|
840
987
|
def validate_ssl_certs(self) -> None:
|
841
988
|
"""
|
@@ -46,7 +46,11 @@ from googleapiclient.errors import HttpError
|
|
46
46
|
|
47
47
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
48
48
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
49
|
-
from airflow.providers.google.common.hooks.base_google import
|
49
|
+
from airflow.providers.google.common.hooks.base_google import (
|
50
|
+
PROVIDE_PROJECT_ID,
|
51
|
+
GoogleBaseAsyncHook,
|
52
|
+
GoogleBaseHook,
|
53
|
+
)
|
50
54
|
|
51
55
|
if TYPE_CHECKING:
|
52
56
|
from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
|
@@ -84,6 +88,7 @@ ALREADY_EXISTING_IN_SINK = "overwriteObjectsAlreadyExistingInSink"
|
|
84
88
|
AWS_ACCESS_KEY = "awsAccessKey"
|
85
89
|
AWS_SECRET_ACCESS_KEY = "secretAccessKey"
|
86
90
|
AWS_S3_DATA_SOURCE = "awsS3DataSource"
|
91
|
+
AWS_ROLE_ARN = "roleArn"
|
87
92
|
BODY = "body"
|
88
93
|
BUCKET_NAME = "bucketName"
|
89
94
|
COUNTERS = "counters"
|
@@ -504,7 +509,7 @@ class CloudDataTransferServiceHook(GoogleBaseHook):
|
|
504
509
|
class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
|
505
510
|
"""Asynchronous hook for Google Storage Transfer Service."""
|
506
511
|
|
507
|
-
def __init__(self, project_id: str
|
512
|
+
def __init__(self, project_id: str = PROVIDE_PROJECT_ID, **kwargs: Any) -> None:
|
508
513
|
super().__init__(**kwargs)
|
509
514
|
self.project_id = project_id
|
510
515
|
self._client: StorageTransferServiceAsyncClient | None = None
|
@@ -29,6 +29,7 @@ from paramiko.ssh_exception import SSHException
|
|
29
29
|
from airflow.exceptions import AirflowException
|
30
30
|
from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook
|
31
31
|
from airflow.providers.google.cloud.hooks.os_login import OSLoginHook
|
32
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
32
33
|
from airflow.providers.ssh.hooks.ssh import SSHHook
|
33
34
|
from airflow.utils.types import NOTSET, ArgNotSet
|
34
35
|
|
@@ -109,7 +110,7 @@ class ComputeEngineSSHHook(SSHHook):
|
|
109
110
|
instance_name: str | None = None,
|
110
111
|
zone: str | None = None,
|
111
112
|
user: str | None = "root",
|
112
|
-
project_id: str
|
113
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
113
114
|
hostname: str | None = None,
|
114
115
|
use_internal_ip: bool = False,
|
115
116
|
use_iap_tunnel: bool = False,
|
@@ -31,9 +31,23 @@ from copy import deepcopy
|
|
31
31
|
from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence, TypeVar, cast
|
32
32
|
|
33
33
|
from deprecated import deprecated
|
34
|
-
from google.cloud.dataflow_v1beta3 import
|
34
|
+
from google.cloud.dataflow_v1beta3 import (
|
35
|
+
GetJobRequest,
|
36
|
+
Job,
|
37
|
+
JobState,
|
38
|
+
JobsV1Beta3AsyncClient,
|
39
|
+
JobView,
|
40
|
+
ListJobMessagesRequest,
|
41
|
+
MessagesV1Beta3AsyncClient,
|
42
|
+
MetricsV1Beta3AsyncClient,
|
43
|
+
)
|
44
|
+
from google.cloud.dataflow_v1beta3.types import (
|
45
|
+
GetJobMetricsRequest,
|
46
|
+
JobMessageImportance,
|
47
|
+
JobMetrics,
|
48
|
+
)
|
35
49
|
from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest
|
36
|
-
from googleapiclient.discovery import build
|
50
|
+
from googleapiclient.discovery import Resource, build
|
37
51
|
|
38
52
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
39
53
|
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
|
@@ -47,6 +61,8 @@ from airflow.utils.timeout import timeout
|
|
47
61
|
|
48
62
|
if TYPE_CHECKING:
|
49
63
|
from google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.pagers import ListJobsAsyncPager
|
64
|
+
from google.cloud.dataflow_v1beta3.services.messages_v1_beta3.pagers import ListJobMessagesAsyncPager
|
65
|
+
from google.protobuf.timestamp_pb2 import Timestamp
|
50
66
|
|
51
67
|
|
52
68
|
# This is the default location
|
@@ -561,7 +577,7 @@ class DataflowHook(GoogleBaseHook):
|
|
561
577
|
impersonation_chain=impersonation_chain,
|
562
578
|
)
|
563
579
|
|
564
|
-
def get_conn(self) ->
|
580
|
+
def get_conn(self) -> Resource:
|
565
581
|
"""Return a Google Cloud Dataflow service object."""
|
566
582
|
http_authorized = self._authorize()
|
567
583
|
return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False)
|
@@ -641,9 +657,9 @@ class DataflowHook(GoogleBaseHook):
|
|
641
657
|
on_new_job_callback: Callable[[dict], None] | None = None,
|
642
658
|
location: str = DEFAULT_DATAFLOW_LOCATION,
|
643
659
|
environment: dict | None = None,
|
644
|
-
) -> dict:
|
660
|
+
) -> dict[str, str]:
|
645
661
|
"""
|
646
|
-
|
662
|
+
Launch a Dataflow job with a Classic Template and wait for its completion.
|
647
663
|
|
648
664
|
:param job_name: The name of the job.
|
649
665
|
:param variables: Map of job runtime environment options.
|
@@ -676,26 +692,14 @@ class DataflowHook(GoogleBaseHook):
|
|
676
692
|
environment=environment,
|
677
693
|
)
|
678
694
|
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
projectId=project_id,
|
687
|
-
location=location,
|
688
|
-
gcsPath=dataflow_template,
|
689
|
-
body={
|
690
|
-
"jobName": name,
|
691
|
-
"parameters": parameters,
|
692
|
-
"environment": environment,
|
693
|
-
},
|
694
|
-
)
|
695
|
+
job: dict[str, str] = self.send_launch_template_request(
|
696
|
+
project_id=project_id,
|
697
|
+
location=location,
|
698
|
+
gcs_path=dataflow_template,
|
699
|
+
job_name=name,
|
700
|
+
parameters=parameters,
|
701
|
+
environment=environment,
|
695
702
|
)
|
696
|
-
response = request.execute(num_retries=self.num_retries)
|
697
|
-
|
698
|
-
job = response["job"]
|
699
703
|
|
700
704
|
if on_new_job_id_callback:
|
701
705
|
warnings.warn(
|
@@ -703,7 +707,7 @@ class DataflowHook(GoogleBaseHook):
|
|
703
707
|
AirflowProviderDeprecationWarning,
|
704
708
|
stacklevel=3,
|
705
709
|
)
|
706
|
-
on_new_job_id_callback(job
|
710
|
+
on_new_job_id_callback(job["id"])
|
707
711
|
|
708
712
|
if on_new_job_callback:
|
709
713
|
on_new_job_callback(job)
|
@@ -722,7 +726,62 @@ class DataflowHook(GoogleBaseHook):
|
|
722
726
|
expected_terminal_state=self.expected_terminal_state,
|
723
727
|
)
|
724
728
|
jobs_controller.wait_for_done()
|
725
|
-
return
|
729
|
+
return job
|
730
|
+
|
731
|
+
@_fallback_to_location_from_variables
|
732
|
+
@_fallback_to_project_id_from_variables
|
733
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
734
|
+
def launch_job_with_template(
|
735
|
+
self,
|
736
|
+
*,
|
737
|
+
job_name: str,
|
738
|
+
variables: dict,
|
739
|
+
parameters: dict,
|
740
|
+
dataflow_template: str,
|
741
|
+
project_id: str,
|
742
|
+
append_job_name: bool = True,
|
743
|
+
location: str = DEFAULT_DATAFLOW_LOCATION,
|
744
|
+
environment: dict | None = None,
|
745
|
+
) -> dict[str, str]:
|
746
|
+
"""
|
747
|
+
Launch a Dataflow job with a Classic Template and exit without waiting for its completion.
|
748
|
+
|
749
|
+
:param job_name: The name of the job.
|
750
|
+
:param variables: Map of job runtime environment options.
|
751
|
+
It will update environment argument if passed.
|
752
|
+
|
753
|
+
.. seealso::
|
754
|
+
For more information on possible configurations, look at the API documentation
|
755
|
+
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
|
756
|
+
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
|
757
|
+
|
758
|
+
:param parameters: Parameters for the template
|
759
|
+
:param dataflow_template: GCS path to the template.
|
760
|
+
:param project_id: Optional, the Google Cloud project ID in which to start a job.
|
761
|
+
If set to None or missing, the default project_id from the Google Cloud connection is used.
|
762
|
+
:param append_job_name: True if unique suffix has to be appended to job name.
|
763
|
+
:param location: Job location.
|
764
|
+
|
765
|
+
.. seealso::
|
766
|
+
For more information on possible configurations, look at the API documentation
|
767
|
+
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
|
768
|
+
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
|
769
|
+
:return: the Dataflow job response
|
770
|
+
"""
|
771
|
+
name = self.build_dataflow_job_name(job_name, append_job_name)
|
772
|
+
environment = self._update_environment(
|
773
|
+
variables=variables,
|
774
|
+
environment=environment,
|
775
|
+
)
|
776
|
+
job: dict[str, str] = self.send_launch_template_request(
|
777
|
+
project_id=project_id,
|
778
|
+
location=location,
|
779
|
+
gcs_path=dataflow_template,
|
780
|
+
job_name=name,
|
781
|
+
parameters=parameters,
|
782
|
+
environment=environment,
|
783
|
+
)
|
784
|
+
return job
|
726
785
|
|
727
786
|
def _update_environment(self, variables: dict, environment: dict | None = None) -> dict:
|
728
787
|
environment = environment or {}
|
@@ -758,6 +817,35 @@ class DataflowHook(GoogleBaseHook):
|
|
758
817
|
|
759
818
|
return environment
|
760
819
|
|
820
|
+
def send_launch_template_request(
|
821
|
+
self,
|
822
|
+
*,
|
823
|
+
project_id: str,
|
824
|
+
location: str,
|
825
|
+
gcs_path: str,
|
826
|
+
job_name: str,
|
827
|
+
parameters: dict,
|
828
|
+
environment: dict,
|
829
|
+
) -> dict[str, str]:
|
830
|
+
service: Resource = self.get_conn()
|
831
|
+
request = (
|
832
|
+
service.projects()
|
833
|
+
.locations()
|
834
|
+
.templates()
|
835
|
+
.launch(
|
836
|
+
projectId=project_id,
|
837
|
+
location=location,
|
838
|
+
gcsPath=gcs_path,
|
839
|
+
body={
|
840
|
+
"jobName": job_name,
|
841
|
+
"parameters": parameters,
|
842
|
+
"environment": environment,
|
843
|
+
},
|
844
|
+
)
|
845
|
+
)
|
846
|
+
response: dict = request.execute(num_retries=self.num_retries)
|
847
|
+
return response["job"]
|
848
|
+
|
761
849
|
@GoogleBaseHook.fallback_to_default_project_id
|
762
850
|
def start_flex_template(
|
763
851
|
self,
|
@@ -766,9 +854,9 @@ class DataflowHook(GoogleBaseHook):
|
|
766
854
|
project_id: str,
|
767
855
|
on_new_job_id_callback: Callable[[str], None] | None = None,
|
768
856
|
on_new_job_callback: Callable[[dict], None] | None = None,
|
769
|
-
) -> dict:
|
857
|
+
) -> dict[str, str]:
|
770
858
|
"""
|
771
|
-
|
859
|
+
Launch a Dataflow job with a Flex Template and wait for its completion.
|
772
860
|
|
773
861
|
:param body: The request body. See:
|
774
862
|
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
|
@@ -779,15 +867,16 @@ class DataflowHook(GoogleBaseHook):
|
|
779
867
|
:param on_new_job_callback: A callback that is called when a Job is detected.
|
780
868
|
:return: the Job
|
781
869
|
"""
|
782
|
-
service = self.get_conn()
|
870
|
+
service: Resource = self.get_conn()
|
783
871
|
request = (
|
784
872
|
service.projects()
|
785
873
|
.locations()
|
786
874
|
.flexTemplates()
|
787
875
|
.launch(projectId=project_id, body=body, location=location)
|
788
876
|
)
|
789
|
-
response = request.execute(num_retries=self.num_retries)
|
877
|
+
response: dict = request.execute(num_retries=self.num_retries)
|
790
878
|
job = response["job"]
|
879
|
+
job_id: str = job["id"]
|
791
880
|
|
792
881
|
if on_new_job_id_callback:
|
793
882
|
warnings.warn(
|
@@ -795,7 +884,7 @@ class DataflowHook(GoogleBaseHook):
|
|
795
884
|
AirflowProviderDeprecationWarning,
|
796
885
|
stacklevel=3,
|
797
886
|
)
|
798
|
-
on_new_job_id_callback(
|
887
|
+
on_new_job_id_callback(job_id)
|
799
888
|
|
800
889
|
if on_new_job_callback:
|
801
890
|
on_new_job_callback(job)
|
@@ -803,7 +892,7 @@ class DataflowHook(GoogleBaseHook):
|
|
803
892
|
jobs_controller = _DataflowJobsController(
|
804
893
|
dataflow=self.get_conn(),
|
805
894
|
project_number=project_id,
|
806
|
-
job_id=
|
895
|
+
job_id=job_id,
|
807
896
|
location=location,
|
808
897
|
poll_sleep=self.poll_sleep,
|
809
898
|
num_retries=self.num_retries,
|
@@ -814,6 +903,42 @@ class DataflowHook(GoogleBaseHook):
|
|
814
903
|
|
815
904
|
return jobs_controller.get_jobs(refresh=True)[0]
|
816
905
|
|
906
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
907
|
+
def launch_job_with_flex_template(
|
908
|
+
self,
|
909
|
+
body: dict,
|
910
|
+
location: str,
|
911
|
+
project_id: str,
|
912
|
+
) -> dict[str, str]:
|
913
|
+
"""
|
914
|
+
Launch a Dataflow Job with a Flex Template and exit without waiting for the job completion.
|
915
|
+
|
916
|
+
:param body: The request body. See:
|
917
|
+
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
|
918
|
+
:param location: The location of the Dataflow job (for example europe-west1)
|
919
|
+
:param project_id: The ID of the GCP project that owns the job.
|
920
|
+
If set to ``None`` or missing, the default project_id from the GCP connection is used.
|
921
|
+
:return: a Dataflow job response
|
922
|
+
"""
|
923
|
+
service: Resource = self.get_conn()
|
924
|
+
request = (
|
925
|
+
service.projects()
|
926
|
+
.locations()
|
927
|
+
.flexTemplates()
|
928
|
+
.launch(projectId=project_id, body=body, location=location)
|
929
|
+
)
|
930
|
+
response: dict = request.execute(num_retries=self.num_retries)
|
931
|
+
return response["job"]
|
932
|
+
|
933
|
+
@staticmethod
|
934
|
+
def extract_job_id(job: dict) -> str:
|
935
|
+
try:
|
936
|
+
return job["id"]
|
937
|
+
except KeyError:
|
938
|
+
raise AirflowException(
|
939
|
+
"While reading job object after template execution error occurred. Job object has no id."
|
940
|
+
)
|
941
|
+
|
817
942
|
@_fallback_to_location_from_variables
|
818
943
|
@_fallback_to_project_id_from_variables
|
819
944
|
@GoogleBaseHook.fallback_to_default_project_id
|
@@ -1353,3 +1478,92 @@ class AsyncDataflowHook(GoogleBaseAsyncHook):
|
|
1353
1478
|
)
|
1354
1479
|
page_result: ListJobsAsyncPager = await client.list_jobs(request=request)
|
1355
1480
|
return page_result
|
1481
|
+
|
1482
|
+
async def list_job_messages(
|
1483
|
+
self,
|
1484
|
+
job_id: str,
|
1485
|
+
project_id: str | None = PROVIDE_PROJECT_ID,
|
1486
|
+
minimum_importance: int = JobMessageImportance.JOB_MESSAGE_BASIC,
|
1487
|
+
page_size: int | None = None,
|
1488
|
+
page_token: str | None = None,
|
1489
|
+
start_time: Timestamp | None = None,
|
1490
|
+
end_time: Timestamp | None = None,
|
1491
|
+
location: str | None = DEFAULT_DATAFLOW_LOCATION,
|
1492
|
+
) -> ListJobMessagesAsyncPager:
|
1493
|
+
"""
|
1494
|
+
Return ListJobMessagesAsyncPager object from MessagesV1Beta3AsyncClient.
|
1495
|
+
|
1496
|
+
This method wraps around a similar method of MessagesV1Beta3AsyncClient. ListJobMessagesAsyncPager can be iterated
|
1497
|
+
over to extract messages associated with a specific Job ID.
|
1498
|
+
|
1499
|
+
For more details see the MessagesV1Beta3AsyncClient method description at:
|
1500
|
+
https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.messages_v1_beta3.MessagesV1Beta3AsyncClient
|
1501
|
+
|
1502
|
+
:param job_id: ID of the Dataflow job to get messages about.
|
1503
|
+
:param project_id: Optional. The Google Cloud project ID in which to start a job.
|
1504
|
+
If set to None or missing, the default project_id from the Google Cloud connection is used.
|
1505
|
+
:param minimum_importance: Optional. Filter to only get messages with importance >= level.
|
1506
|
+
For more details see the description at:
|
1507
|
+
https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.types.JobMessageImportance
|
1508
|
+
:param page_size: Optional. If specified, determines the maximum number of messages to return.
|
1509
|
+
If unspecified, the service may choose an appropriate default, or may return an arbitrarily large number of results.
|
1510
|
+
:param page_token: Optional. If supplied, this should be the value of next_page_token returned by an earlier call.
|
1511
|
+
This will cause the next page of results to be returned.
|
1512
|
+
:param start_time: Optional. If specified, return only messages with timestamps >= start_time.
|
1513
|
+
The default is the job creation time (i.e. beginning of messages).
|
1514
|
+
:param end_time: Optional. If specified, return only messages with timestamps < end_time. The default is the current time.
|
1515
|
+
:param location: Optional. The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains
|
1516
|
+
the job specified by job_id.
|
1517
|
+
"""
|
1518
|
+
project_id = project_id or (await self.get_project_id())
|
1519
|
+
client = await self.initialize_client(MessagesV1Beta3AsyncClient)
|
1520
|
+
request = ListJobMessagesRequest(
|
1521
|
+
{
|
1522
|
+
"project_id": project_id,
|
1523
|
+
"job_id": job_id,
|
1524
|
+
"minimum_importance": minimum_importance,
|
1525
|
+
"page_size": page_size,
|
1526
|
+
"page_token": page_token,
|
1527
|
+
"start_time": start_time,
|
1528
|
+
"end_time": end_time,
|
1529
|
+
"location": location,
|
1530
|
+
}
|
1531
|
+
)
|
1532
|
+
page_results: ListJobMessagesAsyncPager = await client.list_job_messages(request=request)
|
1533
|
+
return page_results
|
1534
|
+
|
1535
|
+
async def get_job_metrics(
|
1536
|
+
self,
|
1537
|
+
job_id: str,
|
1538
|
+
project_id: str | None = PROVIDE_PROJECT_ID,
|
1539
|
+
start_time: Timestamp | None = None,
|
1540
|
+
location: str | None = DEFAULT_DATAFLOW_LOCATION,
|
1541
|
+
) -> JobMetrics:
|
1542
|
+
"""
|
1543
|
+
Return JobMetrics object from MetricsV1Beta3AsyncClient.
|
1544
|
+
|
1545
|
+
This method wraps around a similar method of MetricsV1Beta3AsyncClient.
|
1546
|
+
|
1547
|
+
For more details see the MetricsV1Beta3AsyncClient method description at:
|
1548
|
+
https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.metrics_v1_beta3.MetricsV1Beta3AsyncClient
|
1549
|
+
|
1550
|
+
:param job_id: ID of the Dataflow job to get metrics for.
|
1551
|
+
:param project_id: Optional. The Google Cloud project ID in which to start a job.
|
1552
|
+
If set to None or missing, the default project_id from the Google Cloud connection is used.
|
1553
|
+
:param start_time: Optional. Return only metric data that has changed since this time.
|
1554
|
+
Default is to return all information about all metrics for the job.
|
1555
|
+
:param location: Optional. The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains
|
1556
|
+
the job specified by job_id.
|
1557
|
+
"""
|
1558
|
+
project_id = project_id or (await self.get_project_id())
|
1559
|
+
client: MetricsV1Beta3AsyncClient = await self.initialize_client(MetricsV1Beta3AsyncClient)
|
1560
|
+
request = GetJobMetricsRequest(
|
1561
|
+
{
|
1562
|
+
"project_id": project_id,
|
1563
|
+
"job_id": job_id,
|
1564
|
+
"start_time": start_time,
|
1565
|
+
"location": location,
|
1566
|
+
}
|
1567
|
+
)
|
1568
|
+
job_metrics: JobMetrics = await client.get_job_metrics(request=request)
|
1569
|
+
return job_metrics
|