apache-airflow-providers-google 16.0.0a1__py3-none-any.whl → 16.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +34 -0
- airflow/providers/google/cloud/hooks/bigquery.py +63 -76
- airflow/providers/google/cloud/hooks/gcs.py +3 -3
- airflow/providers/google/cloud/hooks/looker.py +5 -0
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -36
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +1 -66
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +1 -1
- airflow/providers/google/cloud/links/cloud_run.py +59 -0
- airflow/providers/google/cloud/log/gcs_task_handler.py +4 -4
- airflow/providers/google/cloud/operators/bigquery.py +49 -10
- airflow/providers/google/cloud/operators/cloud_run.py +10 -1
- airflow/providers/google/cloud/operators/gcs.py +1 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +3 -85
- airflow/providers/google/cloud/operators/pubsub.py +2 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +0 -92
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +4 -0
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +9 -5
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +11 -5
- airflow/providers/google/cloud/triggers/bigquery.py +32 -5
- airflow/providers/google/cloud/triggers/dataproc.py +62 -10
- airflow/providers/google/get_provider_info.py +14 -5
- airflow/providers/google/leveldb/hooks/leveldb.py +25 -0
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.0.0rc1.dist-info}/METADATA +23 -22
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.0.0rc1.dist-info}/RECORD +29 -28
- airflow/providers/google/cloud/links/automl.py +0 -193
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.0.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.0.0rc1.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "
|
32
|
+
__version__ = "16.0.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.10.0"
|
@@ -101,6 +101,40 @@ class GoogleAdsHook(BaseHook):
|
|
101
101
|
:param api_version: The Google Ads API version to use.
|
102
102
|
"""
|
103
103
|
|
104
|
+
conn_name_attr = "google_ads_conn_id"
|
105
|
+
default_conn_name = "google_ads_default"
|
106
|
+
conn_type = "google_ads"
|
107
|
+
hook_name = "Google Ads"
|
108
|
+
|
109
|
+
@classmethod
|
110
|
+
def get_connection_form_widgets(cls) -> dict[str, Any]:
|
111
|
+
"""Return connection widgets to add to Google Ads connection form."""
|
112
|
+
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
|
113
|
+
from flask_babel import lazy_gettext
|
114
|
+
from wtforms import PasswordField, StringField
|
115
|
+
|
116
|
+
return {
|
117
|
+
"developer_token": StringField(lazy_gettext("Developer token"), widget=BS3TextFieldWidget()),
|
118
|
+
"client_id": StringField(lazy_gettext("OAuth2 Client ID"), widget=BS3TextFieldWidget()),
|
119
|
+
"client_secret": PasswordField(
|
120
|
+
lazy_gettext("OAuth2 Client Secret"), widget=BS3PasswordFieldWidget()
|
121
|
+
),
|
122
|
+
"refresh_token": PasswordField(
|
123
|
+
lazy_gettext("OAuth2 Refresh Token"), widget=BS3PasswordFieldWidget()
|
124
|
+
),
|
125
|
+
}
|
126
|
+
|
127
|
+
@classmethod
|
128
|
+
def get_ui_field_behaviour(cls) -> dict[str, Any]:
|
129
|
+
"""Return custom UI field behaviour for Google Ads connection."""
|
130
|
+
return {
|
131
|
+
"hidden_fields": ["host", "login", "schema", "port"],
|
132
|
+
"relabeling": {},
|
133
|
+
"placeholders": {
|
134
|
+
"password": "Leave blank (optional)",
|
135
|
+
},
|
136
|
+
}
|
137
|
+
|
104
138
|
def __init__(
|
105
139
|
self,
|
106
140
|
api_version: str | None = None,
|
@@ -29,7 +29,7 @@ import uuid
|
|
29
29
|
from collections.abc import Iterable, Mapping, Sequence
|
30
30
|
from copy import deepcopy
|
31
31
|
from datetime import datetime, timedelta
|
32
|
-
from typing import TYPE_CHECKING, Any, NoReturn, Union, cast
|
32
|
+
from typing import TYPE_CHECKING, Any, NoReturn, Union, cast, overload
|
33
33
|
|
34
34
|
from aiohttp import ClientSession as ClientSession
|
35
35
|
from gcloud.aio.bigquery import Job, Table as Table_async
|
@@ -57,8 +57,13 @@ from googleapiclient.discovery import build
|
|
57
57
|
from pandas_gbq import read_gbq
|
58
58
|
from pandas_gbq.gbq import GbqConnector # noqa: F401 used in ``airflow.contrib.hooks.bigquery``
|
59
59
|
from sqlalchemy import create_engine
|
60
|
+
from typing_extensions import Literal
|
60
61
|
|
61
|
-
from airflow.exceptions import
|
62
|
+
from airflow.exceptions import (
|
63
|
+
AirflowException,
|
64
|
+
AirflowOptionalProviderFeatureException,
|
65
|
+
AirflowProviderDeprecationWarning,
|
66
|
+
)
|
62
67
|
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
|
63
68
|
from airflow.providers.common.sql.hooks.sql import DbApiHook
|
64
69
|
from airflow.providers.google.cloud.utils.bigquery import bq_cast
|
@@ -77,6 +82,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
|
|
77
82
|
|
78
83
|
if TYPE_CHECKING:
|
79
84
|
import pandas as pd
|
85
|
+
import polars as pl
|
80
86
|
from google.api_core.page_iterator import HTTPIterator
|
81
87
|
from google.api_core.retry import Retry
|
82
88
|
from requests import Session
|
@@ -275,15 +281,57 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
|
|
275
281
|
"""
|
276
282
|
raise NotImplementedError()
|
277
283
|
|
278
|
-
def
|
284
|
+
def _get_pandas_df(
|
279
285
|
self,
|
280
286
|
sql: str,
|
281
287
|
parameters: Iterable | Mapping[str, Any] | None = None,
|
282
288
|
dialect: str | None = None,
|
283
289
|
**kwargs,
|
284
290
|
) -> pd.DataFrame:
|
291
|
+
if dialect is None:
|
292
|
+
dialect = "legacy" if self.use_legacy_sql else "standard"
|
293
|
+
|
294
|
+
credentials, project_id = self.get_credentials_and_project_id()
|
295
|
+
|
296
|
+
return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs)
|
297
|
+
|
298
|
+
def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.DataFrame:
|
299
|
+
try:
|
300
|
+
import polars as pl
|
301
|
+
except ImportError:
|
302
|
+
raise AirflowOptionalProviderFeatureException(
|
303
|
+
"Polars is not installed. Please install it with `pip install polars`."
|
304
|
+
)
|
305
|
+
|
306
|
+
if dialect is None:
|
307
|
+
dialect = "legacy" if self.use_legacy_sql else "standard"
|
308
|
+
|
309
|
+
credentials, project_id = self.get_credentials_and_project_id()
|
310
|
+
|
311
|
+
pandas_df = read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs)
|
312
|
+
return pl.from_pandas(pandas_df)
|
313
|
+
|
314
|
+
@overload
|
315
|
+
def get_df(
|
316
|
+
self, sql, parameters=None, dialect=None, *, df_type: Literal["pandas"] = "pandas", **kwargs
|
317
|
+
) -> pd.DataFrame: ...
|
318
|
+
|
319
|
+
@overload
|
320
|
+
def get_df(
|
321
|
+
self, sql, parameters=None, dialect=None, *, df_type: Literal["polars"], **kwargs
|
322
|
+
) -> pl.DataFrame: ...
|
323
|
+
|
324
|
+
def get_df(
|
325
|
+
self,
|
326
|
+
sql,
|
327
|
+
parameters=None,
|
328
|
+
dialect=None,
|
329
|
+
*,
|
330
|
+
df_type: Literal["pandas", "polars"] = "pandas",
|
331
|
+
**kwargs,
|
332
|
+
) -> pd.DataFrame | pl.DataFrame:
|
285
333
|
"""
|
286
|
-
Get a
|
334
|
+
Get a DataFrame for the BigQuery results.
|
287
335
|
|
288
336
|
The DbApiHook method must be overridden because Pandas doesn't support
|
289
337
|
PEP 249 connections, except for SQLite.
|
@@ -299,12 +347,19 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
|
|
299
347
|
defaults to use `self.use_legacy_sql` if not specified
|
300
348
|
:param kwargs: (optional) passed into pandas_gbq.read_gbq method
|
301
349
|
"""
|
302
|
-
if
|
303
|
-
|
350
|
+
if df_type == "polars":
|
351
|
+
return self._get_polars_df(sql, parameters, dialect, **kwargs)
|
304
352
|
|
305
|
-
|
353
|
+
if df_type == "pandas":
|
354
|
+
return self._get_pandas_df(sql, parameters, dialect, **kwargs)
|
306
355
|
|
307
|
-
|
356
|
+
@deprecated(
|
357
|
+
planned_removal_date="November 30, 2025",
|
358
|
+
use_instead="airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_df",
|
359
|
+
category=AirflowProviderDeprecationWarning,
|
360
|
+
)
|
361
|
+
def get_pandas_df(self, sql, parameters=None, dialect=None, **kwargs):
|
362
|
+
return self._get_pandas_df(sql, parameters, dialect, **kwargs)
|
308
363
|
|
309
364
|
@GoogleBaseHook.fallback_to_default_project_id
|
310
365
|
def table_exists(self, dataset_id: str, table_id: str, project_id: str) -> bool:
|
@@ -1937,74 +1992,6 @@ def _escape(s: str) -> str:
|
|
1937
1992
|
return e
|
1938
1993
|
|
1939
1994
|
|
1940
|
-
@deprecated(
|
1941
|
-
planned_removal_date="April 01, 2025",
|
1942
|
-
use_instead="airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.split_tablename",
|
1943
|
-
category=AirflowProviderDeprecationWarning,
|
1944
|
-
)
|
1945
|
-
def split_tablename(
|
1946
|
-
table_input: str, default_project_id: str, var_name: str | None = None
|
1947
|
-
) -> tuple[str, str, str]:
|
1948
|
-
if "." not in table_input:
|
1949
|
-
raise ValueError(f"Expected table name in the format of <dataset>.<table>. Got: {table_input}")
|
1950
|
-
|
1951
|
-
if not default_project_id:
|
1952
|
-
raise ValueError("INTERNAL: No default project is specified")
|
1953
|
-
|
1954
|
-
def var_print(var_name):
|
1955
|
-
if var_name is None:
|
1956
|
-
return ""
|
1957
|
-
return f"Format exception for {var_name}: "
|
1958
|
-
|
1959
|
-
if table_input.count(".") + table_input.count(":") > 3:
|
1960
|
-
raise ValueError(f"{var_print(var_name)}Use either : or . to specify project got {table_input}")
|
1961
|
-
cmpt = table_input.rsplit(":", 1)
|
1962
|
-
project_id = None
|
1963
|
-
rest = table_input
|
1964
|
-
if len(cmpt) == 1:
|
1965
|
-
project_id = None
|
1966
|
-
rest = cmpt[0]
|
1967
|
-
elif len(cmpt) == 2 and cmpt[0].count(":") <= 1:
|
1968
|
-
if cmpt[-1].count(".") != 2:
|
1969
|
-
project_id = cmpt[0]
|
1970
|
-
rest = cmpt[1]
|
1971
|
-
else:
|
1972
|
-
raise ValueError(
|
1973
|
-
f"{var_print(var_name)}Expect format of (<project:)<dataset>.<table>, got {table_input}"
|
1974
|
-
)
|
1975
|
-
|
1976
|
-
cmpt = rest.split(".")
|
1977
|
-
if len(cmpt) == 3:
|
1978
|
-
if project_id:
|
1979
|
-
raise ValueError(f"{var_print(var_name)}Use either : or . to specify project")
|
1980
|
-
project_id = cmpt[0]
|
1981
|
-
dataset_id = cmpt[1]
|
1982
|
-
table_id = cmpt[2]
|
1983
|
-
|
1984
|
-
elif len(cmpt) == 2:
|
1985
|
-
dataset_id = cmpt[0]
|
1986
|
-
table_id = cmpt[1]
|
1987
|
-
else:
|
1988
|
-
raise ValueError(
|
1989
|
-
f"{var_print(var_name)}Expect format of (<project.|<project:)<dataset>.<table>, got {table_input}"
|
1990
|
-
)
|
1991
|
-
|
1992
|
-
# Exclude partition from the table name
|
1993
|
-
table_id = table_id.split("$")[0]
|
1994
|
-
|
1995
|
-
if project_id is None:
|
1996
|
-
if var_name is not None:
|
1997
|
-
log.info(
|
1998
|
-
'Project is not included in %s: %s; using project "%s"',
|
1999
|
-
var_name,
|
2000
|
-
table_input,
|
2001
|
-
default_project_id,
|
2002
|
-
)
|
2003
|
-
project_id = default_project_id
|
2004
|
-
|
2005
|
-
return project_id, dataset_id, table_id
|
2006
|
-
|
2007
|
-
|
2008
1995
|
def _cleanse_time_partitioning(
|
2009
1996
|
destination_dataset_table: str | None, time_partitioning_in: dict | None
|
2010
1997
|
) -> dict: # if it is a partitioned table ($ is in the table name) add partition load option
|
@@ -549,13 +549,13 @@ class GCSHook(GoogleBaseHook):
|
|
549
549
|
if cache_control:
|
550
550
|
blob.cache_control = cache_control
|
551
551
|
|
552
|
-
if filename and data:
|
552
|
+
if filename is not None and data is not None:
|
553
553
|
raise ValueError(
|
554
554
|
"'filename' and 'data' parameter provided. Please "
|
555
555
|
"specify a single parameter, either 'filename' for "
|
556
556
|
"local file uploads or 'data' for file content uploads."
|
557
557
|
)
|
558
|
-
if filename:
|
558
|
+
if filename is not None:
|
559
559
|
if not mime_type:
|
560
560
|
mime_type = "application/octet-stream"
|
561
561
|
if gzip:
|
@@ -575,7 +575,7 @@ class GCSHook(GoogleBaseHook):
|
|
575
575
|
if gzip:
|
576
576
|
os.remove(filename)
|
577
577
|
self.log.info("File %s uploaded to %s in %s bucket", filename, object_name, bucket_name)
|
578
|
-
elif data:
|
578
|
+
elif data is not None:
|
579
579
|
if not mime_type:
|
580
580
|
mime_type = "text/plain"
|
581
581
|
if gzip:
|
@@ -39,6 +39,11 @@ if TYPE_CHECKING:
|
|
39
39
|
class LookerHook(BaseHook):
|
40
40
|
"""Hook for Looker APIs."""
|
41
41
|
|
42
|
+
conn_name_attr = "looker_conn_id"
|
43
|
+
default_conn_name = "looker_default"
|
44
|
+
conn_type = "gcp_looker"
|
45
|
+
hook_name = "Google Looker"
|
46
|
+
|
42
47
|
def __init__(
|
43
48
|
self,
|
44
49
|
looker_conn_id: str,
|
@@ -185,42 +185,6 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
|
|
185
185
|
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
186
186
|
)
|
187
187
|
|
188
|
-
@deprecated(
|
189
|
-
planned_removal_date="June 15, 2025",
|
190
|
-
category=AirflowProviderDeprecationWarning,
|
191
|
-
reason="Deprecation of AutoMLText API",
|
192
|
-
)
|
193
|
-
def get_auto_ml_text_training_job(
|
194
|
-
self,
|
195
|
-
display_name: str,
|
196
|
-
prediction_type: str,
|
197
|
-
multi_label: bool = False,
|
198
|
-
sentiment_max: int = 10,
|
199
|
-
project: str | None = None,
|
200
|
-
location: str | None = None,
|
201
|
-
labels: dict[str, str] | None = None,
|
202
|
-
training_encryption_spec_key_name: str | None = None,
|
203
|
-
model_encryption_spec_key_name: str | None = None,
|
204
|
-
) -> AutoMLTextTrainingJob:
|
205
|
-
"""
|
206
|
-
Return AutoMLTextTrainingJob object.
|
207
|
-
|
208
|
-
WARNING: Text creation API is deprecated since September 15, 2024
|
209
|
-
(https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
|
210
|
-
"""
|
211
|
-
return AutoMLTextTrainingJob(
|
212
|
-
display_name=display_name,
|
213
|
-
prediction_type=prediction_type,
|
214
|
-
multi_label=multi_label,
|
215
|
-
sentiment_max=sentiment_max,
|
216
|
-
project=project,
|
217
|
-
location=location,
|
218
|
-
credentials=self.get_credentials(),
|
219
|
-
labels=labels,
|
220
|
-
training_encryption_spec_key_name=training_encryption_spec_key_name,
|
221
|
-
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
222
|
-
)
|
223
|
-
|
224
188
|
def get_auto_ml_video_training_job(
|
225
189
|
self,
|
226
190
|
display_name: str,
|
@@ -25,14 +25,12 @@ from typing import TYPE_CHECKING
|
|
25
25
|
|
26
26
|
import vertexai
|
27
27
|
from vertexai.generative_models import GenerativeModel
|
28
|
-
from vertexai.language_models import TextEmbeddingModel
|
28
|
+
from vertexai.language_models import TextEmbeddingModel
|
29
29
|
from vertexai.preview.caching import CachedContent
|
30
30
|
from vertexai.preview.evaluation import EvalResult, EvalTask
|
31
31
|
from vertexai.preview.generative_models import GenerativeModel as preview_generative_model
|
32
32
|
from vertexai.preview.tuning import sft
|
33
33
|
|
34
|
-
from airflow.exceptions import AirflowProviderDeprecationWarning
|
35
|
-
from airflow.providers.google.common.deprecated import deprecated
|
36
34
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
37
35
|
|
38
36
|
if TYPE_CHECKING:
|
@@ -43,16 +41,6 @@ if TYPE_CHECKING:
|
|
43
41
|
class GenerativeModelHook(GoogleBaseHook):
|
44
42
|
"""Hook for Google Cloud Vertex AI Generative Model APIs."""
|
45
43
|
|
46
|
-
@deprecated(
|
47
|
-
planned_removal_date="April 09, 2025",
|
48
|
-
use_instead="GenerativeModelHook.get_generative_model",
|
49
|
-
category=AirflowProviderDeprecationWarning,
|
50
|
-
)
|
51
|
-
def get_text_generation_model(self, pretrained_model: str):
|
52
|
-
"""Return a Model Garden Model object based on Text Generation."""
|
53
|
-
model = TextGenerationModel.from_pretrained(pretrained_model)
|
54
|
-
return model
|
55
|
-
|
56
44
|
def get_text_embedding_model(self, pretrained_model: str):
|
57
45
|
"""Return a Model Garden Model object based on Text Embedding."""
|
58
46
|
model = TextEmbeddingModel.from_pretrained(pretrained_model)
|
@@ -100,59 +88,6 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
100
88
|
cached_context_model = preview_generative_model.from_cached_content(cached_content)
|
101
89
|
return cached_context_model
|
102
90
|
|
103
|
-
@deprecated(
|
104
|
-
planned_removal_date="April 09, 2025",
|
105
|
-
use_instead="GenerativeModelHook.generative_model_generate_content",
|
106
|
-
category=AirflowProviderDeprecationWarning,
|
107
|
-
)
|
108
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
109
|
-
def text_generation_model_predict(
|
110
|
-
self,
|
111
|
-
prompt: str,
|
112
|
-
pretrained_model: str,
|
113
|
-
temperature: float,
|
114
|
-
max_output_tokens: int,
|
115
|
-
top_p: float,
|
116
|
-
top_k: int,
|
117
|
-
location: str,
|
118
|
-
project_id: str = PROVIDE_PROJECT_ID,
|
119
|
-
) -> str:
|
120
|
-
"""
|
121
|
-
Use the Vertex AI PaLM API to generate natural language text.
|
122
|
-
|
123
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
124
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
125
|
-
:param prompt: Required. Inputs or queries that a user or a program gives
|
126
|
-
to the Vertex AI PaLM API, in order to elicit a specific response.
|
127
|
-
:param pretrained_model: A pre-trained model optimized for performing natural
|
128
|
-
language tasks such as classification, summarization, extraction, content
|
129
|
-
creation, and ideation.
|
130
|
-
:param temperature: Temperature controls the degree of randomness in token
|
131
|
-
selection.
|
132
|
-
:param max_output_tokens: Token limit determines the maximum amount of text
|
133
|
-
output.
|
134
|
-
:param top_p: Tokens are selected from most probable to least until the sum
|
135
|
-
of their probabilities equals the top_p value. Defaults to 0.8.
|
136
|
-
:param top_k: A top_k of 1 means the selected token is the most probable
|
137
|
-
among all tokens.
|
138
|
-
"""
|
139
|
-
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
140
|
-
|
141
|
-
parameters = {
|
142
|
-
"temperature": temperature,
|
143
|
-
"max_output_tokens": max_output_tokens,
|
144
|
-
"top_p": top_p,
|
145
|
-
"top_k": top_k,
|
146
|
-
}
|
147
|
-
|
148
|
-
model = self.get_text_generation_model(pretrained_model)
|
149
|
-
|
150
|
-
response = model.predict(
|
151
|
-
prompt=prompt,
|
152
|
-
**parameters,
|
153
|
-
)
|
154
|
-
return response.text
|
155
|
-
|
156
91
|
@GoogleBaseHook.fallback_to_default_project_id
|
157
92
|
def text_embedding_model_get_embeddings(
|
158
93
|
self,
|
@@ -23,7 +23,7 @@ import dataclasses
|
|
23
23
|
from typing import Any
|
24
24
|
|
25
25
|
import vertex_ray
|
26
|
-
from google._upb._message import ScalarMapContainer
|
26
|
+
from google._upb._message import ScalarMapContainer # type: ignore[attr-defined]
|
27
27
|
from google.cloud import aiplatform
|
28
28
|
from google.cloud.aiplatform.vertex_ray.util import resources
|
29
29
|
from google.cloud.aiplatform_v1 import (
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
from airflow.providers.google.cloud.links.base import BaseGoogleLink
|
22
|
+
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from airflow.models import BaseOperator
|
26
|
+
from airflow.models.taskinstancekey import TaskInstanceKey
|
27
|
+
from airflow.utils.context import Context
|
28
|
+
|
29
|
+
if AIRFLOW_V_3_0_PLUS:
|
30
|
+
from airflow.sdk.execution_time.xcom import XCom
|
31
|
+
else:
|
32
|
+
from airflow.models.xcom import XCom # type: ignore[no-redef]
|
33
|
+
|
34
|
+
|
35
|
+
class CloudRunJobLoggingLink(BaseGoogleLink):
|
36
|
+
"""Helper class for constructing Cloud Run Job Logging link."""
|
37
|
+
|
38
|
+
name = "Cloud Run Job Logging"
|
39
|
+
key = "log_uri"
|
40
|
+
|
41
|
+
@staticmethod
|
42
|
+
def persist(
|
43
|
+
context: Context,
|
44
|
+
task_instance: BaseOperator,
|
45
|
+
log_uri: str,
|
46
|
+
):
|
47
|
+
task_instance.xcom_push(
|
48
|
+
context,
|
49
|
+
key=CloudRunJobLoggingLink.key,
|
50
|
+
value=log_uri,
|
51
|
+
)
|
52
|
+
|
53
|
+
def get_link(
|
54
|
+
self,
|
55
|
+
operator: BaseOperator,
|
56
|
+
*,
|
57
|
+
ti_key: TaskInstanceKey,
|
58
|
+
) -> str:
|
59
|
+
return XCom.get_value(key=self.key, ti_key=ti_key)
|
@@ -61,11 +61,11 @@ class GCSRemoteLogIO(LoggingMixin): # noqa: D101
|
|
61
61
|
remote_base: str
|
62
62
|
base_log_folder: Path = attrs.field(converter=Path)
|
63
63
|
delete_local_copy: bool
|
64
|
-
project_id: str
|
64
|
+
project_id: str | None = None
|
65
65
|
|
66
|
-
gcp_key_path: str | None
|
67
|
-
gcp_keyfile_dict: dict | None
|
68
|
-
scopes: Collection[str] | None
|
66
|
+
gcp_key_path: str | None = None
|
67
|
+
gcp_keyfile_dict: dict | None = None
|
68
|
+
scopes: Collection[str] | None = _DEFAULT_SCOPESS
|
69
69
|
|
70
70
|
processors = ()
|
71
71
|
|