apache-airflow-providers-google 10.18.0rc1__py3-none-any.whl → 10.18.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,15 +25,12 @@ from __future__ import annotations
25
25
 
26
26
  import packaging.version
27
27
 
28
+ from airflow import __version__ as airflow_version
29
+
28
30
  __all__ = ["__version__"]
29
31
 
30
32
  __version__ = "10.18.0"
31
33
 
32
- try:
33
- from airflow import __version__ as airflow_version
34
- except ImportError:
35
- from airflow.version import version as airflow_version
36
-
37
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
38
35
  "2.7.0"
39
36
  ):
@@ -640,3 +640,37 @@ class CloudAutoMLHook(GoogleBaseHook):
640
640
  metadata=metadata,
641
641
  )
642
642
  return result
643
+
644
+ @GoogleBaseHook.fallback_to_default_project_id
645
+ def get_dataset(
646
+ self,
647
+ dataset_id: str,
648
+ location: str,
649
+ project_id: str,
650
+ retry: Retry | _MethodDefault = DEFAULT,
651
+ timeout: float | None = None,
652
+ metadata: Sequence[tuple[str, str]] = (),
653
+ ) -> Dataset:
654
+ """
655
+ Retrieve the dataset for the given dataset_id.
656
+
657
+ :param dataset_id: ID of dataset to be retrieved.
658
+ :param location: The location of the project.
659
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
660
+ default project_id is used.
661
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
662
+ retried.
663
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
664
+ `retry` is specified, the timeout applies to each individual attempt.
665
+ :param metadata: Additional metadata that is provided to the method.
666
+
667
+ :return: `google.cloud.automl_v1beta1.types.dataset.Dataset` instance.
668
+ """
669
+ client = self.get_conn()
670
+ name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
671
+ return client.get_dataset(
672
+ request={"name": name},
673
+ retry=retry,
674
+ timeout=timeout,
675
+ metadata=metadata,
676
+ )
@@ -46,7 +46,14 @@ from google.cloud.bigquery import (
46
46
  UnknownJob,
47
47
  )
48
48
  from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference
49
- from google.cloud.bigquery.table import EncryptionConfiguration, Row, RowIterator, Table, TableReference
49
+ from google.cloud.bigquery.retry import DEFAULT_JOB_RETRY
50
+ from google.cloud.bigquery.table import (
51
+ EncryptionConfiguration,
52
+ Row,
53
+ RowIterator,
54
+ Table,
55
+ TableReference,
56
+ )
50
57
  from google.cloud.exceptions import NotFound
51
58
  from googleapiclient.discovery import Resource, build
52
59
  from pandas_gbq import read_gbq
@@ -65,12 +72,7 @@ from airflow.providers.google.common.hooks.base_google import (
65
72
  GoogleBaseHook,
66
73
  get_field,
67
74
  )
68
-
69
- try:
70
- from airflow.utils.hashlib_wrapper import md5
71
- except ModuleNotFoundError:
72
- # Remove when Airflow providers min Airflow version is "2.7.0"
73
- from hashlib import md5
75
+ from airflow.utils.hashlib_wrapper import md5
74
76
  from airflow.utils.helpers import convert_camel_to_snake
75
77
  from airflow.utils.log.logging_mixin import LoggingMixin
76
78
 
@@ -2390,6 +2392,48 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
2390
2392
 
2391
2393
  return project_id, dataset_id, table_id
2392
2394
 
2395
+ @GoogleBaseHook.fallback_to_default_project_id
2396
+ def get_query_results(
2397
+ self,
2398
+ job_id: str,
2399
+ location: str,
2400
+ max_results: int | None = None,
2401
+ selected_fields: list[str] | str | None = None,
2402
+ project_id: str = PROVIDE_PROJECT_ID,
2403
+ retry: Retry = DEFAULT_RETRY,
2404
+ job_retry: Retry = DEFAULT_JOB_RETRY,
2405
+ ) -> list[dict[str, Any]]:
2406
+ """
2407
+ Get query results given a job_id.
2408
+
2409
+ :param job_id: The ID of the job.
2410
+ The ID must contain only letters (a-z, A-Z), numbers (0-9), underscores (_), or
2411
+ dashes (-). The maximum length is 1,024 characters.
2412
+ :param location: The location used for the operation.
2413
+ :param selected_fields: List of fields to return (comma-separated). If
2414
+ unspecified, all fields are returned.
2415
+ :param max_results: The maximum number of records (rows) to be fetched
2416
+ from the table.
2417
+ :param project_id: Google Cloud Project where the job ran.
2418
+ :param retry: How to retry the RPC.
2419
+ :param job_retry: How to retry failed jobs.
2420
+
2421
+ :return: List of rows where columns are filtered by selected fields, when given
2422
+
2423
+ :raises: AirflowException
2424
+ """
2425
+ if isinstance(selected_fields, str):
2426
+ selected_fields = selected_fields.split(",")
2427
+ job = self.get_job(job_id=job_id, project_id=project_id, location=location)
2428
+ if not isinstance(job, QueryJob):
2429
+ raise AirflowException(f"Job '{job_id}' is not a query job")
2430
+
2431
+ if job.state != "DONE":
2432
+ raise AirflowException(f"Job '{job_id}' is not in DONE state")
2433
+
2434
+ rows = [dict(row) for row in job.result(max_results=max_results, retry=retry, job_retry=job_retry)]
2435
+ return [{k: row[k] for k in row if k in selected_fields} for row in rows] if selected_fields else rows
2436
+
2393
2437
  @property
2394
2438
  def scopes(self) -> Sequence[str]:
2395
2439
  """
@@ -3421,15 +3465,25 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
3421
3465
  self.log.error("Failed to cancel BigQuery job %s: %s", job_id, str(e))
3422
3466
  raise
3423
3467
 
3424
- def get_records(self, query_results: dict[str, Any], as_dict: bool = False) -> list[Any]:
3468
+ # TODO: Convert get_records into an async method
3469
+ def get_records(
3470
+ self,
3471
+ query_results: dict[str, Any],
3472
+ as_dict: bool = False,
3473
+ selected_fields: str | list[str] | None = None,
3474
+ ) -> list[Any]:
3425
3475
  """Convert a response from BigQuery to records.
3426
3476
 
3427
3477
  :param query_results: the results from a SQL query
3428
3478
  :param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists.
3479
+ :param selected_fields:
3429
3480
  """
3481
+ if isinstance(selected_fields, str):
3482
+ selected_fields = selected_fields.split(",")
3430
3483
  buffer: list[Any] = []
3431
3484
  if rows := query_results.get("rows"):
3432
3485
  fields = query_results["schema"]["fields"]
3486
+ fields = [field for field in fields if not selected_fields or field["name"] in selected_fields]
3433
3487
  fields_names = [field["name"] for field in fields]
3434
3488
  col_types = [field["type"] for field in fields]
3435
3489
  for dict_row in rows:
@@ -0,0 +1,91 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import TYPE_CHECKING, Sequence
21
+
22
+ from google.api_core.client_options import ClientOptions
23
+ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
24
+ from google.cloud.aiplatform_v1 import PredictionServiceClient
25
+
26
+ from airflow.providers.google.common.consts import CLIENT_INFO
27
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
28
+
29
+ if TYPE_CHECKING:
30
+ from google.api_core.retry import Retry
31
+ from google.cloud.aiplatform_v1.types import PredictResponse
32
+
33
+
34
+ class PredictionServiceHook(GoogleBaseHook):
35
+ """Hook for Google Cloud Vertex AI Prediction API."""
36
+
37
+ def get_prediction_service_client(self, region: str | None = None) -> PredictionServiceClient:
38
+ """
39
+ Return PredictionServiceClient object.
40
+
41
+ :param region: The ID of the Google Cloud region that the service belongs to. Default is None.
42
+
43
+ :return: `google.cloud.aiplatform_v1.services.prediction_service.client.PredictionServiceClient` instance.
44
+ """
45
+ if region and region != "global":
46
+ client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
47
+ else:
48
+ client_options = ClientOptions()
49
+
50
+ return PredictionServiceClient(
51
+ credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
52
+ )
53
+
54
+ @GoogleBaseHook.fallback_to_default_project_id
55
+ def predict(
56
+ self,
57
+ endpoint_id: str,
58
+ instances: list[str],
59
+ location: str,
60
+ project_id: str = PROVIDE_PROJECT_ID,
61
+ parameters: dict[str, str] | None = None,
62
+ retry: Retry | _MethodDefault = DEFAULT,
63
+ timeout: float | None = None,
64
+ metadata: Sequence[tuple[str, str]] = (),
65
+ ) -> PredictResponse:
66
+ """
67
+ Perform an online prediction and returns the prediction result in the response.
68
+
69
+ :param endpoint_id: Name of the endpoint_id requested to serve the prediction.
70
+ :param instances: Required. The instances that are the input to the prediction call. A DeployedModel
71
+ may have an upper limit on the number of instances it supports per request, and when it is
72
+ exceeded the prediction call errors in case of AutoML Models, or, in case of customer created
73
+ Models, the behaviour is as documented by that Model.
74
+ :param parameters: Additional domain-specific parameters, any string must be up to 25000 characters long.
75
+ :param project_id: ID of the Google Cloud project where model is located if None then
76
+ default project_id is used.
77
+ :param location: The location of the project.
78
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
79
+ retried.
80
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
81
+ `retry` is specified, the timeout applies to each individual attempt.
82
+ :param metadata: Additional metadata that is provided to the method.
83
+ """
84
+ client = self.get_prediction_service_client(location)
85
+ endpoint = f"projects/{project_id}/locations/{location}/endpoints/{endpoint_id}"
86
+ return client.predict(
87
+ request={"endpoint": endpoint, "instances": instances, "parameters": parameters},
88
+ retry=retry,
89
+ timeout=timeout,
90
+ metadata=metadata,
91
+ )