apache-airflow-providers-google 10.18.0rc1__py3-none-any.whl → 10.18.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +2 -5
- airflow/providers/google/cloud/hooks/automl.py +34 -0
- airflow/providers/google/cloud/hooks/bigquery.py +62 -8
- airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +91 -0
- airflow/providers/google/cloud/operators/automl.py +230 -25
- airflow/providers/google/cloud/operators/bigquery.py +128 -40
- airflow/providers/google/cloud/operators/dataproc.py +1 -1
- airflow/providers/google/cloud/operators/kubernetes_engine.py +24 -37
- airflow/providers/google/cloud/operators/workflows.py +2 -5
- airflow/providers/google/cloud/triggers/bigquery.py +64 -6
- airflow/providers/google/cloud/triggers/dataproc.py +82 -3
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -3
- airflow/providers/google/get_provider_info.py +3 -2
- {apache_airflow_providers_google-10.18.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc2.dist-info}/METADATA +7 -7
- {apache_airflow_providers_google-10.18.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc2.dist-info}/RECORD +17 -16
- {apache_airflow_providers_google-10.18.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc2.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.18.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc2.dist-info}/entry_points.txt +0 -0
@@ -25,15 +25,12 @@ from __future__ import annotations
|
|
25
25
|
|
26
26
|
import packaging.version
|
27
27
|
|
28
|
+
from airflow import __version__ as airflow_version
|
29
|
+
|
28
30
|
__all__ = ["__version__"]
|
29
31
|
|
30
32
|
__version__ = "10.18.0"
|
31
33
|
|
32
|
-
try:
|
33
|
-
from airflow import __version__ as airflow_version
|
34
|
-
except ImportError:
|
35
|
-
from airflow.version import version as airflow_version
|
36
|
-
|
37
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
38
35
|
"2.7.0"
|
39
36
|
):
|
@@ -640,3 +640,37 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|
640
640
|
metadata=metadata,
|
641
641
|
)
|
642
642
|
return result
|
643
|
+
|
644
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
645
|
+
def get_dataset(
|
646
|
+
self,
|
647
|
+
dataset_id: str,
|
648
|
+
location: str,
|
649
|
+
project_id: str,
|
650
|
+
retry: Retry | _MethodDefault = DEFAULT,
|
651
|
+
timeout: float | None = None,
|
652
|
+
metadata: Sequence[tuple[str, str]] = (),
|
653
|
+
) -> Dataset:
|
654
|
+
"""
|
655
|
+
Retrieve the dataset for the given dataset_id.
|
656
|
+
|
657
|
+
:param dataset_id: ID of dataset to be retrieved.
|
658
|
+
:param location: The location of the project.
|
659
|
+
:param project_id: ID of the Google Cloud project where dataset is located if None then
|
660
|
+
default project_id is used.
|
661
|
+
:param retry: A retry object used to retry requests. If `None` is specified, requests will not be
|
662
|
+
retried.
|
663
|
+
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
|
664
|
+
`retry` is specified, the timeout applies to each individual attempt.
|
665
|
+
:param metadata: Additional metadata that is provided to the method.
|
666
|
+
|
667
|
+
:return: `google.cloud.automl_v1beta1.types.dataset.Dataset` instance.
|
668
|
+
"""
|
669
|
+
client = self.get_conn()
|
670
|
+
name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
|
671
|
+
return client.get_dataset(
|
672
|
+
request={"name": name},
|
673
|
+
retry=retry,
|
674
|
+
timeout=timeout,
|
675
|
+
metadata=metadata,
|
676
|
+
)
|
@@ -46,7 +46,14 @@ from google.cloud.bigquery import (
|
|
46
46
|
UnknownJob,
|
47
47
|
)
|
48
48
|
from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference
|
49
|
-
from google.cloud.bigquery.
|
49
|
+
from google.cloud.bigquery.retry import DEFAULT_JOB_RETRY
|
50
|
+
from google.cloud.bigquery.table import (
|
51
|
+
EncryptionConfiguration,
|
52
|
+
Row,
|
53
|
+
RowIterator,
|
54
|
+
Table,
|
55
|
+
TableReference,
|
56
|
+
)
|
50
57
|
from google.cloud.exceptions import NotFound
|
51
58
|
from googleapiclient.discovery import Resource, build
|
52
59
|
from pandas_gbq import read_gbq
|
@@ -65,12 +72,7 @@ from airflow.providers.google.common.hooks.base_google import (
|
|
65
72
|
GoogleBaseHook,
|
66
73
|
get_field,
|
67
74
|
)
|
68
|
-
|
69
|
-
try:
|
70
|
-
from airflow.utils.hashlib_wrapper import md5
|
71
|
-
except ModuleNotFoundError:
|
72
|
-
# Remove when Airflow providers min Airflow version is "2.7.0"
|
73
|
-
from hashlib import md5
|
75
|
+
from airflow.utils.hashlib_wrapper import md5
|
74
76
|
from airflow.utils.helpers import convert_camel_to_snake
|
75
77
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
76
78
|
|
@@ -2390,6 +2392,48 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
|
|
2390
2392
|
|
2391
2393
|
return project_id, dataset_id, table_id
|
2392
2394
|
|
2395
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
2396
|
+
def get_query_results(
|
2397
|
+
self,
|
2398
|
+
job_id: str,
|
2399
|
+
location: str,
|
2400
|
+
max_results: int | None = None,
|
2401
|
+
selected_fields: list[str] | str | None = None,
|
2402
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
2403
|
+
retry: Retry = DEFAULT_RETRY,
|
2404
|
+
job_retry: Retry = DEFAULT_JOB_RETRY,
|
2405
|
+
) -> list[dict[str, Any]]:
|
2406
|
+
"""
|
2407
|
+
Get query results given a job_id.
|
2408
|
+
|
2409
|
+
:param job_id: The ID of the job.
|
2410
|
+
The ID must contain only letters (a-z, A-Z), numbers (0-9), underscores (_), or
|
2411
|
+
dashes (-). The maximum length is 1,024 characters.
|
2412
|
+
:param location: The location used for the operation.
|
2413
|
+
:param selected_fields: List of fields to return (comma-separated). If
|
2414
|
+
unspecified, all fields are returned.
|
2415
|
+
:param max_results: The maximum number of records (rows) to be fetched
|
2416
|
+
from the table.
|
2417
|
+
:param project_id: Google Cloud Project where the job ran.
|
2418
|
+
:param retry: How to retry the RPC.
|
2419
|
+
:param job_retry: How to retry failed jobs.
|
2420
|
+
|
2421
|
+
:return: List of rows where columns are filtered by selected fields, when given
|
2422
|
+
|
2423
|
+
:raises: AirflowException
|
2424
|
+
"""
|
2425
|
+
if isinstance(selected_fields, str):
|
2426
|
+
selected_fields = selected_fields.split(",")
|
2427
|
+
job = self.get_job(job_id=job_id, project_id=project_id, location=location)
|
2428
|
+
if not isinstance(job, QueryJob):
|
2429
|
+
raise AirflowException(f"Job '{job_id}' is not a query job")
|
2430
|
+
|
2431
|
+
if job.state != "DONE":
|
2432
|
+
raise AirflowException(f"Job '{job_id}' is not in DONE state")
|
2433
|
+
|
2434
|
+
rows = [dict(row) for row in job.result(max_results=max_results, retry=retry, job_retry=job_retry)]
|
2435
|
+
return [{k: row[k] for k in row if k in selected_fields} for row in rows] if selected_fields else rows
|
2436
|
+
|
2393
2437
|
@property
|
2394
2438
|
def scopes(self) -> Sequence[str]:
|
2395
2439
|
"""
|
@@ -3421,15 +3465,25 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
|
|
3421
3465
|
self.log.error("Failed to cancel BigQuery job %s: %s", job_id, str(e))
|
3422
3466
|
raise
|
3423
3467
|
|
3424
|
-
|
3468
|
+
# TODO: Convert get_records into an async method
|
3469
|
+
def get_records(
|
3470
|
+
self,
|
3471
|
+
query_results: dict[str, Any],
|
3472
|
+
as_dict: bool = False,
|
3473
|
+
selected_fields: str | list[str] | None = None,
|
3474
|
+
) -> list[Any]:
|
3425
3475
|
"""Convert a response from BigQuery to records.
|
3426
3476
|
|
3427
3477
|
:param query_results: the results from a SQL query
|
3428
3478
|
:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists.
|
3479
|
+
:param selected_fields:
|
3429
3480
|
"""
|
3481
|
+
if isinstance(selected_fields, str):
|
3482
|
+
selected_fields = selected_fields.split(",")
|
3430
3483
|
buffer: list[Any] = []
|
3431
3484
|
if rows := query_results.get("rows"):
|
3432
3485
|
fields = query_results["schema"]["fields"]
|
3486
|
+
fields = [field for field in fields if not selected_fields or field["name"] in selected_fields]
|
3433
3487
|
fields_names = [field["name"] for field in fields]
|
3434
3488
|
col_types = [field["type"] for field in fields]
|
3435
3489
|
for dict_row in rows:
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from typing import TYPE_CHECKING, Sequence
|
21
|
+
|
22
|
+
from google.api_core.client_options import ClientOptions
|
23
|
+
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
24
|
+
from google.cloud.aiplatform_v1 import PredictionServiceClient
|
25
|
+
|
26
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
27
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
28
|
+
|
29
|
+
if TYPE_CHECKING:
|
30
|
+
from google.api_core.retry import Retry
|
31
|
+
from google.cloud.aiplatform_v1.types import PredictResponse
|
32
|
+
|
33
|
+
|
34
|
+
class PredictionServiceHook(GoogleBaseHook):
|
35
|
+
"""Hook for Google Cloud Vertex AI Prediction API."""
|
36
|
+
|
37
|
+
def get_prediction_service_client(self, region: str | None = None) -> PredictionServiceClient:
|
38
|
+
"""
|
39
|
+
Return PredictionServiceClient object.
|
40
|
+
|
41
|
+
:param region: The ID of the Google Cloud region that the service belongs to. Default is None.
|
42
|
+
|
43
|
+
:return: `google.cloud.aiplatform_v1.services.prediction_service.client.PredictionServiceClient` instance.
|
44
|
+
"""
|
45
|
+
if region and region != "global":
|
46
|
+
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
|
47
|
+
else:
|
48
|
+
client_options = ClientOptions()
|
49
|
+
|
50
|
+
return PredictionServiceClient(
|
51
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
52
|
+
)
|
53
|
+
|
54
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
55
|
+
def predict(
|
56
|
+
self,
|
57
|
+
endpoint_id: str,
|
58
|
+
instances: list[str],
|
59
|
+
location: str,
|
60
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
61
|
+
parameters: dict[str, str] | None = None,
|
62
|
+
retry: Retry | _MethodDefault = DEFAULT,
|
63
|
+
timeout: float | None = None,
|
64
|
+
metadata: Sequence[tuple[str, str]] = (),
|
65
|
+
) -> PredictResponse:
|
66
|
+
"""
|
67
|
+
Perform an online prediction and returns the prediction result in the response.
|
68
|
+
|
69
|
+
:param endpoint_id: Name of the endpoint_id requested to serve the prediction.
|
70
|
+
:param instances: Required. The instances that are the input to the prediction call. A DeployedModel
|
71
|
+
may have an upper limit on the number of instances it supports per request, and when it is
|
72
|
+
exceeded the prediction call errors in case of AutoML Models, or, in case of customer created
|
73
|
+
Models, the behaviour is as documented by that Model.
|
74
|
+
:param parameters: Additional domain-specific parameters, any string must be up to 25000 characters long.
|
75
|
+
:param project_id: ID of the Google Cloud project where model is located if None then
|
76
|
+
default project_id is used.
|
77
|
+
:param location: The location of the project.
|
78
|
+
:param retry: A retry object used to retry requests. If `None` is specified, requests will not be
|
79
|
+
retried.
|
80
|
+
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
|
81
|
+
`retry` is specified, the timeout applies to each individual attempt.
|
82
|
+
:param metadata: Additional metadata that is provided to the method.
|
83
|
+
"""
|
84
|
+
client = self.get_prediction_service_client(location)
|
85
|
+
endpoint = f"projects/{project_id}/locations/{location}/endpoints/{endpoint_id}"
|
86
|
+
return client.predict(
|
87
|
+
request={"endpoint": endpoint, "instances": instances, "parameters": parameters},
|
88
|
+
retry=retry,
|
89
|
+
timeout=timeout,
|
90
|
+
metadata=metadata,
|
91
|
+
)
|