apache-airflow-providers-google 10.18.0rc1__py3-none-any.whl → 10.18.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +2 -5
- airflow/providers/google/cloud/hooks/automl.py +34 -0
- airflow/providers/google/cloud/hooks/bigquery.py +62 -8
- airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +91 -0
- airflow/providers/google/cloud/operators/automl.py +230 -25
- airflow/providers/google/cloud/operators/bigquery.py +128 -40
- airflow/providers/google/cloud/operators/dataproc.py +1 -1
- airflow/providers/google/cloud/operators/kubernetes_engine.py +24 -37
- airflow/providers/google/cloud/operators/workflows.py +2 -5
- airflow/providers/google/cloud/triggers/bigquery.py +64 -6
- airflow/providers/google/cloud/triggers/dataproc.py +82 -3
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -3
- airflow/providers/google/get_provider_info.py +3 -2
- {apache_airflow_providers_google-10.18.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc2.dist-info}/METADATA +7 -7
- {apache_airflow_providers_google-10.18.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc2.dist-info}/RECORD +17 -16
- {apache_airflow_providers_google-10.18.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc2.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.18.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc2.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Sequence, SupportsAbs
|
|
29
29
|
import attr
|
30
30
|
from deprecated import deprecated
|
31
31
|
from google.api_core.exceptions import Conflict
|
32
|
-
from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob
|
32
|
+
from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob, Row
|
33
33
|
from google.cloud.bigquery.table import RowIterator
|
34
34
|
|
35
35
|
from airflow.configuration import conf
|
@@ -57,6 +57,7 @@ from airflow.providers.google.cloud.triggers.bigquery import (
|
|
57
57
|
)
|
58
58
|
from airflow.providers.google.cloud.utils.bigquery import convert_job_id
|
59
59
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
60
|
+
from airflow.utils.helpers import exactly_one
|
60
61
|
|
61
62
|
if TYPE_CHECKING:
|
62
63
|
from google.api_core.retry import Retry
|
@@ -67,7 +68,7 @@ if TYPE_CHECKING:
|
|
67
68
|
|
68
69
|
BIGQUERY_JOB_DETAILS_LINK_FMT = "https://console.cloud.google.com/bigquery?j={job_id}"
|
69
70
|
|
70
|
-
LABEL_REGEX = re.compile(r"^[
|
71
|
+
LABEL_REGEX = re.compile(r"^[\w-]{0,63}$")
|
71
72
|
|
72
73
|
|
73
74
|
class BigQueryUIColors(enum.Enum):
|
@@ -202,7 +203,25 @@ class _BigQueryOpenLineageMixin:
|
|
202
203
|
)
|
203
204
|
|
204
205
|
|
205
|
-
class
|
206
|
+
class _BigQueryOperatorsEncryptionConfigurationMixin:
|
207
|
+
"""A class to handle the configuration for BigQueryHook.insert_job method."""
|
208
|
+
|
209
|
+
# Note: If you want to add this feature to a new operator you can include the class name in the type
|
210
|
+
# annotation of the `self`. Then you can inherit this class in the target operator.
|
211
|
+
# e.g: BigQueryCheckOperator, BigQueryTableCheckOperator
|
212
|
+
def include_encryption_configuration( # type:ignore[misc]
|
213
|
+
self: BigQueryCheckOperator | BigQueryTableCheckOperator,
|
214
|
+
configuration: dict,
|
215
|
+
config_key: str,
|
216
|
+
) -> None:
|
217
|
+
"""Add encryption_configuration to destinationEncryptionConfiguration key if it is not None."""
|
218
|
+
if self.encryption_configuration is not None:
|
219
|
+
configuration[config_key]["destinationEncryptionConfiguration"] = self.encryption_configuration
|
220
|
+
|
221
|
+
|
222
|
+
class BigQueryCheckOperator(
|
223
|
+
_BigQueryDbHookMixin, SQLCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin
|
224
|
+
):
|
206
225
|
"""Performs checks against BigQuery.
|
207
226
|
|
208
227
|
This operator expects a SQL query that returns a single row. Each value on
|
@@ -247,6 +266,13 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
|
|
247
266
|
Token Creator IAM role to the directly preceding identity, with first
|
248
267
|
account from the list granting this role to the originating account. (templated)
|
249
268
|
:param labels: a dictionary containing labels for the table, passed to BigQuery.
|
269
|
+
:param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys).
|
270
|
+
|
271
|
+
.. code-block:: python
|
272
|
+
|
273
|
+
encryption_configuration = {
|
274
|
+
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
|
275
|
+
}
|
250
276
|
:param deferrable: Run operator in the deferrable mode.
|
251
277
|
:param poll_interval: (Deferrable mode only) polling period in seconds to
|
252
278
|
check for the status of job.
|
@@ -271,6 +297,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
|
|
271
297
|
location: str | None = None,
|
272
298
|
impersonation_chain: str | Sequence[str] | None = None,
|
273
299
|
labels: dict | None = None,
|
300
|
+
encryption_configuration: dict | None = None,
|
274
301
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
275
302
|
poll_interval: float = 4.0,
|
276
303
|
**kwargs,
|
@@ -281,6 +308,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
|
|
281
308
|
self.location = location
|
282
309
|
self.impersonation_chain = impersonation_chain
|
283
310
|
self.labels = labels
|
311
|
+
self.encryption_configuration = encryption_configuration
|
284
312
|
self.deferrable = deferrable
|
285
313
|
self.poll_interval = poll_interval
|
286
314
|
|
@@ -292,6 +320,8 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
|
|
292
320
|
"""Submit a new job and get the job id for polling the status using Trigger."""
|
293
321
|
configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}}
|
294
322
|
|
323
|
+
self.include_encryption_configuration(configuration, "query")
|
324
|
+
|
295
325
|
return hook.insert_job(
|
296
326
|
configuration=configuration,
|
297
327
|
project_id=hook.project_id,
|
@@ -766,7 +796,9 @@ class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator):
|
|
766
796
|
self.log.info("All tests have passed")
|
767
797
|
|
768
798
|
|
769
|
-
class BigQueryTableCheckOperator(
|
799
|
+
class BigQueryTableCheckOperator(
|
800
|
+
_BigQueryDbHookMixin, SQLTableCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin
|
801
|
+
):
|
770
802
|
"""
|
771
803
|
Subclasses the SQLTableCheckOperator in order to provide a job id for OpenLineage to parse.
|
772
804
|
|
@@ -794,6 +826,13 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
|
|
794
826
|
Service Account Token Creator IAM role to the directly preceding identity, with first
|
795
827
|
account from the list granting this role to the originating account (templated).
|
796
828
|
:param labels: a dictionary containing labels for the table, passed to BigQuery
|
829
|
+
:param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys).
|
830
|
+
|
831
|
+
.. code-block:: python
|
832
|
+
|
833
|
+
encryption_configuration = {
|
834
|
+
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
|
835
|
+
}
|
797
836
|
"""
|
798
837
|
|
799
838
|
template_fields: Sequence[str] = tuple(set(SQLTableCheckOperator.template_fields) | {"gcp_conn_id"})
|
@@ -811,6 +850,7 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
|
|
811
850
|
location: str | None = None,
|
812
851
|
impersonation_chain: str | Sequence[str] | None = None,
|
813
852
|
labels: dict | None = None,
|
853
|
+
encryption_configuration: dict | None = None,
|
814
854
|
**kwargs,
|
815
855
|
) -> None:
|
816
856
|
super().__init__(table=table, checks=checks, partition_clause=partition_clause, **kwargs)
|
@@ -819,6 +859,7 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
|
|
819
859
|
self.location = location
|
820
860
|
self.impersonation_chain = impersonation_chain
|
821
861
|
self.labels = labels
|
862
|
+
self.encryption_configuration = encryption_configuration
|
822
863
|
|
823
864
|
def _submit_job(
|
824
865
|
self,
|
@@ -828,6 +869,8 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
|
|
828
869
|
"""Submit a new job and get the job id for polling the status using Trigger."""
|
829
870
|
configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}}
|
830
871
|
|
872
|
+
self.include_encryption_configuration(configuration, "query")
|
873
|
+
|
831
874
|
return hook.insert_job(
|
832
875
|
configuration=configuration,
|
833
876
|
project_id=hook.project_id,
|
@@ -871,9 +914,10 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
|
|
871
914
|
|
872
915
|
class BigQueryGetDataOperator(GoogleCloudBaseOperator):
|
873
916
|
"""
|
874
|
-
|
917
|
+
Fetch data and return it, either from a BigQuery table, or results of a query job.
|
875
918
|
|
876
|
-
Data
|
919
|
+
Data could be narrowed down by specific columns or retrieved as a whole.
|
920
|
+
It is returned in either of the following two formats, based on "as_dict" value:
|
877
921
|
1. False (Default) - A Python list of lists, with the number of nested lists equal to the number of rows
|
878
922
|
fetched. Each nested list represents a row, where the elements within it correspond to the column values
|
879
923
|
for that particular row.
|
@@ -893,27 +937,42 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
|
|
893
937
|
.. note::
|
894
938
|
If you pass fields to ``selected_fields`` which are in different order than the
|
895
939
|
order of columns already in
|
896
|
-
BQ table, the data will still be in the order of BQ table.
|
940
|
+
BQ table/job, the data will still be in the order of BQ table.
|
897
941
|
For example if the BQ table has 3 columns as
|
898
942
|
``[A,B,C]`` and you pass 'B,A' in the ``selected_fields``
|
899
943
|
the data would still be of the form ``'A,B'``.
|
900
944
|
|
901
|
-
|
945
|
+
.. note::
|
946
|
+
When utilizing job id not in deferrable mode, the job should be in DONE state.
|
947
|
+
|
948
|
+
**Example - Retrieve data from BigQuery using table**::
|
902
949
|
|
903
950
|
get_data = BigQueryGetDataOperator(
|
904
951
|
task_id="get_data_from_bq",
|
905
952
|
dataset_id="test_dataset",
|
906
953
|
table_id="Transaction_partitions",
|
907
|
-
|
954
|
+
table_project_id="internal-gcp-project",
|
955
|
+
max_results=100,
|
956
|
+
selected_fields="DATE",
|
957
|
+
gcp_conn_id="airflow-conn-id",
|
958
|
+
)
|
959
|
+
|
960
|
+
**Example - Retrieve data from BigQuery using a job id**::
|
961
|
+
|
962
|
+
get_data = BigQueryGetDataOperator(
|
963
|
+
job_id="airflow_8999918812727394_86a1cecc69c5e3028d28247affd7563",
|
964
|
+
job_project_id="internal-gcp-project",
|
908
965
|
max_results=100,
|
909
966
|
selected_fields="DATE",
|
910
967
|
gcp_conn_id="airflow-conn-id",
|
911
968
|
)
|
912
969
|
|
913
970
|
:param dataset_id: The dataset ID of the requested table. (templated)
|
914
|
-
:param table_id: The table ID of the requested table. (templated)
|
971
|
+
:param table_id: The table ID of the requested table. Mutually exclusive with job_id. (templated)
|
915
972
|
:param table_project_id: (Optional) The project ID of the requested table.
|
916
973
|
If None, it will be derived from the hook's project ID. (templated)
|
974
|
+
:param job_id: The job ID from which query results are retrieved.
|
975
|
+
Mutually exclusive with table_id. (templated)
|
917
976
|
:param job_project_id: (Optional) Google Cloud Project where the job is running.
|
918
977
|
If None, it will be derived from the hook's project ID. (templated)
|
919
978
|
:param project_id: (Deprecated) (Optional) The name of the project where the data
|
@@ -944,6 +1003,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
|
|
944
1003
|
"dataset_id",
|
945
1004
|
"table_id",
|
946
1005
|
"table_project_id",
|
1006
|
+
"job_id",
|
947
1007
|
"job_project_id",
|
948
1008
|
"project_id",
|
949
1009
|
"max_results",
|
@@ -955,9 +1015,10 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
|
|
955
1015
|
def __init__(
|
956
1016
|
self,
|
957
1017
|
*,
|
958
|
-
dataset_id: str,
|
959
|
-
table_id: str,
|
1018
|
+
dataset_id: str | None = None,
|
1019
|
+
table_id: str | None = None,
|
960
1020
|
table_project_id: str | None = None,
|
1021
|
+
job_id: str | None = None,
|
961
1022
|
job_project_id: str | None = None,
|
962
1023
|
project_id: str = PROVIDE_PROJECT_ID,
|
963
1024
|
max_results: int = 100,
|
@@ -977,6 +1038,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
|
|
977
1038
|
self.dataset_id = dataset_id
|
978
1039
|
self.table_id = table_id
|
979
1040
|
self.job_project_id = job_project_id
|
1041
|
+
self.job_id = job_id
|
980
1042
|
self.max_results = max_results
|
981
1043
|
self.selected_fields = selected_fields
|
982
1044
|
self.gcp_conn_id = gcp_conn_id
|
@@ -1013,7 +1075,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
|
|
1013
1075
|
query += "*"
|
1014
1076
|
query += (
|
1015
1077
|
f" from `{self.table_project_id or hook.project_id}.{self.dataset_id}"
|
1016
|
-
f".{self.table_id}` limit {
|
1078
|
+
f".{self.table_id}` limit {self.max_results}"
|
1017
1079
|
)
|
1018
1080
|
return query
|
1019
1081
|
|
@@ -1026,7 +1088,13 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
|
|
1026
1088
|
if not self.table_project_id:
|
1027
1089
|
self.table_project_id = self.project_id
|
1028
1090
|
else:
|
1029
|
-
self.log.info("Ignoring project_id parameter, as table_project_id is found.")
|
1091
|
+
self.log.info("Ignoring 'project_id' parameter, as 'table_project_id' is found.")
|
1092
|
+
|
1093
|
+
if not exactly_one(self.job_id, self.table_id):
|
1094
|
+
raise AirflowException(
|
1095
|
+
"'job_id' and 'table_id' parameters are mutually exclusive, "
|
1096
|
+
"ensure that exactly one of them is specified"
|
1097
|
+
)
|
1030
1098
|
|
1031
1099
|
hook = BigQueryHook(
|
1032
1100
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -1035,31 +1103,45 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
|
|
1035
1103
|
)
|
1036
1104
|
|
1037
1105
|
if not self.deferrable:
|
1038
|
-
self.
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1106
|
+
if not self.job_id:
|
1107
|
+
self.log.info(
|
1108
|
+
"Fetching Data from %s.%s.%s max results: %s",
|
1109
|
+
self.table_project_id or hook.project_id,
|
1110
|
+
self.dataset_id,
|
1111
|
+
self.table_id,
|
1112
|
+
self.max_results,
|
1113
|
+
)
|
1114
|
+
if not self.selected_fields:
|
1115
|
+
schema: dict[str, list] = hook.get_schema(
|
1116
|
+
dataset_id=self.dataset_id,
|
1117
|
+
table_id=self.table_id,
|
1118
|
+
project_id=self.table_project_id or hook.project_id,
|
1119
|
+
)
|
1120
|
+
if "fields" in schema:
|
1121
|
+
self.selected_fields = ",".join([field["name"] for field in schema["fields"]])
|
1122
|
+
rows: list[Row] | RowIterator | list[dict[str, Any]] = hook.list_rows(
|
1047
1123
|
dataset_id=self.dataset_id,
|
1048
1124
|
table_id=self.table_id,
|
1125
|
+
max_results=self.max_results,
|
1126
|
+
selected_fields=self.selected_fields,
|
1127
|
+
location=self.location,
|
1049
1128
|
project_id=self.table_project_id or hook.project_id,
|
1050
1129
|
)
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1130
|
+
else:
|
1131
|
+
self.log.info(
|
1132
|
+
"Fetching data from job '%s:%s.%s' max results: %s",
|
1133
|
+
self.job_project_id or hook.project_id,
|
1134
|
+
self.location,
|
1135
|
+
self.job_id,
|
1136
|
+
self.max_results,
|
1137
|
+
)
|
1138
|
+
rows = hook.get_query_results(
|
1139
|
+
job_id=self.job_id,
|
1140
|
+
location=self.location,
|
1141
|
+
selected_fields=self.selected_fields,
|
1142
|
+
max_results=self.max_results,
|
1143
|
+
project_id=self.job_project_id or hook.project_id,
|
1144
|
+
)
|
1063
1145
|
if isinstance(rows, RowIterator):
|
1064
1146
|
raise TypeError(
|
1065
1147
|
"BigQueryHook.list_rows() returns iterator when return_iterator is False (default)"
|
@@ -1069,11 +1151,16 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
|
|
1069
1151
|
if self.as_dict:
|
1070
1152
|
table_data = [dict(row) for row in rows]
|
1071
1153
|
else:
|
1072
|
-
table_data = [row.values() for row in rows]
|
1154
|
+
table_data = [row.values() if isinstance(row, Row) else list(row.values()) for row in rows]
|
1073
1155
|
|
1074
1156
|
return table_data
|
1075
1157
|
|
1076
|
-
|
1158
|
+
if not self.job_id:
|
1159
|
+
job: BigQueryJob | UnknownJob = self._submit_job(hook, job_id="")
|
1160
|
+
else:
|
1161
|
+
job = hook.get_job(
|
1162
|
+
job_id=self.job_id, project_id=self.job_project_id or hook.project_id, location=self.location
|
1163
|
+
)
|
1077
1164
|
|
1078
1165
|
context["ti"].xcom_push(key="job_id", value=job.job_id)
|
1079
1166
|
self.defer(
|
@@ -1088,6 +1175,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
|
|
1088
1175
|
poll_interval=self.poll_interval,
|
1089
1176
|
as_dict=self.as_dict,
|
1090
1177
|
impersonation_chain=self.impersonation_chain,
|
1178
|
+
selected_fields=self.selected_fields,
|
1091
1179
|
),
|
1092
1180
|
method_name="execute_complete",
|
1093
1181
|
)
|
@@ -1176,7 +1264,7 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
|
|
1176
1264
|
.. code-block:: python
|
1177
1265
|
|
1178
1266
|
encryption_configuration = {
|
1179
|
-
"kmsKeyName": "projects/
|
1267
|
+
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
|
1180
1268
|
}
|
1181
1269
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
1182
1270
|
credentials, or chained list of accounts required to get the access_token
|
@@ -1416,7 +1504,7 @@ class BigQueryCreateEmptyTableOperator(GoogleCloudBaseOperator):
|
|
1416
1504
|
.. code-block:: python
|
1417
1505
|
|
1418
1506
|
encryption_configuration = {
|
1419
|
-
"kmsKeyName": "projects/
|
1507
|
+
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
|
1420
1508
|
}
|
1421
1509
|
:param location: The location used for the operation.
|
1422
1510
|
:param cluster_fields: [Optional] The fields used for clustering.
|
@@ -1644,7 +1732,7 @@ class BigQueryCreateExternalTableOperator(GoogleCloudBaseOperator):
|
|
1644
1732
|
.. code-block:: python
|
1645
1733
|
|
1646
1734
|
encryption_configuration = {
|
1647
|
-
"kmsKeyName": "projects/
|
1735
|
+
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
|
1648
1736
|
}
|
1649
1737
|
:param location: The location used for the operation.
|
1650
1738
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
@@ -2023,7 +2023,7 @@ class DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator):
|
|
2023
2023
|
|
2024
2024
|
@staticmethod
|
2025
2025
|
def _generate_temp_filename(filename):
|
2026
|
-
return f"{time
|
2026
|
+
return f"{time.strftime('%Y%m%d%H%M%S')}_{uuid.uuid4()!s:.8}_{ntpath.basename(filename)}"
|
2027
2027
|
|
2028
2028
|
def _upload_file_temp(self, bucket, local_file):
|
2029
2029
|
"""Upload a local file to a Google Cloud Storage bucket."""
|
@@ -19,7 +19,6 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
-
import re
|
23
22
|
import warnings
|
24
23
|
from functools import cached_property
|
25
24
|
from typing import TYPE_CHECKING, Any, Sequence
|
@@ -43,12 +42,8 @@ from airflow.providers.cncf.kubernetes.operators.resource import (
|
|
43
42
|
)
|
44
43
|
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
|
45
44
|
from airflow.providers.google.cloud.hooks.kubernetes_engine import (
|
46
|
-
GKECustomResourceHook,
|
47
|
-
GKEDeploymentHook,
|
48
45
|
GKEHook,
|
49
|
-
GKEJobHook,
|
50
46
|
GKEKubernetesHook,
|
51
|
-
GKEPodHook,
|
52
47
|
)
|
53
48
|
from airflow.providers.google.cloud.links.kubernetes_engine import (
|
54
49
|
KubernetesEngineClusterLink,
|
@@ -534,13 +529,13 @@ class GKEStartKueueInsideClusterOperator(GoogleCloudBaseOperator):
|
|
534
529
|
)
|
535
530
|
|
536
531
|
@cached_property
|
537
|
-
def deployment_hook(self) ->
|
532
|
+
def deployment_hook(self) -> GKEKubernetesHook:
|
538
533
|
if self._cluster_url is None or self._ssl_ca_cert is None:
|
539
534
|
raise AttributeError(
|
540
|
-
"Cluster url and ssl_ca_cert should be defined before using self.
|
535
|
+
"Cluster url and ssl_ca_cert should be defined before using self.deployment_hook method. "
|
541
536
|
"Try to use self.get_kube_creds method",
|
542
537
|
)
|
543
|
-
return
|
538
|
+
return GKEKubernetesHook(
|
544
539
|
gcp_conn_id=self.gcp_conn_id,
|
545
540
|
impersonation_chain=self.impersonation_chain,
|
546
541
|
cluster_url=self._cluster_url,
|
@@ -548,13 +543,14 @@ class GKEStartKueueInsideClusterOperator(GoogleCloudBaseOperator):
|
|
548
543
|
)
|
549
544
|
|
550
545
|
@cached_property
|
551
|
-
def pod_hook(self) ->
|
546
|
+
def pod_hook(self) -> GKEKubernetesHook:
|
552
547
|
if self._cluster_url is None or self._ssl_ca_cert is None:
|
553
548
|
raise AttributeError(
|
554
|
-
"Cluster url and ssl_ca_cert should be defined before using self.
|
549
|
+
"Cluster url and ssl_ca_cert should be defined before using self.pod_hook method. "
|
555
550
|
"Try to use self.get_kube_creds method",
|
556
551
|
)
|
557
|
-
|
552
|
+
|
553
|
+
return GKEKubernetesHook(
|
558
554
|
gcp_conn_id=self.gcp_conn_id,
|
559
555
|
impersonation_chain=self.impersonation_chain,
|
560
556
|
cluster_url=self._cluster_url,
|
@@ -566,17 +562,10 @@ class GKEStartKueueInsideClusterOperator(GoogleCloudBaseOperator):
|
|
566
562
|
def _get_yaml_content_from_file(kueue_yaml_url) -> list[dict]:
|
567
563
|
"""Download content of YAML file and separate it into several dictionaries."""
|
568
564
|
response = requests.get(kueue_yaml_url, allow_redirects=True)
|
569
|
-
|
570
|
-
if response.status_code == 200:
|
571
|
-
yaml_data = response.text
|
572
|
-
documents = re.split(r"---\n", yaml_data)
|
573
|
-
|
574
|
-
for document in documents:
|
575
|
-
document_dict = yaml.safe_load(document)
|
576
|
-
yaml_dicts.append(document_dict)
|
577
|
-
else:
|
565
|
+
if response.status_code != 200:
|
578
566
|
raise AirflowException("Was not able to read the yaml file from given URL")
|
579
|
-
|
567
|
+
|
568
|
+
return list(yaml.safe_load_all(response.text))
|
580
569
|
|
581
570
|
def execute(self, context: Context):
|
582
571
|
self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails(
|
@@ -750,21 +739,20 @@ class GKEStartPodOperator(KubernetesPodOperator):
|
|
750
739
|
)
|
751
740
|
|
752
741
|
@cached_property
|
753
|
-
def hook(self) ->
|
742
|
+
def hook(self) -> GKEKubernetesHook:
|
754
743
|
if self._cluster_url is None or self._ssl_ca_cert is None:
|
755
744
|
raise AttributeError(
|
756
745
|
"Cluster url and ssl_ca_cert should be defined before using self.hook method. "
|
757
746
|
"Try to use self.get_kube_creds method",
|
758
747
|
)
|
759
748
|
|
760
|
-
|
749
|
+
return GKEKubernetesHook(
|
761
750
|
gcp_conn_id=self.gcp_conn_id,
|
762
751
|
cluster_url=self._cluster_url,
|
763
752
|
ssl_ca_cert=self._ssl_ca_cert,
|
764
753
|
impersonation_chain=self.impersonation_chain,
|
765
754
|
enable_tcp_keepalive=True,
|
766
755
|
)
|
767
|
-
return hook
|
768
756
|
|
769
757
|
def execute(self, context: Context):
|
770
758
|
"""Execute process of creating pod and executing provided command inside it."""
|
@@ -909,19 +897,18 @@ class GKEStartJobOperator(KubernetesJobOperator):
|
|
909
897
|
)
|
910
898
|
|
911
899
|
@cached_property
|
912
|
-
def hook(self) ->
|
900
|
+
def hook(self) -> GKEKubernetesHook:
|
913
901
|
if self._cluster_url is None or self._ssl_ca_cert is None:
|
914
902
|
raise AttributeError(
|
915
903
|
"Cluster url and ssl_ca_cert should be defined before using self.hook method. "
|
916
904
|
"Try to use self.get_kube_creds method",
|
917
905
|
)
|
918
906
|
|
919
|
-
|
907
|
+
return GKEKubernetesHook(
|
920
908
|
gcp_conn_id=self.gcp_conn_id,
|
921
909
|
cluster_url=self._cluster_url,
|
922
910
|
ssl_ca_cert=self._ssl_ca_cert,
|
923
911
|
)
|
924
|
-
return hook
|
925
912
|
|
926
913
|
def execute(self, context: Context):
|
927
914
|
"""Execute process of creating Job."""
|
@@ -1035,7 +1022,7 @@ class GKEDescribeJobOperator(GoogleCloudBaseOperator):
|
|
1035
1022
|
)
|
1036
1023
|
|
1037
1024
|
@cached_property
|
1038
|
-
def hook(self) ->
|
1025
|
+
def hook(self) -> GKEKubernetesHook:
|
1039
1026
|
self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails(
|
1040
1027
|
cluster_name=self.cluster_name,
|
1041
1028
|
project_id=self.project_id,
|
@@ -1043,7 +1030,7 @@ class GKEDescribeJobOperator(GoogleCloudBaseOperator):
|
|
1043
1030
|
cluster_hook=self.cluster_hook,
|
1044
1031
|
).fetch_cluster_info()
|
1045
1032
|
|
1046
|
-
return
|
1033
|
+
return GKEKubernetesHook(
|
1047
1034
|
gcp_conn_id=self.gcp_conn_id,
|
1048
1035
|
cluster_url=self._cluster_url,
|
1049
1036
|
ssl_ca_cert=self._ssl_ca_cert,
|
@@ -1136,7 +1123,7 @@ class GKEListJobsOperator(GoogleCloudBaseOperator):
|
|
1136
1123
|
)
|
1137
1124
|
|
1138
1125
|
@cached_property
|
1139
|
-
def hook(self) ->
|
1126
|
+
def hook(self) -> GKEKubernetesHook:
|
1140
1127
|
self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails(
|
1141
1128
|
cluster_name=self.cluster_name,
|
1142
1129
|
project_id=self.project_id,
|
@@ -1144,7 +1131,7 @@ class GKEListJobsOperator(GoogleCloudBaseOperator):
|
|
1144
1131
|
cluster_hook=self.cluster_hook,
|
1145
1132
|
).fetch_cluster_info()
|
1146
1133
|
|
1147
|
-
return
|
1134
|
+
return GKEKubernetesHook(
|
1148
1135
|
gcp_conn_id=self.gcp_conn_id,
|
1149
1136
|
cluster_url=self._cluster_url,
|
1150
1137
|
ssl_ca_cert=self._ssl_ca_cert,
|
@@ -1242,13 +1229,13 @@ class GKECreateCustomResourceOperator(KubernetesCreateResourceOperator):
|
|
1242
1229
|
)
|
1243
1230
|
|
1244
1231
|
@cached_property
|
1245
|
-
def hook(self) ->
|
1232
|
+
def hook(self) -> GKEKubernetesHook:
|
1246
1233
|
if self._cluster_url is None or self._ssl_ca_cert is None:
|
1247
1234
|
raise AttributeError(
|
1248
1235
|
"Cluster url and ssl_ca_cert should be defined before using self.hook method. "
|
1249
1236
|
"Try to use self.get_kube_creds method",
|
1250
1237
|
)
|
1251
|
-
return
|
1238
|
+
return GKEKubernetesHook(
|
1252
1239
|
gcp_conn_id=self.gcp_conn_id,
|
1253
1240
|
cluster_url=self._cluster_url,
|
1254
1241
|
ssl_ca_cert=self._ssl_ca_cert,
|
@@ -1344,13 +1331,13 @@ class GKEDeleteCustomResourceOperator(KubernetesDeleteResourceOperator):
|
|
1344
1331
|
)
|
1345
1332
|
|
1346
1333
|
@cached_property
|
1347
|
-
def hook(self) ->
|
1334
|
+
def hook(self) -> GKEKubernetesHook:
|
1348
1335
|
if self._cluster_url is None or self._ssl_ca_cert is None:
|
1349
1336
|
raise AttributeError(
|
1350
1337
|
"Cluster url and ssl_ca_cert should be defined before using self.hook method. "
|
1351
1338
|
"Try to use self.get_kube_creds method",
|
1352
1339
|
)
|
1353
|
-
return
|
1340
|
+
return GKEKubernetesHook(
|
1354
1341
|
gcp_conn_id=self.gcp_conn_id,
|
1355
1342
|
cluster_url=self._cluster_url,
|
1356
1343
|
ssl_ca_cert=self._ssl_ca_cert,
|
@@ -1483,14 +1470,14 @@ class GKEDeleteJobOperator(KubernetesDeleteJobOperator):
|
|
1483
1470
|
)
|
1484
1471
|
|
1485
1472
|
@cached_property
|
1486
|
-
def hook(self) ->
|
1473
|
+
def hook(self) -> GKEKubernetesHook:
|
1487
1474
|
if self._cluster_url is None or self._ssl_ca_cert is None:
|
1488
1475
|
raise AttributeError(
|
1489
1476
|
"Cluster url and ssl_ca_cert should be defined before using self.hook method. "
|
1490
1477
|
"Try to use self.get_kube_creds method",
|
1491
1478
|
)
|
1492
1479
|
|
1493
|
-
return
|
1480
|
+
return GKEKubernetesHook(
|
1494
1481
|
gcp_conn_id=self.gcp_conn_id,
|
1495
1482
|
cluster_url=self._cluster_url,
|
1496
1483
|
ssl_ca_cert=self._ssl_ca_cert,
|
@@ -41,11 +41,8 @@ if TYPE_CHECKING:
|
|
41
41
|
from google.protobuf.field_mask_pb2 import FieldMask
|
42
42
|
|
43
43
|
from airflow.utils.context import Context
|
44
|
-
|
45
|
-
|
46
|
-
except ModuleNotFoundError:
|
47
|
-
# Remove when Airflow providers min Airflow version is "2.7.0"
|
48
|
-
from hashlib import md5
|
44
|
+
|
45
|
+
from airflow.utils.hashlib_wrapper import md5
|
49
46
|
|
50
47
|
|
51
48
|
class WorkflowsCreateWorkflowOperator(GoogleCloudBaseOperator):
|