apache-airflow-providers-google 16.1.0rc1__py3-none-any.whl → 17.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +1 -5
- airflow/providers/google/cloud/hooks/bigquery.py +1 -130
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_run.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_sql.py +5 -5
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +1 -1
- airflow/providers/google/cloud/hooks/dataflow.py +0 -85
- airflow/providers/google/cloud/hooks/datafusion.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +1 -4
- airflow/providers/google/cloud/hooks/dataproc.py +68 -70
- airflow/providers/google/cloud/hooks/gcs.py +3 -5
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/looker.py +1 -5
- airflow/providers/google/cloud/hooks/stackdriver.py +10 -8
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +4 -4
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +7 -0
- airflow/providers/google/cloud/links/kubernetes_engine.py +3 -0
- airflow/providers/google/cloud/log/gcs_task_handler.py +2 -2
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +1 -1
- airflow/providers/google/cloud/openlineage/mixins.py +7 -7
- airflow/providers/google/cloud/operators/automl.py +1 -1
- airflow/providers/google/cloud/operators/bigquery.py +8 -609
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_sql.py +1 -5
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +2 -2
- airflow/providers/google/cloud/operators/dataproc.py +1 -1
- airflow/providers/google/cloud/operators/dlp.py +2 -2
- airflow/providers/google/cloud/operators/kubernetes_engine.py +4 -4
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +7 -1
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +7 -5
- airflow/providers/google/cloud/operators/vision.py +1 -1
- airflow/providers/google/cloud/sensors/dataflow.py +23 -6
- airflow/providers/google/cloud/sensors/datafusion.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +3 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +9 -9
- airflow/providers/google/cloud/triggers/bigquery.py +11 -13
- airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +1 -1
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataproc.py +10 -9
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +45 -27
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +1 -1
- airflow/providers/google/common/auth_backend/google_openid.py +2 -2
- airflow/providers/google/common/hooks/base_google.py +2 -6
- airflow/providers/google/common/utils/id_token_credentials.py +2 -2
- airflow/providers/google/get_provider_info.py +19 -16
- airflow/providers/google/leveldb/hooks/leveldb.py +1 -5
- airflow/providers/google/marketing_platform/hooks/display_video.py +47 -3
- airflow/providers/google/marketing_platform/links/analytics_admin.py +1 -1
- airflow/providers/google/marketing_platform/operators/display_video.py +64 -15
- airflow/providers/google/marketing_platform/sensors/display_video.py +9 -2
- airflow/providers/google/version_compat.py +10 -3
- {apache_airflow_providers_google-16.1.0rc1.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/METADATA +99 -93
- {apache_airflow_providers_google-16.1.0rc1.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/RECORD +63 -62
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/life_sciences.py +0 -30
- airflow/providers/google/cloud/operators/life_sciences.py +0 -118
- {apache_airflow_providers_google-16.1.0rc1.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-16.1.0rc1.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/entry_points.txt +0 -0
@@ -47,7 +47,7 @@ from google.cloud.dataproc_v1 import (
|
|
47
47
|
|
48
48
|
from airflow.exceptions import AirflowException
|
49
49
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
50
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
50
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
51
51
|
from airflow.version import version as airflow_version
|
52
52
|
|
53
53
|
if TYPE_CHECKING:
|
@@ -1269,7 +1269,7 @@ class DataprocHook(GoogleBaseHook):
|
|
1269
1269
|
return all([word in error_msg for word in key_words])
|
1270
1270
|
|
1271
1271
|
|
1272
|
-
class DataprocAsyncHook(
|
1272
|
+
class DataprocAsyncHook(GoogleBaseAsyncHook):
|
1273
1273
|
"""
|
1274
1274
|
Asynchronous interaction with Google Cloud Dataproc APIs.
|
1275
1275
|
|
@@ -1277,6 +1277,8 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1277
1277
|
keyword arguments rather than positional.
|
1278
1278
|
"""
|
1279
1279
|
|
1280
|
+
sync_hook_class = DataprocHook
|
1281
|
+
|
1280
1282
|
def __init__(
|
1281
1283
|
self,
|
1282
1284
|
gcp_conn_id: str = "google_cloud_default",
|
@@ -1286,53 +1288,90 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1286
1288
|
super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs)
|
1287
1289
|
self._cached_client: JobControllerAsyncClient | None = None
|
1288
1290
|
|
1289
|
-
def get_cluster_client(self, region: str | None = None) -> ClusterControllerAsyncClient:
|
1291
|
+
async def get_cluster_client(self, region: str | None = None) -> ClusterControllerAsyncClient:
|
1290
1292
|
"""Create a ClusterControllerAsyncClient."""
|
1291
1293
|
client_options = None
|
1292
1294
|
if region and region != "global":
|
1293
1295
|
client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
|
1294
1296
|
|
1297
|
+
sync_hook = await self.get_sync_hook()
|
1295
1298
|
return ClusterControllerAsyncClient(
|
1296
|
-
credentials=
|
1299
|
+
credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
1297
1300
|
)
|
1298
1301
|
|
1299
|
-
def get_template_client(self, region: str | None = None) -> WorkflowTemplateServiceAsyncClient:
|
1302
|
+
async def get_template_client(self, region: str | None = None) -> WorkflowTemplateServiceAsyncClient:
|
1300
1303
|
"""Create a WorkflowTemplateServiceAsyncClient."""
|
1301
1304
|
client_options = None
|
1302
1305
|
if region and region != "global":
|
1303
1306
|
client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
|
1304
1307
|
|
1308
|
+
sync_hook = await self.get_sync_hook()
|
1305
1309
|
return WorkflowTemplateServiceAsyncClient(
|
1306
|
-
credentials=
|
1310
|
+
credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
1307
1311
|
)
|
1308
1312
|
|
1309
|
-
def get_job_client(self, region: str | None = None) -> JobControllerAsyncClient:
|
1313
|
+
async def get_job_client(self, region: str | None = None) -> JobControllerAsyncClient:
|
1310
1314
|
"""Create a JobControllerAsyncClient."""
|
1311
1315
|
if self._cached_client is None:
|
1312
1316
|
client_options = None
|
1313
1317
|
if region and region != "global":
|
1314
1318
|
client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
|
1315
1319
|
|
1320
|
+
sync_hook = await self.get_sync_hook()
|
1316
1321
|
self._cached_client = JobControllerAsyncClient(
|
1317
|
-
credentials=
|
1322
|
+
credentials=sync_hook.get_credentials(),
|
1318
1323
|
client_info=CLIENT_INFO,
|
1319
1324
|
client_options=client_options,
|
1320
1325
|
)
|
1321
1326
|
return self._cached_client
|
1322
1327
|
|
1323
|
-
def get_batch_client(self, region: str | None = None) -> BatchControllerAsyncClient:
|
1328
|
+
async def get_batch_client(self, region: str | None = None) -> BatchControllerAsyncClient:
|
1324
1329
|
"""Create a BatchControllerAsyncClient."""
|
1325
1330
|
client_options = None
|
1326
1331
|
if region and region != "global":
|
1327
1332
|
client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
|
1328
1333
|
|
1334
|
+
sync_hook = await self.get_sync_hook()
|
1329
1335
|
return BatchControllerAsyncClient(
|
1330
|
-
credentials=
|
1336
|
+
credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
1331
1337
|
)
|
1332
1338
|
|
1333
|
-
def get_operations_client(self, region: str) -> OperationsClient:
|
1339
|
+
async def get_operations_client(self, region: str) -> OperationsClient:
|
1334
1340
|
"""Create a OperationsClient."""
|
1335
|
-
|
1341
|
+
template_client = await self.get_template_client(region=region)
|
1342
|
+
return template_client.transport.operations_client
|
1343
|
+
|
1344
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
1345
|
+
async def get_cluster(
|
1346
|
+
self,
|
1347
|
+
region: str,
|
1348
|
+
cluster_name: str,
|
1349
|
+
project_id: str,
|
1350
|
+
retry: AsyncRetry | _MethodDefault = DEFAULT,
|
1351
|
+
timeout: float | None = None,
|
1352
|
+
metadata: Sequence[tuple[str, str]] = (),
|
1353
|
+
) -> Cluster:
|
1354
|
+
"""
|
1355
|
+
Get a cluster.
|
1356
|
+
|
1357
|
+
:param region: Cloud Dataproc region in which to handle the request.
|
1358
|
+
:param cluster_name: Name of the cluster to get.
|
1359
|
+
:param project_id: Google Cloud project ID that the cluster belongs to.
|
1360
|
+
:param retry: A retry object used to retry requests. If *None*, requests
|
1361
|
+
will not be retried.
|
1362
|
+
:param timeout: The amount of time, in seconds, to wait for the request
|
1363
|
+
to complete. If *retry* is specified, the timeout applies to each
|
1364
|
+
individual attempt.
|
1365
|
+
:param metadata: Additional metadata that is provided to the method.
|
1366
|
+
"""
|
1367
|
+
client = await self.get_cluster_client(region=region)
|
1368
|
+
result = await client.get_cluster(
|
1369
|
+
request={"project_id": project_id, "region": region, "cluster_name": cluster_name},
|
1370
|
+
retry=retry,
|
1371
|
+
timeout=timeout,
|
1372
|
+
metadata=metadata,
|
1373
|
+
)
|
1374
|
+
return result
|
1336
1375
|
|
1337
1376
|
@GoogleBaseHook.fallback_to_default_project_id
|
1338
1377
|
async def create_cluster(
|
@@ -1390,7 +1429,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1390
1429
|
cluster["config"] = cluster_config # type: ignore
|
1391
1430
|
cluster["labels"] = labels # type: ignore
|
1392
1431
|
|
1393
|
-
client = self.get_cluster_client(region=region)
|
1432
|
+
client = await self.get_cluster_client(region=region)
|
1394
1433
|
result = await client.create_cluster(
|
1395
1434
|
request={
|
1396
1435
|
"project_id": project_id,
|
@@ -1435,7 +1474,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1435
1474
|
individual attempt.
|
1436
1475
|
:param metadata: Additional metadata that is provided to the method.
|
1437
1476
|
"""
|
1438
|
-
client = self.get_cluster_client(region=region)
|
1477
|
+
client = await self.get_cluster_client(region=region)
|
1439
1478
|
result = await client.delete_cluster(
|
1440
1479
|
request={
|
1441
1480
|
"project_id": project_id,
|
@@ -1483,7 +1522,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1483
1522
|
individual attempt.
|
1484
1523
|
:param metadata: Additional metadata that is provided to the method.
|
1485
1524
|
"""
|
1486
|
-
client = self.get_cluster_client(region=region)
|
1525
|
+
client = await self.get_cluster_client(region=region)
|
1487
1526
|
result = await client.diagnose_cluster(
|
1488
1527
|
request={
|
1489
1528
|
"project_id": project_id,
|
@@ -1500,38 +1539,6 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1500
1539
|
)
|
1501
1540
|
return result
|
1502
1541
|
|
1503
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
1504
|
-
async def get_cluster(
|
1505
|
-
self,
|
1506
|
-
region: str,
|
1507
|
-
cluster_name: str,
|
1508
|
-
project_id: str,
|
1509
|
-
retry: AsyncRetry | _MethodDefault = DEFAULT,
|
1510
|
-
timeout: float | None = None,
|
1511
|
-
metadata: Sequence[tuple[str, str]] = (),
|
1512
|
-
) -> Cluster:
|
1513
|
-
"""
|
1514
|
-
Get the resource representation for a cluster in a project.
|
1515
|
-
|
1516
|
-
:param project_id: Google Cloud project ID that the cluster belongs to.
|
1517
|
-
:param region: Cloud Dataproc region to handle the request.
|
1518
|
-
:param cluster_name: The cluster name.
|
1519
|
-
:param retry: A retry object used to retry requests. If *None*, requests
|
1520
|
-
will not be retried.
|
1521
|
-
:param timeout: The amount of time, in seconds, to wait for the request
|
1522
|
-
to complete. If *retry* is specified, the timeout applies to each
|
1523
|
-
individual attempt.
|
1524
|
-
:param metadata: Additional metadata that is provided to the method.
|
1525
|
-
"""
|
1526
|
-
client = self.get_cluster_client(region=region)
|
1527
|
-
result = await client.get_cluster(
|
1528
|
-
request={"project_id": project_id, "region": region, "cluster_name": cluster_name},
|
1529
|
-
retry=retry,
|
1530
|
-
timeout=timeout,
|
1531
|
-
metadata=metadata,
|
1532
|
-
)
|
1533
|
-
return result
|
1534
|
-
|
1535
1542
|
@GoogleBaseHook.fallback_to_default_project_id
|
1536
1543
|
async def list_clusters(
|
1537
1544
|
self,
|
@@ -1561,7 +1568,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1561
1568
|
individual attempt.
|
1562
1569
|
:param metadata: Additional metadata that is provided to the method.
|
1563
1570
|
"""
|
1564
|
-
client = self.get_cluster_client(region=region)
|
1571
|
+
client = await self.get_cluster_client(region=region)
|
1565
1572
|
result = await client.list_clusters(
|
1566
1573
|
request={"project_id": project_id, "region": region, "filter": filter_, "page_size": page_size},
|
1567
1574
|
retry=retry,
|
@@ -1638,7 +1645,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1638
1645
|
"""
|
1639
1646
|
if region is None:
|
1640
1647
|
raise TypeError("missing 1 required keyword argument: 'region'")
|
1641
|
-
client = self.get_cluster_client(region=region)
|
1648
|
+
client = await self.get_cluster_client(region=region)
|
1642
1649
|
operation = await client.update_cluster(
|
1643
1650
|
request={
|
1644
1651
|
"project_id": project_id,
|
@@ -1680,10 +1687,8 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1680
1687
|
individual attempt.
|
1681
1688
|
:param metadata: Additional metadata that is provided to the method.
|
1682
1689
|
"""
|
1683
|
-
if region is None:
|
1684
|
-
raise TypeError("missing 1 required keyword argument: 'region'")
|
1685
1690
|
metadata = metadata or ()
|
1686
|
-
client = self.get_template_client(region)
|
1691
|
+
client = await self.get_template_client(region)
|
1687
1692
|
parent = f"projects/{project_id}/regions/{region}"
|
1688
1693
|
return await client.create_workflow_template(
|
1689
1694
|
request={"parent": parent, "template": template}, retry=retry, timeout=timeout, metadata=metadata
|
@@ -1725,10 +1730,8 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1725
1730
|
individual attempt.
|
1726
1731
|
:param metadata: Additional metadata that is provided to the method.
|
1727
1732
|
"""
|
1728
|
-
if region is None:
|
1729
|
-
raise TypeError("missing 1 required keyword argument: 'region'")
|
1730
1733
|
metadata = metadata or ()
|
1731
|
-
client = self.get_template_client(region)
|
1734
|
+
client = await self.get_template_client(region)
|
1732
1735
|
name = f"projects/{project_id}/regions/{region}/workflowTemplates/{template_name}"
|
1733
1736
|
operation = await client.instantiate_workflow_template(
|
1734
1737
|
request={"name": name, "version": version, "request_id": request_id, "parameters": parameters},
|
@@ -1767,10 +1770,8 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1767
1770
|
individual attempt.
|
1768
1771
|
:param metadata: Additional metadata that is provided to the method.
|
1769
1772
|
"""
|
1770
|
-
if region is None:
|
1771
|
-
raise TypeError("missing 1 required keyword argument: 'region'")
|
1772
1773
|
metadata = metadata or ()
|
1773
|
-
client = self.get_template_client(region)
|
1774
|
+
client = await self.get_template_client(region)
|
1774
1775
|
parent = f"projects/{project_id}/regions/{region}"
|
1775
1776
|
operation = await client.instantiate_inline_workflow_template(
|
1776
1777
|
request={"parent": parent, "template": template, "request_id": request_id},
|
@@ -1781,7 +1782,8 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1781
1782
|
return operation
|
1782
1783
|
|
1783
1784
|
async def get_operation(self, region, operation_name):
|
1784
|
-
|
1785
|
+
operations_client = await self.get_operations_client(region)
|
1786
|
+
return await operations_client.get_operation(name=operation_name)
|
1785
1787
|
|
1786
1788
|
@GoogleBaseHook.fallback_to_default_project_id
|
1787
1789
|
async def get_job(
|
@@ -1806,9 +1808,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1806
1808
|
individual attempt.
|
1807
1809
|
:param metadata: Additional metadata that is provided to the method.
|
1808
1810
|
"""
|
1809
|
-
|
1810
|
-
raise TypeError("missing 1 required keyword argument: 'region'")
|
1811
|
-
client = self.get_job_client(region=region)
|
1811
|
+
client = await self.get_job_client(region=region)
|
1812
1812
|
job = await client.get_job(
|
1813
1813
|
request={"project_id": project_id, "region": region, "job_id": job_id},
|
1814
1814
|
retry=retry,
|
@@ -1845,9 +1845,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1845
1845
|
individual attempt.
|
1846
1846
|
:param metadata: Additional metadata that is provided to the method.
|
1847
1847
|
"""
|
1848
|
-
|
1849
|
-
raise TypeError("missing 1 required keyword argument: 'region'")
|
1850
|
-
client = self.get_job_client(region=region)
|
1848
|
+
client = await self.get_job_client(region=region)
|
1851
1849
|
return await client.submit_job(
|
1852
1850
|
request={"project_id": project_id, "region": region, "job": job, "request_id": request_id},
|
1853
1851
|
retry=retry,
|
@@ -1878,7 +1876,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1878
1876
|
individual attempt.
|
1879
1877
|
:param metadata: Additional metadata that is provided to the method.
|
1880
1878
|
"""
|
1881
|
-
client = self.get_job_client(region=region)
|
1879
|
+
client = await self.get_job_client(region=region)
|
1882
1880
|
|
1883
1881
|
job = await client.cancel_job(
|
1884
1882
|
request={"project_id": project_id, "region": region, "job_id": job_id},
|
@@ -1920,7 +1918,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1920
1918
|
individual attempt.
|
1921
1919
|
:param metadata: Additional metadata that is provided to the method.
|
1922
1920
|
"""
|
1923
|
-
client = self.get_batch_client(region)
|
1921
|
+
client = await self.get_batch_client(region)
|
1924
1922
|
parent = f"projects/{project_id}/regions/{region}"
|
1925
1923
|
|
1926
1924
|
result = await client.create_batch(
|
@@ -1959,7 +1957,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1959
1957
|
individual attempt.
|
1960
1958
|
:param metadata: Additional metadata that is provided to the method.
|
1961
1959
|
"""
|
1962
|
-
client = self.get_batch_client(region)
|
1960
|
+
client = await self.get_batch_client(region)
|
1963
1961
|
name = f"projects/{project_id}/locations/{region}/batches/{batch_id}"
|
1964
1962
|
|
1965
1963
|
await client.delete_batch(
|
@@ -1994,7 +1992,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
1994
1992
|
individual attempt.
|
1995
1993
|
:param metadata: Additional metadata that is provided to the method.
|
1996
1994
|
"""
|
1997
|
-
client = self.get_batch_client(region)
|
1995
|
+
client = await self.get_batch_client(region)
|
1998
1996
|
name = f"projects/{project_id}/locations/{region}/batches/{batch_id}"
|
1999
1997
|
|
2000
1998
|
result = await client.get_batch(
|
@@ -2039,7 +2037,7 @@ class DataprocAsyncHook(GoogleBaseHook):
|
|
2039
2037
|
:param filter: Result filters as specified in ListBatchesRequest
|
2040
2038
|
:param order_by: How to order results as specified in ListBatchesRequest
|
2041
2039
|
"""
|
2042
|
-
client = self.get_batch_client(region)
|
2040
|
+
client = await self.get_batch_client(region)
|
2043
2041
|
parent = f"projects/{project_id}/regions/{region}"
|
2044
2042
|
|
2045
2043
|
result = await client.list_batches(
|
@@ -31,14 +31,13 @@ from contextlib import contextmanager
|
|
31
31
|
from functools import partial
|
32
32
|
from io import BytesIO
|
33
33
|
from tempfile import NamedTemporaryFile
|
34
|
-
from typing import IO, TYPE_CHECKING, Any, TypeVar, cast, overload
|
34
|
+
from typing import IO, TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload
|
35
35
|
from urllib.parse import urlsplit
|
36
36
|
|
37
|
+
# Make mypy happy by importing as aliases
|
38
|
+
import google.cloud.storage as storage
|
37
39
|
from gcloud.aio.storage import Storage
|
38
40
|
from google.api_core.exceptions import GoogleAPICallError, NotFound
|
39
|
-
|
40
|
-
# not sure why but mypy complains on missing `storage` but it is clearly there and is importable
|
41
|
-
from google.cloud import storage # type: ignore[attr-defined]
|
42
41
|
from google.cloud.exceptions import GoogleCloudError
|
43
42
|
from google.cloud.storage.retry import DEFAULT_RETRY
|
44
43
|
|
@@ -51,7 +50,6 @@ from airflow.providers.google.common.hooks.base_google import (
|
|
51
50
|
GoogleBaseAsyncHook,
|
52
51
|
GoogleBaseHook,
|
53
52
|
)
|
54
|
-
from airflow.typing_compat import ParamSpec
|
55
53
|
from airflow.utils import timezone
|
56
54
|
from airflow.version import version
|
57
55
|
|
@@ -30,7 +30,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
30
30
|
from google.auth.transport import requests as google_requests
|
31
31
|
|
32
32
|
# not sure why but mypy complains on missing `container_v1` but it is clearly there and is importable
|
33
|
-
from google.cloud import exceptions
|
33
|
+
from google.cloud import exceptions
|
34
34
|
from google.cloud.container_v1 import ClusterManagerAsyncClient, ClusterManagerClient
|
35
35
|
from google.cloud.container_v1.types import Cluster, Operation
|
36
36
|
from kubernetes import client
|
@@ -498,7 +498,7 @@ class GKEKubernetesAsyncHook(GoogleBaseAsyncHook, AsyncKubernetesHook):
|
|
498
498
|
)
|
499
499
|
|
500
500
|
@contextlib.asynccontextmanager
|
501
|
-
async def get_conn(self) -> async_client.ApiClient:
|
501
|
+
async def get_conn(self) -> async_client.ApiClient:
|
502
502
|
kube_client = None
|
503
503
|
try:
|
504
504
|
kube_client = await self._load_config()
|
@@ -29,11 +29,7 @@ from looker_sdk.sdk.api40 import methods as methods40
|
|
29
29
|
from packaging.version import parse as parse_version
|
30
30
|
|
31
31
|
from airflow.exceptions import AirflowException
|
32
|
-
|
33
|
-
try:
|
34
|
-
from airflow.sdk import BaseHook
|
35
|
-
except ImportError:
|
36
|
-
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]
|
32
|
+
from airflow.providers.google.version_compat import BaseHook
|
37
33
|
from airflow.version import version
|
38
34
|
|
39
35
|
if TYPE_CHECKING:
|
@@ -261,8 +261,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
261
261
|
channel_name_map = {}
|
262
262
|
|
263
263
|
for channel in channels:
|
264
|
+
# This field is immutable, illegal to specifying non-default UNVERIFIED or VERIFIED, so setting default
|
264
265
|
channel.verification_status = (
|
265
|
-
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
|
266
|
+
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED # type: ignore[assignment]
|
266
267
|
)
|
267
268
|
|
268
269
|
if channel.name in existing_channels:
|
@@ -274,7 +275,7 @@ class StackdriverHook(GoogleBaseHook):
|
|
274
275
|
)
|
275
276
|
else:
|
276
277
|
old_name = channel.name
|
277
|
-
channel.name
|
278
|
+
del channel.name
|
278
279
|
new_channel = channel_client.create_notification_channel(
|
279
280
|
request={"name": f"projects/{project_id}", "notification_channel": channel},
|
280
281
|
retry=retry,
|
@@ -284,8 +285,8 @@ class StackdriverHook(GoogleBaseHook):
|
|
284
285
|
channel_name_map[old_name] = new_channel.name
|
285
286
|
|
286
287
|
for policy in policies_:
|
287
|
-
policy.creation_record
|
288
|
-
policy.mutation_record
|
288
|
+
del policy.creation_record
|
289
|
+
del policy.mutation_record
|
289
290
|
|
290
291
|
for i, channel in enumerate(policy.notification_channels):
|
291
292
|
new_channel = channel_name_map.get(channel)
|
@@ -301,9 +302,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
301
302
|
metadata=metadata,
|
302
303
|
)
|
303
304
|
else:
|
304
|
-
policy.name
|
305
|
+
del policy.name
|
305
306
|
for condition in policy.conditions:
|
306
|
-
condition.name
|
307
|
+
del condition.name
|
307
308
|
policy_client.create_alert_policy(
|
308
309
|
request={"name": f"projects/{project_id}", "alert_policy": policy},
|
309
310
|
retry=retry,
|
@@ -531,8 +532,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
531
532
|
channels_list.append(NotificationChannel(**channel))
|
532
533
|
|
533
534
|
for channel in channels_list:
|
535
|
+
# This field is immutable, illegal to specifying non-default UNVERIFIED or VERIFIED, so setting default
|
534
536
|
channel.verification_status = (
|
535
|
-
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
|
537
|
+
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED # type: ignore[assignment]
|
536
538
|
)
|
537
539
|
|
538
540
|
if channel.name in existing_channels:
|
@@ -544,7 +546,7 @@ class StackdriverHook(GoogleBaseHook):
|
|
544
546
|
)
|
545
547
|
else:
|
546
548
|
old_name = channel.name
|
547
|
-
channel.name
|
549
|
+
del channel.name
|
548
550
|
new_channel = channel_client.create_notification_channel(
|
549
551
|
request={"name": f"projects/{project_id}", "notification_channel": channel},
|
550
552
|
retry=retry,
|
@@ -1098,13 +1098,13 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
|
|
1098
1098
|
raise AirflowException("AutoMLTextTrainingJob was not created")
|
1099
1099
|
|
1100
1100
|
model = self._job.run(
|
1101
|
-
dataset=dataset,
|
1102
|
-
training_fraction_split=training_fraction_split,
|
1103
|
-
validation_fraction_split=validation_fraction_split,
|
1101
|
+
dataset=dataset,
|
1102
|
+
training_fraction_split=training_fraction_split,
|
1103
|
+
validation_fraction_split=validation_fraction_split,
|
1104
1104
|
test_fraction_split=test_fraction_split,
|
1105
1105
|
training_filter_split=training_filter_split,
|
1106
1106
|
validation_filter_split=validation_filter_split,
|
1107
|
-
test_filter_split=test_filter_split,
|
1107
|
+
test_filter_split=test_filter_split,
|
1108
1108
|
model_display_name=model_display_name,
|
1109
1109
|
model_labels=model_labels,
|
1110
1110
|
sync=sync,
|
@@ -0,0 +1,202 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from google.cloud import aiplatform
|
21
|
+
from google.cloud.aiplatform.compat.types import execution_v1 as gca_execution
|
22
|
+
|
23
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
24
|
+
|
25
|
+
|
26
|
+
class ExperimentHook(GoogleBaseHook):
|
27
|
+
"""Use the Vertex AI SDK for Python to manage your experiments."""
|
28
|
+
|
29
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
30
|
+
def create_experiment(
|
31
|
+
self,
|
32
|
+
experiment_name: str,
|
33
|
+
location: str,
|
34
|
+
experiment_description: str = "",
|
35
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
36
|
+
experiment_tensorboard: str | None = None,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Create an experiment and, optionally, associate a Vertex AI TensorBoard instance using the Vertex AI SDK for Python.
|
40
|
+
|
41
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
42
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
43
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
44
|
+
:param experiment_description: Optional. Description of the evaluation experiment.
|
45
|
+
:param experiment_tensorboard: Optional. The Vertex TensorBoard instance to use as a backing
|
46
|
+
TensorBoard for the provided experiment. If no TensorBoard is provided, a default Tensorboard
|
47
|
+
instance is created and used by this experiment.
|
48
|
+
"""
|
49
|
+
aiplatform.init(
|
50
|
+
experiment=experiment_name,
|
51
|
+
experiment_description=experiment_description,
|
52
|
+
experiment_tensorboard=experiment_tensorboard if experiment_tensorboard else False,
|
53
|
+
project=project_id,
|
54
|
+
location=location,
|
55
|
+
)
|
56
|
+
self.log.info("Created experiment with name: %s", experiment_name)
|
57
|
+
|
58
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
59
|
+
def delete_experiment(
|
60
|
+
self,
|
61
|
+
experiment_name: str,
|
62
|
+
location: str,
|
63
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
64
|
+
delete_backing_tensorboard_runs: bool = False,
|
65
|
+
) -> None:
|
66
|
+
"""
|
67
|
+
Delete an experiment.
|
68
|
+
|
69
|
+
Deleting an experiment deletes that experiment and all experiment runs associated with the experiment.
|
70
|
+
The Vertex AI TensorBoard experiment associated with the experiment is not deleted.
|
71
|
+
|
72
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
73
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
74
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
75
|
+
:param delete_backing_tensorboard_runs: Optional. If True will also delete the Vertex AI TensorBoard
|
76
|
+
runs associated with the experiment runs under this experiment that we used to store time series
|
77
|
+
metrics.
|
78
|
+
"""
|
79
|
+
experiment = aiplatform.Experiment(
|
80
|
+
experiment_name=experiment_name, project=project_id, location=location
|
81
|
+
)
|
82
|
+
|
83
|
+
experiment.delete(delete_backing_tensorboard_runs=delete_backing_tensorboard_runs)
|
84
|
+
|
85
|
+
|
86
|
+
class ExperimentRunHook(GoogleBaseHook):
|
87
|
+
"""Use the Vertex AI SDK for Python to create and manage your experiment runs."""
|
88
|
+
|
89
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
90
|
+
def create_experiment_run(
|
91
|
+
self,
|
92
|
+
experiment_run_name: str,
|
93
|
+
experiment_name: str,
|
94
|
+
location: str,
|
95
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
96
|
+
experiment_run_tensorboard: str | None = None,
|
97
|
+
run_after_creation: bool = False,
|
98
|
+
) -> None:
|
99
|
+
"""
|
100
|
+
Create experiment run for the experiment.
|
101
|
+
|
102
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
103
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
104
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
105
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
106
|
+
:param experiment_run_tensorboard: Optional. A backing TensorBoard resource to enable and store time
|
107
|
+
series metrics logged to this experiment run.
|
108
|
+
:param run_after_creation: Optional. Responsible for state after creation of experiment run.
|
109
|
+
If true experiment run will be created with state RUNNING.
|
110
|
+
"""
|
111
|
+
experiment_run_state = (
|
112
|
+
gca_execution.Execution.State.NEW
|
113
|
+
if not run_after_creation
|
114
|
+
else gca_execution.Execution.State.RUNNING
|
115
|
+
)
|
116
|
+
experiment_run = aiplatform.ExperimentRun.create(
|
117
|
+
run_name=experiment_run_name,
|
118
|
+
experiment=experiment_name,
|
119
|
+
project=project_id,
|
120
|
+
location=location,
|
121
|
+
state=experiment_run_state,
|
122
|
+
tensorboard=experiment_run_tensorboard,
|
123
|
+
)
|
124
|
+
self.log.info(
|
125
|
+
"Created experiment run with name: %s and status: %s",
|
126
|
+
experiment_run.name,
|
127
|
+
experiment_run.state,
|
128
|
+
)
|
129
|
+
|
130
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
131
|
+
def list_experiment_runs(
|
132
|
+
self,
|
133
|
+
experiment_name: str,
|
134
|
+
location: str,
|
135
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
136
|
+
) -> list[aiplatform.ExperimentRun]:
|
137
|
+
"""
|
138
|
+
List experiment run for the experiment.
|
139
|
+
|
140
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
141
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
142
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
143
|
+
"""
|
144
|
+
experiment_runs = aiplatform.ExperimentRun.list(
|
145
|
+
experiment=experiment_name,
|
146
|
+
project=project_id,
|
147
|
+
location=location,
|
148
|
+
)
|
149
|
+
return experiment_runs
|
150
|
+
|
151
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
152
|
+
def update_experiment_run_state(
|
153
|
+
self,
|
154
|
+
experiment_run_name: str,
|
155
|
+
experiment_name: str,
|
156
|
+
location: str,
|
157
|
+
new_state: gca_execution.Execution.State,
|
158
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
159
|
+
) -> None:
|
160
|
+
"""
|
161
|
+
Update state of the experiment run.
|
162
|
+
|
163
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
164
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
165
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
166
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
167
|
+
:param new_state: Required. New state of the experiment run.
|
168
|
+
"""
|
169
|
+
experiment_run = aiplatform.ExperimentRun(
|
170
|
+
run_name=experiment_run_name,
|
171
|
+
experiment=experiment_name,
|
172
|
+
project=project_id,
|
173
|
+
location=location,
|
174
|
+
)
|
175
|
+
self.log.info("State of the %s before update is: %s", experiment_run.name, experiment_run.state)
|
176
|
+
|
177
|
+
experiment_run.update_state(new_state)
|
178
|
+
|
179
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
180
|
+
def delete_experiment_run(
|
181
|
+
self,
|
182
|
+
experiment_run_name: str,
|
183
|
+
experiment_name: str,
|
184
|
+
location: str,
|
185
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
186
|
+
delete_backing_tensorboard_run: bool = False,
|
187
|
+
) -> None:
|
188
|
+
"""
|
189
|
+
Delete experiment run from the experiment.
|
190
|
+
|
191
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
192
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
193
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
194
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
195
|
+
:param delete_backing_tensorboard_run: Whether to delete the backing Vertex AI TensorBoard run
|
196
|
+
that stores time series metrics for this run.
|
197
|
+
"""
|
198
|
+
self.log.info("Next experiment run will be deleted: %s", experiment_run_name)
|
199
|
+
experiment_run = aiplatform.ExperimentRun(
|
200
|
+
run_name=experiment_run_name, experiment=experiment_name, project=project_id, location=location
|
201
|
+
)
|
202
|
+
experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run)
|