apache-airflow-providers-google 10.19.0__py3-none-any.whl → 10.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/LICENSE +4 -4
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +4 -4
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +26 -0
- airflow/providers/google/cloud/hooks/dataflow.py +132 -1
- airflow/providers/google/cloud/hooks/datapipeline.py +22 -73
- airflow/providers/google/cloud/hooks/gcs.py +21 -0
- airflow/providers/google/cloud/hooks/pubsub.py +10 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -0
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +15 -3
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
- airflow/providers/google/cloud/links/dataflow.py +25 -0
- airflow/providers/google/cloud/openlineage/mixins.py +271 -0
- airflow/providers/google/cloud/openlineage/utils.py +5 -218
- airflow/providers/google/cloud/operators/bigquery.py +74 -20
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +76 -0
- airflow/providers/google/cloud/operators/dataflow.py +235 -1
- airflow/providers/google/cloud/operators/datapipeline.py +29 -121
- airflow/providers/google/cloud/operators/dataplex.py +1 -1
- airflow/providers/google/cloud/operators/dataproc_metastore.py +17 -6
- airflow/providers/google/cloud/operators/kubernetes_engine.py +9 -6
- airflow/providers/google/cloud/operators/pubsub.py +18 -0
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +6 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +16 -0
- airflow/providers/google/cloud/sensors/cloud_composer.py +171 -2
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +13 -0
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +56 -1
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +6 -12
- airflow/providers/google/cloud/triggers/cloud_composer.py +115 -0
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -0
- airflow/providers/google/cloud/utils/credentials_provider.py +81 -6
- airflow/providers/google/cloud/utils/external_token_supplier.py +175 -0
- airflow/providers/google/common/hooks/base_google.py +35 -1
- airflow/providers/google/common/utils/id_token_credentials.py +1 -1
- airflow/providers/google/get_provider_info.py +19 -14
- {apache_airflow_providers_google-10.19.0.dist-info → apache_airflow_providers_google-10.20.0.dist-info}/METADATA +41 -35
- {apache_airflow_providers_google-10.19.0.dist-info → apache_airflow_providers_google-10.20.0.dist-info}/RECORD +39 -37
- {apache_airflow_providers_google-10.19.0.dist-info → apache_airflow_providers_google-10.20.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.19.0.dist-info → apache_airflow_providers_google-10.20.0.dist-info}/entry_points.txt +0 -0
airflow/providers/google/LICENSE
CHANGED
@@ -215,7 +215,7 @@ Third party Apache 2.0 licenses
|
|
215
215
|
|
216
216
|
The following components are provided under the Apache 2.0 License.
|
217
217
|
See project link for details. The text of each license is also included
|
218
|
-
at licenses/LICENSE-[project].txt.
|
218
|
+
at 3rd-party-licenses/LICENSE-[project].txt.
|
219
219
|
|
220
220
|
(ALv2 License) hue v4.3.0 (https://github.com/cloudera/hue/)
|
221
221
|
(ALv2 License) jqclock v2.3.0 (https://github.com/JohnRDOrazio/jQuery-Clock-Plugin)
|
@@ -227,7 +227,7 @@ MIT licenses
|
|
227
227
|
========================================================================
|
228
228
|
|
229
229
|
The following components are provided under the MIT License. See project link for details.
|
230
|
-
The text of each license is also included at licenses/LICENSE-[project].txt.
|
230
|
+
The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
|
231
231
|
|
232
232
|
(MIT License) jquery v3.5.1 (https://jquery.org/license/)
|
233
233
|
(MIT License) dagre-d3 v0.6.4 (https://github.com/cpettitt/dagre-d3)
|
@@ -243,11 +243,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
|
|
243
243
|
BSD 3-Clause licenses
|
244
244
|
========================================================================
|
245
245
|
The following components are provided under the BSD 3-Clause license. See project links for details.
|
246
|
-
The text of each license is also included at licenses/LICENSE-[project].txt.
|
246
|
+
The text of each license is also included at 3rd-party-licenses/LICENSE-[project].txt.
|
247
247
|
|
248
248
|
(BSD 3 License) d3 v5.16.0 (https://d3js.org)
|
249
249
|
(BSD 3 License) d3-shape v2.1.0 (https://github.com/d3/d3-shape)
|
250
250
|
(BSD 3 License) cgroupspy 0.2.1 (https://github.com/cloudsigma/cgroupspy)
|
251
251
|
|
252
252
|
========================================================================
|
253
|
-
See licenses/LICENSES-ui.txt for packages used in `/airflow/www`
|
253
|
+
See 3rd-party-licenses/LICENSES-ui.txt for packages used in `/airflow/www`
|
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "10.
|
32
|
+
__version__ = "10.20.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.7.0"
|
@@ -32,9 +32,9 @@ from airflow.hooks.base import BaseHook
|
|
32
32
|
from airflow.providers.google.common.hooks.base_google import get_field
|
33
33
|
|
34
34
|
if TYPE_CHECKING:
|
35
|
-
from google.ads.googleads.
|
36
|
-
from google.ads.googleads.
|
37
|
-
from google.ads.googleads.
|
35
|
+
from google.ads.googleads.v17.services.services.customer_service import CustomerServiceClient
|
36
|
+
from google.ads.googleads.v17.services.services.google_ads_service import GoogleAdsServiceClient
|
37
|
+
from google.ads.googleads.v17.services.types.google_ads_service import GoogleAdsRow
|
38
38
|
from google.api_core.page_iterator import GRPCIterator
|
39
39
|
|
40
40
|
|
@@ -100,7 +100,7 @@ class GoogleAdsHook(BaseHook):
|
|
100
100
|
:param api_version: The Google Ads API version to use.
|
101
101
|
"""
|
102
102
|
|
103
|
-
default_api_version = "
|
103
|
+
default_api_version = "v17"
|
104
104
|
|
105
105
|
def __init__(
|
106
106
|
self,
|
@@ -344,6 +344,32 @@ class CloudDataTransferServiceHook(GoogleBaseHook):
|
|
344
344
|
.execute(num_retries=self.num_retries)
|
345
345
|
)
|
346
346
|
|
347
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
348
|
+
def run_transfer_job(self, job_name: str, project_id: str) -> dict:
|
349
|
+
"""Run Google Storage Transfer Service job.
|
350
|
+
|
351
|
+
:param job_name: (Required) Name of the job to be fetched
|
352
|
+
:param project_id: (Optional) the ID of the project that owns the Transfer
|
353
|
+
Job. If set to None or missing, the default project_id from the Google Cloud
|
354
|
+
connection is used.
|
355
|
+
:return: If successful, Operation. See:
|
356
|
+
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/Operation
|
357
|
+
|
358
|
+
.. seealso:: https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/run
|
359
|
+
|
360
|
+
"""
|
361
|
+
return (
|
362
|
+
self.get_conn()
|
363
|
+
.transferJobs()
|
364
|
+
.run(
|
365
|
+
jobName=job_name,
|
366
|
+
body={
|
367
|
+
PROJECT_ID: project_id,
|
368
|
+
},
|
369
|
+
)
|
370
|
+
.execute(num_retries=self.num_retries)
|
371
|
+
)
|
372
|
+
|
347
373
|
def cancel_transfer_operation(self, operation_name: str) -> None:
|
348
374
|
"""Cancel a transfer operation in Google Storage Transfer Service.
|
349
375
|
|
@@ -71,7 +71,7 @@ DEFAULT_DATAFLOW_LOCATION = "us-central1"
|
|
71
71
|
|
72
72
|
|
73
73
|
JOB_ID_PATTERN = re.compile(
|
74
|
-
r"Submitted job: (?P<job_id_java
|
74
|
+
r"Submitted job: (?P<job_id_java>[^\"\n]*)|Created job with id: \[(?P<job_id_python>[^\"\n]*)\]"
|
75
75
|
)
|
76
76
|
|
77
77
|
T = TypeVar("T", bound=Callable)
|
@@ -582,6 +582,11 @@ class DataflowHook(GoogleBaseHook):
|
|
582
582
|
http_authorized = self._authorize()
|
583
583
|
return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False)
|
584
584
|
|
585
|
+
def get_pipelines_conn(self) -> build:
|
586
|
+
"""Return a Google Cloud Data Pipelines service object."""
|
587
|
+
http_authorized = self._authorize()
|
588
|
+
return build("datapipelines", "v1", http=http_authorized, cache_discovery=False)
|
589
|
+
|
585
590
|
@_fallback_to_location_from_variables
|
586
591
|
@_fallback_to_project_id_from_variables
|
587
592
|
@GoogleBaseHook.fallback_to_default_project_id
|
@@ -1351,6 +1356,132 @@ class DataflowHook(GoogleBaseHook):
|
|
1351
1356
|
|
1352
1357
|
return job_controller._check_dataflow_job_state(job)
|
1353
1358
|
|
1359
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
1360
|
+
def create_data_pipeline(
|
1361
|
+
self,
|
1362
|
+
body: dict,
|
1363
|
+
project_id: str,
|
1364
|
+
location: str = DEFAULT_DATAFLOW_LOCATION,
|
1365
|
+
):
|
1366
|
+
"""
|
1367
|
+
Create a new Dataflow Data Pipelines instance.
|
1368
|
+
|
1369
|
+
:param body: The request body (contains instance of Pipeline). See:
|
1370
|
+
https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/create#request-body
|
1371
|
+
:param project_id: The ID of the GCP project that owns the job.
|
1372
|
+
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
|
1373
|
+
|
1374
|
+
Returns the created Data Pipelines instance in JSON representation.
|
1375
|
+
"""
|
1376
|
+
parent = self.build_parent_name(project_id, location)
|
1377
|
+
service = self.get_pipelines_conn()
|
1378
|
+
request = (
|
1379
|
+
service.projects()
|
1380
|
+
.locations()
|
1381
|
+
.pipelines()
|
1382
|
+
.create(
|
1383
|
+
parent=parent,
|
1384
|
+
body=body,
|
1385
|
+
)
|
1386
|
+
)
|
1387
|
+
response = request.execute(num_retries=self.num_retries)
|
1388
|
+
return response
|
1389
|
+
|
1390
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
1391
|
+
def get_data_pipeline(
|
1392
|
+
self,
|
1393
|
+
pipeline_name: str,
|
1394
|
+
project_id: str,
|
1395
|
+
location: str = DEFAULT_DATAFLOW_LOCATION,
|
1396
|
+
) -> dict:
|
1397
|
+
"""
|
1398
|
+
Retrieve a new Dataflow Data Pipelines instance.
|
1399
|
+
|
1400
|
+
:param pipeline_name: The display name of the pipeline. In example
|
1401
|
+
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
|
1402
|
+
:param project_id: The ID of the GCP project that owns the job.
|
1403
|
+
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
|
1404
|
+
|
1405
|
+
Returns the created Data Pipelines instance in JSON representation.
|
1406
|
+
"""
|
1407
|
+
parent = self.build_parent_name(project_id, location)
|
1408
|
+
service = self.get_pipelines_conn()
|
1409
|
+
request = (
|
1410
|
+
service.projects()
|
1411
|
+
.locations()
|
1412
|
+
.pipelines()
|
1413
|
+
.get(
|
1414
|
+
name=f"{parent}/pipelines/{pipeline_name}",
|
1415
|
+
)
|
1416
|
+
)
|
1417
|
+
response = request.execute(num_retries=self.num_retries)
|
1418
|
+
return response
|
1419
|
+
|
1420
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
1421
|
+
def run_data_pipeline(
|
1422
|
+
self,
|
1423
|
+
pipeline_name: str,
|
1424
|
+
project_id: str,
|
1425
|
+
location: str = DEFAULT_DATAFLOW_LOCATION,
|
1426
|
+
) -> dict:
|
1427
|
+
"""
|
1428
|
+
Run a Dataflow Data Pipeline Instance.
|
1429
|
+
|
1430
|
+
:param pipeline_name: The display name of the pipeline. In example
|
1431
|
+
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
|
1432
|
+
:param project_id: The ID of the GCP project that owns the job.
|
1433
|
+
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
|
1434
|
+
|
1435
|
+
Returns the created Job in JSON representation.
|
1436
|
+
"""
|
1437
|
+
parent = self.build_parent_name(project_id, location)
|
1438
|
+
service = self.get_pipelines_conn()
|
1439
|
+
request = (
|
1440
|
+
service.projects()
|
1441
|
+
.locations()
|
1442
|
+
.pipelines()
|
1443
|
+
.run(
|
1444
|
+
name=f"{parent}/pipelines/{pipeline_name}",
|
1445
|
+
body={},
|
1446
|
+
)
|
1447
|
+
)
|
1448
|
+
response = request.execute(num_retries=self.num_retries)
|
1449
|
+
return response
|
1450
|
+
|
1451
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
1452
|
+
def delete_data_pipeline(
|
1453
|
+
self,
|
1454
|
+
pipeline_name: str,
|
1455
|
+
project_id: str,
|
1456
|
+
location: str = DEFAULT_DATAFLOW_LOCATION,
|
1457
|
+
) -> dict | None:
|
1458
|
+
"""
|
1459
|
+
Delete a Dataflow Data Pipelines Instance.
|
1460
|
+
|
1461
|
+
:param pipeline_name: The display name of the pipeline. In example
|
1462
|
+
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
|
1463
|
+
:param project_id: The ID of the GCP project that owns the job.
|
1464
|
+
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
|
1465
|
+
|
1466
|
+
Returns the created Job in JSON representation.
|
1467
|
+
"""
|
1468
|
+
parent = self.build_parent_name(project_id, location)
|
1469
|
+
service = self.get_pipelines_conn()
|
1470
|
+
request = (
|
1471
|
+
service.projects()
|
1472
|
+
.locations()
|
1473
|
+
.pipelines()
|
1474
|
+
.delete(
|
1475
|
+
name=f"{parent}/pipelines/{pipeline_name}",
|
1476
|
+
)
|
1477
|
+
)
|
1478
|
+
response = request.execute(num_retries=self.num_retries)
|
1479
|
+
return response
|
1480
|
+
|
1481
|
+
@staticmethod
|
1482
|
+
def build_parent_name(project_id: str, location: str):
|
1483
|
+
return f"projects/{project_id}/locations/{location}"
|
1484
|
+
|
1354
1485
|
|
1355
1486
|
class AsyncDataflowHook(GoogleBaseAsyncHook):
|
1356
1487
|
"""Async hook class for dataflow service."""
|
@@ -19,40 +19,30 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
-
from typing import
|
22
|
+
from typing import TYPE_CHECKING
|
23
23
|
|
24
|
-
from
|
24
|
+
from deprecated import deprecated
|
25
25
|
|
26
|
-
from airflow.
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
DEFAULT_DATAPIPELINE_LOCATION = "us-central1"
|
26
|
+
from airflow.exceptions import AirflowProviderDeprecationWarning
|
27
|
+
from airflow.providers.google.cloud.hooks.dataflow import DataflowHook
|
28
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
31
29
|
|
30
|
+
if TYPE_CHECKING:
|
31
|
+
from googleapiclient.discovery import build
|
32
32
|
|
33
|
-
|
34
|
-
"""
|
35
|
-
Hook for Google Data Pipelines.
|
33
|
+
DEFAULT_DATAPIPELINE_LOCATION = "us-central1"
|
36
34
|
|
37
|
-
All the methods in the hook where project_id is used must be called with
|
38
|
-
keyword arguments rather than positional.
|
39
|
-
"""
|
40
35
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
super().__init__(
|
48
|
-
gcp_conn_id=gcp_conn_id,
|
49
|
-
impersonation_chain=impersonation_chain,
|
50
|
-
)
|
36
|
+
@deprecated(
|
37
|
+
reason="This hook is deprecated and will be removed after 01.12.2024. Please use `DataflowHook`.",
|
38
|
+
category=AirflowProviderDeprecationWarning,
|
39
|
+
)
|
40
|
+
class DataPipelineHook(DataflowHook):
|
41
|
+
"""Hook for Google Data Pipelines."""
|
51
42
|
|
52
43
|
def get_conn(self) -> build:
|
53
44
|
"""Return a Google Cloud Data Pipelines service object."""
|
54
|
-
|
55
|
-
return build("datapipelines", "v1", http=http_authorized, cache_discovery=False)
|
45
|
+
return super().get_pipelines_conn()
|
56
46
|
|
57
47
|
@GoogleBaseHook.fallback_to_default_project_id
|
58
48
|
def create_data_pipeline(
|
@@ -60,31 +50,9 @@ class DataPipelineHook(GoogleBaseHook):
|
|
60
50
|
body: dict,
|
61
51
|
project_id: str,
|
62
52
|
location: str = DEFAULT_DATAPIPELINE_LOCATION,
|
63
|
-
) ->
|
64
|
-
"""
|
65
|
-
|
66
|
-
|
67
|
-
:param body: The request body (contains instance of Pipeline). See:
|
68
|
-
https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/create#request-body
|
69
|
-
:param project_id: The ID of the GCP project that owns the job.
|
70
|
-
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
|
71
|
-
|
72
|
-
Returns the created Data Pipelines instance in JSON representation.
|
73
|
-
"""
|
74
|
-
parent = self.build_parent_name(project_id, location)
|
75
|
-
service = self.get_conn()
|
76
|
-
self.log.info(dir(service.projects().locations()))
|
77
|
-
request = (
|
78
|
-
service.projects()
|
79
|
-
.locations()
|
80
|
-
.pipelines()
|
81
|
-
.create(
|
82
|
-
parent=parent,
|
83
|
-
body=body,
|
84
|
-
)
|
85
|
-
)
|
86
|
-
response = request.execute(num_retries=self.num_retries)
|
87
|
-
return response
|
53
|
+
) -> dict:
|
54
|
+
"""Create a new Data Pipelines instance from the Data Pipelines API."""
|
55
|
+
return super().create_data_pipeline(body=body, project_id=project_id, location=location)
|
88
56
|
|
89
57
|
@GoogleBaseHook.fallback_to_default_project_id
|
90
58
|
def run_data_pipeline(
|
@@ -92,30 +60,11 @@ class DataPipelineHook(GoogleBaseHook):
|
|
92
60
|
data_pipeline_name: str,
|
93
61
|
project_id: str,
|
94
62
|
location: str = DEFAULT_DATAPIPELINE_LOCATION,
|
95
|
-
) ->
|
96
|
-
"""
|
97
|
-
|
98
|
-
|
99
|
-
:param data_pipeline_name: The display name of the pipeline. In example
|
100
|
-
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
|
101
|
-
:param project_id: The ID of the GCP project that owns the job.
|
102
|
-
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
|
103
|
-
|
104
|
-
Returns the created Job in JSON representation.
|
105
|
-
"""
|
106
|
-
parent = self.build_parent_name(project_id, location)
|
107
|
-
service = self.get_conn()
|
108
|
-
request = (
|
109
|
-
service.projects()
|
110
|
-
.locations()
|
111
|
-
.pipelines()
|
112
|
-
.run(
|
113
|
-
name=f"{parent}/pipelines/{data_pipeline_name}",
|
114
|
-
body={},
|
115
|
-
)
|
63
|
+
) -> dict:
|
64
|
+
"""Run a Data Pipelines Instance using the Data Pipelines API."""
|
65
|
+
return super().run_data_pipeline(
|
66
|
+
pipeline_name=data_pipeline_name, project_id=project_id, location=location
|
116
67
|
)
|
117
|
-
response = request.execute(num_retries=self.num_retries)
|
118
|
-
return response
|
119
68
|
|
120
69
|
@staticmethod
|
121
70
|
def build_parent_name(project_id: str, location: str):
|
@@ -1010,6 +1010,27 @@ class GCSHook(GoogleBaseHook):
|
|
1010
1010
|
self.log.info("The md5Hash of %s is %s", object_name, blob_md5hash)
|
1011
1011
|
return blob_md5hash
|
1012
1012
|
|
1013
|
+
def get_metadata(self, bucket_name: str, object_name: str) -> dict | None:
|
1014
|
+
"""
|
1015
|
+
Get the metadata of an object in Google Cloud Storage.
|
1016
|
+
|
1017
|
+
:param bucket_name: Name of the Google Cloud Storage bucket where the object is.
|
1018
|
+
:param object_name: The name of the object containing the desired metadata
|
1019
|
+
:return: The metadata associated with the object
|
1020
|
+
"""
|
1021
|
+
self.log.info("Retrieving the metadata dict of object (%s) in bucket (%s)", object_name, bucket_name)
|
1022
|
+
client = self.get_conn()
|
1023
|
+
bucket = client.bucket(bucket_name)
|
1024
|
+
blob = bucket.get_blob(blob_name=object_name)
|
1025
|
+
if blob is None:
|
1026
|
+
raise ValueError("Object (%s) not found in bucket (%s)", object_name, bucket_name)
|
1027
|
+
blob_metadata = blob.metadata
|
1028
|
+
if blob_metadata:
|
1029
|
+
self.log.info("Retrieved metadata of object (%s) with %s fields", object_name, len(blob_metadata))
|
1030
|
+
else:
|
1031
|
+
self.log.info("Metadata of object (%s) is empty or it does not exist", object_name)
|
1032
|
+
return blob_metadata
|
1033
|
+
|
1013
1034
|
@GoogleBaseHook.fallback_to_default_project_id
|
1014
1035
|
def create_bucket(
|
1015
1036
|
self,
|
@@ -36,6 +36,7 @@ from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
|
|
36
36
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
37
37
|
from google.cloud.exceptions import NotFound
|
38
38
|
from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient
|
39
|
+
from google.cloud.pubsub_v1.types import PublisherOptions
|
39
40
|
from google.pubsub_v1.services.subscriber.async_client import SubscriberAsyncClient
|
40
41
|
from googleapiclient.errors import HttpError
|
41
42
|
|
@@ -79,6 +80,7 @@ class PubSubHook(GoogleBaseHook):
|
|
79
80
|
self,
|
80
81
|
gcp_conn_id: str = "google_cloud_default",
|
81
82
|
impersonation_chain: str | Sequence[str] | None = None,
|
83
|
+
enable_message_ordering: bool = False,
|
82
84
|
**kwargs,
|
83
85
|
) -> None:
|
84
86
|
if kwargs.get("delegate_to") is not None:
|
@@ -90,6 +92,7 @@ class PubSubHook(GoogleBaseHook):
|
|
90
92
|
gcp_conn_id=gcp_conn_id,
|
91
93
|
impersonation_chain=impersonation_chain,
|
92
94
|
)
|
95
|
+
self.enable_message_ordering = enable_message_ordering
|
93
96
|
self._client = None
|
94
97
|
|
95
98
|
def get_conn(self) -> PublisherClient:
|
@@ -99,7 +102,13 @@ class PubSubHook(GoogleBaseHook):
|
|
99
102
|
:return: Google Cloud Pub/Sub client object.
|
100
103
|
"""
|
101
104
|
if not self._client:
|
102
|
-
self._client = PublisherClient(
|
105
|
+
self._client = PublisherClient(
|
106
|
+
credentials=self.get_credentials(),
|
107
|
+
client_info=CLIENT_INFO,
|
108
|
+
publisher_options=PublisherOptions(
|
109
|
+
enable_message_ordering=self.enable_message_ordering,
|
110
|
+
),
|
111
|
+
)
|
103
112
|
return self._client
|
104
113
|
|
105
114
|
@cached_property
|
@@ -551,6 +551,8 @@ class AutoMLHook(GoogleBaseHook):
|
|
551
551
|
is_default_version: bool | None = None,
|
552
552
|
model_version_aliases: list[str] | None = None,
|
553
553
|
model_version_description: str | None = None,
|
554
|
+
window_stride_length: int | None = None,
|
555
|
+
window_max_count: int | None = None,
|
554
556
|
) -> tuple[models.Model | None, str]:
|
555
557
|
"""
|
556
558
|
Create an AutoML Forecasting Training Job.
|
@@ -703,6 +705,10 @@ class AutoMLHook(GoogleBaseHook):
|
|
703
705
|
:param sync: Whether to execute this method synchronously. If False, this method will be executed in
|
704
706
|
concurrent Future and any downstream object will be immediately returned and synced when the
|
705
707
|
Future has completed.
|
708
|
+
:param window_stride_length: Optional. Step length used to generate input examples. Every
|
709
|
+
``window_stride_length`` rows will be used to generate a sliding window.
|
710
|
+
:param window_max_count: Optional. Number of rows that should be used to generate input examples. If the
|
711
|
+
total row count is larger than this number, the input data will be randomly sampled to hit the count.
|
706
712
|
"""
|
707
713
|
if column_transformations:
|
708
714
|
warnings.warn(
|
@@ -758,6 +764,8 @@ class AutoMLHook(GoogleBaseHook):
|
|
758
764
|
is_default_version=is_default_version,
|
759
765
|
model_version_aliases=model_version_aliases,
|
760
766
|
model_version_description=model_version_description,
|
767
|
+
window_stride_length=window_stride_length,
|
768
|
+
window_max_count=window_max_count,
|
761
769
|
)
|
762
770
|
training_id = self.extract_training_id(self._job.resource_name)
|
763
771
|
if model:
|
@@ -141,6 +141,8 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
141
141
|
self,
|
142
142
|
prompt: str,
|
143
143
|
location: str,
|
144
|
+
generation_config: dict | None = None,
|
145
|
+
safety_settings: dict | None = None,
|
144
146
|
pretrained_model: str = "gemini-pro",
|
145
147
|
project_id: str = PROVIDE_PROJECT_ID,
|
146
148
|
) -> str:
|
@@ -149,17 +151,21 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
149
151
|
|
150
152
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
151
153
|
to the Multi-modal model, in order to elicit a specific response.
|
154
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
155
|
+
:param generation_config: Optional. Generation configuration settings.
|
156
|
+
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
152
157
|
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
|
153
158
|
supporting prompts with text-only input, including natural language
|
154
159
|
tasks, multi-turn text and code chat, and code generation. It can
|
155
160
|
output text and code.
|
156
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
157
161
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
158
162
|
"""
|
159
163
|
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
160
164
|
|
161
165
|
model = self.get_generative_model(pretrained_model)
|
162
|
-
response = model.generate_content(
|
166
|
+
response = model.generate_content(
|
167
|
+
contents=[prompt], generation_config=generation_config, safety_settings=safety_settings
|
168
|
+
)
|
163
169
|
|
164
170
|
return response.text
|
165
171
|
|
@@ -170,6 +176,8 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
170
176
|
location: str,
|
171
177
|
media_gcs_path: str,
|
172
178
|
mime_type: str,
|
179
|
+
generation_config: dict | None = None,
|
180
|
+
safety_settings: dict | None = None,
|
173
181
|
pretrained_model: str = "gemini-pro-vision",
|
174
182
|
project_id: str = PROVIDE_PROJECT_ID,
|
175
183
|
) -> str:
|
@@ -178,6 +186,8 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
178
186
|
|
179
187
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
180
188
|
to the Multi-modal model, in order to elicit a specific response.
|
189
|
+
:param generation_config: Optional. Generation configuration settings.
|
190
|
+
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
181
191
|
:param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
|
182
192
|
supporting prompts with text-only input, including natural language
|
183
193
|
tasks, multi-turn text and code chat, and code generation. It can
|
@@ -192,6 +202,8 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
192
202
|
|
193
203
|
model = self.get_generative_model(pretrained_model)
|
194
204
|
part = self.get_generative_model_part(media_gcs_path, mime_type)
|
195
|
-
response = model.generate_content(
|
205
|
+
response = model.generate_content(
|
206
|
+
contents=[prompt, part], generation_config=generation_config, safety_settings=safety_settings
|
207
|
+
)
|
196
208
|
|
197
209
|
return response.text
|
@@ -126,7 +126,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
126
126
|
base_output_dir=base_output_dir,
|
127
127
|
project=project,
|
128
128
|
location=location,
|
129
|
-
credentials=self.get_credentials,
|
129
|
+
credentials=self.get_credentials(),
|
130
130
|
labels=labels,
|
131
131
|
encryption_spec_key_name=encryption_spec_key_name,
|
132
132
|
staging_bucket=staging_bucket,
|
@@ -30,6 +30,9 @@ if TYPE_CHECKING:
|
|
30
30
|
DATAFLOW_BASE_LINK = "/dataflow/jobs"
|
31
31
|
DATAFLOW_JOB_LINK = DATAFLOW_BASE_LINK + "/{region}/{job_id}?project={project_id}"
|
32
32
|
|
33
|
+
DATAFLOW_PIPELINE_BASE_LINK = "/dataflow/pipelines"
|
34
|
+
DATAFLOW_PIPELINE_LINK = DATAFLOW_PIPELINE_BASE_LINK + "/{location}/{pipeline_name}?project={project_id}"
|
35
|
+
|
33
36
|
|
34
37
|
class DataflowJobLink(BaseGoogleLink):
|
35
38
|
"""Helper class for constructing Dataflow Job Link."""
|
@@ -51,3 +54,25 @@ class DataflowJobLink(BaseGoogleLink):
|
|
51
54
|
key=DataflowJobLink.key,
|
52
55
|
value={"project_id": project_id, "region": region, "job_id": job_id},
|
53
56
|
)
|
57
|
+
|
58
|
+
|
59
|
+
class DataflowPipelineLink(BaseGoogleLink):
|
60
|
+
"""Helper class for constructing Dataflow Pipeline Link."""
|
61
|
+
|
62
|
+
name = "Dataflow Pipeline"
|
63
|
+
key = "dataflow_pipeline_config"
|
64
|
+
format_str = DATAFLOW_PIPELINE_LINK
|
65
|
+
|
66
|
+
@staticmethod
|
67
|
+
def persist(
|
68
|
+
operator_instance: BaseOperator,
|
69
|
+
context: Context,
|
70
|
+
project_id: str | None,
|
71
|
+
location: str | None,
|
72
|
+
pipeline_name: str | None,
|
73
|
+
):
|
74
|
+
operator_instance.xcom_push(
|
75
|
+
context,
|
76
|
+
key=DataflowPipelineLink.key,
|
77
|
+
value={"project_id": project_id, "location": location, "pipeline_name": pipeline_name},
|
78
|
+
)
|