apache-airflow-providers-google 10.16.0rc1__py3-none-any.whl → 10.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +5 -4
- airflow/providers/google/ads/operators/ads.py +1 -0
- airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +1 -0
- airflow/providers/google/cloud/example_dags/example_cloud_task.py +1 -0
- airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py +1 -0
- airflow/providers/google/cloud/example_dags/example_looker.py +1 -0
- airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py +1 -0
- airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py +1 -0
- airflow/providers/google/cloud/fs/gcs.py +1 -2
- airflow/providers/google/cloud/hooks/automl.py +1 -0
- airflow/providers/google/cloud/hooks/bigquery.py +87 -24
- airflow/providers/google/cloud/hooks/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/hooks/bigtable.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_build.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_sql.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +9 -4
- airflow/providers/google/cloud/hooks/compute.py +1 -0
- airflow/providers/google/cloud/hooks/compute_ssh.py +2 -2
- airflow/providers/google/cloud/hooks/dataflow.py +6 -5
- airflow/providers/google/cloud/hooks/datafusion.py +1 -0
- airflow/providers/google/cloud/hooks/datapipeline.py +1 -0
- airflow/providers/google/cloud/hooks/dataplex.py +1 -0
- airflow/providers/google/cloud/hooks/dataprep.py +1 -0
- airflow/providers/google/cloud/hooks/dataproc.py +3 -2
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +1 -0
- airflow/providers/google/cloud/hooks/datastore.py +1 -0
- airflow/providers/google/cloud/hooks/dlp.py +1 -0
- airflow/providers/google/cloud/hooks/functions.py +1 -0
- airflow/providers/google/cloud/hooks/gcs.py +12 -5
- airflow/providers/google/cloud/hooks/kms.py +1 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +178 -300
- airflow/providers/google/cloud/hooks/life_sciences.py +1 -0
- airflow/providers/google/cloud/hooks/looker.py +1 -0
- airflow/providers/google/cloud/hooks/mlengine.py +1 -0
- airflow/providers/google/cloud/hooks/natural_language.py +1 -0
- airflow/providers/google/cloud/hooks/os_login.py +1 -0
- airflow/providers/google/cloud/hooks/pubsub.py +1 -0
- airflow/providers/google/cloud/hooks/secret_manager.py +1 -0
- airflow/providers/google/cloud/hooks/spanner.py +1 -0
- airflow/providers/google/cloud/hooks/speech_to_text.py +1 -0
- airflow/providers/google/cloud/hooks/stackdriver.py +1 -0
- airflow/providers/google/cloud/hooks/text_to_speech.py +1 -0
- airflow/providers/google/cloud/hooks/translate.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +255 -3
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +197 -0
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +9 -9
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +231 -12
- airflow/providers/google/cloud/hooks/video_intelligence.py +1 -0
- airflow/providers/google/cloud/hooks/vision.py +1 -0
- airflow/providers/google/cloud/links/automl.py +1 -0
- airflow/providers/google/cloud/links/bigquery.py +1 -0
- airflow/providers/google/cloud/links/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/links/cloud_memorystore.py +1 -0
- airflow/providers/google/cloud/links/cloud_sql.py +1 -0
- airflow/providers/google/cloud/links/cloud_tasks.py +1 -0
- airflow/providers/google/cloud/links/compute.py +1 -0
- airflow/providers/google/cloud/links/datacatalog.py +1 -0
- airflow/providers/google/cloud/links/dataflow.py +1 -0
- airflow/providers/google/cloud/links/dataform.py +1 -0
- airflow/providers/google/cloud/links/datafusion.py +1 -0
- airflow/providers/google/cloud/links/dataplex.py +1 -0
- airflow/providers/google/cloud/links/dataproc.py +1 -0
- airflow/providers/google/cloud/links/kubernetes_engine.py +28 -0
- airflow/providers/google/cloud/links/mlengine.py +1 -0
- airflow/providers/google/cloud/links/pubsub.py +1 -0
- airflow/providers/google/cloud/links/spanner.py +1 -0
- airflow/providers/google/cloud/links/stackdriver.py +1 -0
- airflow/providers/google/cloud/links/workflows.py +1 -0
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +18 -4
- airflow/providers/google/cloud/operators/automl.py +1 -0
- airflow/providers/google/cloud/operators/bigquery.py +21 -0
- airflow/providers/google/cloud/operators/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/operators/bigtable.py +1 -0
- airflow/providers/google/cloud/operators/cloud_base.py +1 -0
- airflow/providers/google/cloud/operators/cloud_build.py +1 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +1 -0
- airflow/providers/google/cloud/operators/cloud_sql.py +1 -0
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +11 -5
- airflow/providers/google/cloud/operators/compute.py +1 -0
- airflow/providers/google/cloud/operators/dataflow.py +1 -0
- airflow/providers/google/cloud/operators/datafusion.py +1 -0
- airflow/providers/google/cloud/operators/datapipeline.py +1 -0
- airflow/providers/google/cloud/operators/dataprep.py +1 -0
- airflow/providers/google/cloud/operators/dataproc.py +3 -2
- airflow/providers/google/cloud/operators/dataproc_metastore.py +1 -0
- airflow/providers/google/cloud/operators/datastore.py +1 -0
- airflow/providers/google/cloud/operators/functions.py +1 -0
- airflow/providers/google/cloud/operators/gcs.py +1 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +600 -4
- airflow/providers/google/cloud/operators/life_sciences.py +1 -0
- airflow/providers/google/cloud/operators/looker.py +1 -0
- airflow/providers/google/cloud/operators/mlengine.py +283 -259
- airflow/providers/google/cloud/operators/natural_language.py +1 -0
- airflow/providers/google/cloud/operators/pubsub.py +1 -0
- airflow/providers/google/cloud/operators/spanner.py +1 -0
- airflow/providers/google/cloud/operators/speech_to_text.py +1 -0
- airflow/providers/google/cloud/operators/text_to_speech.py +1 -0
- airflow/providers/google/cloud/operators/translate.py +1 -0
- airflow/providers/google/cloud/operators/translate_speech.py +1 -0
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +14 -7
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +67 -13
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +26 -8
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +1 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +306 -0
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +29 -48
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +52 -17
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -0
- airflow/providers/google/cloud/operators/vision.py +1 -0
- airflow/providers/google/cloud/secrets/secret_manager.py +1 -0
- airflow/providers/google/cloud/sensors/bigquery.py +1 -0
- airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/sensors/bigtable.py +1 -0
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +1 -0
- airflow/providers/google/cloud/sensors/dataflow.py +1 -0
- airflow/providers/google/cloud/sensors/dataform.py +1 -0
- airflow/providers/google/cloud/sensors/datafusion.py +1 -0
- airflow/providers/google/cloud/sensors/dataplex.py +1 -0
- airflow/providers/google/cloud/sensors/dataprep.py +1 -0
- airflow/providers/google/cloud/sensors/dataproc.py +1 -0
- airflow/providers/google/cloud/sensors/gcs.py +1 -0
- airflow/providers/google/cloud/sensors/looker.py +1 -0
- airflow/providers/google/cloud/sensors/pubsub.py +1 -0
- airflow/providers/google/cloud/sensors/tasks.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +3 -2
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/mysql_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +19 -1
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +3 -5
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +4 -2
- airflow/providers/google/cloud/triggers/bigquery.py +4 -3
- airflow/providers/google/cloud/triggers/cloud_batch.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_run.py +1 -0
- airflow/providers/google/cloud/triggers/cloud_sql.py +2 -0
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +14 -2
- airflow/providers/google/cloud/triggers/dataplex.py +1 -0
- airflow/providers/google/cloud/triggers/dataproc.py +1 -0
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +72 -2
- airflow/providers/google/cloud/triggers/mlengine.py +2 -0
- airflow/providers/google/cloud/triggers/pubsub.py +3 -3
- airflow/providers/google/cloud/triggers/vertex_ai.py +107 -15
- airflow/providers/google/cloud/utils/field_sanitizer.py +2 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -0
- airflow/providers/google/cloud/utils/helpers.py +1 -0
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +1 -0
- airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +1 -0
- airflow/providers/google/cloud/utils/openlineage.py +1 -0
- airflow/providers/google/common/auth_backend/google_openid.py +1 -0
- airflow/providers/google/common/hooks/base_google.py +2 -1
- airflow/providers/google/common/hooks/discovery_api.py +1 -0
- airflow/providers/google/common/links/storage.py +1 -0
- airflow/providers/google/common/utils/id_token_credentials.py +1 -0
- airflow/providers/google/firebase/hooks/firestore.py +1 -0
- airflow/providers/google/get_provider_info.py +9 -3
- airflow/providers/google/go_module_utils.py +1 -0
- airflow/providers/google/leveldb/hooks/leveldb.py +8 -7
- airflow/providers/google/marketing_platform/example_dags/example_display_video.py +1 -0
- airflow/providers/google/marketing_platform/hooks/analytics_admin.py +1 -0
- airflow/providers/google/marketing_platform/hooks/campaign_manager.py +1 -0
- airflow/providers/google/marketing_platform/hooks/display_video.py +1 -0
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -0
- airflow/providers/google/marketing_platform/operators/analytics.py +1 -0
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +4 -2
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +1 -0
- airflow/providers/google/marketing_platform/operators/display_video.py +1 -0
- airflow/providers/google/marketing_platform/operators/search_ads.py +1 -0
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +1 -0
- airflow/providers/google/marketing_platform/sensors/display_video.py +2 -1
- airflow/providers/google/marketing_platform/sensors/search_ads.py +1 -0
- airflow/providers/google/suite/hooks/calendar.py +1 -0
- airflow/providers/google/suite/hooks/drive.py +1 -0
- airflow/providers/google/suite/hooks/sheets.py +1 -0
- airflow/providers/google/suite/sensors/drive.py +1 -0
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +7 -0
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +4 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +1 -0
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0.dist-info}/METADATA +18 -13
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0.dist-info}/RECORD +196 -194
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0.dist-info}/entry_points.txt +0 -0
@@ -21,22 +21,30 @@
|
|
21
21
|
|
22
22
|
aiplatform
|
23
23
|
"""
|
24
|
+
|
24
25
|
from __future__ import annotations
|
25
26
|
|
27
|
+
import asyncio
|
26
28
|
from typing import TYPE_CHECKING, Any, Sequence
|
27
29
|
|
28
30
|
from google.api_core.client_options import ClientOptions
|
29
31
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
30
32
|
from google.cloud.aiplatform import PipelineJob
|
31
|
-
from google.cloud.aiplatform_v1 import
|
33
|
+
from google.cloud.aiplatform_v1 import (
|
34
|
+
PipelineServiceAsyncClient,
|
35
|
+
PipelineServiceClient,
|
36
|
+
PipelineState,
|
37
|
+
types,
|
38
|
+
)
|
32
39
|
|
33
40
|
from airflow.exceptions import AirflowException
|
34
41
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
35
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
42
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
36
43
|
|
37
44
|
if TYPE_CHECKING:
|
38
45
|
from google.api_core.operation import Operation
|
39
|
-
from google.api_core.retry import Retry
|
46
|
+
from google.api_core.retry import AsyncRetry, Retry
|
47
|
+
from google.auth.credentials import Credentials
|
40
48
|
from google.cloud.aiplatform.metadata import experiment_resources
|
41
49
|
from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ListPipelineJobsPager
|
42
50
|
|
@@ -101,11 +109,6 @@ class PipelineJobHook(GoogleBaseHook):
|
|
101
109
|
failure_policy=failure_policy,
|
102
110
|
)
|
103
111
|
|
104
|
-
@staticmethod
|
105
|
-
def extract_pipeline_job_id(obj: dict) -> str:
|
106
|
-
"""Return unique id of the pipeline_job."""
|
107
|
-
return obj["name"].rpartition("/")[-1]
|
108
|
-
|
109
112
|
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
110
113
|
"""Wait for long-lasting operation to complete."""
|
111
114
|
try:
|
@@ -129,7 +132,7 @@ class PipelineJobHook(GoogleBaseHook):
|
|
129
132
|
retry: Retry | _MethodDefault = DEFAULT,
|
130
133
|
timeout: float | None = None,
|
131
134
|
metadata: Sequence[tuple[str, str]] = (),
|
132
|
-
) -> PipelineJob:
|
135
|
+
) -> types.PipelineJob:
|
133
136
|
"""
|
134
137
|
Create a PipelineJob. A PipelineJob will run immediately when created.
|
135
138
|
|
@@ -182,7 +185,7 @@ class PipelineJobHook(GoogleBaseHook):
|
|
182
185
|
# END: run param
|
183
186
|
) -> PipelineJob:
|
184
187
|
"""
|
185
|
-
|
188
|
+
Create and run a PipelineJob until its completion.
|
186
189
|
|
187
190
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
188
191
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -243,7 +246,103 @@ class PipelineJobHook(GoogleBaseHook):
|
|
243
246
|
location=region,
|
244
247
|
failure_policy=failure_policy,
|
245
248
|
)
|
249
|
+
self._pipeline_job.submit(
|
250
|
+
service_account=service_account,
|
251
|
+
network=network,
|
252
|
+
create_request_timeout=create_request_timeout,
|
253
|
+
experiment=experiment,
|
254
|
+
)
|
255
|
+
self._pipeline_job.wait()
|
256
|
+
|
257
|
+
return self._pipeline_job
|
258
|
+
|
259
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
260
|
+
def submit_pipeline_job(
|
261
|
+
self,
|
262
|
+
project_id: str,
|
263
|
+
region: str,
|
264
|
+
display_name: str,
|
265
|
+
template_path: str,
|
266
|
+
job_id: str | None = None,
|
267
|
+
pipeline_root: str | None = None,
|
268
|
+
parameter_values: dict[str, Any] | None = None,
|
269
|
+
input_artifacts: dict[str, str] | None = None,
|
270
|
+
enable_caching: bool | None = None,
|
271
|
+
encryption_spec_key_name: str | None = None,
|
272
|
+
labels: dict[str, str] | None = None,
|
273
|
+
failure_policy: str | None = None,
|
274
|
+
# START: run param
|
275
|
+
service_account: str | None = None,
|
276
|
+
network: str | None = None,
|
277
|
+
create_request_timeout: float | None = None,
|
278
|
+
experiment: str | experiment_resources.Experiment | None = None,
|
279
|
+
# END: run param
|
280
|
+
) -> PipelineJob:
|
281
|
+
"""
|
282
|
+
Create and start a PipelineJob run.
|
246
283
|
|
284
|
+
For more info about the client method please see:
|
285
|
+
https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.PipelineJob#google_cloud_aiplatform_PipelineJob_submit
|
286
|
+
|
287
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
288
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
289
|
+
:param display_name: Required. The user-defined name of this Pipeline.
|
290
|
+
:param template_path: Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It can be
|
291
|
+
a local path, a Google Cloud Storage URI (e.g. "gs://project.name"), an Artifact Registry URI
|
292
|
+
(e.g. "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI.
|
293
|
+
:param job_id: Optional. The unique ID of the job run. If not specified, pipeline name + timestamp
|
294
|
+
will be used.
|
295
|
+
:param pipeline_root: Optional. The root of the pipeline outputs. If not set, the staging bucket set
|
296
|
+
in aiplatform.init will be used. If that's not set a pipeline-specific artifacts bucket will be
|
297
|
+
used.
|
298
|
+
:param parameter_values: Optional. The mapping from runtime parameter names to its values that
|
299
|
+
control the pipeline run.
|
300
|
+
:param input_artifacts: Optional. The mapping from the runtime parameter name for this artifact to
|
301
|
+
its resource id. For example: "vertex_model":"456". Note: full resource name
|
302
|
+
("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
|
303
|
+
:param enable_caching: Optional. Whether to turn on caching for the run.
|
304
|
+
If this is not set, defaults to the compile time settings, which are True for all tasks by
|
305
|
+
default, while users may specify different caching options for individual tasks.
|
306
|
+
If this is set, the setting applies to all tasks in the pipeline. Overrides the compile time
|
307
|
+
settings.
|
308
|
+
:param encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed
|
309
|
+
encryption key used to protect the job. Has the form:
|
310
|
+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
311
|
+
The key needs to be in the same region as where the compute resource is created. If this is set,
|
312
|
+
then all resources created by the PipelineJob will be encrypted with the provided encryption key.
|
313
|
+
Overrides encryption_spec_key_name set in aiplatform.init.
|
314
|
+
:param labels: Optional. The user defined metadata to organize PipelineJob.
|
315
|
+
:param failure_policy: Optional. The failure policy - "slow" or "fast". Currently, the default of a
|
316
|
+
pipeline is that the pipeline will continue to run until no more tasks can be executed, also
|
317
|
+
known as PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow"). However, if a pipeline is set
|
318
|
+
to PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"), it will stop scheduling any new
|
319
|
+
tasks when a task has failed. Any scheduled tasks will continue to completion.
|
320
|
+
:param service_account: Optional. Specifies the service account for workload run-as account. Users
|
321
|
+
submitting jobs must have act-as permission on this run-as account.
|
322
|
+
:param network: Optional. The full name of the Compute Engine network to which the job should be
|
323
|
+
peered. For example, projects/12345/global/networks/myVPC.
|
324
|
+
Private services access must already be configured for the network. If left unspecified, the
|
325
|
+
network set in aiplatform.init will be used. Otherwise, the job is not peered with any network.
|
326
|
+
:param create_request_timeout: Optional. The timeout for the create request in seconds.
|
327
|
+
:param experiment: Optional. The Vertex AI experiment name or instance to associate to this PipelineJob.
|
328
|
+
Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as metrics
|
329
|
+
to the current Experiment Run. Pipeline parameters will be associated as parameters to
|
330
|
+
the current Experiment Run.
|
331
|
+
"""
|
332
|
+
self._pipeline_job = self.get_pipeline_job_object(
|
333
|
+
display_name=display_name,
|
334
|
+
template_path=template_path,
|
335
|
+
job_id=job_id,
|
336
|
+
pipeline_root=pipeline_root,
|
337
|
+
parameter_values=parameter_values,
|
338
|
+
input_artifacts=input_artifacts,
|
339
|
+
enable_caching=enable_caching,
|
340
|
+
encryption_spec_key_name=encryption_spec_key_name,
|
341
|
+
labels=labels,
|
342
|
+
project=project_id,
|
343
|
+
location=region,
|
344
|
+
failure_policy=failure_policy,
|
345
|
+
)
|
247
346
|
self._pipeline_job.submit(
|
248
347
|
service_account=service_account,
|
249
348
|
network=network,
|
@@ -251,7 +350,6 @@ class PipelineJobHook(GoogleBaseHook):
|
|
251
350
|
experiment=experiment,
|
252
351
|
)
|
253
352
|
|
254
|
-
self._pipeline_job.wait()
|
255
353
|
return self._pipeline_job
|
256
354
|
|
257
355
|
@GoogleBaseHook.fallback_to_default_project_id
|
@@ -263,7 +361,7 @@ class PipelineJobHook(GoogleBaseHook):
|
|
263
361
|
retry: Retry | _MethodDefault = DEFAULT,
|
264
362
|
timeout: float | None = None,
|
265
363
|
metadata: Sequence[tuple[str, str]] = (),
|
266
|
-
) -> PipelineJob:
|
364
|
+
) -> types.PipelineJob:
|
267
365
|
"""
|
268
366
|
Get a PipelineJob.
|
269
367
|
|
@@ -407,3 +505,124 @@ class PipelineJobHook(GoogleBaseHook):
|
|
407
505
|
metadata=metadata,
|
408
506
|
)
|
409
507
|
return result
|
508
|
+
|
509
|
+
@staticmethod
|
510
|
+
def extract_pipeline_job_id(obj: dict) -> str:
|
511
|
+
"""Return unique id of a pipeline job from its name."""
|
512
|
+
return obj["name"].rpartition("/")[-1]
|
513
|
+
|
514
|
+
|
515
|
+
class PipelineJobAsyncHook(GoogleBaseAsyncHook):
|
516
|
+
"""Asynchronous hook for Google Cloud Vertex AI Pipeline Job APIs."""
|
517
|
+
|
518
|
+
sync_hook_class = PipelineJobHook
|
519
|
+
PIPELINE_COMPLETE_STATES = (
|
520
|
+
PipelineState.PIPELINE_STATE_CANCELLED,
|
521
|
+
PipelineState.PIPELINE_STATE_FAILED,
|
522
|
+
PipelineState.PIPELINE_STATE_PAUSED,
|
523
|
+
PipelineState.PIPELINE_STATE_SUCCEEDED,
|
524
|
+
)
|
525
|
+
|
526
|
+
def __init__(
|
527
|
+
self,
|
528
|
+
gcp_conn_id: str = "google_cloud_default",
|
529
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
530
|
+
**kwargs,
|
531
|
+
) -> None:
|
532
|
+
super().__init__(
|
533
|
+
gcp_conn_id=gcp_conn_id,
|
534
|
+
impersonation_chain=impersonation_chain,
|
535
|
+
**kwargs,
|
536
|
+
)
|
537
|
+
|
538
|
+
async def get_credentials(self) -> Credentials:
|
539
|
+
sync_hook = await self.get_sync_hook()
|
540
|
+
return sync_hook.get_credentials()
|
541
|
+
|
542
|
+
async def get_project_id(self) -> str:
|
543
|
+
sync_hook = await self.get_sync_hook()
|
544
|
+
return sync_hook.project_id
|
545
|
+
|
546
|
+
async def get_location(self) -> str:
|
547
|
+
sync_hook = await self.get_sync_hook()
|
548
|
+
return sync_hook.location
|
549
|
+
|
550
|
+
async def get_pipeline_service_client(
|
551
|
+
self,
|
552
|
+
region: str | None = None,
|
553
|
+
) -> PipelineServiceAsyncClient:
|
554
|
+
if region and region != "global":
|
555
|
+
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
|
556
|
+
else:
|
557
|
+
client_options = ClientOptions()
|
558
|
+
return PipelineServiceAsyncClient(
|
559
|
+
credentials=await self.get_credentials(),
|
560
|
+
client_info=CLIENT_INFO,
|
561
|
+
client_options=client_options,
|
562
|
+
)
|
563
|
+
|
564
|
+
async def get_pipeline_job(
|
565
|
+
self,
|
566
|
+
project_id: str,
|
567
|
+
location: str,
|
568
|
+
job_id: str,
|
569
|
+
retry: AsyncRetry | _MethodDefault = DEFAULT,
|
570
|
+
timeout: float | _MethodDefault | None = DEFAULT,
|
571
|
+
metadata: Sequence[tuple[str, str]] = (),
|
572
|
+
) -> types.PipelineJob:
|
573
|
+
"""
|
574
|
+
Get a PipelineJob proto message from PipelineServiceAsyncClient.
|
575
|
+
|
576
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
577
|
+
:param location: Required. The ID of the Google Cloud region that the service belongs to.
|
578
|
+
:param job_id: Required. The ID of the PipelineJob resource.
|
579
|
+
:param retry: Designation of what errors, if any, should be retried.
|
580
|
+
:param timeout: The timeout for this request.
|
581
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
582
|
+
"""
|
583
|
+
client = await self.get_pipeline_service_client(region=location)
|
584
|
+
pipeline_job_name = client.pipeline_job_path(
|
585
|
+
project=project_id,
|
586
|
+
location=location,
|
587
|
+
pipeline_job=job_id,
|
588
|
+
)
|
589
|
+
pipeline_job: types.PipelineJob = await client.get_pipeline_job(
|
590
|
+
request={"name": pipeline_job_name},
|
591
|
+
retry=retry,
|
592
|
+
timeout=timeout,
|
593
|
+
metadata=metadata,
|
594
|
+
)
|
595
|
+
return pipeline_job
|
596
|
+
|
597
|
+
async def wait_for_pipeline_job(
|
598
|
+
self,
|
599
|
+
project_id: str,
|
600
|
+
location: str,
|
601
|
+
job_id: str,
|
602
|
+
retry: AsyncRetry | _MethodDefault = DEFAULT,
|
603
|
+
timeout: float | None = None,
|
604
|
+
metadata: Sequence[tuple[str, str]] = (),
|
605
|
+
poll_interval: int = 10,
|
606
|
+
) -> types.PipelineJob:
|
607
|
+
"""Wait until the pipeline job is in a complete state and return it."""
|
608
|
+
while True:
|
609
|
+
try:
|
610
|
+
self.log.info("Requesting a pipeline job with id %s", job_id)
|
611
|
+
job: types.PipelineJob = await self.get_pipeline_job(
|
612
|
+
project_id=project_id,
|
613
|
+
location=location,
|
614
|
+
job_id=job_id,
|
615
|
+
retry=retry,
|
616
|
+
timeout=timeout,
|
617
|
+
metadata=metadata,
|
618
|
+
)
|
619
|
+
except Exception as ex:
|
620
|
+
self.log.exception("Exception occurred while requesting pipeline job %s", job_id)
|
621
|
+
raise AirflowException(ex)
|
622
|
+
|
623
|
+
self.log.info("Status of the pipeline job %s is %s", job.name, job.state.name)
|
624
|
+
if job.state in self.PIPELINE_COMPLETE_STATES:
|
625
|
+
return job
|
626
|
+
|
627
|
+
self.log.info("Sleeping for %s seconds.", poll_interval)
|
628
|
+
await asyncio.sleep(poll_interval)
|
@@ -38,6 +38,10 @@ KUBERNETES_JOB_LINK = (
|
|
38
38
|
KUBERNETES_BASE_LINK
|
39
39
|
+ "/job/{location}/{cluster_name}/{namespace}/{job_name}/details?project={project_id}"
|
40
40
|
)
|
41
|
+
KUBERNETES_WORKLOADS_LINK = (
|
42
|
+
KUBERNETES_BASE_LINK + '/workload/overview?project={project_id}&pageState=("savedViews":'
|
43
|
+
'("c":%5B"gke%2F{location}%2F{cluster_name}"%5D,"n":%5B"{namespace}"%5D))'
|
44
|
+
)
|
41
45
|
|
42
46
|
|
43
47
|
class KubernetesEngineClusterLink(BaseGoogleLink):
|
@@ -111,3 +115,27 @@ class KubernetesEngineJobLink(BaseGoogleLink):
|
|
111
115
|
"project_id": task_instance.project_id,
|
112
116
|
},
|
113
117
|
)
|
118
|
+
|
119
|
+
|
120
|
+
class KubernetesEngineWorkloadsLink(BaseGoogleLink):
|
121
|
+
"""Helper class for constructing Kubernetes Engine Workloads Link."""
|
122
|
+
|
123
|
+
name = "Kubernetes Workloads"
|
124
|
+
key = "kubernetes_workloads_conf"
|
125
|
+
format_str = KUBERNETES_WORKLOADS_LINK
|
126
|
+
|
127
|
+
@staticmethod
|
128
|
+
def persist(
|
129
|
+
context: Context,
|
130
|
+
task_instance,
|
131
|
+
):
|
132
|
+
task_instance.xcom_push(
|
133
|
+
context=context,
|
134
|
+
key=KubernetesEngineWorkloadsLink.key,
|
135
|
+
value={
|
136
|
+
"location": task_instance.location,
|
137
|
+
"cluster_name": task_instance.cluster_name,
|
138
|
+
"namespace": task_instance.namespace,
|
139
|
+
"project_id": task_instance.project_id,
|
140
|
+
},
|
141
|
+
)
|
@@ -15,9 +15,11 @@
|
|
15
15
|
# specific language governing permissions and limitations
|
16
16
|
# under the License.
|
17
17
|
"""Handler that integrates with Stackdriver."""
|
18
|
+
|
18
19
|
from __future__ import annotations
|
19
20
|
|
20
21
|
import logging
|
22
|
+
import warnings
|
21
23
|
from functools import cached_property
|
22
24
|
from typing import TYPE_CHECKING, Collection
|
23
25
|
from urllib.parse import urlencode
|
@@ -28,9 +30,11 @@ from google.cloud.logging.handlers.transports import BackgroundThreadTransport,
|
|
28
30
|
from google.cloud.logging_v2.services.logging_service_v2 import LoggingServiceV2Client
|
29
31
|
from google.cloud.logging_v2.types import ListLogEntriesRequest, ListLogEntriesResponse
|
30
32
|
|
33
|
+
from airflow.exceptions import RemovedInAirflow3Warning
|
31
34
|
from airflow.providers.google.cloud.utils.credentials_provider import get_credentials_and_project_id
|
32
35
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
33
36
|
from airflow.utils.log.trigger_handler import ctx_indiv_trigger
|
37
|
+
from airflow.utils.types import NOTSET, ArgNotSet
|
34
38
|
|
35
39
|
if TYPE_CHECKING:
|
36
40
|
from google.auth.credentials import Credentials
|
@@ -91,15 +95,25 @@ class StackdriverTaskHandler(logging.Handler):
|
|
91
95
|
self,
|
92
96
|
gcp_key_path: str | None = None,
|
93
97
|
scopes: Collection[str] | None = _DEFAULT_SCOPESS,
|
94
|
-
name: str =
|
98
|
+
name: str | ArgNotSet = NOTSET,
|
95
99
|
transport: type[Transport] = BackgroundThreadTransport,
|
96
100
|
resource: Resource = _GLOBAL_RESOURCE,
|
97
101
|
labels: dict[str, str] | None = None,
|
102
|
+
gcp_log_name: str = DEFAULT_LOGGER_NAME,
|
98
103
|
):
|
104
|
+
if name is not NOTSET:
|
105
|
+
warnings.warn(
|
106
|
+
"Param `name` is deprecated and will be removed in a future release. "
|
107
|
+
"Please use `gcp_log_name` instead. ",
|
108
|
+
RemovedInAirflow3Warning,
|
109
|
+
stacklevel=2,
|
110
|
+
)
|
111
|
+
gcp_log_name = str(name)
|
112
|
+
|
99
113
|
super().__init__()
|
100
114
|
self.gcp_key_path: str | None = gcp_key_path
|
101
115
|
self.scopes: Collection[str] | None = scopes
|
102
|
-
self.
|
116
|
+
self.gcp_log_name: str = gcp_log_name
|
103
117
|
self.transport_type: type[Transport] = transport
|
104
118
|
self.resource: Resource = resource
|
105
119
|
self.labels: dict[str, str] | None = labels
|
@@ -139,7 +153,7 @@ class StackdriverTaskHandler(logging.Handler):
|
|
139
153
|
"""Object responsible for sending data to Stackdriver."""
|
140
154
|
# The Transport object is badly defined (no init) but in the docs client/name as constructor
|
141
155
|
# arguments are a requirement for any class that derives from Transport class, hence ignore:
|
142
|
-
return self.transport_type(self._client, self.
|
156
|
+
return self.transport_type(self._client, self.gcp_log_name) # type: ignore[call-arg]
|
143
157
|
|
144
158
|
def _get_labels(self, task_instance=None):
|
145
159
|
if task_instance:
|
@@ -244,7 +258,7 @@ class StackdriverTaskHandler(logging.Handler):
|
|
244
258
|
_, project = self._credentials_and_project
|
245
259
|
log_filters = [
|
246
260
|
f"resource.type={escale_label_value(self.resource.type)}",
|
247
|
-
f'logName="projects/{project}/logs/{self.
|
261
|
+
f'logName="projects/{project}/logs/{self.gcp_log_name}"',
|
248
262
|
]
|
249
263
|
|
250
264
|
for key, value in self.resource.labels.items():
|