apache-airflow-providers-google 16.1.0rc1__py3-none-any.whl → 17.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +1 -5
- airflow/providers/google/cloud/hooks/bigquery.py +1 -130
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_run.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_sql.py +5 -5
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +1 -1
- airflow/providers/google/cloud/hooks/dataflow.py +0 -85
- airflow/providers/google/cloud/hooks/datafusion.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +1 -4
- airflow/providers/google/cloud/hooks/dataproc.py +68 -70
- airflow/providers/google/cloud/hooks/gcs.py +3 -5
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/looker.py +1 -5
- airflow/providers/google/cloud/hooks/stackdriver.py +10 -8
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +4 -4
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +7 -0
- airflow/providers/google/cloud/links/kubernetes_engine.py +3 -0
- airflow/providers/google/cloud/log/gcs_task_handler.py +2 -2
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +1 -1
- airflow/providers/google/cloud/openlineage/mixins.py +7 -7
- airflow/providers/google/cloud/operators/automl.py +1 -1
- airflow/providers/google/cloud/operators/bigquery.py +8 -609
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_sql.py +1 -5
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +2 -2
- airflow/providers/google/cloud/operators/dataproc.py +1 -1
- airflow/providers/google/cloud/operators/dlp.py +2 -2
- airflow/providers/google/cloud/operators/kubernetes_engine.py +4 -4
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +7 -1
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +7 -5
- airflow/providers/google/cloud/operators/vision.py +1 -1
- airflow/providers/google/cloud/sensors/dataflow.py +23 -6
- airflow/providers/google/cloud/sensors/datafusion.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +3 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +9 -9
- airflow/providers/google/cloud/triggers/bigquery.py +11 -13
- airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +1 -1
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataproc.py +10 -9
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +45 -27
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +1 -1
- airflow/providers/google/common/auth_backend/google_openid.py +2 -2
- airflow/providers/google/common/hooks/base_google.py +2 -6
- airflow/providers/google/common/utils/id_token_credentials.py +2 -2
- airflow/providers/google/get_provider_info.py +19 -16
- airflow/providers/google/leveldb/hooks/leveldb.py +1 -5
- airflow/providers/google/marketing_platform/hooks/display_video.py +47 -3
- airflow/providers/google/marketing_platform/links/analytics_admin.py +1 -1
- airflow/providers/google/marketing_platform/operators/display_video.py +64 -15
- airflow/providers/google/marketing_platform/sensors/display_video.py +9 -2
- airflow/providers/google/version_compat.py +10 -3
- {apache_airflow_providers_google-16.1.0rc1.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/METADATA +99 -93
- {apache_airflow_providers_google-16.1.0rc1.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/RECORD +63 -62
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/life_sciences.py +0 -30
- airflow/providers/google/cloud/operators/life_sciences.py +0 -118
- {apache_airflow_providers_google-16.1.0rc1.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-16.1.0rc1.dist-info → apache_airflow_providers_google-17.0.0rc1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,435 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from collections.abc import Sequence
|
21
|
+
from typing import TYPE_CHECKING
|
22
|
+
|
23
|
+
from google.api_core import exceptions
|
24
|
+
|
25
|
+
from airflow.exceptions import AirflowException
|
26
|
+
from airflow.providers.google.cloud.hooks.vertex_ai.experiment_service import (
|
27
|
+
ExperimentHook,
|
28
|
+
ExperimentRunHook,
|
29
|
+
)
|
30
|
+
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
31
|
+
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
from airflow.utils.context import Context
|
34
|
+
|
35
|
+
|
36
|
+
class CreateExperimentOperator(GoogleCloudBaseOperator):
|
37
|
+
"""
|
38
|
+
Use the Vertex AI SDK to create experiment.
|
39
|
+
|
40
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
41
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
42
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
43
|
+
:param experiment_description: Optional. Description of the evaluation experiment.
|
44
|
+
:param experiment_tensorboard: Optional. The Vertex TensorBoard instance to use as a backing
|
45
|
+
TensorBoard for the provided experiment. If no TensorBoard is provided, a default TensorBoard
|
46
|
+
instance is created and used by this experiment.
|
47
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
48
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
49
|
+
credentials, or chained list of accounts required to get the access_token
|
50
|
+
of the last account in the list, which will be impersonated in the request.
|
51
|
+
If set as a string, the account must grant the originating account
|
52
|
+
the Service Account Token Creator IAM role.
|
53
|
+
If set as a sequence, the identities from the list must grant
|
54
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
55
|
+
account from the list granting this role to the originating account (templated).
|
56
|
+
"""
|
57
|
+
|
58
|
+
template_fields = (
|
59
|
+
"location",
|
60
|
+
"project_id",
|
61
|
+
"impersonation_chain",
|
62
|
+
"experiment_name",
|
63
|
+
)
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
*,
|
68
|
+
project_id: str,
|
69
|
+
location: str,
|
70
|
+
experiment_name: str,
|
71
|
+
experiment_description: str = "",
|
72
|
+
gcp_conn_id: str = "google_cloud_default",
|
73
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
74
|
+
experiment_tensorboard: str | None = None,
|
75
|
+
**kwargs,
|
76
|
+
) -> None:
|
77
|
+
super().__init__(**kwargs)
|
78
|
+
self.project_id = project_id
|
79
|
+
self.location = location
|
80
|
+
self.experiment_name = experiment_name
|
81
|
+
self.experiment_description = experiment_description
|
82
|
+
self.experiment_tensorboard = experiment_tensorboard
|
83
|
+
self.gcp_conn_id = gcp_conn_id
|
84
|
+
self.impersonation_chain = impersonation_chain
|
85
|
+
|
86
|
+
def execute(self, context: Context) -> None:
|
87
|
+
self.hook = ExperimentHook(
|
88
|
+
gcp_conn_id=self.gcp_conn_id,
|
89
|
+
impersonation_chain=self.impersonation_chain,
|
90
|
+
)
|
91
|
+
|
92
|
+
try:
|
93
|
+
self.hook.create_experiment(
|
94
|
+
project_id=self.project_id,
|
95
|
+
location=self.location,
|
96
|
+
experiment_name=self.experiment_name,
|
97
|
+
experiment_description=self.experiment_description,
|
98
|
+
experiment_tensorboard=self.experiment_tensorboard,
|
99
|
+
)
|
100
|
+
except exceptions.AlreadyExists:
|
101
|
+
raise AirflowException(f"Experiment with name {self.experiment_name} already exist")
|
102
|
+
|
103
|
+
self.log.info("Created experiment: %s", self.experiment_name)
|
104
|
+
|
105
|
+
|
106
|
+
class DeleteExperimentOperator(GoogleCloudBaseOperator):
|
107
|
+
"""
|
108
|
+
Use the Vertex AI SDK to delete experiment.
|
109
|
+
|
110
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
111
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
112
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
113
|
+
:param delete_backing_tensorboard_runs: Optional. If True will also delete the Vertex AI TensorBoard
|
114
|
+
runs associated with the experiment runs under this experiment that we used to store time series
|
115
|
+
metrics.
|
116
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
117
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
118
|
+
credentials, or chained list of accounts required to get the access_token
|
119
|
+
of the last account in the list, which will be impersonated in the request.
|
120
|
+
If set as a string, the account must grant the originating account
|
121
|
+
the Service Account Token Creator IAM role.
|
122
|
+
If set as a sequence, the identities from the list must grant
|
123
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
124
|
+
account from the list granting this role to the originating account (templated).
|
125
|
+
"""
|
126
|
+
|
127
|
+
template_fields = (
|
128
|
+
"location",
|
129
|
+
"project_id",
|
130
|
+
"impersonation_chain",
|
131
|
+
"experiment_name",
|
132
|
+
)
|
133
|
+
|
134
|
+
def __init__(
|
135
|
+
self,
|
136
|
+
*,
|
137
|
+
project_id: str,
|
138
|
+
location: str,
|
139
|
+
experiment_name: str,
|
140
|
+
delete_backing_tensorboard_runs: bool = False,
|
141
|
+
gcp_conn_id: str = "google_cloud_default",
|
142
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
143
|
+
**kwargs,
|
144
|
+
) -> None:
|
145
|
+
super().__init__(**kwargs)
|
146
|
+
self.project_id = project_id
|
147
|
+
self.location = location
|
148
|
+
self.experiment_name = experiment_name
|
149
|
+
self.delete_backing_tensorboard_runs = delete_backing_tensorboard_runs
|
150
|
+
self.gcp_conn_id = gcp_conn_id
|
151
|
+
self.impersonation_chain = impersonation_chain
|
152
|
+
|
153
|
+
def execute(self, context: Context) -> None:
|
154
|
+
self.hook = ExperimentHook(
|
155
|
+
gcp_conn_id=self.gcp_conn_id,
|
156
|
+
impersonation_chain=self.impersonation_chain,
|
157
|
+
)
|
158
|
+
|
159
|
+
try:
|
160
|
+
self.hook.delete_experiment(
|
161
|
+
project_id=self.project_id,
|
162
|
+
location=self.location,
|
163
|
+
experiment_name=self.experiment_name,
|
164
|
+
delete_backing_tensorboard_runs=self.delete_backing_tensorboard_runs,
|
165
|
+
)
|
166
|
+
except exceptions.NotFound:
|
167
|
+
raise AirflowException(f"Experiment with name {self.experiment_name} not found")
|
168
|
+
|
169
|
+
self.log.info("Deleted experiment: %s", self.experiment_name)
|
170
|
+
|
171
|
+
|
172
|
+
class CreateExperimentRunOperator(GoogleCloudBaseOperator):
|
173
|
+
"""
|
174
|
+
Use the Vertex AI SDK to create experiment run.
|
175
|
+
|
176
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
177
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
178
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
179
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
180
|
+
:param experiment_run_tensorboard: Optional. A backing TensorBoard resource to enable and store time series
|
181
|
+
metrics logged to this experiment run using log_time_series_metrics.
|
182
|
+
:param run_after_creation: Optional. If True experiment run will be created with state running.
|
183
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
184
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
185
|
+
credentials, or chained list of accounts required to get the access_token
|
186
|
+
of the last account in the list, which will be impersonated in the request.
|
187
|
+
If set as a string, the account must grant the originating account
|
188
|
+
the Service Account Token Creator IAM role.
|
189
|
+
If set as a sequence, the identities from the list must grant
|
190
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
191
|
+
account from the list granting this role to the originating account (templated).
|
192
|
+
"""
|
193
|
+
|
194
|
+
template_fields = (
|
195
|
+
"location",
|
196
|
+
"project_id",
|
197
|
+
"impersonation_chain",
|
198
|
+
"experiment_name",
|
199
|
+
"experiment_run_name",
|
200
|
+
)
|
201
|
+
|
202
|
+
def __init__(
|
203
|
+
self,
|
204
|
+
*,
|
205
|
+
project_id: str,
|
206
|
+
location: str,
|
207
|
+
experiment_name: str,
|
208
|
+
experiment_run_name: str,
|
209
|
+
experiment_run_tensorboard: str | None = None,
|
210
|
+
run_after_creation: bool = False,
|
211
|
+
gcp_conn_id: str = "google_cloud_default",
|
212
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
213
|
+
**kwargs,
|
214
|
+
) -> None:
|
215
|
+
super().__init__(**kwargs)
|
216
|
+
self.project_id = project_id
|
217
|
+
self.location = location
|
218
|
+
self.experiment_name = experiment_name
|
219
|
+
self.experiment_run_name = experiment_run_name
|
220
|
+
self.experiment_run_tensorboard = experiment_run_tensorboard
|
221
|
+
self.run_after_creation = run_after_creation
|
222
|
+
self.gcp_conn_id = gcp_conn_id
|
223
|
+
self.impersonation_chain = impersonation_chain
|
224
|
+
|
225
|
+
def execute(self, context: Context) -> None:
|
226
|
+
self.hook = ExperimentRunHook(
|
227
|
+
gcp_conn_id=self.gcp_conn_id,
|
228
|
+
impersonation_chain=self.impersonation_chain,
|
229
|
+
)
|
230
|
+
|
231
|
+
try:
|
232
|
+
self.hook.create_experiment_run(
|
233
|
+
project_id=self.project_id,
|
234
|
+
location=self.location,
|
235
|
+
experiment_name=self.experiment_name,
|
236
|
+
experiment_run_name=self.experiment_run_name,
|
237
|
+
experiment_run_tensorboard=self.experiment_run_tensorboard,
|
238
|
+
run_after_creation=self.run_after_creation,
|
239
|
+
)
|
240
|
+
except exceptions.AlreadyExists:
|
241
|
+
raise AirflowException(f"Experiment Run with name {self.experiment_run_name} already exist")
|
242
|
+
|
243
|
+
self.log.info("Created experiment run: %s", self.experiment_run_name)
|
244
|
+
|
245
|
+
|
246
|
+
class ListExperimentRunsOperator(GoogleCloudBaseOperator):
|
247
|
+
"""
|
248
|
+
Use the Vertex AI SDK to list experiment runs in experiment.
|
249
|
+
|
250
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
251
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
252
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
253
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
254
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
255
|
+
credentials, or chained list of accounts required to get the access_token
|
256
|
+
of the last account in the list, which will be impersonated in the request.
|
257
|
+
If set as a string, the account must grant the originating account
|
258
|
+
the Service Account Token Creator IAM role.
|
259
|
+
If set as a sequence, the identities from the list must grant
|
260
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
261
|
+
account from the list granting this role to the originating account (templated).
|
262
|
+
"""
|
263
|
+
|
264
|
+
template_fields = (
|
265
|
+
"location",
|
266
|
+
"project_id",
|
267
|
+
"impersonation_chain",
|
268
|
+
"experiment_name",
|
269
|
+
)
|
270
|
+
|
271
|
+
def __init__(
|
272
|
+
self,
|
273
|
+
*,
|
274
|
+
project_id: str,
|
275
|
+
location: str,
|
276
|
+
experiment_name: str,
|
277
|
+
gcp_conn_id: str = "google_cloud_default",
|
278
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
279
|
+
**kwargs,
|
280
|
+
):
|
281
|
+
super().__init__(**kwargs)
|
282
|
+
self.project_id = project_id
|
283
|
+
self.location = location
|
284
|
+
self.experiment_name = experiment_name
|
285
|
+
self.gcp_conn_id = gcp_conn_id
|
286
|
+
self.impersonation_chain = impersonation_chain
|
287
|
+
|
288
|
+
def execute(self, context: Context) -> list[str]:
|
289
|
+
self.hook = ExperimentRunHook(
|
290
|
+
gcp_conn_id=self.gcp_conn_id,
|
291
|
+
impersonation_chain=self.impersonation_chain,
|
292
|
+
)
|
293
|
+
|
294
|
+
try:
|
295
|
+
experiment_runs = self.hook.list_experiment_runs(
|
296
|
+
project_id=self.project_id, experiment_name=self.experiment_name, location=self.location
|
297
|
+
)
|
298
|
+
except exceptions.NotFound:
|
299
|
+
raise AirflowException("Experiment %s not found", self.experiment_name)
|
300
|
+
|
301
|
+
return [er.name for er in experiment_runs]
|
302
|
+
|
303
|
+
|
304
|
+
class UpdateExperimentRunStateOperator(GoogleCloudBaseOperator):
|
305
|
+
"""
|
306
|
+
Use the Vertex AI SDK to update state of the experiment run.
|
307
|
+
|
308
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
309
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
310
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
311
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
312
|
+
:param new_state: Required. The specific state of experiment run.
|
313
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
314
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
315
|
+
credentials, or chained list of accounts required to get the access_token
|
316
|
+
of the last account in the list, which will be impersonated in the request.
|
317
|
+
If set as a string, the account must grant the originating account
|
318
|
+
the Service Account Token Creator IAM role.
|
319
|
+
If set as a sequence, the identities from the list must grant
|
320
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
321
|
+
account from the list granting this role to the originating account (templated).
|
322
|
+
"""
|
323
|
+
|
324
|
+
template_fields = (
|
325
|
+
"location",
|
326
|
+
"project_id",
|
327
|
+
"impersonation_chain",
|
328
|
+
"experiment_name",
|
329
|
+
"experiment_run_name",
|
330
|
+
"new_state",
|
331
|
+
)
|
332
|
+
|
333
|
+
def __init__(
|
334
|
+
self,
|
335
|
+
*,
|
336
|
+
project_id: str,
|
337
|
+
location: str,
|
338
|
+
experiment_name: str,
|
339
|
+
experiment_run_name: str,
|
340
|
+
new_state: int,
|
341
|
+
gcp_conn_id: str = "google_cloud_default",
|
342
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
343
|
+
**kwargs,
|
344
|
+
):
|
345
|
+
super().__init__(**kwargs)
|
346
|
+
self.project_id = project_id
|
347
|
+
self.location = location
|
348
|
+
self.experiment_name = experiment_name
|
349
|
+
self.experiment_run_name = experiment_run_name
|
350
|
+
self.gcp_conn_id = gcp_conn_id
|
351
|
+
self.impersonation_chain = impersonation_chain
|
352
|
+
self.new_state = new_state
|
353
|
+
|
354
|
+
def execute(self, context: Context) -> None:
|
355
|
+
self.hook = ExperimentRunHook(
|
356
|
+
gcp_conn_id=self.gcp_conn_id,
|
357
|
+
impersonation_chain=self.impersonation_chain,
|
358
|
+
)
|
359
|
+
|
360
|
+
try:
|
361
|
+
self.hook.update_experiment_run_state(
|
362
|
+
project_id=self.project_id,
|
363
|
+
experiment_name=self.experiment_name,
|
364
|
+
experiment_run_name=self.experiment_run_name,
|
365
|
+
new_state=self.new_state,
|
366
|
+
location=self.location,
|
367
|
+
)
|
368
|
+
self.log.info("New state of the %s is: %s", self.experiment_run_name, self.new_state)
|
369
|
+
except exceptions.NotFound:
|
370
|
+
raise AirflowException("Experiment or experiment run not found")
|
371
|
+
|
372
|
+
|
373
|
+
class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
|
374
|
+
"""
|
375
|
+
Use the Vertex AI SDK to delete experiment run.
|
376
|
+
|
377
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
378
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
379
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
380
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
381
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
382
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
383
|
+
credentials, or chained list of accounts required to get the access_token
|
384
|
+
of the last account in the list, which will be impersonated in the request.
|
385
|
+
If set as a string, the account must grant the originating account
|
386
|
+
the Service Account Token Creator IAM role.
|
387
|
+
If set as a sequence, the identities from the list must grant
|
388
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
389
|
+
account from the list granting this role to the originating account (templated).
|
390
|
+
"""
|
391
|
+
|
392
|
+
template_fields = (
|
393
|
+
"location",
|
394
|
+
"project_id",
|
395
|
+
"impersonation_chain",
|
396
|
+
"experiment_name",
|
397
|
+
"experiment_run_name",
|
398
|
+
)
|
399
|
+
|
400
|
+
def __init__(
|
401
|
+
self,
|
402
|
+
*,
|
403
|
+
project_id: str,
|
404
|
+
location: str,
|
405
|
+
experiment_name: str,
|
406
|
+
experiment_run_name: str,
|
407
|
+
gcp_conn_id: str = "google_cloud_default",
|
408
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
409
|
+
**kwargs,
|
410
|
+
) -> None:
|
411
|
+
super().__init__(**kwargs)
|
412
|
+
self.project_id = project_id
|
413
|
+
self.location = location
|
414
|
+
self.experiment_name = experiment_name
|
415
|
+
self.experiment_run_name = experiment_run_name
|
416
|
+
self.gcp_conn_id = gcp_conn_id
|
417
|
+
self.impersonation_chain = impersonation_chain
|
418
|
+
|
419
|
+
def execute(self, context: Context) -> None:
|
420
|
+
self.hook = ExperimentRunHook(
|
421
|
+
gcp_conn_id=self.gcp_conn_id,
|
422
|
+
impersonation_chain=self.impersonation_chain,
|
423
|
+
)
|
424
|
+
|
425
|
+
try:
|
426
|
+
self.hook.delete_experiment_run(
|
427
|
+
project_id=self.project_id,
|
428
|
+
location=self.location,
|
429
|
+
experiment_name=self.experiment_name,
|
430
|
+
experiment_run_name=self.experiment_run_name,
|
431
|
+
)
|
432
|
+
except exceptions.NotFound:
|
433
|
+
raise AirflowException(f"Experiment Run with name {self.experiment_run_name} not found")
|
434
|
+
|
435
|
+
self.log.info("Deleted experiment run: %s", self.experiment_run_name)
|
@@ -24,12 +24,13 @@ from typing import TYPE_CHECKING, Any, Literal
|
|
24
24
|
|
25
25
|
from google.api_core import exceptions
|
26
26
|
|
27
|
-
from airflow.exceptions import AirflowException
|
27
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
28
28
|
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
|
29
29
|
ExperimentRunHook,
|
30
30
|
GenerativeModelHook,
|
31
31
|
)
|
32
32
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
33
|
+
from airflow.providers.google.common.deprecated import deprecated
|
33
34
|
|
34
35
|
if TYPE_CHECKING:
|
35
36
|
from airflow.utils.context import Context
|
@@ -587,6 +588,11 @@ class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
|
|
587
588
|
return cached_content_text
|
588
589
|
|
589
590
|
|
591
|
+
@deprecated(
|
592
|
+
planned_removal_date="January 3, 2026",
|
593
|
+
use_instead="airflow.providers.google.cloud.operators.vertex_ai.experiment_service.DeleteExperimentRunOperator",
|
594
|
+
category=AirflowProviderDeprecationWarning,
|
595
|
+
)
|
590
596
|
class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
|
591
597
|
"""
|
592
598
|
Use the Rapid Evaluation API to evaluate a model.
|
@@ -21,7 +21,7 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
from collections.abc import Sequence
|
23
23
|
from functools import cached_property
|
24
|
-
from typing import TYPE_CHECKING
|
24
|
+
from typing import TYPE_CHECKING, Literal
|
25
25
|
|
26
26
|
from google.api_core.exceptions import NotFound
|
27
27
|
from google.cloud.aiplatform.vertex_ray.util import resources
|
@@ -93,8 +93,10 @@ class CreateRayClusterOperator(RayBaseOperator):
|
|
93
93
|
:param location: Required. The ID of the Google Cloud region that the service belongs to.
|
94
94
|
:param head_node_type: The head node resource. Resources.node_count must be 1. If not set, default
|
95
95
|
value of Resources() class will be used.
|
96
|
-
:param python_version: Python version for the ray cluster.
|
97
|
-
:param ray_version: Ray version for the ray cluster.
|
96
|
+
:param python_version: Required. Python version for the ray cluster.
|
97
|
+
:param ray_version: Required. Ray version for the ray cluster.
|
98
|
+
Currently only 3 version are available: 2.9.3, 2.33, 2.42. For more information please refer to
|
99
|
+
https://github.com/googleapis/python-aiplatform/blob/main/setup.py#L101
|
98
100
|
:param network: Virtual private cloud (VPC) network. For Ray Client, VPC peering is required to
|
99
101
|
connect to the Ray Cluster managed in the Vertex API service. For Ray Job API, VPC network is not
|
100
102
|
required because Ray Cluster connection can be accessed through dashboard address.
|
@@ -136,9 +138,9 @@ class CreateRayClusterOperator(RayBaseOperator):
|
|
136
138
|
|
137
139
|
def __init__(
|
138
140
|
self,
|
141
|
+
python_version: str,
|
142
|
+
ray_version: Literal["2.9.3", "2.33", "2.42"],
|
139
143
|
head_node_type: resources.Resources = resources.Resources(),
|
140
|
-
python_version: str = "3.10",
|
141
|
-
ray_version: str = "2.33",
|
142
144
|
network: str | None = None,
|
143
145
|
service_account: str | None = None,
|
144
146
|
cluster_name: str | None = None,
|
@@ -693,7 +693,7 @@ class CloudVisionUpdateProductOperator(GoogleCloudBaseOperator):
|
|
693
693
|
location=self.location,
|
694
694
|
product_id=self.product_id,
|
695
695
|
project_id=self.project_id,
|
696
|
-
update_mask=self.update_mask,
|
696
|
+
update_mask=self.update_mask,
|
697
697
|
retry=self.retry,
|
698
698
|
timeout=self.timeout,
|
699
699
|
metadata=self.metadata,
|
@@ -37,7 +37,7 @@ from airflow.providers.google.cloud.triggers.dataflow import (
|
|
37
37
|
DataflowJobStatusTrigger,
|
38
38
|
)
|
39
39
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
40
|
-
from airflow.providers.google.version_compat import BaseSensorOperator
|
40
|
+
from airflow.providers.google.version_compat import BaseSensorOperator, PokeReturnValue
|
41
41
|
|
42
42
|
if TYPE_CHECKING:
|
43
43
|
from airflow.utils.context import Context
|
@@ -342,7 +342,7 @@ class DataflowJobMessagesSensor(BaseSensorOperator):
|
|
342
342
|
self.deferrable = deferrable
|
343
343
|
self.poll_interval = poll_interval
|
344
344
|
|
345
|
-
def poke(self, context: Context) -> bool:
|
345
|
+
def poke(self, context: Context) -> PokeReturnValue | bool:
|
346
346
|
if self.fail_on_terminal_state:
|
347
347
|
job = self.hook.get_job(
|
348
348
|
job_id=self.job_id,
|
@@ -359,8 +359,17 @@ class DataflowJobMessagesSensor(BaseSensorOperator):
|
|
359
359
|
project_id=self.project_id,
|
360
360
|
location=self.location,
|
361
361
|
)
|
362
|
+
result = result if self.callback is None else self.callback(result)
|
363
|
+
|
364
|
+
if isinstance(result, PokeReturnValue):
|
365
|
+
return result
|
362
366
|
|
363
|
-
|
367
|
+
if bool(result):
|
368
|
+
return PokeReturnValue(
|
369
|
+
is_done=True,
|
370
|
+
xcom_value=result,
|
371
|
+
)
|
372
|
+
return False
|
364
373
|
|
365
374
|
def execute(self, context: Context) -> Any:
|
366
375
|
"""Airflow runs this method on the worker and defers using the trigger."""
|
@@ -464,7 +473,7 @@ class DataflowJobAutoScalingEventsSensor(BaseSensorOperator):
|
|
464
473
|
self.deferrable = deferrable
|
465
474
|
self.poll_interval = poll_interval
|
466
475
|
|
467
|
-
def poke(self, context: Context) -> bool:
|
476
|
+
def poke(self, context: Context) -> PokeReturnValue | bool:
|
468
477
|
if self.fail_on_terminal_state:
|
469
478
|
job = self.hook.get_job(
|
470
479
|
job_id=self.job_id,
|
@@ -481,8 +490,16 @@ class DataflowJobAutoScalingEventsSensor(BaseSensorOperator):
|
|
481
490
|
project_id=self.project_id,
|
482
491
|
location=self.location,
|
483
492
|
)
|
484
|
-
|
485
|
-
|
493
|
+
result = result if self.callback is None else self.callback(result)
|
494
|
+
if isinstance(result, PokeReturnValue):
|
495
|
+
return result
|
496
|
+
|
497
|
+
if bool(result):
|
498
|
+
return PokeReturnValue(
|
499
|
+
is_done=True,
|
500
|
+
xcom_value=result,
|
501
|
+
)
|
502
|
+
return False
|
486
503
|
|
487
504
|
def execute(self, context: Context) -> Any:
|
488
505
|
"""Airflow runs this method on the worker and defers using the trigger."""
|
@@ -115,7 +115,7 @@ class CloudDataFusionPipelineStateSensor(BaseSensorOperator):
|
|
115
115
|
pipeline_id=self.pipeline_id,
|
116
116
|
namespace=self.namespace,
|
117
117
|
)
|
118
|
-
pipeline_status = pipeline_workflow
|
118
|
+
pipeline_status = pipeline_workflow.get("status")
|
119
119
|
except AirflowNotFoundException:
|
120
120
|
message = "Specified Pipeline ID was not found."
|
121
121
|
raise AirflowException(message)
|
@@ -132,4 +132,4 @@ class CloudDataFusionPipelineStateSensor(BaseSensorOperator):
|
|
132
132
|
self.log.debug(
|
133
133
|
"Current status of the pipeline workflow for %s: %s.", self.pipeline_id, pipeline_status
|
134
134
|
)
|
135
|
-
return pipeline_status in self.expected_statuses
|
135
|
+
return pipeline_status is not None and pipeline_status in self.expected_statuses
|
@@ -429,7 +429,6 @@ class GCSToBigQueryOperator(BaseOperator):
|
|
429
429
|
table = job_configuration[job_type][table_prop]
|
430
430
|
persist_kwargs = {
|
431
431
|
"context": context,
|
432
|
-
"task_instance": self,
|
433
432
|
"table_id": table,
|
434
433
|
}
|
435
434
|
if not isinstance(table, str):
|
@@ -581,7 +580,7 @@ class GCSToBigQueryOperator(BaseOperator):
|
|
581
580
|
table_obj_api_repr = table.to_api_repr()
|
582
581
|
|
583
582
|
self.log.info("Creating external table: %s", self.destination_project_dataset_table)
|
584
|
-
self.hook.
|
583
|
+
self.hook.create_table(
|
585
584
|
table_resource=table_obj_api_repr,
|
586
585
|
project_id=self.project_id or self.hook.project_id,
|
587
586
|
location=self.location,
|
@@ -20,13 +20,15 @@ from collections.abc import Sequence
|
|
20
20
|
from typing import TYPE_CHECKING
|
21
21
|
|
22
22
|
from airflow.exceptions import AirflowException
|
23
|
-
from airflow.models.xcom import MAX_XCOM_SIZE
|
24
23
|
from airflow.providers.google.cloud.hooks.gcs import GCSHook
|
25
24
|
from airflow.providers.google.version_compat import BaseOperator
|
26
25
|
|
27
26
|
if TYPE_CHECKING:
|
28
27
|
from airflow.utils.context import Context
|
29
28
|
|
29
|
+
# MAX XCOM Size is 48KB, check discussion: https://github.com/apache/airflow/pull/1618#discussion_r68249677
|
30
|
+
MAX_XCOM_SIZE = 49344
|
31
|
+
|
30
32
|
|
31
33
|
class GCSToLocalFilesystemOperator(BaseOperator):
|
32
34
|
"""
|
@@ -46,15 +46,15 @@ class OracleToGCSOperator(BaseSQLToGCSOperator):
|
|
46
46
|
ui_color = "#a0e08c"
|
47
47
|
|
48
48
|
type_map = {
|
49
|
-
oracledb.DB_TYPE_BINARY_DOUBLE: "DECIMAL",
|
50
|
-
oracledb.DB_TYPE_BINARY_FLOAT: "DECIMAL",
|
51
|
-
oracledb.DB_TYPE_BINARY_INTEGER: "INTEGER",
|
52
|
-
oracledb.DB_TYPE_BOOLEAN: "BOOLEAN",
|
53
|
-
oracledb.DB_TYPE_DATE: "TIMESTAMP",
|
54
|
-
oracledb.DB_TYPE_NUMBER: "NUMERIC",
|
55
|
-
oracledb.DB_TYPE_TIMESTAMP: "TIMESTAMP",
|
56
|
-
oracledb.DB_TYPE_TIMESTAMP_LTZ: "TIMESTAMP",
|
57
|
-
oracledb.DB_TYPE_TIMESTAMP_TZ: "TIMESTAMP",
|
49
|
+
oracledb.DB_TYPE_BINARY_DOUBLE: "DECIMAL",
|
50
|
+
oracledb.DB_TYPE_BINARY_FLOAT: "DECIMAL",
|
51
|
+
oracledb.DB_TYPE_BINARY_INTEGER: "INTEGER",
|
52
|
+
oracledb.DB_TYPE_BOOLEAN: "BOOLEAN",
|
53
|
+
oracledb.DB_TYPE_DATE: "TIMESTAMP",
|
54
|
+
oracledb.DB_TYPE_NUMBER: "NUMERIC",
|
55
|
+
oracledb.DB_TYPE_TIMESTAMP: "TIMESTAMP",
|
56
|
+
oracledb.DB_TYPE_TIMESTAMP_LTZ: "TIMESTAMP",
|
57
|
+
oracledb.DB_TYPE_TIMESTAMP_TZ: "TIMESTAMP",
|
58
58
|
}
|
59
59
|
|
60
60
|
def __init__(self, *, oracle_conn_id="oracle_default", ensure_utc=False, **kwargs):
|