apache-airflow-providers-google 10.19.0rc1__py3-none-any.whl → 10.20.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. airflow/providers/google/LICENSE +4 -4
  2. airflow/providers/google/__init__.py +1 -1
  3. airflow/providers/google/ads/hooks/ads.py +4 -4
  4. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +26 -0
  5. airflow/providers/google/cloud/hooks/dataflow.py +132 -1
  6. airflow/providers/google/cloud/hooks/datapipeline.py +22 -73
  7. airflow/providers/google/cloud/hooks/gcs.py +21 -0
  8. airflow/providers/google/cloud/hooks/pubsub.py +10 -1
  9. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -0
  10. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +15 -3
  11. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
  12. airflow/providers/google/cloud/links/dataflow.py +25 -0
  13. airflow/providers/google/cloud/openlineage/mixins.py +271 -0
  14. airflow/providers/google/cloud/openlineage/utils.py +5 -218
  15. airflow/providers/google/cloud/operators/bigquery.py +74 -20
  16. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +76 -0
  17. airflow/providers/google/cloud/operators/dataflow.py +235 -1
  18. airflow/providers/google/cloud/operators/datapipeline.py +29 -121
  19. airflow/providers/google/cloud/operators/dataplex.py +1 -1
  20. airflow/providers/google/cloud/operators/dataproc_metastore.py +17 -6
  21. airflow/providers/google/cloud/operators/kubernetes_engine.py +9 -6
  22. airflow/providers/google/cloud/operators/pubsub.py +18 -0
  23. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +6 -0
  24. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +16 -0
  25. airflow/providers/google/cloud/sensors/cloud_composer.py +171 -2
  26. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +13 -0
  27. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +56 -1
  28. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +6 -12
  29. airflow/providers/google/cloud/triggers/cloud_composer.py +115 -0
  30. airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -0
  31. airflow/providers/google/cloud/utils/credentials_provider.py +81 -6
  32. airflow/providers/google/cloud/utils/external_token_supplier.py +175 -0
  33. airflow/providers/google/common/hooks/base_google.py +35 -1
  34. airflow/providers/google/common/utils/id_token_credentials.py +1 -1
  35. airflow/providers/google/get_provider_info.py +19 -14
  36. {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/METADATA +41 -35
  37. {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/RECORD +39 -37
  38. {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/WHEEL +0 -0
  39. {apache_airflow_providers_google-10.19.0rc1.dist-info → apache_airflow_providers_google-10.20.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,20 +19,42 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
+ from typing import TYPE_CHECKING
23
+
24
+ from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
22
25
  from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
26
+ from airflow.providers.google.cloud.utils.bigquery_get_data import bigquery_get_data
23
27
  from airflow.providers.postgres.hooks.postgres import PostgresHook
24
28
 
29
+ if TYPE_CHECKING:
30
+ from airflow.utils.context import Context
31
+
25
32
 
26
33
  class BigQueryToPostgresOperator(BigQueryToSqlBaseOperator):
27
34
  """
28
35
  Fetch data from a BigQuery table (alternatively fetch selected columns) and insert into PostgreSQL table.
29
36
 
37
+ Due to constraints of the PostgreSQL's ON CONFLICT clause both `selected_fields` and `replace_index`
38
+ parameters need to be specified when using the operator with parameter `replace=True`.
39
+ In effect this means that in order to run this operator with `replace=True` your target table MUST
40
+ already have a unique index column / columns, otherwise the INSERT command will fail with an error.
41
+ See more at https://www.postgresql.org/docs/current/sql-insert.html.
42
+
43
+ Please note that currently most of the clauses that can be used with PostgreSQL's INSERT
44
+ command, such as ON CONSTRAINT, WHERE, DEFAULT, etc., are not supported by this operator.
45
+ If you need the clauses for your queries, `SQLExecuteQueryOperator` will be a more suitable option.
46
+
30
47
  .. seealso::
31
48
  For more information on how to use this operator, take a look at the guide:
32
49
  :ref:`howto/operator:BigQueryToPostgresOperator`
33
50
 
34
51
  :param target_table_name: target Postgres table (templated)
35
52
  :param postgres_conn_id: Reference to :ref:`postgres connection id <howto/connection:postgres>`.
53
+ :param replace: Whether to replace instead of insert
54
+ :param selected_fields: List of fields to return (comma-separated). If
55
+ unspecified, all fields are returned. Must be specified if `replace` is True
56
+ :param replace_index: the column or list of column names to act as
57
+ index for the ON CONFLICT clause. Must be specified if `replace` is True
36
58
  """
37
59
 
38
60
  def __init__(
@@ -40,10 +62,43 @@ class BigQueryToPostgresOperator(BigQueryToSqlBaseOperator):
40
62
  *,
41
63
  target_table_name: str,
42
64
  postgres_conn_id: str = "postgres_default",
65
+ replace: bool = False,
66
+ selected_fields: list[str] | str | None = None,
67
+ replace_index: list[str] | str | None = None,
43
68
  **kwargs,
44
69
  ) -> None:
45
- super().__init__(target_table_name=target_table_name, **kwargs)
70
+ if replace and not (selected_fields and replace_index):
71
+ raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names and a unique index.")
72
+ super().__init__(
73
+ target_table_name=target_table_name, replace=replace, selected_fields=selected_fields, **kwargs
74
+ )
46
75
  self.postgres_conn_id = postgres_conn_id
76
+ self.replace_index = replace_index
47
77
 
48
78
  def get_sql_hook(self) -> PostgresHook:
49
79
  return PostgresHook(schema=self.database, postgres_conn_id=self.postgres_conn_id)
80
+
81
+ def execute(self, context: Context) -> None:
82
+ big_query_hook = BigQueryHook(
83
+ gcp_conn_id=self.gcp_conn_id,
84
+ location=self.location,
85
+ impersonation_chain=self.impersonation_chain,
86
+ )
87
+ self.persist_links(context)
88
+ sql_hook: PostgresHook = self.get_sql_hook()
89
+ for rows in bigquery_get_data(
90
+ self.log,
91
+ self.dataset_id,
92
+ self.table_id,
93
+ big_query_hook,
94
+ self.batch_size,
95
+ self.selected_fields,
96
+ ):
97
+ sql_hook.insert_rows(
98
+ table=self.target_table_name,
99
+ rows=rows,
100
+ target_fields=self.selected_fields,
101
+ replace=self.replace,
102
+ commit_every=self.batch_size,
103
+ replace_index=self.replace_index,
104
+ )
@@ -408,20 +408,9 @@ class GCSToGCSOperator(BaseOperator):
408
408
  msg = f"{prefix} does not exist in bucket {self.source_bucket}"
409
409
  self.log.warning(msg)
410
410
  raise AirflowException(msg)
411
- if len(objects) == 1 and objects[0][-1] != "/":
412
- self._copy_file(hook=hook, source_object=objects[0])
413
411
  elif len(objects):
414
412
  self._copy_multiple_objects(hook=hook, source_objects=objects, prefix=prefix)
415
413
 
416
- def _copy_file(self, hook, source_object):
417
- destination_object = self.destination_object or source_object
418
- if self.destination_object and self.destination_object[-1] == "/":
419
- file_name = source_object.split("/")[-1]
420
- destination_object += file_name
421
- self._copy_single_object(
422
- hook=hook, source_object=source_object, destination_object=destination_object
423
- )
424
-
425
414
  def _copy_multiple_objects(self, hook, source_objects, prefix):
426
415
  # Check whether the prefix is a root directory for all the rest of objects.
427
416
  _pref = prefix.rstrip("/")
@@ -441,7 +430,12 @@ class GCSToGCSOperator(BaseOperator):
441
430
  destination_object = source_obj
442
431
  else:
443
432
  file_name_postfix = source_obj.replace(base_path, "", 1)
444
- destination_object = self.destination_object.rstrip("/") + "/" + file_name_postfix
433
+
434
+ destination_object = (
435
+ self.destination_object.rstrip("/")[0 : self.destination_object.rfind("/")]
436
+ + "/"
437
+ + file_name_postfix
438
+ )
445
439
 
446
440
  self._copy_single_object(
447
441
  hook=hook, source_object=source_obj, destination_object=destination_object
@@ -19,8 +19,13 @@
19
19
  from __future__ import annotations
20
20
 
21
21
  import asyncio
22
+ import json
23
+ from datetime import datetime
22
24
  from typing import Any, Sequence
23
25
 
26
+ from dateutil import parser
27
+ from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse
28
+
24
29
  from airflow.exceptions import AirflowException
25
30
  from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerAsyncHook
26
31
  from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -146,3 +151,113 @@ class CloudComposerAirflowCLICommandTrigger(BaseTrigger):
146
151
  }
147
152
  )
148
153
  return
154
+
155
+
156
+ class CloudComposerDAGRunTrigger(BaseTrigger):
157
+ """The trigger wait for the DAG run completion."""
158
+
159
+ def __init__(
160
+ self,
161
+ project_id: str,
162
+ region: str,
163
+ environment_id: str,
164
+ composer_dag_id: str,
165
+ start_date: datetime,
166
+ end_date: datetime,
167
+ allowed_states: list[str],
168
+ gcp_conn_id: str = "google_cloud_default",
169
+ impersonation_chain: str | Sequence[str] | None = None,
170
+ poll_interval: int = 10,
171
+ ):
172
+ super().__init__()
173
+ self.project_id = project_id
174
+ self.region = region
175
+ self.environment_id = environment_id
176
+ self.composer_dag_id = composer_dag_id
177
+ self.start_date = start_date
178
+ self.end_date = end_date
179
+ self.allowed_states = allowed_states
180
+ self.gcp_conn_id = gcp_conn_id
181
+ self.impersonation_chain = impersonation_chain
182
+ self.poll_interval = poll_interval
183
+
184
+ self.gcp_hook = CloudComposerAsyncHook(
185
+ gcp_conn_id=self.gcp_conn_id,
186
+ impersonation_chain=self.impersonation_chain,
187
+ )
188
+
189
+ def serialize(self) -> tuple[str, dict[str, Any]]:
190
+ return (
191
+ "airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerDAGRunTrigger",
192
+ {
193
+ "project_id": self.project_id,
194
+ "region": self.region,
195
+ "environment_id": self.environment_id,
196
+ "composer_dag_id": self.composer_dag_id,
197
+ "start_date": self.start_date,
198
+ "end_date": self.end_date,
199
+ "allowed_states": self.allowed_states,
200
+ "gcp_conn_id": self.gcp_conn_id,
201
+ "impersonation_chain": self.impersonation_chain,
202
+ "poll_interval": self.poll_interval,
203
+ },
204
+ )
205
+
206
+ async def _pull_dag_runs(self) -> list[dict]:
207
+ """Pull the list of dag runs."""
208
+ dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
209
+ project_id=self.project_id,
210
+ region=self.region,
211
+ environment_id=self.environment_id,
212
+ command="dags",
213
+ subcommand="list-runs",
214
+ parameters=["-d", self.composer_dag_id, "-o", "json"],
215
+ )
216
+ cmd_result = await self.gcp_hook.wait_command_execution_result(
217
+ project_id=self.project_id,
218
+ region=self.region,
219
+ environment_id=self.environment_id,
220
+ execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
221
+ )
222
+ dag_runs = json.loads(cmd_result["output"][0]["content"])
223
+ return dag_runs
224
+
225
+ def _check_dag_runs_states(
226
+ self,
227
+ dag_runs: list[dict],
228
+ start_date: datetime,
229
+ end_date: datetime,
230
+ ) -> bool:
231
+ for dag_run in dag_runs:
232
+ if (
233
+ start_date.timestamp()
234
+ < parser.parse(dag_run["execution_date"]).timestamp()
235
+ < end_date.timestamp()
236
+ ) and dag_run["state"] not in self.allowed_states:
237
+ return False
238
+ return True
239
+
240
+ async def run(self):
241
+ try:
242
+ while True:
243
+ if datetime.now(self.end_date.tzinfo).timestamp() > self.end_date.timestamp():
244
+ dag_runs = await self._pull_dag_runs()
245
+
246
+ self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
247
+ if self._check_dag_runs_states(
248
+ dag_runs=dag_runs,
249
+ start_date=self.start_date,
250
+ end_date=self.end_date,
251
+ ):
252
+ yield TriggerEvent({"status": "success"})
253
+ return
254
+ self.log.info("Sleeping for %s seconds.", self.poll_interval)
255
+ await asyncio.sleep(self.poll_interval)
256
+ except AirflowException as ex:
257
+ yield TriggerEvent(
258
+ {
259
+ "status": "error",
260
+ "message": str(ex),
261
+ }
262
+ )
263
+ return
@@ -142,6 +142,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
142
142
  "on_finish_action": self.on_finish_action.value,
143
143
  "gcp_conn_id": self.gcp_conn_id,
144
144
  "impersonation_chain": self.impersonation_chain,
145
+ "logging_interval": self.logging_interval,
146
+ "last_log_time": self.last_log_time,
145
147
  },
146
148
  )
147
149
 
@@ -35,6 +35,9 @@ from google.auth.environment_vars import CREDENTIALS, LEGACY_PROJECT, PROJECT
35
35
 
36
36
  from airflow.exceptions import AirflowException
37
37
  from airflow.providers.google.cloud._internal_client.secret_manager_client import _SecretManagerClient
38
+ from airflow.providers.google.cloud.utils.external_token_supplier import (
39
+ ClientCredentialsGrantFlowTokenSupplier,
40
+ )
38
41
  from airflow.utils.log.logging_mixin import LoggingMixin
39
42
  from airflow.utils.process_utils import patch_environ
40
43
 
@@ -210,6 +213,10 @@ class _CredentialProvider(LoggingMixin):
210
213
  target_principal: str | None = None,
211
214
  delegates: Sequence[str] | None = None,
212
215
  is_anonymous: bool | None = None,
216
+ idp_issuer_url: str | None = None,
217
+ client_id: str | None = None,
218
+ client_secret: str | None = None,
219
+ idp_extra_params_dict: dict[str, str] | None = None,
213
220
  ) -> None:
214
221
  super().__init__()
215
222
  key_options = [key_path, keyfile_dict, credential_config_file, key_secret_name, is_anonymous]
@@ -229,6 +236,10 @@ class _CredentialProvider(LoggingMixin):
229
236
  self.target_principal = target_principal
230
237
  self.delegates = delegates
231
238
  self.is_anonymous = is_anonymous
239
+ self.idp_issuer_url = idp_issuer_url
240
+ self.client_id = client_id
241
+ self.client_secret = client_secret
242
+ self.idp_extra_params_dict = idp_extra_params_dict
232
243
 
233
244
  def get_credentials_and_project(self) -> tuple[Credentials, str]:
234
245
  """
@@ -239,7 +250,8 @@ class _CredentialProvider(LoggingMixin):
239
250
  :return: Google Auth Credentials
240
251
  """
241
252
  if self.is_anonymous:
242
- credentials, project_id = AnonymousCredentials(), ""
253
+ credentials: Credentials = AnonymousCredentials()
254
+ project_id = ""
243
255
  else:
244
256
  if self.key_path:
245
257
  credentials, project_id = self._get_credentials_using_key_path()
@@ -247,6 +259,10 @@ class _CredentialProvider(LoggingMixin):
247
259
  credentials, project_id = self._get_credentials_using_key_secret_name()
248
260
  elif self.keyfile_dict:
249
261
  credentials, project_id = self._get_credentials_using_keyfile_dict()
262
+ elif self.idp_issuer_url:
263
+ credentials, project_id = (
264
+ self._get_credentials_using_credential_config_file_and_token_supplier()
265
+ )
250
266
  elif self.credential_config_file:
251
267
  credentials, project_id = self._get_credentials_using_credential_config_file()
252
268
  else:
@@ -273,10 +289,12 @@ class _CredentialProvider(LoggingMixin):
273
289
 
274
290
  return credentials, project_id
275
291
 
276
- def _get_credentials_using_keyfile_dict(self):
292
+ def _get_credentials_using_keyfile_dict(self) -> tuple[Credentials, str]:
277
293
  self._log_debug("Getting connection using JSON Dict")
278
294
  # Depending on how the JSON was formatted, it may contain
279
295
  # escaped newlines. Convert those to actual newlines.
296
+ if self.keyfile_dict is None:
297
+ raise ValueError("The keyfile_dict field is None, and we need it for keyfile_dict auth.")
280
298
  self.keyfile_dict["private_key"] = self.keyfile_dict["private_key"].replace("\\n", "\n")
281
299
  credentials = google.oauth2.service_account.Credentials.from_service_account_info(
282
300
  self.keyfile_dict, scopes=self.scopes
@@ -284,7 +302,9 @@ class _CredentialProvider(LoggingMixin):
284
302
  project_id = credentials.project_id
285
303
  return credentials, project_id
286
304
 
287
- def _get_credentials_using_key_path(self):
305
+ def _get_credentials_using_key_path(self) -> tuple[Credentials, str]:
306
+ if self.key_path is None:
307
+ raise ValueError("The ky_path field is None, and we need it for keyfile_dict auth.")
288
308
  if self.key_path.endswith(".p12"):
289
309
  raise AirflowException("Legacy P12 key file are not supported, use a JSON key file.")
290
310
 
@@ -298,13 +318,15 @@ class _CredentialProvider(LoggingMixin):
298
318
  project_id = credentials.project_id
299
319
  return credentials, project_id
300
320
 
301
- def _get_credentials_using_key_secret_name(self):
321
+ def _get_credentials_using_key_secret_name(self) -> tuple[Credentials, str]:
302
322
  self._log_debug("Getting connection using JSON key data from GCP secret: %s", self.key_secret_name)
303
323
 
304
324
  # Use ADC to access GCP Secret Manager.
305
325
  adc_credentials, adc_project_id = google.auth.default(scopes=self.scopes)
306
326
  secret_manager_client = _SecretManagerClient(credentials=adc_credentials)
307
327
 
328
+ if self.key_secret_name is None:
329
+ raise ValueError("The key_secret_name field is None, and we need it for keyfile_dict auth.")
308
330
  if not secret_manager_client.is_valid_secret_name(self.key_secret_name):
309
331
  raise AirflowException("Invalid secret name specified for fetching JSON key data.")
310
332
 
@@ -326,7 +348,7 @@ class _CredentialProvider(LoggingMixin):
326
348
  project_id = credentials.project_id
327
349
  return credentials, project_id
328
350
 
329
- def _get_credentials_using_credential_config_file(self):
351
+ def _get_credentials_using_credential_config_file(self) -> tuple[Credentials, str]:
330
352
  if isinstance(self.credential_config_file, str) and os.path.exists(self.credential_config_file):
331
353
  self._log_info(
332
354
  f"Getting connection using credential configuration file: `{self.credential_config_file}`"
@@ -350,7 +372,25 @@ class _CredentialProvider(LoggingMixin):
350
372
 
351
373
  return credentials, project_id
352
374
 
353
- def _get_credentials_using_adc(self):
375
+ def _get_credentials_using_credential_config_file_and_token_supplier(self):
376
+ self._log_info(
377
+ "Getting connection using credential configuration file and external Identity Provider."
378
+ )
379
+
380
+ if not self.credential_config_file:
381
+ raise AirflowException(
382
+ "Credential Configuration File is needed to use authentication by External Identity Provider."
383
+ )
384
+
385
+ info = _get_info_from_credential_configuration_file(self.credential_config_file)
386
+ info["subject_token_supplier"] = ClientCredentialsGrantFlowTokenSupplier(
387
+ oidc_issuer_url=self.idp_issuer_url, client_id=self.client_id, client_secret=self.client_secret
388
+ )
389
+
390
+ credentials, project_id = google.auth.load_credentials_from_dict(info=info, scopes=self.scopes)
391
+ return credentials, project_id
392
+
393
+ def _get_credentials_using_adc(self) -> tuple[Credentials, str]:
354
394
  self._log_info(
355
395
  "Getting connection using `google.auth.default()` since no explicit credentials are provided."
356
396
  )
@@ -419,3 +459,38 @@ def _get_project_id_from_service_account_email(service_account_email: str) -> st
419
459
  raise AirflowException(
420
460
  f"Could not extract project_id from service account's email: {service_account_email}."
421
461
  )
462
+
463
+
464
+ def _get_info_from_credential_configuration_file(
465
+ credential_configuration_file: str | dict[str, str],
466
+ ) -> dict[str, str]:
467
+ """
468
+ Extract the Credential Configuration File information, either from a json file, json string or dictionary.
469
+
470
+ :param credential_configuration_file: File path or content (as json string or dictionary) of a GCP credential configuration file.
471
+
472
+ :return: Returns a dictionary containing the Credential Configuration File information.
473
+ """
474
+ # if it's already a dict, just return it
475
+ if isinstance(credential_configuration_file, dict):
476
+ return credential_configuration_file
477
+
478
+ if not isinstance(credential_configuration_file, str):
479
+ raise AirflowException(
480
+ f"Invalid argument type, expected str or dict, got {type(credential_configuration_file)}."
481
+ )
482
+
483
+ if os.path.exists(credential_configuration_file): # attempts to load from json file
484
+ with open(credential_configuration_file) as file_obj:
485
+ try:
486
+ return json.load(file_obj)
487
+ except ValueError:
488
+ raise AirflowException(
489
+ f"Credential Configuration File '{credential_configuration_file}' is not a valid json file."
490
+ )
491
+
492
+ # if not a file, attempt to load it from a json string
493
+ try:
494
+ return json.loads(credential_configuration_file)
495
+ except ValueError:
496
+ raise AirflowException("Credential Configuration File is not a valid json string.")
@@ -0,0 +1,175 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import abc
20
+ import time
21
+ from functools import wraps
22
+ from typing import TYPE_CHECKING, Any
23
+
24
+ import requests
25
+ from google.auth.exceptions import RefreshError
26
+ from google.auth.identity_pool import SubjectTokenSupplier
27
+
28
+ if TYPE_CHECKING:
29
+ from google.auth.external_account import SupplierContext
30
+ from google.auth.transport import Request
31
+
32
+ from airflow.utils.log.logging_mixin import LoggingMixin
33
+
34
+
35
+ def cache_token_decorator(get_subject_token_method):
36
+ """Cache calls to ``SubjectTokenSupplier`` instances' ``get_token_supplier`` methods.
37
+
38
+ Different instances of a same SubjectTokenSupplier class with the same attributes
39
+ share the OIDC token cache.
40
+
41
+ :param get_subject_token_method: A method that returns both a token and an integer specifying
42
+ the time in seconds until the token expires
43
+
44
+ See also:
45
+ https://googleapis.dev/python/google-auth/latest/reference/google.auth.identity_pool.html#google.auth.identity_pool.SubjectTokenSupplier.get_subject_token
46
+ """
47
+ cache = {}
48
+
49
+ @wraps(get_subject_token_method)
50
+ def wrapper(supplier_instance: CacheTokenSupplier, *args, **kwargs) -> str:
51
+ """Obeys the interface set by ``SubjectTokenSupplier`` for ``get_subject_token`` methods.
52
+
53
+ :param supplier_instance: the SubjectTokenSupplier instance whose get_subject_token method is being decorated
54
+ :return: The token string
55
+ """
56
+ nonlocal cache
57
+
58
+ cache_key = supplier_instance.get_subject_key()
59
+ token: dict[str, str | float] = {}
60
+
61
+ if cache_key not in cache or cache[cache_key]["expiration_time"] < time.monotonic():
62
+ supplier_instance.log.info("OIDC token missing or expired")
63
+ try:
64
+ access_token, expires_in = get_subject_token_method(supplier_instance, *args, **kwargs)
65
+ if not isinstance(expires_in, int) or not isinstance(access_token, str):
66
+ raise RefreshError # assume error if strange values are provided
67
+
68
+ except RefreshError:
69
+ supplier_instance.log.error("Failed retrieving new OIDC Token from IdP")
70
+ raise
71
+
72
+ expiration_time = time.monotonic() + float(expires_in)
73
+ token["access_token"] = access_token
74
+ token["expiration_time"] = expiration_time
75
+ cache[cache_key] = token
76
+
77
+ supplier_instance.log.info("New OIDC token retrieved, expires in %s seconds.", expires_in)
78
+
79
+ return cache[cache_key]["access_token"]
80
+
81
+ return wrapper
82
+
83
+
84
+ class CacheTokenSupplier(LoggingMixin, SubjectTokenSupplier):
85
+ """
86
+ A superclass for all Subject Token Supplier classes that wish to implement a caching mechanism.
87
+
88
+ Child classes must implement the ``get_subject_key`` method to generate a string that serves as the cache key,
89
+ ensuring that tokens are shared appropriately among instances.
90
+
91
+ Methods:
92
+ get_subject_key: Abstract method to be implemented by child classes. It should return a string that serves as the cache key.
93
+ """
94
+
95
+ def __init__(self):
96
+ super().__init__()
97
+
98
+ @abc.abstractmethod
99
+ def get_subject_key(self) -> str:
100
+ raise NotImplementedError("")
101
+
102
+
103
+ class ClientCredentialsGrantFlowTokenSupplier(CacheTokenSupplier):
104
+ """
105
+ Class that retrieves an OIDC token from an external IdP using OAuth2.0 Client Credentials Grant flow.
106
+
107
+ This class implements the ``SubjectTokenSupplier`` interface class used by ``google.auth.identity_pool.Credentials``
108
+
109
+ :params oidc_issuer_url: URL of the IdP that performs OAuth2.0 Client Credentials Grant flow and returns an OIDC token.
110
+ :params client_id: Client ID of the application requesting the token
111
+ :params client_secret: Client secret of the application requesting the token
112
+ :params extra_params_kwargs: Extra parameters to be passed in the payload of the POST request to the `oidc_issuer_url`
113
+
114
+ See also:
115
+ https://googleapis.dev/python/google-auth/latest/reference/google.auth.identity_pool.html#google.auth.identity_pool.SubjectTokenSupplier
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ oidc_issuer_url: str,
121
+ client_id: str,
122
+ client_secret: str,
123
+ **extra_params_kwargs: Any,
124
+ ) -> None:
125
+ super().__init__()
126
+ self.oidc_issuer_url = oidc_issuer_url
127
+ self.client_id = client_id
128
+ self.client_secret = client_secret
129
+ self.extra_params_kwargs = extra_params_kwargs
130
+
131
+ @cache_token_decorator
132
+ def get_subject_token(self, context: SupplierContext, request: Request) -> tuple[str, int]:
133
+ """Perform Client Credentials Grant flow with IdP and retrieves an OIDC token and expiration time."""
134
+ self.log.info("Requesting new OIDC token from external IdP.")
135
+ try:
136
+ response = requests.post(
137
+ self.oidc_issuer_url,
138
+ data={
139
+ "grant_type": "client_credentials",
140
+ "client_id": self.client_id,
141
+ "client_secret": self.client_secret,
142
+ **self.extra_params_kwargs,
143
+ },
144
+ )
145
+ response.raise_for_status()
146
+ except requests.HTTPError as e:
147
+ raise RefreshError(str(e))
148
+ except requests.ConnectionError as e:
149
+ raise RefreshError(str(e))
150
+
151
+ try:
152
+ response_dict = response.json()
153
+ except requests.JSONDecodeError:
154
+ raise RefreshError(f"Didn't get a json response from {self.oidc_issuer_url}")
155
+
156
+ # These fields are required
157
+ if {"access_token", "expires_in"} - set(response_dict.keys()):
158
+ # TODO more information about the error can be provided in the exception by inspecting the response
159
+ raise RefreshError(f"No access token returned from {self.oidc_issuer_url}")
160
+
161
+ return response_dict["access_token"], response_dict["expires_in"]
162
+
163
+ def get_subject_key(self) -> str:
164
+ """
165
+ Create a cache key using the OIDC issuer URL, client ID, client secret and additional parameters.
166
+
167
+ Instances with the same credentials will share tokens.
168
+ """
169
+ cache_key = (
170
+ self.oidc_issuer_url
171
+ + self.client_id
172
+ + self.client_secret
173
+ + ",".join(sorted(self.extra_params_kwargs))
174
+ )
175
+ return cache_key
@@ -248,6 +248,20 @@ class GoogleBaseHook(BaseHook):
248
248
  "impersonation_chain": StringField(
249
249
  lazy_gettext("Impersonation Chain"), widget=BS3TextFieldWidget()
250
250
  ),
251
+ "idp_issuer_url": StringField(
252
+ lazy_gettext("IdP Token Issue URL (Client Credentials Grant Flow)"),
253
+ widget=BS3TextFieldWidget(),
254
+ ),
255
+ "client_id": StringField(
256
+ lazy_gettext("Client ID (Client Credentials Grant Flow)"), widget=BS3TextFieldWidget()
257
+ ),
258
+ "client_secret": StringField(
259
+ lazy_gettext("Client Secret (Client Credentials Grant Flow)"),
260
+ widget=BS3PasswordFieldWidget(),
261
+ ),
262
+ "idp_extra_parameters": StringField(
263
+ lazy_gettext("IdP Extra Request Parameters"), widget=BS3TextFieldWidget()
264
+ ),
251
265
  "is_anonymous": BooleanField(
252
266
  lazy_gettext("Anonymous credentials (ignores all other settings)"), default=False
253
267
  ),
@@ -305,6 +319,18 @@ class GoogleBaseHook(BaseHook):
305
319
  target_principal, delegates = _get_target_principal_and_delegates(self.impersonation_chain)
306
320
  is_anonymous = self._get_field("is_anonymous")
307
321
 
322
+ idp_issuer_url: str | None = self._get_field("idp_issuer_url", None)
323
+ client_id: str | None = self._get_field("client_id", None)
324
+ client_secret: str | None = self._get_field("client_secret", None)
325
+ idp_extra_params: str | None = self._get_field("idp_extra_params", None)
326
+
327
+ idp_extra_params_dict: dict[str, str] | None = None
328
+ if idp_extra_params:
329
+ try:
330
+ idp_extra_params_dict = json.loads(idp_extra_params)
331
+ except json.decoder.JSONDecodeError:
332
+ raise AirflowException("Invalid JSON.")
333
+
308
334
  credentials, project_id = get_credentials_and_project_id(
309
335
  key_path=key_path,
310
336
  keyfile_dict=keyfile_dict_json,
@@ -316,6 +342,10 @@ class GoogleBaseHook(BaseHook):
316
342
  target_principal=target_principal,
317
343
  delegates=delegates,
318
344
  is_anonymous=is_anonymous,
345
+ idp_issuer_url=idp_issuer_url,
346
+ client_id=client_id,
347
+ client_secret=client_secret,
348
+ idp_extra_params_dict=idp_extra_params_dict,
319
349
  )
320
350
 
321
351
  overridden_project_id = self._get_field("project")
@@ -731,7 +761,11 @@ class GoogleBaseAsyncHook(BaseHook):
731
761
 
732
762
  sync_hook_class: Any = None
733
763
 
734
- def __init__(self, **kwargs: Any):
764
+ def __init__(self, **kwargs: Any) -> None:
765
+ # add default value to gcp_conn_id
766
+ if "gcp_conn_id" not in kwargs:
767
+ kwargs["gcp_conn_id"] = "google_cloud_default"
768
+
735
769
  self._hook_kwargs = kwargs
736
770
  self._sync_hook = None
737
771
 
@@ -190,7 +190,7 @@ def _get_gce_credentials(
190
190
 
191
191
 
192
192
  def get_default_id_token_credentials(
193
- target_audience: str | None, request: google.auth.transport.Request = None
193
+ target_audience: str | None, request: google.auth.transport.Request | None = None
194
194
  ) -> google_auth_credentials.Credentials:
195
195
  """Get the default ID Token credentials for the current environment.
196
196