apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 16.0.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,8 +32,8 @@ __all__ = ["__version__"]
32
32
  __version__ = "15.1.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
- "2.9.0"
35
+ "2.10.0"
36
36
  ):
37
37
  raise RuntimeError(
38
- f"The package `apache-airflow-providers-google:{__version__}` needs Apache Airflow 2.9.0+"
38
+ f"The package `apache-airflow-providers-google:{__version__}` needs Apache Airflow 2.10.0+"
39
39
  )
@@ -185,7 +185,67 @@ class DataflowJobType:
185
185
  JOB_TYPE_STREAMING = "JOB_TYPE_STREAMING"
186
186
 
187
187
 
188
- class _DataflowJobsController(LoggingMixin):
188
+ class DataflowJobTerminalStateHelper(LoggingMixin):
189
+ """Helper to define and validate the dataflow job terminal state."""
190
+
191
+ @staticmethod
192
+ def expected_terminal_state_is_allowed(expected_terminal_state):
193
+ job_allowed_terminal_states = DataflowJobStatus.TERMINAL_STATES | {
194
+ DataflowJobStatus.JOB_STATE_RUNNING
195
+ }
196
+ if expected_terminal_state not in job_allowed_terminal_states:
197
+ raise AirflowException(
198
+ f"Google Cloud Dataflow job's expected terminal state "
199
+ f"'{expected_terminal_state}' is invalid."
200
+ f" The value should be any of the following: {job_allowed_terminal_states}"
201
+ )
202
+ return True
203
+
204
+ @staticmethod
205
+ def expected_terminal_state_is_valid_for_job_type(expected_terminal_state, is_streaming: bool):
206
+ if is_streaming:
207
+ invalid_terminal_state = DataflowJobStatus.JOB_STATE_DONE
208
+ job_type = "streaming"
209
+ else:
210
+ invalid_terminal_state = DataflowJobStatus.JOB_STATE_DRAINED
211
+ job_type = "batch"
212
+
213
+ if expected_terminal_state == invalid_terminal_state:
214
+ raise AirflowException(
215
+ f"Google Cloud Dataflow job's expected terminal state cannot be {invalid_terminal_state} while it is a {job_type} job"
216
+ )
217
+ return True
218
+
219
+ def job_reached_terminal_state(self, job, wait_until_finished=None, custom_terminal_state=None) -> bool:
220
+ """
221
+ Check the job reached terminal state, if job failed raise exception.
222
+
223
+ :return: True if job is done.
224
+ :raise: Exception
225
+ """
226
+ current_state = job["currentState"]
227
+ is_streaming = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING
228
+ expected_terminal_state = (
229
+ DataflowJobStatus.JOB_STATE_RUNNING if is_streaming else DataflowJobStatus.JOB_STATE_DONE
230
+ )
231
+ if custom_terminal_state is not None:
232
+ expected_terminal_state = custom_terminal_state
233
+ self.expected_terminal_state_is_allowed(expected_terminal_state)
234
+ self.expected_terminal_state_is_valid_for_job_type(expected_terminal_state, is_streaming=is_streaming)
235
+ if current_state == expected_terminal_state:
236
+ if expected_terminal_state == DataflowJobStatus.JOB_STATE_RUNNING and wait_until_finished:
237
+ return False
238
+ return True
239
+ if current_state in DataflowJobStatus.AWAITING_STATES:
240
+ return wait_until_finished is False
241
+ self.log.debug("Current job: %s", job)
242
+ raise AirflowException(
243
+ f"Google Cloud Dataflow job {job['name']} is in an unexpected terminal state: {current_state}, "
244
+ f"expected terminal state: {expected_terminal_state}"
245
+ )
246
+
247
+
248
+ class _DataflowJobsController(DataflowJobTerminalStateHelper):
189
249
  """
190
250
  Interface for communication with Google Cloud Dataflow API.
191
251
 
@@ -462,7 +522,10 @@ class _DataflowJobsController(LoggingMixin):
462
522
  """Wait for result of submitted job."""
463
523
  self.log.info("Start waiting for done.")
464
524
  self._refresh_jobs()
465
- while self._jobs and not all(self._check_dataflow_job_state(job) for job in self._jobs):
525
+ while self._jobs and not all(
526
+ self.job_reached_terminal_state(job, self._wait_until_finished, self._expected_terminal_state)
527
+ for job in self._jobs
528
+ ):
466
529
  self.log.info("Waiting for done. Sleep %s s", self._poll_sleep)
467
530
  time.sleep(self._poll_sleep)
468
531
  self._refresh_jobs()
@@ -1295,8 +1358,7 @@ class DataflowHook(GoogleBaseHook):
1295
1358
  location=location,
1296
1359
  )
1297
1360
  job = job_controller.fetch_job_by_id(job_id)
1298
-
1299
- return job_controller._check_dataflow_job_state(job)
1361
+ return job_controller.job_reached_terminal_state(job)
1300
1362
 
1301
1363
  @GoogleBaseHook.fallback_to_default_project_id
1302
1364
  def create_data_pipeline(
@@ -1425,7 +1487,7 @@ class DataflowHook(GoogleBaseHook):
1425
1487
  return f"projects/{project_id}/locations/{location}"
1426
1488
 
1427
1489
 
1428
- class AsyncDataflowHook(GoogleBaseAsyncHook):
1490
+ class AsyncDataflowHook(GoogleBaseAsyncHook, DataflowJobTerminalStateHelper):
1429
1491
  """Async hook class for dataflow service."""
1430
1492
 
1431
1493
  sync_hook_class = DataflowHook
@@ -0,0 +1,223 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ """This module contains a Google Cloud Vertex AI hook."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import dataclasses
23
+ from typing import Any
24
+
25
+ import vertex_ray
26
+ from google._upb._message import ScalarMapContainer
27
+ from google.cloud import aiplatform
28
+ from google.cloud.aiplatform.vertex_ray.util import resources
29
+ from google.cloud.aiplatform_v1 import (
30
+ PersistentResourceServiceClient,
31
+ )
32
+ from proto.marshal.collections.repeated import Repeated
33
+
34
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
35
+
36
+
37
+ class RayHook(GoogleBaseHook):
38
+ """Hook for Google Cloud Vertex AI Ray APIs."""
39
+
40
+ def extract_cluster_id(self, cluster_path) -> str:
41
+ """Extract cluster_id from cluster_path."""
42
+ cluster_id = PersistentResourceServiceClient.parse_persistent_resource_path(cluster_path)[
43
+ "persistent_resource"
44
+ ]
45
+ return cluster_id
46
+
47
+ def serialize_cluster_obj(self, cluster_obj: resources.Cluster) -> dict:
48
+ """Serialize Cluster dataclass to dict."""
49
+
50
+ def __encode_value(value: Any) -> Any:
51
+ if isinstance(value, (list, Repeated)):
52
+ return [__encode_value(nested_value) for nested_value in value]
53
+ if isinstance(value, ScalarMapContainer):
54
+ return {key: __encode_value(nested_value) for key, nested_value in dict(value).items()}
55
+ if dataclasses.is_dataclass(value):
56
+ return dataclasses.asdict(value)
57
+ return value
58
+
59
+ return {
60
+ field.name: __encode_value(getattr(cluster_obj, field.name))
61
+ for field in dataclasses.fields(cluster_obj)
62
+ }
63
+
64
+ @GoogleBaseHook.fallback_to_default_project_id
65
+ def create_ray_cluster(
66
+ self,
67
+ project_id: str,
68
+ location: str,
69
+ head_node_type: resources.Resources = resources.Resources(),
70
+ python_version: str = "3.10",
71
+ ray_version: str = "2.33",
72
+ network: str | None = None,
73
+ service_account: str | None = None,
74
+ cluster_name: str | None = None,
75
+ worker_node_types: list[resources.Resources] | None = None,
76
+ custom_images: resources.NodeImages | None = None,
77
+ enable_metrics_collection: bool = True,
78
+ enable_logging: bool = True,
79
+ psc_interface_config: resources.PscIConfig | None = None,
80
+ reserved_ip_ranges: list[str] | None = None,
81
+ labels: dict[str, str] | None = None,
82
+ ) -> str:
83
+ """
84
+ Create a Ray cluster on the Vertex AI.
85
+
86
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
87
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
88
+ :param head_node_type: The head node resource. Resources.node_count must be 1. If not set, default
89
+ value of Resources() class will be used.
90
+ :param python_version: Python version for the ray cluster.
91
+ :param ray_version: Ray version for the ray cluster. Default is 2.33.0.
92
+ :param network: Virtual private cloud (VPC) network. For Ray Client, VPC peering is required to
93
+ connect to the Ray Cluster managed in the Vertex API service. For Ray Job API, VPC network is not
94
+ required because Ray Cluster connection can be accessed through dashboard address.
95
+ :param service_account: Service account to be used for running Ray programs on the cluster.
96
+ :param cluster_name: This value may be up to 63 characters, and valid characters are `[a-z0-9_-]`.
97
+ The first character cannot be a number or hyphen.
98
+ :param worker_node_types: The list of Resources of the worker nodes. The same Resources object should
99
+ not appear multiple times in the list.
100
+ :param custom_images: The NodeImages which specifies head node and worker nodes images. All the
101
+ workers will share the same image. If each Resource has a specific custom image, use
102
+ `Resources.custom_image` for head/worker_node_type(s). Note that configuring
103
+ `Resources.custom_image` will override `custom_images` here. Allowlist only.
104
+ :param enable_metrics_collection: Enable Ray metrics collection for visualization.
105
+ :param enable_logging: Enable exporting Ray logs to Cloud Logging.
106
+ :param psc_interface_config: PSC-I config.
107
+ :param reserved_ip_ranges: A list of names for the reserved IP ranges under the VPC network that can
108
+ be used for this cluster. If set, we will deploy the cluster within the provided IP ranges.
109
+ Otherwise, the cluster is deployed to any IP ranges under the provided VPC network.
110
+ Example: ["vertex-ai-ip-range"].
111
+ :param labels: The labels with user-defined metadata to organize Ray cluster.
112
+ Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
113
+ lowercase letters, numeric characters, underscores and dashes. International characters are allowed.
114
+ See https://goo.gl/xmQnxf for more information and examples of labels.
115
+ """
116
+ aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
117
+ cluster_path = vertex_ray.create_ray_cluster(
118
+ head_node_type=head_node_type,
119
+ python_version=python_version,
120
+ ray_version=ray_version,
121
+ network=network,
122
+ service_account=service_account,
123
+ cluster_name=cluster_name,
124
+ worker_node_types=worker_node_types,
125
+ custom_images=custom_images,
126
+ enable_metrics_collection=enable_metrics_collection,
127
+ enable_logging=enable_logging,
128
+ psc_interface_config=psc_interface_config,
129
+ reserved_ip_ranges=reserved_ip_ranges,
130
+ labels=labels,
131
+ )
132
+ return cluster_path
133
+
134
+ @GoogleBaseHook.fallback_to_default_project_id
135
+ def list_ray_clusters(
136
+ self,
137
+ project_id: str,
138
+ location: str,
139
+ ) -> list[resources.Cluster]:
140
+ """
141
+ List Ray clusters under the currently authenticated project.
142
+
143
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
144
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
145
+ """
146
+ aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
147
+ ray_clusters = vertex_ray.list_ray_clusters()
148
+ return ray_clusters
149
+
150
+ @GoogleBaseHook.fallback_to_default_project_id
151
+ def get_ray_cluster(
152
+ self,
153
+ project_id: str,
154
+ location: str,
155
+ cluster_id: str,
156
+ ) -> resources.Cluster:
157
+ """
158
+ Get Ray cluster.
159
+
160
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
161
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
162
+ :param cluster_id: Cluster resource ID.
163
+ """
164
+ aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
165
+ ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
166
+ project=project_id,
167
+ location=location,
168
+ persistent_resource=cluster_id,
169
+ )
170
+ ray_cluster = vertex_ray.get_ray_cluster(
171
+ cluster_resource_name=ray_cluster_name,
172
+ )
173
+ return ray_cluster
174
+
175
+ @GoogleBaseHook.fallback_to_default_project_id
176
+ def update_ray_cluster(
177
+ self,
178
+ project_id: str,
179
+ location: str,
180
+ cluster_id: str,
181
+ worker_node_types: list[resources.Resources],
182
+ ) -> str:
183
+ """
184
+ Update Ray cluster (currently support resizing node counts for worker nodes).
185
+
186
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
187
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
188
+ :param cluster_id: Cluster resource ID.
189
+ :param worker_node_types: The list of Resources of the resized worker nodes. The same Resources
190
+ object should not appear multiple times in the list.
191
+ """
192
+ aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
193
+ ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
194
+ project=project_id,
195
+ location=location,
196
+ persistent_resource=cluster_id,
197
+ )
198
+ updated_ray_cluster_name = vertex_ray.update_ray_cluster(
199
+ cluster_resource_name=ray_cluster_name, worker_node_types=worker_node_types
200
+ )
201
+ return updated_ray_cluster_name
202
+
203
+ @GoogleBaseHook.fallback_to_default_project_id
204
+ def delete_ray_cluster(
205
+ self,
206
+ project_id: str,
207
+ location: str,
208
+ cluster_id: str,
209
+ ) -> None:
210
+ """
211
+ Delete Ray cluster.
212
+
213
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
214
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
215
+ :param cluster_id: Cluster resource ID.
216
+ """
217
+ aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
218
+ ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
219
+ project=project_id,
220
+ location=location,
221
+ persistent_resource=cluster_id,
222
+ )
223
+ vertex_ray.delete_ray_cluster(cluster_resource_name=ray_cluster_name)
@@ -54,6 +54,10 @@ VERTEX_AI_PIPELINE_JOB_LINK = (
54
54
  VERTEX_AI_BASE_LINK + "/locations/{region}/pipelines/runs/{pipeline_id}?project={project_id}"
55
55
  )
56
56
  VERTEX_AI_PIPELINE_JOB_LIST_LINK = VERTEX_AI_BASE_LINK + "/pipelines/runs?project={project_id}"
57
+ VERTEX_AI_RAY_CLUSTER_LINK = (
58
+ VERTEX_AI_BASE_LINK + "/locations/{location}/ray-clusters/{cluster_id}?project={project_id}"
59
+ )
60
+ VERTEX_AI_RAY_CLUSTER_LIST_LINK = VERTEX_AI_BASE_LINK + "/ray?project={project_id}"
57
61
 
58
62
 
59
63
  class VertexAIModelLink(BaseGoogleLink):
@@ -369,3 +373,48 @@ class VertexAIPipelineJobListLink(BaseGoogleLink):
369
373
  "project_id": task_instance.project_id,
370
374
  },
371
375
  )
376
+
377
+
378
+ class VertexAIRayClusterLink(BaseGoogleLink):
379
+ """Helper class for constructing Vertex AI Ray Cluster link."""
380
+
381
+ name = "Ray Cluster"
382
+ key = "ray_cluster_conf"
383
+ format_str = VERTEX_AI_RAY_CLUSTER_LINK
384
+
385
+ @staticmethod
386
+ def persist(
387
+ context: Context,
388
+ task_instance,
389
+ cluster_id: str,
390
+ ):
391
+ task_instance.xcom_push(
392
+ context=context,
393
+ key=VertexAIRayClusterLink.key,
394
+ value={
395
+ "location": task_instance.location,
396
+ "cluster_id": cluster_id,
397
+ "project_id": task_instance.project_id,
398
+ },
399
+ )
400
+
401
+
402
+ class VertexAIRayClusterListLink(BaseGoogleLink):
403
+ """Helper class for constructing Vertex AI Ray Cluster List link."""
404
+
405
+ name = "Ray Cluster List"
406
+ key = "ray_cluster_list_conf"
407
+ format_str = VERTEX_AI_RAY_CLUSTER_LIST_LINK
408
+
409
+ @staticmethod
410
+ def persist(
411
+ context: Context,
412
+ task_instance,
413
+ ):
414
+ task_instance.xcom_push(
415
+ context=context,
416
+ key=VertexAIRayClusterListLink.key,
417
+ value={
418
+ "project_id": task_instance.project_id,
419
+ },
420
+ )
@@ -61,13 +61,15 @@ class GCSRemoteLogIO(LoggingMixin): # noqa: D101
61
61
  remote_base: str
62
62
  base_log_folder: Path = attrs.field(converter=Path)
63
63
  delete_local_copy: bool
64
+ project_id: str
64
65
 
65
66
  gcp_key_path: str | None
66
67
  gcp_keyfile_dict: dict | None
67
68
  scopes: Collection[str] | None
68
- project_id: str
69
69
 
70
- def upload(self, path: os.PathLike, ti: RuntimeTI):
70
+ processors = ()
71
+
72
+ def upload(self, path: os.PathLike | str, ti: RuntimeTI):
71
73
  """Upload the given log path to the remote storage."""
72
74
  path = Path(path)
73
75
  if path.is_absolute():
@@ -265,7 +265,16 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
265
265
  :param deferrable: Run the operator in deferrable mode.
266
266
  """
267
267
 
268
- template_fields = ("project_id", "region", "gcp_conn_id", "impersonation_chain", "job_name", "overrides")
268
+ template_fields = (
269
+ "project_id",
270
+ "region",
271
+ "gcp_conn_id",
272
+ "impersonation_chain",
273
+ "job_name",
274
+ "overrides",
275
+ "polling_period_seconds",
276
+ "timeout_seconds",
277
+ )
269
278
 
270
279
  def __init__(
271
280
  self,
@@ -636,7 +636,7 @@ class GKEStartPodOperator(GKEOperatorMixin, KubernetesPodOperator):
636
636
  """
637
637
 
638
638
  template_fields: Sequence[str] = tuple(
639
- {"on_finish_action", "deferrable"}
639
+ {"deferrable"}
640
640
  | (set(KubernetesPodOperator.template_fields) - {"is_delete_operator_pod", "regional"})
641
641
  | set(GKEOperatorMixin.template_fields)
642
642
  )