apache-airflow-providers-google 16.1.0__py3-none-any.whl → 17.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +1 -5
  3. airflow/providers/google/cloud/hooks/bigquery.py +1 -130
  4. airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
  5. airflow/providers/google/cloud/hooks/cloud_run.py +1 -1
  6. airflow/providers/google/cloud/hooks/cloud_sql.py +5 -5
  7. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +1 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +0 -85
  9. airflow/providers/google/cloud/hooks/datafusion.py +1 -1
  10. airflow/providers/google/cloud/hooks/dataprep.py +1 -4
  11. airflow/providers/google/cloud/hooks/dataproc.py +68 -70
  12. airflow/providers/google/cloud/hooks/gcs.py +3 -5
  13. airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
  14. airflow/providers/google/cloud/hooks/looker.py +1 -5
  15. airflow/providers/google/cloud/hooks/stackdriver.py +10 -8
  16. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +4 -4
  17. airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
  18. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +7 -0
  19. airflow/providers/google/cloud/links/kubernetes_engine.py +3 -0
  20. airflow/providers/google/cloud/log/gcs_task_handler.py +2 -2
  21. airflow/providers/google/cloud/log/stackdriver_task_handler.py +1 -1
  22. airflow/providers/google/cloud/openlineage/mixins.py +7 -7
  23. airflow/providers/google/cloud/operators/automl.py +1 -1
  24. airflow/providers/google/cloud/operators/bigquery.py +8 -609
  25. airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
  26. airflow/providers/google/cloud/operators/cloud_sql.py +1 -5
  27. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +2 -2
  28. airflow/providers/google/cloud/operators/dataproc.py +1 -1
  29. airflow/providers/google/cloud/operators/dlp.py +2 -2
  30. airflow/providers/google/cloud/operators/kubernetes_engine.py +4 -4
  31. airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
  32. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +7 -1
  33. airflow/providers/google/cloud/operators/vertex_ai/ray.py +7 -5
  34. airflow/providers/google/cloud/operators/vision.py +1 -1
  35. airflow/providers/google/cloud/sensors/dataflow.py +23 -6
  36. airflow/providers/google/cloud/sensors/datafusion.py +2 -2
  37. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
  38. airflow/providers/google/cloud/transfers/gcs_to_local.py +3 -1
  39. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +9 -9
  40. airflow/providers/google/cloud/triggers/bigquery.py +11 -13
  41. airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
  42. airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
  43. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +1 -1
  44. airflow/providers/google/cloud/triggers/datafusion.py +1 -1
  45. airflow/providers/google/cloud/triggers/dataproc.py +10 -9
  46. airflow/providers/google/cloud/triggers/kubernetes_engine.py +45 -27
  47. airflow/providers/google/cloud/triggers/mlengine.py +1 -1
  48. airflow/providers/google/cloud/triggers/pubsub.py +1 -1
  49. airflow/providers/google/cloud/utils/credentials_provider.py +1 -1
  50. airflow/providers/google/common/auth_backend/google_openid.py +2 -2
  51. airflow/providers/google/common/hooks/base_google.py +2 -6
  52. airflow/providers/google/common/utils/id_token_credentials.py +2 -2
  53. airflow/providers/google/get_provider_info.py +19 -16
  54. airflow/providers/google/leveldb/hooks/leveldb.py +1 -5
  55. airflow/providers/google/marketing_platform/hooks/display_video.py +47 -3
  56. airflow/providers/google/marketing_platform/links/analytics_admin.py +1 -1
  57. airflow/providers/google/marketing_platform/operators/display_video.py +64 -15
  58. airflow/providers/google/marketing_platform/sensors/display_video.py +9 -2
  59. airflow/providers/google/version_compat.py +10 -3
  60. {apache_airflow_providers_google-16.1.0.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/METADATA +106 -100
  61. {apache_airflow_providers_google-16.1.0.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/RECORD +63 -62
  62. airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
  63. airflow/providers/google/cloud/links/life_sciences.py +0 -30
  64. airflow/providers/google/cloud/operators/life_sciences.py +0 -118
  65. {apache_airflow_providers_google-16.1.0.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/WHEEL +0 -0
  66. {apache_airflow_providers_google-16.1.0.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/entry_points.txt +0 -0
@@ -47,7 +47,7 @@ from google.cloud.dataproc_v1 import (
47
47
 
48
48
  from airflow.exceptions import AirflowException
49
49
  from airflow.providers.google.common.consts import CLIENT_INFO
50
- from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
50
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
51
51
  from airflow.version import version as airflow_version
52
52
 
53
53
  if TYPE_CHECKING:
@@ -1269,7 +1269,7 @@ class DataprocHook(GoogleBaseHook):
1269
1269
  return all([word in error_msg for word in key_words])
1270
1270
 
1271
1271
 
1272
- class DataprocAsyncHook(GoogleBaseHook):
1272
+ class DataprocAsyncHook(GoogleBaseAsyncHook):
1273
1273
  """
1274
1274
  Asynchronous interaction with Google Cloud Dataproc APIs.
1275
1275
 
@@ -1277,6 +1277,8 @@ class DataprocAsyncHook(GoogleBaseHook):
1277
1277
  keyword arguments rather than positional.
1278
1278
  """
1279
1279
 
1280
+ sync_hook_class = DataprocHook
1281
+
1280
1282
  def __init__(
1281
1283
  self,
1282
1284
  gcp_conn_id: str = "google_cloud_default",
@@ -1286,53 +1288,90 @@ class DataprocAsyncHook(GoogleBaseHook):
1286
1288
  super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs)
1287
1289
  self._cached_client: JobControllerAsyncClient | None = None
1288
1290
 
1289
- def get_cluster_client(self, region: str | None = None) -> ClusterControllerAsyncClient:
1291
+ async def get_cluster_client(self, region: str | None = None) -> ClusterControllerAsyncClient:
1290
1292
  """Create a ClusterControllerAsyncClient."""
1291
1293
  client_options = None
1292
1294
  if region and region != "global":
1293
1295
  client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
1294
1296
 
1297
+ sync_hook = await self.get_sync_hook()
1295
1298
  return ClusterControllerAsyncClient(
1296
- credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
1299
+ credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
1297
1300
  )
1298
1301
 
1299
- def get_template_client(self, region: str | None = None) -> WorkflowTemplateServiceAsyncClient:
1302
+ async def get_template_client(self, region: str | None = None) -> WorkflowTemplateServiceAsyncClient:
1300
1303
  """Create a WorkflowTemplateServiceAsyncClient."""
1301
1304
  client_options = None
1302
1305
  if region and region != "global":
1303
1306
  client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
1304
1307
 
1308
+ sync_hook = await self.get_sync_hook()
1305
1309
  return WorkflowTemplateServiceAsyncClient(
1306
- credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
1310
+ credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
1307
1311
  )
1308
1312
 
1309
- def get_job_client(self, region: str | None = None) -> JobControllerAsyncClient:
1313
+ async def get_job_client(self, region: str | None = None) -> JobControllerAsyncClient:
1310
1314
  """Create a JobControllerAsyncClient."""
1311
1315
  if self._cached_client is None:
1312
1316
  client_options = None
1313
1317
  if region and region != "global":
1314
1318
  client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
1315
1319
 
1320
+ sync_hook = await self.get_sync_hook()
1316
1321
  self._cached_client = JobControllerAsyncClient(
1317
- credentials=self.get_credentials(),
1322
+ credentials=sync_hook.get_credentials(),
1318
1323
  client_info=CLIENT_INFO,
1319
1324
  client_options=client_options,
1320
1325
  )
1321
1326
  return self._cached_client
1322
1327
 
1323
- def get_batch_client(self, region: str | None = None) -> BatchControllerAsyncClient:
1328
+ async def get_batch_client(self, region: str | None = None) -> BatchControllerAsyncClient:
1324
1329
  """Create a BatchControllerAsyncClient."""
1325
1330
  client_options = None
1326
1331
  if region and region != "global":
1327
1332
  client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
1328
1333
 
1334
+ sync_hook = await self.get_sync_hook()
1329
1335
  return BatchControllerAsyncClient(
1330
- credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
1336
+ credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
1331
1337
  )
1332
1338
 
1333
- def get_operations_client(self, region: str) -> OperationsClient:
1339
+ async def get_operations_client(self, region: str) -> OperationsClient:
1334
1340
  """Create a OperationsClient."""
1335
- return self.get_template_client(region=region).transport.operations_client
1341
+ template_client = await self.get_template_client(region=region)
1342
+ return template_client.transport.operations_client
1343
+
1344
+ @GoogleBaseHook.fallback_to_default_project_id
1345
+ async def get_cluster(
1346
+ self,
1347
+ region: str,
1348
+ cluster_name: str,
1349
+ project_id: str,
1350
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
1351
+ timeout: float | None = None,
1352
+ metadata: Sequence[tuple[str, str]] = (),
1353
+ ) -> Cluster:
1354
+ """
1355
+ Get a cluster.
1356
+
1357
+ :param region: Cloud Dataproc region in which to handle the request.
1358
+ :param cluster_name: Name of the cluster to get.
1359
+ :param project_id: Google Cloud project ID that the cluster belongs to.
1360
+ :param retry: A retry object used to retry requests. If *None*, requests
1361
+ will not be retried.
1362
+ :param timeout: The amount of time, in seconds, to wait for the request
1363
+ to complete. If *retry* is specified, the timeout applies to each
1364
+ individual attempt.
1365
+ :param metadata: Additional metadata that is provided to the method.
1366
+ """
1367
+ client = await self.get_cluster_client(region=region)
1368
+ result = await client.get_cluster(
1369
+ request={"project_id": project_id, "region": region, "cluster_name": cluster_name},
1370
+ retry=retry,
1371
+ timeout=timeout,
1372
+ metadata=metadata,
1373
+ )
1374
+ return result
1336
1375
 
1337
1376
  @GoogleBaseHook.fallback_to_default_project_id
1338
1377
  async def create_cluster(
@@ -1390,7 +1429,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1390
1429
  cluster["config"] = cluster_config # type: ignore
1391
1430
  cluster["labels"] = labels # type: ignore
1392
1431
 
1393
- client = self.get_cluster_client(region=region)
1432
+ client = await self.get_cluster_client(region=region)
1394
1433
  result = await client.create_cluster(
1395
1434
  request={
1396
1435
  "project_id": project_id,
@@ -1435,7 +1474,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1435
1474
  individual attempt.
1436
1475
  :param metadata: Additional metadata that is provided to the method.
1437
1476
  """
1438
- client = self.get_cluster_client(region=region)
1477
+ client = await self.get_cluster_client(region=region)
1439
1478
  result = await client.delete_cluster(
1440
1479
  request={
1441
1480
  "project_id": project_id,
@@ -1483,7 +1522,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1483
1522
  individual attempt.
1484
1523
  :param metadata: Additional metadata that is provided to the method.
1485
1524
  """
1486
- client = self.get_cluster_client(region=region)
1525
+ client = await self.get_cluster_client(region=region)
1487
1526
  result = await client.diagnose_cluster(
1488
1527
  request={
1489
1528
  "project_id": project_id,
@@ -1500,38 +1539,6 @@ class DataprocAsyncHook(GoogleBaseHook):
1500
1539
  )
1501
1540
  return result
1502
1541
 
1503
- @GoogleBaseHook.fallback_to_default_project_id
1504
- async def get_cluster(
1505
- self,
1506
- region: str,
1507
- cluster_name: str,
1508
- project_id: str,
1509
- retry: AsyncRetry | _MethodDefault = DEFAULT,
1510
- timeout: float | None = None,
1511
- metadata: Sequence[tuple[str, str]] = (),
1512
- ) -> Cluster:
1513
- """
1514
- Get the resource representation for a cluster in a project.
1515
-
1516
- :param project_id: Google Cloud project ID that the cluster belongs to.
1517
- :param region: Cloud Dataproc region to handle the request.
1518
- :param cluster_name: The cluster name.
1519
- :param retry: A retry object used to retry requests. If *None*, requests
1520
- will not be retried.
1521
- :param timeout: The amount of time, in seconds, to wait for the request
1522
- to complete. If *retry* is specified, the timeout applies to each
1523
- individual attempt.
1524
- :param metadata: Additional metadata that is provided to the method.
1525
- """
1526
- client = self.get_cluster_client(region=region)
1527
- result = await client.get_cluster(
1528
- request={"project_id": project_id, "region": region, "cluster_name": cluster_name},
1529
- retry=retry,
1530
- timeout=timeout,
1531
- metadata=metadata,
1532
- )
1533
- return result
1534
-
1535
1542
  @GoogleBaseHook.fallback_to_default_project_id
1536
1543
  async def list_clusters(
1537
1544
  self,
@@ -1561,7 +1568,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1561
1568
  individual attempt.
1562
1569
  :param metadata: Additional metadata that is provided to the method.
1563
1570
  """
1564
- client = self.get_cluster_client(region=region)
1571
+ client = await self.get_cluster_client(region=region)
1565
1572
  result = await client.list_clusters(
1566
1573
  request={"project_id": project_id, "region": region, "filter": filter_, "page_size": page_size},
1567
1574
  retry=retry,
@@ -1638,7 +1645,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1638
1645
  """
1639
1646
  if region is None:
1640
1647
  raise TypeError("missing 1 required keyword argument: 'region'")
1641
- client = self.get_cluster_client(region=region)
1648
+ client = await self.get_cluster_client(region=region)
1642
1649
  operation = await client.update_cluster(
1643
1650
  request={
1644
1651
  "project_id": project_id,
@@ -1680,10 +1687,8 @@ class DataprocAsyncHook(GoogleBaseHook):
1680
1687
  individual attempt.
1681
1688
  :param metadata: Additional metadata that is provided to the method.
1682
1689
  """
1683
- if region is None:
1684
- raise TypeError("missing 1 required keyword argument: 'region'")
1685
1690
  metadata = metadata or ()
1686
- client = self.get_template_client(region)
1691
+ client = await self.get_template_client(region)
1687
1692
  parent = f"projects/{project_id}/regions/{region}"
1688
1693
  return await client.create_workflow_template(
1689
1694
  request={"parent": parent, "template": template}, retry=retry, timeout=timeout, metadata=metadata
@@ -1725,10 +1730,8 @@ class DataprocAsyncHook(GoogleBaseHook):
1725
1730
  individual attempt.
1726
1731
  :param metadata: Additional metadata that is provided to the method.
1727
1732
  """
1728
- if region is None:
1729
- raise TypeError("missing 1 required keyword argument: 'region'")
1730
1733
  metadata = metadata or ()
1731
- client = self.get_template_client(region)
1734
+ client = await self.get_template_client(region)
1732
1735
  name = f"projects/{project_id}/regions/{region}/workflowTemplates/{template_name}"
1733
1736
  operation = await client.instantiate_workflow_template(
1734
1737
  request={"name": name, "version": version, "request_id": request_id, "parameters": parameters},
@@ -1767,10 +1770,8 @@ class DataprocAsyncHook(GoogleBaseHook):
1767
1770
  individual attempt.
1768
1771
  :param metadata: Additional metadata that is provided to the method.
1769
1772
  """
1770
- if region is None:
1771
- raise TypeError("missing 1 required keyword argument: 'region'")
1772
1773
  metadata = metadata or ()
1773
- client = self.get_template_client(region)
1774
+ client = await self.get_template_client(region)
1774
1775
  parent = f"projects/{project_id}/regions/{region}"
1775
1776
  operation = await client.instantiate_inline_workflow_template(
1776
1777
  request={"parent": parent, "template": template, "request_id": request_id},
@@ -1781,7 +1782,8 @@ class DataprocAsyncHook(GoogleBaseHook):
1781
1782
  return operation
1782
1783
 
1783
1784
  async def get_operation(self, region, operation_name):
1784
- return await self.get_operations_client(region).get_operation(name=operation_name)
1785
+ operations_client = await self.get_operations_client(region)
1786
+ return await operations_client.get_operation(name=operation_name)
1785
1787
 
1786
1788
  @GoogleBaseHook.fallback_to_default_project_id
1787
1789
  async def get_job(
@@ -1806,9 +1808,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1806
1808
  individual attempt.
1807
1809
  :param metadata: Additional metadata that is provided to the method.
1808
1810
  """
1809
- if region is None:
1810
- raise TypeError("missing 1 required keyword argument: 'region'")
1811
- client = self.get_job_client(region=region)
1811
+ client = await self.get_job_client(region=region)
1812
1812
  job = await client.get_job(
1813
1813
  request={"project_id": project_id, "region": region, "job_id": job_id},
1814
1814
  retry=retry,
@@ -1845,9 +1845,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1845
1845
  individual attempt.
1846
1846
  :param metadata: Additional metadata that is provided to the method.
1847
1847
  """
1848
- if region is None:
1849
- raise TypeError("missing 1 required keyword argument: 'region'")
1850
- client = self.get_job_client(region=region)
1848
+ client = await self.get_job_client(region=region)
1851
1849
  return await client.submit_job(
1852
1850
  request={"project_id": project_id, "region": region, "job": job, "request_id": request_id},
1853
1851
  retry=retry,
@@ -1878,7 +1876,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1878
1876
  individual attempt.
1879
1877
  :param metadata: Additional metadata that is provided to the method.
1880
1878
  """
1881
- client = self.get_job_client(region=region)
1879
+ client = await self.get_job_client(region=region)
1882
1880
 
1883
1881
  job = await client.cancel_job(
1884
1882
  request={"project_id": project_id, "region": region, "job_id": job_id},
@@ -1920,7 +1918,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1920
1918
  individual attempt.
1921
1919
  :param metadata: Additional metadata that is provided to the method.
1922
1920
  """
1923
- client = self.get_batch_client(region)
1921
+ client = await self.get_batch_client(region)
1924
1922
  parent = f"projects/{project_id}/regions/{region}"
1925
1923
 
1926
1924
  result = await client.create_batch(
@@ -1959,7 +1957,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1959
1957
  individual attempt.
1960
1958
  :param metadata: Additional metadata that is provided to the method.
1961
1959
  """
1962
- client = self.get_batch_client(region)
1960
+ client = await self.get_batch_client(region)
1963
1961
  name = f"projects/{project_id}/locations/{region}/batches/{batch_id}"
1964
1962
 
1965
1963
  await client.delete_batch(
@@ -1994,7 +1992,7 @@ class DataprocAsyncHook(GoogleBaseHook):
1994
1992
  individual attempt.
1995
1993
  :param metadata: Additional metadata that is provided to the method.
1996
1994
  """
1997
- client = self.get_batch_client(region)
1995
+ client = await self.get_batch_client(region)
1998
1996
  name = f"projects/{project_id}/locations/{region}/batches/{batch_id}"
1999
1997
 
2000
1998
  result = await client.get_batch(
@@ -2039,7 +2037,7 @@ class DataprocAsyncHook(GoogleBaseHook):
2039
2037
  :param filter: Result filters as specified in ListBatchesRequest
2040
2038
  :param order_by: How to order results as specified in ListBatchesRequest
2041
2039
  """
2042
- client = self.get_batch_client(region)
2040
+ client = await self.get_batch_client(region)
2043
2041
  parent = f"projects/{project_id}/regions/{region}"
2044
2042
 
2045
2043
  result = await client.list_batches(
@@ -31,14 +31,13 @@ from contextlib import contextmanager
31
31
  from functools import partial
32
32
  from io import BytesIO
33
33
  from tempfile import NamedTemporaryFile
34
- from typing import IO, TYPE_CHECKING, Any, TypeVar, cast, overload
34
+ from typing import IO, TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload
35
35
  from urllib.parse import urlsplit
36
36
 
37
+ # Make mypy happy by importing as aliases
38
+ import google.cloud.storage as storage
37
39
  from gcloud.aio.storage import Storage
38
40
  from google.api_core.exceptions import GoogleAPICallError, NotFound
39
-
40
- # not sure why but mypy complains on missing `storage` but it is clearly there and is importable
41
- from google.cloud import storage # type: ignore[attr-defined]
42
41
  from google.cloud.exceptions import GoogleCloudError
43
42
  from google.cloud.storage.retry import DEFAULT_RETRY
44
43
 
@@ -51,7 +50,6 @@ from airflow.providers.google.common.hooks.base_google import (
51
50
  GoogleBaseAsyncHook,
52
51
  GoogleBaseHook,
53
52
  )
54
- from airflow.typing_compat import ParamSpec
55
53
  from airflow.utils import timezone
56
54
  from airflow.version import version
57
55
 
@@ -30,7 +30,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
30
30
  from google.auth.transport import requests as google_requests
31
31
 
32
32
  # not sure why but mypy complains on missing `container_v1` but it is clearly there and is importable
33
- from google.cloud import exceptions # type: ignore[attr-defined]
33
+ from google.cloud import exceptions
34
34
  from google.cloud.container_v1 import ClusterManagerAsyncClient, ClusterManagerClient
35
35
  from google.cloud.container_v1.types import Cluster, Operation
36
36
  from kubernetes import client
@@ -498,7 +498,7 @@ class GKEKubernetesAsyncHook(GoogleBaseAsyncHook, AsyncKubernetesHook):
498
498
  )
499
499
 
500
500
  @contextlib.asynccontextmanager
501
- async def get_conn(self) -> async_client.ApiClient: # type: ignore[override]
501
+ async def get_conn(self) -> async_client.ApiClient:
502
502
  kube_client = None
503
503
  try:
504
504
  kube_client = await self._load_config()
@@ -29,11 +29,7 @@ from looker_sdk.sdk.api40 import methods as methods40
29
29
  from packaging.version import parse as parse_version
30
30
 
31
31
  from airflow.exceptions import AirflowException
32
-
33
- try:
34
- from airflow.sdk import BaseHook
35
- except ImportError:
36
- from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]
32
+ from airflow.providers.google.version_compat import BaseHook
37
33
  from airflow.version import version
38
34
 
39
35
  if TYPE_CHECKING:
@@ -261,8 +261,9 @@ class StackdriverHook(GoogleBaseHook):
261
261
  channel_name_map = {}
262
262
 
263
263
  for channel in channels:
264
+ # This field is immutable, illegal to specifying non-default UNVERIFIED or VERIFIED, so setting default
264
265
  channel.verification_status = (
265
- monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
266
+ monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED # type: ignore[assignment]
266
267
  )
267
268
 
268
269
  if channel.name in existing_channels:
@@ -274,7 +275,7 @@ class StackdriverHook(GoogleBaseHook):
274
275
  )
275
276
  else:
276
277
  old_name = channel.name
277
- channel.name = None
278
+ del channel.name
278
279
  new_channel = channel_client.create_notification_channel(
279
280
  request={"name": f"projects/{project_id}", "notification_channel": channel},
280
281
  retry=retry,
@@ -284,8 +285,8 @@ class StackdriverHook(GoogleBaseHook):
284
285
  channel_name_map[old_name] = new_channel.name
285
286
 
286
287
  for policy in policies_:
287
- policy.creation_record = None
288
- policy.mutation_record = None
288
+ del policy.creation_record
289
+ del policy.mutation_record
289
290
 
290
291
  for i, channel in enumerate(policy.notification_channels):
291
292
  new_channel = channel_name_map.get(channel)
@@ -301,9 +302,9 @@ class StackdriverHook(GoogleBaseHook):
301
302
  metadata=metadata,
302
303
  )
303
304
  else:
304
- policy.name = None
305
+ del policy.name
305
306
  for condition in policy.conditions:
306
- condition.name = None
307
+ del condition.name
307
308
  policy_client.create_alert_policy(
308
309
  request={"name": f"projects/{project_id}", "alert_policy": policy},
309
310
  retry=retry,
@@ -531,8 +532,9 @@ class StackdriverHook(GoogleBaseHook):
531
532
  channels_list.append(NotificationChannel(**channel))
532
533
 
533
534
  for channel in channels_list:
535
+ # This field is immutable, illegal to specifying non-default UNVERIFIED or VERIFIED, so setting default
534
536
  channel.verification_status = (
535
- monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
537
+ monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED # type: ignore[assignment]
536
538
  )
537
539
 
538
540
  if channel.name in existing_channels:
@@ -544,7 +546,7 @@ class StackdriverHook(GoogleBaseHook):
544
546
  )
545
547
  else:
546
548
  old_name = channel.name
547
- channel.name = None
549
+ del channel.name
548
550
  new_channel = channel_client.create_notification_channel(
549
551
  request={"name": f"projects/{project_id}", "notification_channel": channel},
550
552
  retry=retry,
@@ -1098,13 +1098,13 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
1098
1098
  raise AirflowException("AutoMLTextTrainingJob was not created")
1099
1099
 
1100
1100
  model = self._job.run(
1101
- dataset=dataset, # type: ignore[arg-type]
1102
- training_fraction_split=training_fraction_split, # type: ignore[call-arg]
1103
- validation_fraction_split=validation_fraction_split, # type: ignore[call-arg]
1101
+ dataset=dataset,
1102
+ training_fraction_split=training_fraction_split,
1103
+ validation_fraction_split=validation_fraction_split,
1104
1104
  test_fraction_split=test_fraction_split,
1105
1105
  training_filter_split=training_filter_split,
1106
1106
  validation_filter_split=validation_filter_split,
1107
- test_filter_split=test_filter_split, # type: ignore[call-arg]
1107
+ test_filter_split=test_filter_split,
1108
1108
  model_display_name=model_display_name,
1109
1109
  model_labels=model_labels,
1110
1110
  sync=sync,
@@ -0,0 +1,202 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ from google.cloud import aiplatform
21
+ from google.cloud.aiplatform.compat.types import execution_v1 as gca_execution
22
+
23
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
24
+
25
+
26
+ class ExperimentHook(GoogleBaseHook):
27
+ """Use the Vertex AI SDK for Python to manage your experiments."""
28
+
29
+ @GoogleBaseHook.fallback_to_default_project_id
30
+ def create_experiment(
31
+ self,
32
+ experiment_name: str,
33
+ location: str,
34
+ experiment_description: str = "",
35
+ project_id: str = PROVIDE_PROJECT_ID,
36
+ experiment_tensorboard: str | None = None,
37
+ ):
38
+ """
39
+ Create an experiment and, optionally, associate a Vertex AI TensorBoard instance using the Vertex AI SDK for Python.
40
+
41
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
42
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
43
+ :param experiment_name: Required. The name of the evaluation experiment.
44
+ :param experiment_description: Optional. Description of the evaluation experiment.
45
+ :param experiment_tensorboard: Optional. The Vertex TensorBoard instance to use as a backing
46
+ TensorBoard for the provided experiment. If no TensorBoard is provided, a default Tensorboard
47
+ instance is created and used by this experiment.
48
+ """
49
+ aiplatform.init(
50
+ experiment=experiment_name,
51
+ experiment_description=experiment_description,
52
+ experiment_tensorboard=experiment_tensorboard if experiment_tensorboard else False,
53
+ project=project_id,
54
+ location=location,
55
+ )
56
+ self.log.info("Created experiment with name: %s", experiment_name)
57
+
58
+ @GoogleBaseHook.fallback_to_default_project_id
59
+ def delete_experiment(
60
+ self,
61
+ experiment_name: str,
62
+ location: str,
63
+ project_id: str = PROVIDE_PROJECT_ID,
64
+ delete_backing_tensorboard_runs: bool = False,
65
+ ) -> None:
66
+ """
67
+ Delete an experiment.
68
+
69
+ Deleting an experiment deletes that experiment and all experiment runs associated with the experiment.
70
+ The Vertex AI TensorBoard experiment associated with the experiment is not deleted.
71
+
72
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
73
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
74
+ :param experiment_name: Required. The name of the evaluation experiment.
75
+ :param delete_backing_tensorboard_runs: Optional. If True will also delete the Vertex AI TensorBoard
76
+ runs associated with the experiment runs under this experiment that we used to store time series
77
+ metrics.
78
+ """
79
+ experiment = aiplatform.Experiment(
80
+ experiment_name=experiment_name, project=project_id, location=location
81
+ )
82
+
83
+ experiment.delete(delete_backing_tensorboard_runs=delete_backing_tensorboard_runs)
84
+
85
+
86
+ class ExperimentRunHook(GoogleBaseHook):
87
+ """Use the Vertex AI SDK for Python to create and manage your experiment runs."""
88
+
89
+ @GoogleBaseHook.fallback_to_default_project_id
90
+ def create_experiment_run(
91
+ self,
92
+ experiment_run_name: str,
93
+ experiment_name: str,
94
+ location: str,
95
+ project_id: str = PROVIDE_PROJECT_ID,
96
+ experiment_run_tensorboard: str | None = None,
97
+ run_after_creation: bool = False,
98
+ ) -> None:
99
+ """
100
+ Create experiment run for the experiment.
101
+
102
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
103
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
104
+ :param experiment_name: Required. The name of the evaluation experiment.
105
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
106
+ :param experiment_run_tensorboard: Optional. A backing TensorBoard resource to enable and store time
107
+ series metrics logged to this experiment run.
108
+ :param run_after_creation: Optional. Responsible for state after creation of experiment run.
109
+ If true experiment run will be created with state RUNNING.
110
+ """
111
+ experiment_run_state = (
112
+ gca_execution.Execution.State.NEW
113
+ if not run_after_creation
114
+ else gca_execution.Execution.State.RUNNING
115
+ )
116
+ experiment_run = aiplatform.ExperimentRun.create(
117
+ run_name=experiment_run_name,
118
+ experiment=experiment_name,
119
+ project=project_id,
120
+ location=location,
121
+ state=experiment_run_state,
122
+ tensorboard=experiment_run_tensorboard,
123
+ )
124
+ self.log.info(
125
+ "Created experiment run with name: %s and status: %s",
126
+ experiment_run.name,
127
+ experiment_run.state,
128
+ )
129
+
130
+ @GoogleBaseHook.fallback_to_default_project_id
131
+ def list_experiment_runs(
132
+ self,
133
+ experiment_name: str,
134
+ location: str,
135
+ project_id: str = PROVIDE_PROJECT_ID,
136
+ ) -> list[aiplatform.ExperimentRun]:
137
+ """
138
+ List experiment run for the experiment.
139
+
140
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
141
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
142
+ :param experiment_name: Required. The name of the evaluation experiment.
143
+ """
144
+ experiment_runs = aiplatform.ExperimentRun.list(
145
+ experiment=experiment_name,
146
+ project=project_id,
147
+ location=location,
148
+ )
149
+ return experiment_runs
150
+
151
+ @GoogleBaseHook.fallback_to_default_project_id
152
+ def update_experiment_run_state(
153
+ self,
154
+ experiment_run_name: str,
155
+ experiment_name: str,
156
+ location: str,
157
+ new_state: gca_execution.Execution.State,
158
+ project_id: str = PROVIDE_PROJECT_ID,
159
+ ) -> None:
160
+ """
161
+ Update state of the experiment run.
162
+
163
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
164
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
165
+ :param experiment_name: Required. The name of the evaluation experiment.
166
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
167
+ :param new_state: Required. New state of the experiment run.
168
+ """
169
+ experiment_run = aiplatform.ExperimentRun(
170
+ run_name=experiment_run_name,
171
+ experiment=experiment_name,
172
+ project=project_id,
173
+ location=location,
174
+ )
175
+ self.log.info("State of the %s before update is: %s", experiment_run.name, experiment_run.state)
176
+
177
+ experiment_run.update_state(new_state)
178
+
179
+ @GoogleBaseHook.fallback_to_default_project_id
180
+ def delete_experiment_run(
181
+ self,
182
+ experiment_run_name: str,
183
+ experiment_name: str,
184
+ location: str,
185
+ project_id: str = PROVIDE_PROJECT_ID,
186
+ delete_backing_tensorboard_run: bool = False,
187
+ ) -> None:
188
+ """
189
+ Delete experiment run from the experiment.
190
+
191
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
192
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
193
+ :param experiment_name: Required. The name of the evaluation experiment.
194
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
195
+ :param delete_backing_tensorboard_run: Whether to delete the backing Vertex AI TensorBoard run
196
+ that stores time series metrics for this run.
197
+ """
198
+ self.log.info("Next experiment run will be deleted: %s", experiment_run_name)
199
+ experiment_run = aiplatform.ExperimentRun(
200
+ run_name=experiment_run_name, experiment=experiment_name, project=project_id, location=location
201
+ )
202
+ experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run)