apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/ads/hooks/ads.py +39 -6
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/bundles/__init__.py +16 -0
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/alloy_db.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +176 -293
- airflow/providers/google/cloud/hooks/cloud_batch.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_build.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_composer.py +288 -15
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_run.py +18 -10
- airflow/providers/google/cloud/hooks/cloud_sql.py +102 -23
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +29 -7
- airflow/providers/google/cloud/hooks/compute.py +1 -1
- airflow/providers/google/cloud/hooks/compute_ssh.py +6 -2
- airflow/providers/google/cloud/hooks/datacatalog.py +10 -1
- airflow/providers/google/cloud/hooks/dataflow.py +72 -95
- airflow/providers/google/cloud/hooks/dataform.py +1 -1
- airflow/providers/google/cloud/hooks/datafusion.py +21 -19
- airflow/providers/google/cloud/hooks/dataplex.py +2 -2
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dataproc.py +73 -72
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +1 -1
- airflow/providers/google/cloud/hooks/dlp.py +1 -1
- airflow/providers/google/cloud/hooks/functions.py +1 -1
- airflow/providers/google/cloud/hooks/gcs.py +112 -15
- airflow/providers/google/cloud/hooks/gdm.py +1 -1
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +3 -3
- airflow/providers/google/cloud/hooks/looker.py +6 -2
- airflow/providers/google/cloud/hooks/managed_kafka.py +1 -1
- airflow/providers/google/cloud/hooks/mlengine.py +4 -3
- airflow/providers/google/cloud/hooks/pubsub.py +3 -0
- airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
- airflow/providers/google/cloud/hooks/spanner.py +74 -9
- airflow/providers/google/cloud/hooks/stackdriver.py +11 -9
- airflow/providers/google/cloud/hooks/tasks.py +1 -1
- airflow/providers/google/cloud/hooks/translate.py +2 -2
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +2 -210
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +3 -3
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +28 -2
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +308 -8
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/vision.py +3 -3
- airflow/providers/google/cloud/hooks/workflows.py +1 -1
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -13
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -96
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -95
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
- airflow/providers/google/cloud/links/managed_kafka.py +0 -70
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +58 -22
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +102 -1
- airflow/providers/google/cloud/openlineage/mixins.py +10 -8
- airflow/providers/google/cloud/openlineage/utils.py +15 -1
- airflow/providers/google/cloud/operators/alloy_db.py +71 -56
- airflow/providers/google/cloud/operators/bigquery.py +73 -636
- airflow/providers/google/cloud/operators/bigquery_dts.py +4 -6
- airflow/providers/google/cloud/operators/bigtable.py +37 -8
- airflow/providers/google/cloud/operators/cloud_base.py +21 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +3 -3
- airflow/providers/google/cloud/operators/cloud_build.py +76 -33
- airflow/providers/google/cloud/operators/cloud_composer.py +129 -41
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
- airflow/providers/google/cloud/operators/cloud_run.py +24 -6
- airflow/providers/google/cloud/operators/cloud_sql.py +8 -17
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +93 -12
- airflow/providers/google/cloud/operators/compute.py +9 -41
- airflow/providers/google/cloud/operators/datacatalog.py +157 -21
- airflow/providers/google/cloud/operators/dataflow.py +40 -16
- airflow/providers/google/cloud/operators/dataform.py +15 -5
- airflow/providers/google/cloud/operators/datafusion.py +42 -21
- airflow/providers/google/cloud/operators/dataplex.py +194 -110
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +80 -36
- airflow/providers/google/cloud/operators/dataproc_metastore.py +97 -89
- airflow/providers/google/cloud/operators/datastore.py +23 -7
- airflow/providers/google/cloud/operators/dlp.py +6 -29
- airflow/providers/google/cloud/operators/functions.py +17 -8
- airflow/providers/google/cloud/operators/gcs.py +12 -9
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +62 -100
- airflow/providers/google/cloud/operators/looker.py +2 -2
- airflow/providers/google/cloud/operators/managed_kafka.py +108 -53
- airflow/providers/google/cloud/operators/natural_language.py +1 -1
- airflow/providers/google/cloud/operators/pubsub.py +68 -15
- airflow/providers/google/cloud/operators/spanner.py +26 -13
- airflow/providers/google/cloud/operators/speech_to_text.py +2 -3
- airflow/providers/google/cloud/operators/stackdriver.py +1 -9
- airflow/providers/google/cloud/operators/tasks.py +1 -12
- airflow/providers/google/cloud/operators/text_to_speech.py +2 -3
- airflow/providers/google/cloud/operators/translate.py +41 -17
- airflow/providers/google/cloud/operators/translate_speech.py +2 -3
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +30 -10
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +55 -27
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -115
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +12 -10
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +31 -8
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
- airflow/providers/google/cloud/operators/vision.py +2 -2
- airflow/providers/google/cloud/operators/workflows.py +18 -15
- airflow/providers/google/cloud/secrets/secret_manager.py +3 -2
- airflow/providers/google/cloud/sensors/bigquery.py +3 -3
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -3
- airflow/providers/google/cloud/sensors/bigtable.py +11 -4
- airflow/providers/google/cloud/sensors/cloud_composer.py +533 -30
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -3
- airflow/providers/google/cloud/sensors/dataflow.py +26 -10
- airflow/providers/google/cloud/sensors/dataform.py +2 -3
- airflow/providers/google/cloud/sensors/datafusion.py +4 -5
- airflow/providers/google/cloud/sensors/dataplex.py +2 -3
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +2 -3
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -3
- airflow/providers/google/cloud/sensors/gcs.py +4 -5
- airflow/providers/google/cloud/sensors/looker.py +2 -3
- airflow/providers/google/cloud/sensors/pubsub.py +4 -5
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -3
- airflow/providers/google/cloud/sensors/workflows.py +2 -3
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +4 -3
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +10 -5
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -4
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +21 -13
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +4 -3
- airflow/providers/google/cloud/transfers/gcs_to_local.py +6 -4
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +11 -5
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +13 -7
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +14 -5
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +76 -35
- airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
- airflow/providers/google/cloud/triggers/cloud_run.py +3 -3
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +92 -2
- airflow/providers/google/cloud/triggers/dataflow.py +122 -0
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +14 -2
- airflow/providers/google/cloud/triggers/dataproc.py +123 -53
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +47 -28
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +15 -19
- airflow/providers/google/cloud/triggers/vertex_ai.py +1 -1
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +2 -2
- airflow/providers/google/cloud/utils/field_sanitizer.py +1 -1
- airflow/providers/google/cloud/utils/field_validator.py +2 -3
- airflow/providers/google/common/auth_backend/google_openid.py +4 -4
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +27 -9
- airflow/providers/google/common/hooks/operation_helpers.py +1 -1
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +3 -4
- airflow/providers/google/firebase/hooks/firestore.py +1 -1
- airflow/providers/google/firebase/operators/firestore.py +3 -3
- airflow/providers/google/get_provider_info.py +56 -52
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +27 -2
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/campaign_manager.py +1 -1
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +2 -3
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +6 -6
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +3 -64
- airflow/providers/google/suite/hooks/calendar.py +2 -2
- airflow/providers/google/suite/hooks/sheets.py +16 -2
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +3 -3
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/METADATA +90 -46
- apache_airflow_providers_google-19.3.0.dist-info/RECORD +331 -0
- apache_airflow_providers_google-19.3.0.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/hooks/automl.py +0 -673
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1362
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -112
- apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.3.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -21,14 +21,15 @@ from __future__ import annotations
|
|
|
21
21
|
|
|
22
22
|
import time
|
|
23
23
|
from datetime import timedelta
|
|
24
|
-
from typing import TYPE_CHECKING
|
|
24
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
25
25
|
|
|
26
26
|
import vertexai
|
|
27
|
+
from google.cloud import aiplatform
|
|
27
28
|
from vertexai.generative_models import GenerativeModel
|
|
28
|
-
from vertexai.language_models import TextEmbeddingModel
|
|
29
|
+
from vertexai.language_models import TextEmbeddingModel
|
|
30
|
+
from vertexai.preview import generative_models as preview_generative_model
|
|
29
31
|
from vertexai.preview.caching import CachedContent
|
|
30
32
|
from vertexai.preview.evaluation import EvalResult, EvalTask
|
|
31
|
-
from vertexai.preview.generative_models import GenerativeModel as preview_generative_model
|
|
32
33
|
from vertexai.preview.tuning import sft
|
|
33
34
|
|
|
34
35
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
@@ -36,23 +37,12 @@ from airflow.providers.google.common.deprecated import deprecated
|
|
|
36
37
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
37
38
|
|
|
38
39
|
if TYPE_CHECKING:
|
|
39
|
-
from google.cloud.aiplatform_v1 import types as types_v1
|
|
40
40
|
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class GenerativeModelHook(GoogleBaseHook):
|
|
44
44
|
"""Hook for Google Cloud Vertex AI Generative Model APIs."""
|
|
45
45
|
|
|
46
|
-
@deprecated(
|
|
47
|
-
planned_removal_date="April 09, 2025",
|
|
48
|
-
use_instead="GenerativeModelHook.get_generative_model",
|
|
49
|
-
category=AirflowProviderDeprecationWarning,
|
|
50
|
-
)
|
|
51
|
-
def get_text_generation_model(self, pretrained_model: str):
|
|
52
|
-
"""Return a Model Garden Model object based on Text Generation."""
|
|
53
|
-
model = TextGenerationModel.from_pretrained(pretrained_model)
|
|
54
|
-
return model
|
|
55
|
-
|
|
56
46
|
def get_text_embedding_model(self, pretrained_model: str):
|
|
57
47
|
"""Return a Model Garden Model object based on Text Embedding."""
|
|
58
48
|
model = TextEmbeddingModel.from_pretrained(pretrained_model)
|
|
@@ -61,7 +51,7 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
61
51
|
def get_generative_model(
|
|
62
52
|
self,
|
|
63
53
|
pretrained_model: str,
|
|
64
|
-
system_instruction:
|
|
54
|
+
system_instruction: Any | None = None,
|
|
65
55
|
generation_config: dict | None = None,
|
|
66
56
|
safety_settings: dict | None = None,
|
|
67
57
|
tools: list | None = None,
|
|
@@ -93,66 +83,18 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
93
83
|
def get_cached_context_model(
|
|
94
84
|
self,
|
|
95
85
|
cached_content_name: str,
|
|
96
|
-
) ->
|
|
86
|
+
) -> Any:
|
|
97
87
|
"""Return a Generative Model with Cached Context."""
|
|
98
88
|
cached_content = CachedContent(cached_content_name=cached_content_name)
|
|
99
89
|
|
|
100
|
-
cached_context_model = preview_generative_model.from_cached_content(cached_content)
|
|
90
|
+
cached_context_model = preview_generative_model.GenerativeModel.from_cached_content(cached_content)
|
|
101
91
|
return cached_context_model
|
|
102
92
|
|
|
103
93
|
@deprecated(
|
|
104
|
-
planned_removal_date="
|
|
105
|
-
use_instead="
|
|
94
|
+
planned_removal_date="January 3, 2026",
|
|
95
|
+
use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.embed_content",
|
|
106
96
|
category=AirflowProviderDeprecationWarning,
|
|
107
97
|
)
|
|
108
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
|
109
|
-
def text_generation_model_predict(
|
|
110
|
-
self,
|
|
111
|
-
prompt: str,
|
|
112
|
-
pretrained_model: str,
|
|
113
|
-
temperature: float,
|
|
114
|
-
max_output_tokens: int,
|
|
115
|
-
top_p: float,
|
|
116
|
-
top_k: int,
|
|
117
|
-
location: str,
|
|
118
|
-
project_id: str = PROVIDE_PROJECT_ID,
|
|
119
|
-
) -> str:
|
|
120
|
-
"""
|
|
121
|
-
Use the Vertex AI PaLM API to generate natural language text.
|
|
122
|
-
|
|
123
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
124
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
125
|
-
:param prompt: Required. Inputs or queries that a user or a program gives
|
|
126
|
-
to the Vertex AI PaLM API, in order to elicit a specific response.
|
|
127
|
-
:param pretrained_model: A pre-trained model optimized for performing natural
|
|
128
|
-
language tasks such as classification, summarization, extraction, content
|
|
129
|
-
creation, and ideation.
|
|
130
|
-
:param temperature: Temperature controls the degree of randomness in token
|
|
131
|
-
selection.
|
|
132
|
-
:param max_output_tokens: Token limit determines the maximum amount of text
|
|
133
|
-
output.
|
|
134
|
-
:param top_p: Tokens are selected from most probable to least until the sum
|
|
135
|
-
of their probabilities equals the top_p value. Defaults to 0.8.
|
|
136
|
-
:param top_k: A top_k of 1 means the selected token is the most probable
|
|
137
|
-
among all tokens.
|
|
138
|
-
"""
|
|
139
|
-
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
|
140
|
-
|
|
141
|
-
parameters = {
|
|
142
|
-
"temperature": temperature,
|
|
143
|
-
"max_output_tokens": max_output_tokens,
|
|
144
|
-
"top_p": top_p,
|
|
145
|
-
"top_k": top_k,
|
|
146
|
-
}
|
|
147
|
-
|
|
148
|
-
model = self.get_text_generation_model(pretrained_model)
|
|
149
|
-
|
|
150
|
-
response = model.predict(
|
|
151
|
-
prompt=prompt,
|
|
152
|
-
**parameters,
|
|
153
|
-
)
|
|
154
|
-
return response.text
|
|
155
|
-
|
|
156
98
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
157
99
|
def text_embedding_model_get_embeddings(
|
|
158
100
|
self,
|
|
@@ -177,16 +119,21 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
177
119
|
|
|
178
120
|
return response.values
|
|
179
121
|
|
|
122
|
+
@deprecated(
|
|
123
|
+
planned_removal_date="January 3, 2026",
|
|
124
|
+
use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.generate_content",
|
|
125
|
+
category=AirflowProviderDeprecationWarning,
|
|
126
|
+
)
|
|
180
127
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
181
128
|
def generative_model_generate_content(
|
|
182
129
|
self,
|
|
183
130
|
contents: list,
|
|
184
131
|
location: str,
|
|
132
|
+
pretrained_model: str,
|
|
185
133
|
tools: list | None = None,
|
|
186
134
|
generation_config: dict | None = None,
|
|
187
135
|
safety_settings: dict | None = None,
|
|
188
136
|
system_instruction: str | None = None,
|
|
189
|
-
pretrained_model: str = "gemini-pro",
|
|
190
137
|
project_id: str = PROVIDE_PROJECT_ID,
|
|
191
138
|
) -> str:
|
|
192
139
|
"""
|
|
@@ -200,7 +147,7 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
200
147
|
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
|
201
148
|
:param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
|
|
202
149
|
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
|
|
203
|
-
:param pretrained_model:
|
|
150
|
+
:param pretrained_model: Required. Model,
|
|
204
151
|
supporting prompts with text-only input, including natural language
|
|
205
152
|
tasks, multi-turn text and code chat, and code generation. It can
|
|
206
153
|
output text and code.
|
|
@@ -219,6 +166,11 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
219
166
|
|
|
220
167
|
return response.text
|
|
221
168
|
|
|
169
|
+
@deprecated(
|
|
170
|
+
planned_removal_date="January 3, 2026",
|
|
171
|
+
use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.supervised_fine_tuning_train",
|
|
172
|
+
category=AirflowProviderDeprecationWarning,
|
|
173
|
+
)
|
|
222
174
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
223
175
|
def supervised_fine_tuning_train(
|
|
224
176
|
self,
|
|
@@ -228,10 +180,10 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
228
180
|
tuned_model_display_name: str | None = None,
|
|
229
181
|
validation_dataset: str | None = None,
|
|
230
182
|
epochs: int | None = None,
|
|
231
|
-
adapter_size:
|
|
183
|
+
adapter_size: Literal[1, 4, 8, 16] | None = None,
|
|
232
184
|
learning_rate_multiplier: float | None = None,
|
|
233
185
|
project_id: str = PROVIDE_PROJECT_ID,
|
|
234
|
-
) ->
|
|
186
|
+
) -> Any:
|
|
235
187
|
"""
|
|
236
188
|
Use the Supervised Fine Tuning API to create a tuning job.
|
|
237
189
|
|
|
@@ -272,12 +224,17 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
272
224
|
|
|
273
225
|
return sft_tuning_job
|
|
274
226
|
|
|
227
|
+
@deprecated(
|
|
228
|
+
planned_removal_date="January 3, 2026",
|
|
229
|
+
use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.count_tokens",
|
|
230
|
+
category=AirflowProviderDeprecationWarning,
|
|
231
|
+
)
|
|
275
232
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
276
233
|
def count_tokens(
|
|
277
234
|
self,
|
|
278
235
|
contents: list,
|
|
279
236
|
location: str,
|
|
280
|
-
pretrained_model: str
|
|
237
|
+
pretrained_model: str,
|
|
281
238
|
project_id: str = PROVIDE_PROJECT_ID,
|
|
282
239
|
) -> types_v1beta1.CountTokensResponse:
|
|
283
240
|
"""
|
|
@@ -287,7 +244,7 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
287
244
|
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
288
245
|
:param contents: Required. The multi-part content of a message that a user or a program
|
|
289
246
|
gives to the generative model, in order to elicit a specific response.
|
|
290
|
-
:param pretrained_model:
|
|
247
|
+
:param pretrained_model: Required. Model,
|
|
291
248
|
supporting prompts with text-only input, including natural language
|
|
292
249
|
tasks, multi-turn text and code chat, and code generation. It can
|
|
293
250
|
output text and code.
|
|
@@ -359,13 +316,18 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
359
316
|
|
|
360
317
|
return eval_result
|
|
361
318
|
|
|
319
|
+
@deprecated(
|
|
320
|
+
planned_removal_date="January 3, 2026",
|
|
321
|
+
use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.create_cached_content",
|
|
322
|
+
category=AirflowProviderDeprecationWarning,
|
|
323
|
+
)
|
|
362
324
|
def create_cached_content(
|
|
363
325
|
self,
|
|
364
326
|
model_name: str,
|
|
365
327
|
location: str,
|
|
366
328
|
ttl_hours: float = 1,
|
|
367
|
-
system_instruction:
|
|
368
|
-
contents: list | None = None,
|
|
329
|
+
system_instruction: Any | None = None,
|
|
330
|
+
contents: list[Any] | None = None,
|
|
369
331
|
display_name: str | None = None,
|
|
370
332
|
project_id: str = PROVIDE_PROJECT_ID,
|
|
371
333
|
) -> str:
|
|
@@ -393,6 +355,11 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
393
355
|
|
|
394
356
|
return response.name
|
|
395
357
|
|
|
358
|
+
@deprecated(
|
|
359
|
+
planned_removal_date="January 3, 2026",
|
|
360
|
+
use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.generate_content",
|
|
361
|
+
category=AirflowProviderDeprecationWarning,
|
|
362
|
+
)
|
|
396
363
|
def generate_from_cached_content(
|
|
397
364
|
self,
|
|
398
365
|
location: str,
|
|
@@ -413,6 +380,9 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
413
380
|
:param generation_config: Optional. Generation configuration settings.
|
|
414
381
|
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
|
415
382
|
"""
|
|
383
|
+
# During run of the system test it was found out that names from xcom, e.g. 3402922389 can be
|
|
384
|
+
# treated as int and throw an error TypeError: expected string or bytes-like object, got 'int'
|
|
385
|
+
cached_content_name = str(cached_content_name)
|
|
416
386
|
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
|
417
387
|
|
|
418
388
|
cached_context_model = self.get_cached_context_model(cached_content_name=cached_content_name)
|
|
@@ -424,3 +394,37 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
|
424
394
|
)
|
|
425
395
|
|
|
426
396
|
return response.text
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@deprecated(
|
|
400
|
+
planned_removal_date="January 3, 2026",
|
|
401
|
+
use_instead="airflow.providers.google.cloud.hooks.vertex_ai.experiment_service.ExperimentRunHook",
|
|
402
|
+
category=AirflowProviderDeprecationWarning,
|
|
403
|
+
)
|
|
404
|
+
class ExperimentRunHook(GoogleBaseHook):
|
|
405
|
+
"""Use the Vertex AI SDK for Python to create and manage your experiment runs."""
|
|
406
|
+
|
|
407
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
408
|
+
def delete_experiment_run(
|
|
409
|
+
self,
|
|
410
|
+
experiment_run_name: str,
|
|
411
|
+
experiment_name: str,
|
|
412
|
+
location: str,
|
|
413
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
414
|
+
delete_backing_tensorboard_run: bool = False,
|
|
415
|
+
) -> None:
|
|
416
|
+
"""
|
|
417
|
+
Delete experiment run from the experiment.
|
|
418
|
+
|
|
419
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
420
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
421
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
|
422
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
|
423
|
+
:param delete_backing_tensorboard_run: Whether to delete the backing Vertex AI TensorBoard run
|
|
424
|
+
that stores time series metrics for this run.
|
|
425
|
+
"""
|
|
426
|
+
self.log.info("Next experiment run will be deleted: %s", experiment_run_name)
|
|
427
|
+
experiment_run = aiplatform.ExperimentRun(
|
|
428
|
+
run_name=experiment_run_name, experiment=experiment_name, project=project_id, location=location
|
|
429
|
+
)
|
|
430
|
+
experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run)
|
|
@@ -34,7 +34,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
|
34
34
|
from google.cloud.aiplatform import CustomJob, HyperparameterTuningJob, gapic, hyperparameter_tuning
|
|
35
35
|
from google.cloud.aiplatform_v1 import JobServiceAsyncClient, JobServiceClient, JobState, types
|
|
36
36
|
|
|
37
|
-
from airflow.
|
|
37
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
38
38
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
39
39
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
|
40
40
|
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
|
@@ -27,7 +27,7 @@ from google.api_core.client_options import ClientOptions
|
|
|
27
27
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
28
28
|
from google.cloud.aiplatform_v1 import ModelServiceClient
|
|
29
29
|
|
|
30
|
-
from airflow.
|
|
30
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
31
31
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
32
32
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
33
33
|
|
|
@@ -39,7 +39,7 @@ from google.cloud.aiplatform_v1 import (
|
|
|
39
39
|
types,
|
|
40
40
|
)
|
|
41
41
|
|
|
42
|
-
from airflow.
|
|
42
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
43
43
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
44
44
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
|
45
45
|
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
|
4
|
+
# distributed with this work for additional information
|
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
|
7
|
+
# "License"); you may not use this file except in compliance
|
|
8
|
+
# with the License. You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
|
13
|
+
# software distributed under the License is distributed on an
|
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
15
|
+
# KIND, either express or implied. See the License for the
|
|
16
|
+
# specific language governing permissions and limitations
|
|
17
|
+
# under the License.
|
|
18
|
+
"""This module contains a Google Cloud Vertex AI hook."""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import dataclasses
|
|
23
|
+
from collections.abc import MutableMapping
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
import vertex_ray
|
|
27
|
+
from google.cloud import aiplatform
|
|
28
|
+
from google.cloud.aiplatform.vertex_ray.util import resources
|
|
29
|
+
from google.cloud.aiplatform_v1 import (
|
|
30
|
+
PersistentResourceServiceClient,
|
|
31
|
+
)
|
|
32
|
+
from proto.marshal.collections.repeated import Repeated
|
|
33
|
+
|
|
34
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class RayHook(GoogleBaseHook):
|
|
38
|
+
"""Hook for Google Cloud Vertex AI Ray APIs."""
|
|
39
|
+
|
|
40
|
+
def extract_cluster_id(self, cluster_path) -> str:
|
|
41
|
+
"""Extract cluster_id from cluster_path."""
|
|
42
|
+
cluster_id = PersistentResourceServiceClient.parse_persistent_resource_path(cluster_path)[
|
|
43
|
+
"persistent_resource"
|
|
44
|
+
]
|
|
45
|
+
return cluster_id
|
|
46
|
+
|
|
47
|
+
def serialize_cluster_obj(self, cluster_obj: resources.Cluster) -> dict:
|
|
48
|
+
"""Serialize Cluster dataclass to dict."""
|
|
49
|
+
|
|
50
|
+
def __encode_value(value: Any) -> Any:
|
|
51
|
+
if isinstance(value, (list, Repeated)):
|
|
52
|
+
return [__encode_value(nested_value) for nested_value in value]
|
|
53
|
+
if not isinstance(value, dict) and isinstance(value, MutableMapping):
|
|
54
|
+
return {key: __encode_value(nested_value) for key, nested_value in dict(value).items()}
|
|
55
|
+
if dataclasses.is_dataclass(value) and not isinstance(value, type):
|
|
56
|
+
return dataclasses.asdict(value)
|
|
57
|
+
return value
|
|
58
|
+
|
|
59
|
+
return {
|
|
60
|
+
field.name: __encode_value(getattr(cluster_obj, field.name))
|
|
61
|
+
for field in dataclasses.fields(cluster_obj)
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
65
|
+
def create_ray_cluster(
|
|
66
|
+
self,
|
|
67
|
+
project_id: str,
|
|
68
|
+
location: str,
|
|
69
|
+
head_node_type: resources.Resources = resources.Resources(),
|
|
70
|
+
python_version: str = "3.10",
|
|
71
|
+
ray_version: str = "2.33",
|
|
72
|
+
network: str | None = None,
|
|
73
|
+
service_account: str | None = None,
|
|
74
|
+
cluster_name: str | None = None,
|
|
75
|
+
worker_node_types: list[resources.Resources] | None = None,
|
|
76
|
+
custom_images: resources.NodeImages | None = None,
|
|
77
|
+
enable_metrics_collection: bool = True,
|
|
78
|
+
enable_logging: bool = True,
|
|
79
|
+
psc_interface_config: resources.PscIConfig | None = None,
|
|
80
|
+
reserved_ip_ranges: list[str] | None = None,
|
|
81
|
+
labels: dict[str, str] | None = None,
|
|
82
|
+
) -> str:
|
|
83
|
+
"""
|
|
84
|
+
Create a Ray cluster on the Vertex AI.
|
|
85
|
+
|
|
86
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
87
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
88
|
+
:param head_node_type: The head node resource. Resources.node_count must be 1. If not set, default
|
|
89
|
+
value of Resources() class will be used.
|
|
90
|
+
:param python_version: Python version for the ray cluster.
|
|
91
|
+
:param ray_version: Ray version for the ray cluster. Default is 2.33.0.
|
|
92
|
+
:param network: Virtual private cloud (VPC) network. For Ray Client, VPC peering is required to
|
|
93
|
+
connect to the Ray Cluster managed in the Vertex API service. For Ray Job API, VPC network is not
|
|
94
|
+
required because Ray Cluster connection can be accessed through dashboard address.
|
|
95
|
+
:param service_account: Service account to be used for running Ray programs on the cluster.
|
|
96
|
+
:param cluster_name: This value may be up to 63 characters, and valid characters are `[a-z0-9_-]`.
|
|
97
|
+
The first character cannot be a number or hyphen.
|
|
98
|
+
:param worker_node_types: The list of Resources of the worker nodes. The same Resources object should
|
|
99
|
+
not appear multiple times in the list.
|
|
100
|
+
:param custom_images: The NodeImages which specifies head node and worker nodes images. All the
|
|
101
|
+
workers will share the same image. If each Resource has a specific custom image, use
|
|
102
|
+
`Resources.custom_image` for head/worker_node_type(s). Note that configuring
|
|
103
|
+
`Resources.custom_image` will override `custom_images` here. Allowlist only.
|
|
104
|
+
:param enable_metrics_collection: Enable Ray metrics collection for visualization.
|
|
105
|
+
:param enable_logging: Enable exporting Ray logs to Cloud Logging.
|
|
106
|
+
:param psc_interface_config: PSC-I config.
|
|
107
|
+
:param reserved_ip_ranges: A list of names for the reserved IP ranges under the VPC network that can
|
|
108
|
+
be used for this cluster. If set, we will deploy the cluster within the provided IP ranges.
|
|
109
|
+
Otherwise, the cluster is deployed to any IP ranges under the provided VPC network.
|
|
110
|
+
Example: ["vertex-ai-ip-range"].
|
|
111
|
+
:param labels: The labels with user-defined metadata to organize Ray cluster.
|
|
112
|
+
Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
|
|
113
|
+
lowercase letters, numeric characters, underscores and dashes. International characters are allowed.
|
|
114
|
+
See https://goo.gl/xmQnxf for more information and examples of labels.
|
|
115
|
+
"""
|
|
116
|
+
aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
|
|
117
|
+
cluster_path = vertex_ray.create_ray_cluster(
|
|
118
|
+
head_node_type=head_node_type,
|
|
119
|
+
python_version=python_version,
|
|
120
|
+
ray_version=ray_version,
|
|
121
|
+
network=network,
|
|
122
|
+
service_account=service_account,
|
|
123
|
+
cluster_name=cluster_name,
|
|
124
|
+
worker_node_types=worker_node_types,
|
|
125
|
+
custom_images=custom_images,
|
|
126
|
+
enable_metrics_collection=enable_metrics_collection,
|
|
127
|
+
enable_logging=enable_logging,
|
|
128
|
+
psc_interface_config=psc_interface_config,
|
|
129
|
+
reserved_ip_ranges=reserved_ip_ranges,
|
|
130
|
+
labels=labels,
|
|
131
|
+
)
|
|
132
|
+
return cluster_path
|
|
133
|
+
|
|
134
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
135
|
+
def list_ray_clusters(
|
|
136
|
+
self,
|
|
137
|
+
project_id: str,
|
|
138
|
+
location: str,
|
|
139
|
+
) -> list[resources.Cluster]:
|
|
140
|
+
"""
|
|
141
|
+
List Ray clusters under the currently authenticated project.
|
|
142
|
+
|
|
143
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
144
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
145
|
+
"""
|
|
146
|
+
aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
|
|
147
|
+
ray_clusters = vertex_ray.list_ray_clusters()
|
|
148
|
+
return ray_clusters
|
|
149
|
+
|
|
150
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
151
|
+
def get_ray_cluster(
|
|
152
|
+
self,
|
|
153
|
+
project_id: str,
|
|
154
|
+
location: str,
|
|
155
|
+
cluster_id: str,
|
|
156
|
+
) -> resources.Cluster:
|
|
157
|
+
"""
|
|
158
|
+
Get Ray cluster.
|
|
159
|
+
|
|
160
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
161
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
162
|
+
:param cluster_id: Cluster resource ID.
|
|
163
|
+
"""
|
|
164
|
+
aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
|
|
165
|
+
ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
|
|
166
|
+
project=project_id,
|
|
167
|
+
location=location,
|
|
168
|
+
persistent_resource=cluster_id,
|
|
169
|
+
)
|
|
170
|
+
ray_cluster = vertex_ray.get_ray_cluster(
|
|
171
|
+
cluster_resource_name=ray_cluster_name,
|
|
172
|
+
)
|
|
173
|
+
return ray_cluster
|
|
174
|
+
|
|
175
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
176
|
+
def update_ray_cluster(
|
|
177
|
+
self,
|
|
178
|
+
project_id: str,
|
|
179
|
+
location: str,
|
|
180
|
+
cluster_id: str,
|
|
181
|
+
worker_node_types: list[resources.Resources],
|
|
182
|
+
) -> str:
|
|
183
|
+
"""
|
|
184
|
+
Update Ray cluster (currently support resizing node counts for worker nodes).
|
|
185
|
+
|
|
186
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
187
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
188
|
+
:param cluster_id: Cluster resource ID.
|
|
189
|
+
:param worker_node_types: The list of Resources of the resized worker nodes. The same Resources
|
|
190
|
+
object should not appear multiple times in the list.
|
|
191
|
+
"""
|
|
192
|
+
aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
|
|
193
|
+
ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
|
|
194
|
+
project=project_id,
|
|
195
|
+
location=location,
|
|
196
|
+
persistent_resource=cluster_id,
|
|
197
|
+
)
|
|
198
|
+
updated_ray_cluster_name = vertex_ray.update_ray_cluster(
|
|
199
|
+
cluster_resource_name=ray_cluster_name, worker_node_types=worker_node_types
|
|
200
|
+
)
|
|
201
|
+
return updated_ray_cluster_name
|
|
202
|
+
|
|
203
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
204
|
+
def delete_ray_cluster(
|
|
205
|
+
self,
|
|
206
|
+
project_id: str,
|
|
207
|
+
location: str,
|
|
208
|
+
cluster_id: str,
|
|
209
|
+
) -> None:
|
|
210
|
+
"""
|
|
211
|
+
Delete Ray cluster.
|
|
212
|
+
|
|
213
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
214
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
215
|
+
:param cluster_id: Cluster resource ID.
|
|
216
|
+
"""
|
|
217
|
+
aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
|
|
218
|
+
ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
|
|
219
|
+
project=project_id,
|
|
220
|
+
location=location,
|
|
221
|
+
persistent_resource=cluster_id,
|
|
222
|
+
)
|
|
223
|
+
vertex_ray.delete_ray_cluster(cluster_resource_name=ray_cluster_name)
|
|
@@ -19,10 +19,10 @@
|
|
|
19
19
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
|
-
from collections.abc import Sequence
|
|
22
|
+
from collections.abc import Callable, Sequence
|
|
23
23
|
from copy import deepcopy
|
|
24
24
|
from functools import cached_property
|
|
25
|
-
from typing import TYPE_CHECKING, Any
|
|
25
|
+
from typing import TYPE_CHECKING, Any
|
|
26
26
|
|
|
27
27
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
28
28
|
from google.cloud.vision_v1 import (
|
|
@@ -36,7 +36,7 @@ from google.cloud.vision_v1 import (
|
|
|
36
36
|
)
|
|
37
37
|
from google.protobuf.json_format import MessageToDict
|
|
38
38
|
|
|
39
|
-
from airflow.
|
|
39
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
40
40
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
41
41
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
42
42
|
|
|
@@ -65,7 +65,7 @@ class WorkflowsHook(GoogleBaseHook):
|
|
|
65
65
|
Create a new workflow.
|
|
66
66
|
|
|
67
67
|
If a workflow with the specified name already exists in the
|
|
68
|
-
specified project and location, the long
|
|
68
|
+
specified project and location, the long-running operation will
|
|
69
69
|
return [ALREADY_EXISTS][google.rpc.Code.ALREADY_EXISTS] error.
|
|
70
70
|
|
|
71
71
|
:param workflow: Required. Workflow to be created.
|
|
@@ -19,14 +19,8 @@
|
|
|
19
19
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
|
-
from typing import TYPE_CHECKING
|
|
23
|
-
|
|
24
22
|
from airflow.providers.google.cloud.links.base import BaseGoogleLink
|
|
25
23
|
|
|
26
|
-
if TYPE_CHECKING:
|
|
27
|
-
from airflow.models import BaseOperator
|
|
28
|
-
from airflow.utils.context import Context
|
|
29
|
-
|
|
30
24
|
ALLOY_DB_BASE_LINK = "/alloydb"
|
|
31
25
|
ALLOY_DB_CLUSTER_LINK = (
|
|
32
26
|
ALLOY_DB_BASE_LINK + "/locations/{location_id}/clusters/{cluster_id}?project={project_id}"
|
|
@@ -44,20 +38,6 @@ class AlloyDBClusterLink(BaseGoogleLink):
|
|
|
44
38
|
key = "alloy_db_cluster"
|
|
45
39
|
format_str = ALLOY_DB_CLUSTER_LINK
|
|
46
40
|
|
|
47
|
-
@staticmethod
|
|
48
|
-
def persist(
|
|
49
|
-
context: Context,
|
|
50
|
-
task_instance: BaseOperator,
|
|
51
|
-
location_id: str,
|
|
52
|
-
cluster_id: str,
|
|
53
|
-
project_id: str | None,
|
|
54
|
-
):
|
|
55
|
-
task_instance.xcom_push(
|
|
56
|
-
context,
|
|
57
|
-
key=AlloyDBClusterLink.key,
|
|
58
|
-
value={"location_id": location_id, "cluster_id": cluster_id, "project_id": project_id},
|
|
59
|
-
)
|
|
60
|
-
|
|
61
41
|
|
|
62
42
|
class AlloyDBUsersLink(BaseGoogleLink):
|
|
63
43
|
"""Helper class for constructing AlloyDB users Link."""
|
|
@@ -66,20 +46,6 @@ class AlloyDBUsersLink(BaseGoogleLink):
|
|
|
66
46
|
key = "alloy_db_users"
|
|
67
47
|
format_str = ALLOY_DB_USERS_LINK
|
|
68
48
|
|
|
69
|
-
@staticmethod
|
|
70
|
-
def persist(
|
|
71
|
-
context: Context,
|
|
72
|
-
task_instance: BaseOperator,
|
|
73
|
-
location_id: str,
|
|
74
|
-
cluster_id: str,
|
|
75
|
-
project_id: str | None,
|
|
76
|
-
):
|
|
77
|
-
task_instance.xcom_push(
|
|
78
|
-
context,
|
|
79
|
-
key=AlloyDBUsersLink.key,
|
|
80
|
-
value={"location_id": location_id, "cluster_id": cluster_id, "project_id": project_id},
|
|
81
|
-
)
|
|
82
|
-
|
|
83
49
|
|
|
84
50
|
class AlloyDBBackupsLink(BaseGoogleLink):
|
|
85
51
|
"""Helper class for constructing AlloyDB backups Link."""
|
|
@@ -87,15 +53,3 @@ class AlloyDBBackupsLink(BaseGoogleLink):
|
|
|
87
53
|
name = "AlloyDB Backups"
|
|
88
54
|
key = "alloy_db_backups"
|
|
89
55
|
format_str = ALLOY_DB_BACKUPS_LINK
|
|
90
|
-
|
|
91
|
-
@staticmethod
|
|
92
|
-
def persist(
|
|
93
|
-
context: Context,
|
|
94
|
-
task_instance: BaseOperator,
|
|
95
|
-
project_id: str | None,
|
|
96
|
-
):
|
|
97
|
-
task_instance.xcom_push(
|
|
98
|
-
context,
|
|
99
|
-
key=AlloyDBBackupsLink.key,
|
|
100
|
-
value={"project_id": project_id},
|
|
101
|
-
)
|