apache-airflow-providers-google 15.1.0__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/ads/hooks/ads.py +34 -0
- airflow/providers/google/cloud/hooks/bigquery.py +63 -76
- airflow/providers/google/cloud/hooks/dataflow.py +67 -5
- airflow/providers/google/cloud/hooks/gcs.py +3 -3
- airflow/providers/google/cloud/hooks/looker.py +5 -0
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -36
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +1 -66
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/links/cloud_run.py +59 -0
- airflow/providers/google/cloud/links/vertex_ai.py +49 -0
- airflow/providers/google/cloud/log/gcs_task_handler.py +7 -5
- airflow/providers/google/cloud/operators/bigquery.py +49 -10
- airflow/providers/google/cloud/operators/cloud_run.py +20 -2
- airflow/providers/google/cloud/operators/gcs.py +1 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +4 -86
- airflow/providers/google/cloud/operators/pubsub.py +2 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +0 -92
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +4 -0
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +388 -0
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +9 -5
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +2 -0
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +11 -5
- airflow/providers/google/cloud/triggers/bigquery.py +32 -5
- airflow/providers/google/cloud/triggers/dataflow.py +122 -0
- airflow/providers/google/cloud/triggers/dataproc.py +62 -10
- airflow/providers/google/get_provider_info.py +18 -5
- airflow/providers/google/leveldb/hooks/leveldb.py +25 -0
- airflow/providers/google/version_compat.py +0 -1
- {apache_airflow_providers_google-15.1.0.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/METADATA +91 -84
- {apache_airflow_providers_google-15.1.0.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/RECORD +35 -32
- airflow/providers/google/cloud/links/automl.py +0 -193
- {apache_airflow_providers_google-15.1.0.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.1.0.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,223 @@
|
|
1
|
+
#
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
4
|
+
# distributed with this work for additional information
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
7
|
+
# "License"); you may not use this file except in compliance
|
8
|
+
# with the License. You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
13
|
+
# software distributed under the License is distributed on an
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
15
|
+
# KIND, either express or implied. See the License for the
|
16
|
+
# specific language governing permissions and limitations
|
17
|
+
# under the License.
|
18
|
+
"""This module contains a Google Cloud Vertex AI hook."""
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
import dataclasses
|
23
|
+
from typing import Any
|
24
|
+
|
25
|
+
import vertex_ray
|
26
|
+
from google._upb._message import ScalarMapContainer # type: ignore[attr-defined]
|
27
|
+
from google.cloud import aiplatform
|
28
|
+
from google.cloud.aiplatform.vertex_ray.util import resources
|
29
|
+
from google.cloud.aiplatform_v1 import (
|
30
|
+
PersistentResourceServiceClient,
|
31
|
+
)
|
32
|
+
from proto.marshal.collections.repeated import Repeated
|
33
|
+
|
34
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
35
|
+
|
36
|
+
|
37
|
+
class RayHook(GoogleBaseHook):
|
38
|
+
"""Hook for Google Cloud Vertex AI Ray APIs."""
|
39
|
+
|
40
|
+
def extract_cluster_id(self, cluster_path) -> str:
|
41
|
+
"""Extract cluster_id from cluster_path."""
|
42
|
+
cluster_id = PersistentResourceServiceClient.parse_persistent_resource_path(cluster_path)[
|
43
|
+
"persistent_resource"
|
44
|
+
]
|
45
|
+
return cluster_id
|
46
|
+
|
47
|
+
def serialize_cluster_obj(self, cluster_obj: resources.Cluster) -> dict:
|
48
|
+
"""Serialize Cluster dataclass to dict."""
|
49
|
+
|
50
|
+
def __encode_value(value: Any) -> Any:
|
51
|
+
if isinstance(value, (list, Repeated)):
|
52
|
+
return [__encode_value(nested_value) for nested_value in value]
|
53
|
+
if isinstance(value, ScalarMapContainer):
|
54
|
+
return {key: __encode_value(nested_value) for key, nested_value in dict(value).items()}
|
55
|
+
if dataclasses.is_dataclass(value):
|
56
|
+
return dataclasses.asdict(value)
|
57
|
+
return value
|
58
|
+
|
59
|
+
return {
|
60
|
+
field.name: __encode_value(getattr(cluster_obj, field.name))
|
61
|
+
for field in dataclasses.fields(cluster_obj)
|
62
|
+
}
|
63
|
+
|
64
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
65
|
+
def create_ray_cluster(
|
66
|
+
self,
|
67
|
+
project_id: str,
|
68
|
+
location: str,
|
69
|
+
head_node_type: resources.Resources = resources.Resources(),
|
70
|
+
python_version: str = "3.10",
|
71
|
+
ray_version: str = "2.33",
|
72
|
+
network: str | None = None,
|
73
|
+
service_account: str | None = None,
|
74
|
+
cluster_name: str | None = None,
|
75
|
+
worker_node_types: list[resources.Resources] | None = None,
|
76
|
+
custom_images: resources.NodeImages | None = None,
|
77
|
+
enable_metrics_collection: bool = True,
|
78
|
+
enable_logging: bool = True,
|
79
|
+
psc_interface_config: resources.PscIConfig | None = None,
|
80
|
+
reserved_ip_ranges: list[str] | None = None,
|
81
|
+
labels: dict[str, str] | None = None,
|
82
|
+
) -> str:
|
83
|
+
"""
|
84
|
+
Create a Ray cluster on the Vertex AI.
|
85
|
+
|
86
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
87
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
88
|
+
:param head_node_type: The head node resource. Resources.node_count must be 1. If not set, default
|
89
|
+
value of Resources() class will be used.
|
90
|
+
:param python_version: Python version for the ray cluster.
|
91
|
+
:param ray_version: Ray version for the ray cluster. Default is 2.33.0.
|
92
|
+
:param network: Virtual private cloud (VPC) network. For Ray Client, VPC peering is required to
|
93
|
+
connect to the Ray Cluster managed in the Vertex API service. For Ray Job API, VPC network is not
|
94
|
+
required because Ray Cluster connection can be accessed through dashboard address.
|
95
|
+
:param service_account: Service account to be used for running Ray programs on the cluster.
|
96
|
+
:param cluster_name: This value may be up to 63 characters, and valid characters are `[a-z0-9_-]`.
|
97
|
+
The first character cannot be a number or hyphen.
|
98
|
+
:param worker_node_types: The list of Resources of the worker nodes. The same Resources object should
|
99
|
+
not appear multiple times in the list.
|
100
|
+
:param custom_images: The NodeImages which specifies head node and worker nodes images. All the
|
101
|
+
workers will share the same image. If each Resource has a specific custom image, use
|
102
|
+
`Resources.custom_image` for head/worker_node_type(s). Note that configuring
|
103
|
+
`Resources.custom_image` will override `custom_images` here. Allowlist only.
|
104
|
+
:param enable_metrics_collection: Enable Ray metrics collection for visualization.
|
105
|
+
:param enable_logging: Enable exporting Ray logs to Cloud Logging.
|
106
|
+
:param psc_interface_config: PSC-I config.
|
107
|
+
:param reserved_ip_ranges: A list of names for the reserved IP ranges under the VPC network that can
|
108
|
+
be used for this cluster. If set, we will deploy the cluster within the provided IP ranges.
|
109
|
+
Otherwise, the cluster is deployed to any IP ranges under the provided VPC network.
|
110
|
+
Example: ["vertex-ai-ip-range"].
|
111
|
+
:param labels: The labels with user-defined metadata to organize Ray cluster.
|
112
|
+
Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
|
113
|
+
lowercase letters, numeric characters, underscores and dashes. International characters are allowed.
|
114
|
+
See https://goo.gl/xmQnxf for more information and examples of labels.
|
115
|
+
"""
|
116
|
+
aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
|
117
|
+
cluster_path = vertex_ray.create_ray_cluster(
|
118
|
+
head_node_type=head_node_type,
|
119
|
+
python_version=python_version,
|
120
|
+
ray_version=ray_version,
|
121
|
+
network=network,
|
122
|
+
service_account=service_account,
|
123
|
+
cluster_name=cluster_name,
|
124
|
+
worker_node_types=worker_node_types,
|
125
|
+
custom_images=custom_images,
|
126
|
+
enable_metrics_collection=enable_metrics_collection,
|
127
|
+
enable_logging=enable_logging,
|
128
|
+
psc_interface_config=psc_interface_config,
|
129
|
+
reserved_ip_ranges=reserved_ip_ranges,
|
130
|
+
labels=labels,
|
131
|
+
)
|
132
|
+
return cluster_path
|
133
|
+
|
134
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
135
|
+
def list_ray_clusters(
|
136
|
+
self,
|
137
|
+
project_id: str,
|
138
|
+
location: str,
|
139
|
+
) -> list[resources.Cluster]:
|
140
|
+
"""
|
141
|
+
List Ray clusters under the currently authenticated project.
|
142
|
+
|
143
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
144
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
145
|
+
"""
|
146
|
+
aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
|
147
|
+
ray_clusters = vertex_ray.list_ray_clusters()
|
148
|
+
return ray_clusters
|
149
|
+
|
150
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
151
|
+
def get_ray_cluster(
|
152
|
+
self,
|
153
|
+
project_id: str,
|
154
|
+
location: str,
|
155
|
+
cluster_id: str,
|
156
|
+
) -> resources.Cluster:
|
157
|
+
"""
|
158
|
+
Get Ray cluster.
|
159
|
+
|
160
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
161
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
162
|
+
:param cluster_id: Cluster resource ID.
|
163
|
+
"""
|
164
|
+
aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
|
165
|
+
ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
|
166
|
+
project=project_id,
|
167
|
+
location=location,
|
168
|
+
persistent_resource=cluster_id,
|
169
|
+
)
|
170
|
+
ray_cluster = vertex_ray.get_ray_cluster(
|
171
|
+
cluster_resource_name=ray_cluster_name,
|
172
|
+
)
|
173
|
+
return ray_cluster
|
174
|
+
|
175
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
176
|
+
def update_ray_cluster(
|
177
|
+
self,
|
178
|
+
project_id: str,
|
179
|
+
location: str,
|
180
|
+
cluster_id: str,
|
181
|
+
worker_node_types: list[resources.Resources],
|
182
|
+
) -> str:
|
183
|
+
"""
|
184
|
+
Update Ray cluster (currently support resizing node counts for worker nodes).
|
185
|
+
|
186
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
187
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
188
|
+
:param cluster_id: Cluster resource ID.
|
189
|
+
:param worker_node_types: The list of Resources of the resized worker nodes. The same Resources
|
190
|
+
object should not appear multiple times in the list.
|
191
|
+
"""
|
192
|
+
aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
|
193
|
+
ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
|
194
|
+
project=project_id,
|
195
|
+
location=location,
|
196
|
+
persistent_resource=cluster_id,
|
197
|
+
)
|
198
|
+
updated_ray_cluster_name = vertex_ray.update_ray_cluster(
|
199
|
+
cluster_resource_name=ray_cluster_name, worker_node_types=worker_node_types
|
200
|
+
)
|
201
|
+
return updated_ray_cluster_name
|
202
|
+
|
203
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
204
|
+
def delete_ray_cluster(
|
205
|
+
self,
|
206
|
+
project_id: str,
|
207
|
+
location: str,
|
208
|
+
cluster_id: str,
|
209
|
+
) -> None:
|
210
|
+
"""
|
211
|
+
Delete Ray cluster.
|
212
|
+
|
213
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
214
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
215
|
+
:param cluster_id: Cluster resource ID.
|
216
|
+
"""
|
217
|
+
aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
|
218
|
+
ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
|
219
|
+
project=project_id,
|
220
|
+
location=location,
|
221
|
+
persistent_resource=cluster_id,
|
222
|
+
)
|
223
|
+
vertex_ray.delete_ray_cluster(cluster_resource_name=ray_cluster_name)
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
from airflow.providers.google.cloud.links.base import BaseGoogleLink
|
22
|
+
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from airflow.models import BaseOperator
|
26
|
+
from airflow.models.taskinstancekey import TaskInstanceKey
|
27
|
+
from airflow.utils.context import Context
|
28
|
+
|
29
|
+
if AIRFLOW_V_3_0_PLUS:
|
30
|
+
from airflow.sdk.execution_time.xcom import XCom
|
31
|
+
else:
|
32
|
+
from airflow.models.xcom import XCom # type: ignore[no-redef]
|
33
|
+
|
34
|
+
|
35
|
+
class CloudRunJobLoggingLink(BaseGoogleLink):
|
36
|
+
"""Helper class for constructing Cloud Run Job Logging link."""
|
37
|
+
|
38
|
+
name = "Cloud Run Job Logging"
|
39
|
+
key = "log_uri"
|
40
|
+
|
41
|
+
@staticmethod
|
42
|
+
def persist(
|
43
|
+
context: Context,
|
44
|
+
task_instance: BaseOperator,
|
45
|
+
log_uri: str,
|
46
|
+
):
|
47
|
+
task_instance.xcom_push(
|
48
|
+
context,
|
49
|
+
key=CloudRunJobLoggingLink.key,
|
50
|
+
value=log_uri,
|
51
|
+
)
|
52
|
+
|
53
|
+
def get_link(
|
54
|
+
self,
|
55
|
+
operator: BaseOperator,
|
56
|
+
*,
|
57
|
+
ti_key: TaskInstanceKey,
|
58
|
+
) -> str:
|
59
|
+
return XCom.get_value(key=self.key, ti_key=ti_key)
|
@@ -54,6 +54,10 @@ VERTEX_AI_PIPELINE_JOB_LINK = (
|
|
54
54
|
VERTEX_AI_BASE_LINK + "/locations/{region}/pipelines/runs/{pipeline_id}?project={project_id}"
|
55
55
|
)
|
56
56
|
VERTEX_AI_PIPELINE_JOB_LIST_LINK = VERTEX_AI_BASE_LINK + "/pipelines/runs?project={project_id}"
|
57
|
+
VERTEX_AI_RAY_CLUSTER_LINK = (
|
58
|
+
VERTEX_AI_BASE_LINK + "/locations/{location}/ray-clusters/{cluster_id}?project={project_id}"
|
59
|
+
)
|
60
|
+
VERTEX_AI_RAY_CLUSTER_LIST_LINK = VERTEX_AI_BASE_LINK + "/ray?project={project_id}"
|
57
61
|
|
58
62
|
|
59
63
|
class VertexAIModelLink(BaseGoogleLink):
|
@@ -369,3 +373,48 @@ class VertexAIPipelineJobListLink(BaseGoogleLink):
|
|
369
373
|
"project_id": task_instance.project_id,
|
370
374
|
},
|
371
375
|
)
|
376
|
+
|
377
|
+
|
378
|
+
class VertexAIRayClusterLink(BaseGoogleLink):
|
379
|
+
"""Helper class for constructing Vertex AI Ray Cluster link."""
|
380
|
+
|
381
|
+
name = "Ray Cluster"
|
382
|
+
key = "ray_cluster_conf"
|
383
|
+
format_str = VERTEX_AI_RAY_CLUSTER_LINK
|
384
|
+
|
385
|
+
@staticmethod
|
386
|
+
def persist(
|
387
|
+
context: Context,
|
388
|
+
task_instance,
|
389
|
+
cluster_id: str,
|
390
|
+
):
|
391
|
+
task_instance.xcom_push(
|
392
|
+
context=context,
|
393
|
+
key=VertexAIRayClusterLink.key,
|
394
|
+
value={
|
395
|
+
"location": task_instance.location,
|
396
|
+
"cluster_id": cluster_id,
|
397
|
+
"project_id": task_instance.project_id,
|
398
|
+
},
|
399
|
+
)
|
400
|
+
|
401
|
+
|
402
|
+
class VertexAIRayClusterListLink(BaseGoogleLink):
|
403
|
+
"""Helper class for constructing Vertex AI Ray Cluster List link."""
|
404
|
+
|
405
|
+
name = "Ray Cluster List"
|
406
|
+
key = "ray_cluster_list_conf"
|
407
|
+
format_str = VERTEX_AI_RAY_CLUSTER_LIST_LINK
|
408
|
+
|
409
|
+
@staticmethod
|
410
|
+
def persist(
|
411
|
+
context: Context,
|
412
|
+
task_instance,
|
413
|
+
):
|
414
|
+
task_instance.xcom_push(
|
415
|
+
context=context,
|
416
|
+
key=VertexAIRayClusterListLink.key,
|
417
|
+
value={
|
418
|
+
"project_id": task_instance.project_id,
|
419
|
+
},
|
420
|
+
)
|
@@ -61,13 +61,15 @@ class GCSRemoteLogIO(LoggingMixin): # noqa: D101
|
|
61
61
|
remote_base: str
|
62
62
|
base_log_folder: Path = attrs.field(converter=Path)
|
63
63
|
delete_local_copy: bool
|
64
|
+
project_id: str | None = None
|
64
65
|
|
65
|
-
gcp_key_path: str | None
|
66
|
-
gcp_keyfile_dict: dict | None
|
67
|
-
scopes: Collection[str] | None
|
68
|
-
project_id: str
|
66
|
+
gcp_key_path: str | None = None
|
67
|
+
gcp_keyfile_dict: dict | None = None
|
68
|
+
scopes: Collection[str] | None = _DEFAULT_SCOPESS
|
69
69
|
|
70
|
-
|
70
|
+
processors = ()
|
71
|
+
|
72
|
+
def upload(self, path: os.PathLike | str, ti: RuntimeTI):
|
71
73
|
"""Upload the given log path to the remote storage."""
|
72
74
|
path = Path(path)
|
73
75
|
if path.is_absolute():
|
@@ -93,16 +93,32 @@ class IfExistAction(enum.Enum):
|
|
93
93
|
SKIP = "skip"
|
94
94
|
|
95
95
|
|
96
|
+
class _BigQueryHookWithFlexibleProjectId(BigQueryHook):
|
97
|
+
@property
|
98
|
+
def project_id(self) -> str:
|
99
|
+
_, project_id = self.get_credentials_and_project_id()
|
100
|
+
return project_id or PROVIDE_PROJECT_ID
|
101
|
+
|
102
|
+
@project_id.setter
|
103
|
+
def project_id(self, value: str) -> None:
|
104
|
+
cached_creds, _ = self.get_credentials_and_project_id()
|
105
|
+
self._cached_project_id = value or PROVIDE_PROJECT_ID
|
106
|
+
self._cached_credntials = cached_creds
|
107
|
+
|
108
|
+
|
96
109
|
class _BigQueryDbHookMixin:
|
97
|
-
def get_db_hook(self: BigQueryCheckOperator) ->
|
110
|
+
def get_db_hook(self: BigQueryCheckOperator) -> _BigQueryHookWithFlexibleProjectId: # type:ignore[misc]
|
98
111
|
"""Get BigQuery DB Hook."""
|
99
|
-
|
112
|
+
hook = _BigQueryHookWithFlexibleProjectId(
|
100
113
|
gcp_conn_id=self.gcp_conn_id,
|
101
114
|
use_legacy_sql=self.use_legacy_sql,
|
102
115
|
location=self.location,
|
103
116
|
impersonation_chain=self.impersonation_chain,
|
104
117
|
labels=self.labels,
|
105
118
|
)
|
119
|
+
if self.project_id:
|
120
|
+
hook.project_id = self.project_id
|
121
|
+
return hook
|
106
122
|
|
107
123
|
|
108
124
|
class _BigQueryOperatorsEncryptionConfigurationMixin:
|
@@ -190,6 +206,7 @@ class BigQueryCheckOperator(
|
|
190
206
|
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs.
|
191
207
|
For example, [{ 'name': 'corpus', 'parameterType': { 'type': 'STRING' },
|
192
208
|
'parameterValue': { 'value': 'romeoandjuliet' } }]. (templated)
|
209
|
+
:param project_id: Google Cloud Project where the job is running
|
193
210
|
"""
|
194
211
|
|
195
212
|
template_fields: Sequence[str] = (
|
@@ -208,6 +225,7 @@ class BigQueryCheckOperator(
|
|
208
225
|
*,
|
209
226
|
sql: str,
|
210
227
|
gcp_conn_id: str = "google_cloud_default",
|
228
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
211
229
|
use_legacy_sql: bool = True,
|
212
230
|
location: str | None = None,
|
213
231
|
impersonation_chain: str | Sequence[str] | None = None,
|
@@ -228,6 +246,7 @@ class BigQueryCheckOperator(
|
|
228
246
|
self.deferrable = deferrable
|
229
247
|
self.poll_interval = poll_interval
|
230
248
|
self.query_params = query_params
|
249
|
+
self.project_id = project_id
|
231
250
|
|
232
251
|
def _submit_job(
|
233
252
|
self,
|
@@ -243,7 +262,7 @@ class BigQueryCheckOperator(
|
|
243
262
|
|
244
263
|
return hook.insert_job(
|
245
264
|
configuration=configuration,
|
246
|
-
project_id=
|
265
|
+
project_id=self.project_id,
|
247
266
|
location=self.location,
|
248
267
|
job_id=job_id,
|
249
268
|
nowait=True,
|
@@ -257,6 +276,8 @@ class BigQueryCheckOperator(
|
|
257
276
|
gcp_conn_id=self.gcp_conn_id,
|
258
277
|
impersonation_chain=self.impersonation_chain,
|
259
278
|
)
|
279
|
+
if self.project_id is None:
|
280
|
+
self.project_id = hook.project_id
|
260
281
|
job = self._submit_job(hook, job_id="")
|
261
282
|
context["ti"].xcom_push(key="job_id", value=job.job_id)
|
262
283
|
if job.running():
|
@@ -265,7 +286,7 @@ class BigQueryCheckOperator(
|
|
265
286
|
trigger=BigQueryCheckTrigger(
|
266
287
|
conn_id=self.gcp_conn_id,
|
267
288
|
job_id=job.job_id,
|
268
|
-
project_id=
|
289
|
+
project_id=self.project_id,
|
269
290
|
location=self.location or hook.location,
|
270
291
|
poll_interval=self.poll_interval,
|
271
292
|
impersonation_chain=self.impersonation_chain,
|
@@ -342,6 +363,7 @@ class BigQueryValueCheckOperator(
|
|
342
363
|
:param deferrable: Run operator in the deferrable mode.
|
343
364
|
:param poll_interval: (Deferrable mode only) polling period in seconds to
|
344
365
|
check for the status of job.
|
366
|
+
:param project_id: Google Cloud Project where the job is running
|
345
367
|
"""
|
346
368
|
|
347
369
|
template_fields: Sequence[str] = (
|
@@ -363,6 +385,7 @@ class BigQueryValueCheckOperator(
|
|
363
385
|
tolerance: Any = None,
|
364
386
|
encryption_configuration: dict | None = None,
|
365
387
|
gcp_conn_id: str = "google_cloud_default",
|
388
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
366
389
|
use_legacy_sql: bool = True,
|
367
390
|
location: str | None = None,
|
368
391
|
impersonation_chain: str | Sequence[str] | None = None,
|
@@ -380,6 +403,7 @@ class BigQueryValueCheckOperator(
|
|
380
403
|
self.labels = labels
|
381
404
|
self.deferrable = deferrable
|
382
405
|
self.poll_interval = poll_interval
|
406
|
+
self.project_id = project_id
|
383
407
|
|
384
408
|
def _submit_job(
|
385
409
|
self,
|
@@ -398,7 +422,7 @@ class BigQueryValueCheckOperator(
|
|
398
422
|
|
399
423
|
return hook.insert_job(
|
400
424
|
configuration=configuration,
|
401
|
-
project_id=
|
425
|
+
project_id=self.project_id,
|
402
426
|
location=self.location,
|
403
427
|
job_id=job_id,
|
404
428
|
nowait=True,
|
@@ -409,7 +433,8 @@ class BigQueryValueCheckOperator(
|
|
409
433
|
super().execute(context=context)
|
410
434
|
else:
|
411
435
|
hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
|
412
|
-
|
436
|
+
if self.project_id is None:
|
437
|
+
self.project_id = hook.project_id
|
413
438
|
job = self._submit_job(hook, job_id="")
|
414
439
|
context["ti"].xcom_push(key="job_id", value=job.job_id)
|
415
440
|
if job.running():
|
@@ -418,7 +443,7 @@ class BigQueryValueCheckOperator(
|
|
418
443
|
trigger=BigQueryValueCheckTrigger(
|
419
444
|
conn_id=self.gcp_conn_id,
|
420
445
|
job_id=job.job_id,
|
421
|
-
project_id=
|
446
|
+
project_id=self.project_id,
|
422
447
|
location=self.location or hook.location,
|
423
448
|
sql=self.sql,
|
424
449
|
pass_value=self.pass_value,
|
@@ -575,6 +600,9 @@ class BigQueryIntervalCheckOperator(
|
|
575
600
|
hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
|
576
601
|
self.log.info("Using ratio formula: %s", self.ratio_formula)
|
577
602
|
|
603
|
+
if self.project_id is None:
|
604
|
+
self.project_id = hook.project_id
|
605
|
+
|
578
606
|
self.log.info("Executing SQL check: %s", self.sql1)
|
579
607
|
job_1 = self._submit_job(hook, sql=self.sql1, job_id="")
|
580
608
|
context["ti"].xcom_push(key="job_id", value=job_1.job_id)
|
@@ -587,7 +615,7 @@ class BigQueryIntervalCheckOperator(
|
|
587
615
|
conn_id=self.gcp_conn_id,
|
588
616
|
first_job_id=job_1.job_id,
|
589
617
|
second_job_id=job_2.job_id,
|
590
|
-
project_id=
|
618
|
+
project_id=self.project_id,
|
591
619
|
table=self.table,
|
592
620
|
location=self.location or hook.location,
|
593
621
|
metrics_thresholds=self.metrics_thresholds,
|
@@ -654,6 +682,7 @@ class BigQueryColumnCheckOperator(
|
|
654
682
|
Service Account Token Creator IAM role to the directly preceding identity, with first
|
655
683
|
account from the list granting this role to the originating account (templated).
|
656
684
|
:param labels: a dictionary containing labels for the table, passed to BigQuery
|
685
|
+
:param project_id: Google Cloud Project where the job is running
|
657
686
|
"""
|
658
687
|
|
659
688
|
template_fields: Sequence[str] = tuple(set(SQLColumnCheckOperator.template_fields) | {"gcp_conn_id"})
|
@@ -670,6 +699,7 @@ class BigQueryColumnCheckOperator(
|
|
670
699
|
accept_none: bool = True,
|
671
700
|
encryption_configuration: dict | None = None,
|
672
701
|
gcp_conn_id: str = "google_cloud_default",
|
702
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
673
703
|
use_legacy_sql: bool = True,
|
674
704
|
location: str | None = None,
|
675
705
|
impersonation_chain: str | Sequence[str] | None = None,
|
@@ -695,6 +725,7 @@ class BigQueryColumnCheckOperator(
|
|
695
725
|
self.location = location
|
696
726
|
self.impersonation_chain = impersonation_chain
|
697
727
|
self.labels = labels
|
728
|
+
self.project_id = project_id
|
698
729
|
|
699
730
|
def _submit_job(
|
700
731
|
self,
|
@@ -706,7 +737,7 @@ class BigQueryColumnCheckOperator(
|
|
706
737
|
self.include_encryption_configuration(configuration, "query")
|
707
738
|
return hook.insert_job(
|
708
739
|
configuration=configuration,
|
709
|
-
project_id=
|
740
|
+
project_id=self.project_id,
|
710
741
|
location=self.location,
|
711
742
|
job_id=job_id,
|
712
743
|
nowait=False,
|
@@ -715,6 +746,9 @@ class BigQueryColumnCheckOperator(
|
|
715
746
|
def execute(self, context=None):
|
716
747
|
"""Perform checks on the given columns."""
|
717
748
|
hook = self.get_db_hook()
|
749
|
+
|
750
|
+
if self.project_id is None:
|
751
|
+
self.project_id = hook.project_id
|
718
752
|
failed_tests = []
|
719
753
|
|
720
754
|
job = self._submit_job(hook, job_id="")
|
@@ -786,6 +820,7 @@ class BigQueryTableCheckOperator(
|
|
786
820
|
account from the list granting this role to the originating account (templated).
|
787
821
|
:param labels: a dictionary containing labels for the table, passed to BigQuery
|
788
822
|
:param encryption_configuration: (Optional) Custom encryption configuration (e.g., Cloud KMS keys).
|
823
|
+
:param project_id: Google Cloud Project where the job is running
|
789
824
|
|
790
825
|
.. code-block:: python
|
791
826
|
|
@@ -805,6 +840,7 @@ class BigQueryTableCheckOperator(
|
|
805
840
|
checks: dict,
|
806
841
|
partition_clause: str | None = None,
|
807
842
|
gcp_conn_id: str = "google_cloud_default",
|
843
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
808
844
|
use_legacy_sql: bool = True,
|
809
845
|
location: str | None = None,
|
810
846
|
impersonation_chain: str | Sequence[str] | None = None,
|
@@ -819,6 +855,7 @@ class BigQueryTableCheckOperator(
|
|
819
855
|
self.impersonation_chain = impersonation_chain
|
820
856
|
self.labels = labels
|
821
857
|
self.encryption_configuration = encryption_configuration
|
858
|
+
self.project_id = project_id
|
822
859
|
|
823
860
|
def _submit_job(
|
824
861
|
self,
|
@@ -832,7 +869,7 @@ class BigQueryTableCheckOperator(
|
|
832
869
|
|
833
870
|
return hook.insert_job(
|
834
871
|
configuration=configuration,
|
835
|
-
project_id=
|
872
|
+
project_id=self.project_id,
|
836
873
|
location=self.location,
|
837
874
|
job_id=job_id,
|
838
875
|
nowait=False,
|
@@ -841,6 +878,8 @@ class BigQueryTableCheckOperator(
|
|
841
878
|
def execute(self, context=None):
|
842
879
|
"""Execute the given checks on the table."""
|
843
880
|
hook = self.get_db_hook()
|
881
|
+
if self.project_id is None:
|
882
|
+
self.project_id = hook.project_id
|
844
883
|
job = self._submit_job(hook, job_id="")
|
845
884
|
context["ti"].xcom_push(key="job_id", value=job.job_id)
|
846
885
|
records = job.result().to_dataframe()
|
@@ -27,6 +27,7 @@ from google.cloud.run_v2 import Job, Service
|
|
27
27
|
from airflow.configuration import conf
|
28
28
|
from airflow.exceptions import AirflowException
|
29
29
|
from airflow.providers.google.cloud.hooks.cloud_run import CloudRunHook, CloudRunServiceHook
|
30
|
+
from airflow.providers.google.cloud.links.cloud_run import CloudRunJobLoggingLink
|
30
31
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
31
32
|
from airflow.providers.google.cloud.triggers.cloud_run import CloudRunJobFinishedTrigger, RunJobStatus
|
32
33
|
|
@@ -248,7 +249,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
|
|
248
249
|
|
249
250
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
250
251
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
251
|
-
:param job_name: Required. The name of the job to
|
252
|
+
:param job_name: Required. The name of the job to execute.
|
252
253
|
:param overrides: Optional map of override values.
|
253
254
|
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
|
254
255
|
:param polling_period_seconds: Optional. Control the rate of the poll for the result of deferrable run.
|
@@ -265,7 +266,17 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
|
|
265
266
|
:param deferrable: Run the operator in deferrable mode.
|
266
267
|
"""
|
267
268
|
|
268
|
-
|
269
|
+
operator_extra_links = (CloudRunJobLoggingLink(),)
|
270
|
+
template_fields = (
|
271
|
+
"project_id",
|
272
|
+
"region",
|
273
|
+
"gcp_conn_id",
|
274
|
+
"impersonation_chain",
|
275
|
+
"job_name",
|
276
|
+
"overrides",
|
277
|
+
"polling_period_seconds",
|
278
|
+
"timeout_seconds",
|
279
|
+
)
|
269
280
|
|
270
281
|
def __init__(
|
271
282
|
self,
|
@@ -303,6 +314,13 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
|
|
303
314
|
if self.operation is None:
|
304
315
|
raise AirflowException("Operation is None")
|
305
316
|
|
317
|
+
if self.operation.metadata.log_uri:
|
318
|
+
CloudRunJobLoggingLink.persist(
|
319
|
+
context=context,
|
320
|
+
task_instance=self,
|
321
|
+
log_uri=self.operation.metadata.log_uri,
|
322
|
+
)
|
323
|
+
|
306
324
|
if not self.deferrable:
|
307
325
|
result: Execution = self._wait_for_operation(self.operation)
|
308
326
|
self._fail_if_execution_failed(result)
|