apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/cloud/hooks/automl.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +64 -33
- airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
- airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
- airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
- airflow/providers/google/cloud/hooks/dataflow.py +246 -32
- airflow/providers/google/cloud/hooks/dataplex.py +6 -2
- airflow/providers/google/cloud/hooks/dlp.py +14 -14
- airflow/providers/google/cloud/hooks/gcs.py +6 -2
- airflow/providers/google/cloud/hooks/gdm.py +2 -2
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/mlengine.py +8 -4
- airflow/providers/google/cloud/hooks/pubsub.py +1 -1
- airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
- airflow/providers/google/cloud/links/vertex_ai.py +2 -1
- airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
- airflow/providers/google/cloud/operators/automl.py +13 -12
- airflow/providers/google/cloud/operators/bigquery.py +36 -22
- airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
- airflow/providers/google/cloud/operators/bigtable.py +7 -6
- airflow/providers/google/cloud/operators/cloud_build.py +12 -11
- airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
- airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
- airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
- airflow/providers/google/cloud/operators/compute.py +12 -11
- airflow/providers/google/cloud/operators/datacatalog.py +21 -20
- airflow/providers/google/cloud/operators/dataflow.py +59 -42
- airflow/providers/google/cloud/operators/datafusion.py +11 -10
- airflow/providers/google/cloud/operators/datapipeline.py +3 -2
- airflow/providers/google/cloud/operators/dataprep.py +5 -4
- airflow/providers/google/cloud/operators/dataproc.py +19 -16
- airflow/providers/google/cloud/operators/datastore.py +8 -7
- airflow/providers/google/cloud/operators/dlp.py +31 -30
- airflow/providers/google/cloud/operators/functions.py +4 -3
- airflow/providers/google/cloud/operators/gcs.py +66 -41
- airflow/providers/google/cloud/operators/kubernetes_engine.py +232 -12
- airflow/providers/google/cloud/operators/life_sciences.py +2 -1
- airflow/providers/google/cloud/operators/mlengine.py +11 -10
- airflow/providers/google/cloud/operators/pubsub.py +6 -5
- airflow/providers/google/cloud/operators/spanner.py +7 -6
- airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
- airflow/providers/google/cloud/operators/stackdriver.py +11 -10
- airflow/providers/google/cloud/operators/tasks.py +14 -13
- airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
- airflow/providers/google/cloud/operators/translate_speech.py +2 -1
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
- airflow/providers/google/cloud/operators/vision.py +13 -12
- airflow/providers/google/cloud/operators/workflows.py +10 -9
- airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
- airflow/providers/google/cloud/sensors/bigtable.py +2 -1
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
- airflow/providers/google/cloud/sensors/dataflow.py +239 -52
- airflow/providers/google/cloud/sensors/datafusion.py +2 -1
- airflow/providers/google/cloud/sensors/dataproc.py +3 -2
- airflow/providers/google/cloud/sensors/gcs.py +14 -12
- airflow/providers/google/cloud/sensors/tasks.py +2 -1
- airflow/providers/google/cloud/sensors/workflows.py +2 -1
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
- airflow/providers/google/cloud/triggers/bigquery.py +14 -3
- airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
- airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
- airflow/providers/google/cloud/triggers/dataflow.py +504 -4
- airflow/providers/google/cloud/triggers/dataproc.py +110 -26
- airflow/providers/google/cloud/triggers/mlengine.py +2 -1
- airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
- airflow/providers/google/common/hooks/base_google.py +45 -7
- airflow/providers/google/firebase/hooks/firestore.py +2 -2
- airflow/providers/google/firebase/operators/firestore.py +2 -1
- airflow/providers/google/get_provider_info.py +3 -2
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/METADATA +8 -8
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/RECORD +88 -89
- airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,7 +19,9 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
-
|
22
|
+
import warnings
|
23
|
+
from functools import cached_property
|
24
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
23
25
|
|
24
26
|
from deprecated import deprecated
|
25
27
|
from google.api_core.exceptions import NotFound
|
@@ -28,7 +30,8 @@ from google.cloud.aiplatform.models import Model
|
|
28
30
|
from google.cloud.aiplatform_v1.types.dataset import Dataset
|
29
31
|
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
|
30
32
|
|
31
|
-
from airflow.
|
33
|
+
from airflow.configuration import conf
|
34
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
32
35
|
from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
|
33
36
|
from airflow.providers.google.cloud.links.vertex_ai import (
|
34
37
|
VertexAIModelLink,
|
@@ -36,9 +39,19 @@ from airflow.providers.google.cloud.links.vertex_ai import (
|
|
36
39
|
VertexAITrainingPipelinesLink,
|
37
40
|
)
|
38
41
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
42
|
+
from airflow.providers.google.cloud.triggers.vertex_ai import (
|
43
|
+
CustomContainerTrainingJobTrigger,
|
44
|
+
CustomPythonPackageTrainingJobTrigger,
|
45
|
+
CustomTrainingJobTrigger,
|
46
|
+
)
|
39
47
|
|
40
48
|
if TYPE_CHECKING:
|
41
49
|
from google.api_core.retry import Retry
|
50
|
+
from google.cloud.aiplatform import (
|
51
|
+
CustomContainerTrainingJob,
|
52
|
+
CustomPythonPackageTrainingJob,
|
53
|
+
CustomTrainingJob,
|
54
|
+
)
|
42
55
|
|
43
56
|
from airflow.utils.context import Context
|
44
57
|
|
@@ -160,6 +173,13 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
|
|
160
173
|
self.gcp_conn_id = gcp_conn_id
|
161
174
|
self.impersonation_chain = impersonation_chain
|
162
175
|
|
176
|
+
def execute(self, context: Context) -> None:
|
177
|
+
warnings.warn(
|
178
|
+
"The 'sync' parameter is deprecated and will be removed after 01.10.2024.",
|
179
|
+
AirflowProviderDeprecationWarning,
|
180
|
+
stacklevel=2,
|
181
|
+
)
|
182
|
+
|
163
183
|
|
164
184
|
class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
165
185
|
"""Create Custom Container Training job.
|
@@ -421,9 +441,6 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
421
441
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
422
442
|
For more information on configuring your service account please visit:
|
423
443
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
424
|
-
:param sync: Whether to execute the AI Platform job synchronously. If False, this method
|
425
|
-
will be executed in concurrent Future and any downstream object will
|
426
|
-
be immediately returned and synced when the Future has completed.
|
427
444
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
428
445
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
429
446
|
credentials, or chained list of accounts required to get the access_token
|
@@ -433,6 +450,9 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
433
450
|
If set as a sequence, the identities from the list must grant
|
434
451
|
Service Account Token Creator IAM role to the directly preceding identity, with first
|
435
452
|
account from the list granting this role to the originating account (templated).
|
453
|
+
:param deferrable: If True, run the task in the deferrable mode.
|
454
|
+
:param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
|
455
|
+
The default is 60 seconds.
|
436
456
|
"""
|
437
457
|
|
438
458
|
template_fields = (
|
@@ -442,7 +462,10 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
442
462
|
"dataset_id",
|
443
463
|
"impersonation_chain",
|
444
464
|
)
|
445
|
-
operator_extra_links = (
|
465
|
+
operator_extra_links = (
|
466
|
+
VertexAIModelLink(),
|
467
|
+
VertexAITrainingLink(),
|
468
|
+
)
|
446
469
|
|
447
470
|
def __init__(
|
448
471
|
self,
|
@@ -452,6 +475,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
452
475
|
parent_model: str | None = None,
|
453
476
|
impersonation_chain: str | Sequence[str] | None = None,
|
454
477
|
dataset_id: str | None = None,
|
478
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
479
|
+
poll_interval: int = 60,
|
455
480
|
**kwargs,
|
456
481
|
) -> None:
|
457
482
|
super().__init__(
|
@@ -462,12 +487,15 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
462
487
|
**kwargs,
|
463
488
|
)
|
464
489
|
self.command = command
|
490
|
+
self.deferrable = deferrable
|
491
|
+
self.poll_interval = poll_interval
|
465
492
|
|
466
493
|
def execute(self, context: Context):
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
494
|
+
super().execute(context)
|
495
|
+
|
496
|
+
if self.deferrable:
|
497
|
+
self.invoke_defer(context=context)
|
498
|
+
|
471
499
|
model, training_id, custom_job_id = self.hook.create_custom_container_training_job(
|
472
500
|
project_id=self.project_id,
|
473
501
|
region=self.region,
|
@@ -539,6 +567,94 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
539
567
|
if self.hook:
|
540
568
|
self.hook.cancel_job()
|
541
569
|
|
570
|
+
def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
|
571
|
+
if event["status"] == "error":
|
572
|
+
raise AirflowException(event["message"])
|
573
|
+
result = event["job"]
|
574
|
+
model_id = self.hook.extract_model_id_from_training_pipeline(result)
|
575
|
+
custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(result)
|
576
|
+
self.xcom_push(context, key="model_id", value=model_id)
|
577
|
+
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
|
578
|
+
# push custom_job_id to xcom so it could be pulled by other tasks
|
579
|
+
self.xcom_push(context, key="custom_job_id", value=custom_job_id)
|
580
|
+
return result
|
581
|
+
|
582
|
+
def invoke_defer(self, context: Context) -> None:
|
583
|
+
custom_container_training_job_obj: CustomContainerTrainingJob = self.hook.submit_custom_container_training_job(
|
584
|
+
project_id=self.project_id,
|
585
|
+
region=self.region,
|
586
|
+
display_name=self.display_name,
|
587
|
+
command=self.command,
|
588
|
+
container_uri=self.container_uri,
|
589
|
+
model_serving_container_image_uri=self.model_serving_container_image_uri,
|
590
|
+
model_serving_container_predict_route=self.model_serving_container_predict_route,
|
591
|
+
model_serving_container_health_route=self.model_serving_container_health_route,
|
592
|
+
model_serving_container_command=self.model_serving_container_command,
|
593
|
+
model_serving_container_args=self.model_serving_container_args,
|
594
|
+
model_serving_container_environment_variables=self.model_serving_container_environment_variables,
|
595
|
+
model_serving_container_ports=self.model_serving_container_ports,
|
596
|
+
model_description=self.model_description,
|
597
|
+
model_instance_schema_uri=self.model_instance_schema_uri,
|
598
|
+
model_parameters_schema_uri=self.model_parameters_schema_uri,
|
599
|
+
model_prediction_schema_uri=self.model_prediction_schema_uri,
|
600
|
+
parent_model=self.parent_model,
|
601
|
+
is_default_version=self.is_default_version,
|
602
|
+
model_version_aliases=self.model_version_aliases,
|
603
|
+
model_version_description=self.model_version_description,
|
604
|
+
labels=self.labels,
|
605
|
+
training_encryption_spec_key_name=self.training_encryption_spec_key_name,
|
606
|
+
model_encryption_spec_key_name=self.model_encryption_spec_key_name,
|
607
|
+
staging_bucket=self.staging_bucket,
|
608
|
+
# RUN
|
609
|
+
dataset=Dataset(name=self.dataset_id) if self.dataset_id else None,
|
610
|
+
annotation_schema_uri=self.annotation_schema_uri,
|
611
|
+
model_display_name=self.model_display_name,
|
612
|
+
model_labels=self.model_labels,
|
613
|
+
base_output_dir=self.base_output_dir,
|
614
|
+
service_account=self.service_account,
|
615
|
+
network=self.network,
|
616
|
+
bigquery_destination=self.bigquery_destination,
|
617
|
+
args=self.args,
|
618
|
+
environment_variables=self.environment_variables,
|
619
|
+
replica_count=self.replica_count,
|
620
|
+
machine_type=self.machine_type,
|
621
|
+
accelerator_type=self.accelerator_type,
|
622
|
+
accelerator_count=self.accelerator_count,
|
623
|
+
boot_disk_type=self.boot_disk_type,
|
624
|
+
boot_disk_size_gb=self.boot_disk_size_gb,
|
625
|
+
training_fraction_split=self.training_fraction_split,
|
626
|
+
validation_fraction_split=self.validation_fraction_split,
|
627
|
+
test_fraction_split=self.test_fraction_split,
|
628
|
+
training_filter_split=self.training_filter_split,
|
629
|
+
validation_filter_split=self.validation_filter_split,
|
630
|
+
test_filter_split=self.test_filter_split,
|
631
|
+
predefined_split_column_name=self.predefined_split_column_name,
|
632
|
+
timestamp_split_column_name=self.timestamp_split_column_name,
|
633
|
+
tensorboard=self.tensorboard,
|
634
|
+
)
|
635
|
+
custom_container_training_job_obj.wait_for_resource_creation()
|
636
|
+
training_pipeline_id: str = custom_container_training_job_obj.name
|
637
|
+
self.xcom_push(context, key="training_id", value=training_pipeline_id)
|
638
|
+
VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
|
639
|
+
self.defer(
|
640
|
+
trigger=CustomContainerTrainingJobTrigger(
|
641
|
+
conn_id=self.gcp_conn_id,
|
642
|
+
project_id=self.project_id,
|
643
|
+
location=self.region,
|
644
|
+
job_id=training_pipeline_id,
|
645
|
+
poll_interval=self.poll_interval,
|
646
|
+
impersonation_chain=self.impersonation_chain,
|
647
|
+
),
|
648
|
+
method_name="execute_complete",
|
649
|
+
)
|
650
|
+
|
651
|
+
@cached_property
|
652
|
+
def hook(self) -> CustomJobHook:
|
653
|
+
return CustomJobHook(
|
654
|
+
gcp_conn_id=self.gcp_conn_id,
|
655
|
+
impersonation_chain=self.impersonation_chain,
|
656
|
+
)
|
657
|
+
|
542
658
|
|
543
659
|
class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator):
|
544
660
|
"""Create Custom Python Package Training job.
|
@@ -800,9 +916,6 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
800
916
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
801
917
|
For more information on configuring your service account please visit:
|
802
918
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
803
|
-
:param sync: Whether to execute the AI Platform job synchronously. If False, this method
|
804
|
-
will be executed in concurrent Future and any downstream object will
|
805
|
-
be immediately returned and synced when the Future has completed.
|
806
919
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
807
920
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
808
921
|
credentials, or chained list of accounts required to get the access_token
|
@@ -812,6 +925,9 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
812
925
|
If set as a sequence, the identities from the list must grant
|
813
926
|
Service Account Token Creator IAM role to the directly preceding identity, with first
|
814
927
|
account from the list granting this role to the originating account (templated).
|
928
|
+
:param deferrable: If True, run the task in the deferrable mode.
|
929
|
+
:param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
|
930
|
+
The default is 60 seconds.
|
815
931
|
"""
|
816
932
|
|
817
933
|
template_fields = (
|
@@ -831,6 +947,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
831
947
|
parent_model: str | None = None,
|
832
948
|
impersonation_chain: str | Sequence[str] | None = None,
|
833
949
|
dataset_id: str | None = None,
|
950
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
951
|
+
poll_interval: int = 60,
|
834
952
|
**kwargs,
|
835
953
|
) -> None:
|
836
954
|
super().__init__(
|
@@ -842,12 +960,15 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
842
960
|
)
|
843
961
|
self.python_package_gcs_uri = python_package_gcs_uri
|
844
962
|
self.python_module_name = python_module_name
|
963
|
+
self.deferrable = deferrable
|
964
|
+
self.poll_interval = poll_interval
|
845
965
|
|
846
966
|
def execute(self, context: Context):
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
967
|
+
super().execute(context)
|
968
|
+
|
969
|
+
if self.deferrable:
|
970
|
+
self.invoke_defer(context=context)
|
971
|
+
|
851
972
|
model, training_id, custom_job_id = self.hook.create_custom_python_package_training_job(
|
852
973
|
project_id=self.project_id,
|
853
974
|
region=self.region,
|
@@ -920,9 +1041,98 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
|
|
920
1041
|
if self.hook:
|
921
1042
|
self.hook.cancel_job()
|
922
1043
|
|
1044
|
+
def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
|
1045
|
+
if event["status"] == "error":
|
1046
|
+
raise AirflowException(event["message"])
|
1047
|
+
result = event["job"]
|
1048
|
+
model_id = self.hook.extract_model_id_from_training_pipeline(result)
|
1049
|
+
custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(result)
|
1050
|
+
self.xcom_push(context, key="model_id", value=model_id)
|
1051
|
+
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
|
1052
|
+
# push custom_job_id to xcom so it could be pulled by other tasks
|
1053
|
+
self.xcom_push(context, key="custom_job_id", value=custom_job_id)
|
1054
|
+
return result
|
1055
|
+
|
1056
|
+
def invoke_defer(self, context: Context) -> None:
|
1057
|
+
custom_python_training_job_obj: CustomPythonPackageTrainingJob = self.hook.submit_custom_python_package_training_job(
|
1058
|
+
project_id=self.project_id,
|
1059
|
+
region=self.region,
|
1060
|
+
display_name=self.display_name,
|
1061
|
+
python_package_gcs_uri=self.python_package_gcs_uri,
|
1062
|
+
python_module_name=self.python_module_name,
|
1063
|
+
container_uri=self.container_uri,
|
1064
|
+
model_serving_container_image_uri=self.model_serving_container_image_uri,
|
1065
|
+
model_serving_container_predict_route=self.model_serving_container_predict_route,
|
1066
|
+
model_serving_container_health_route=self.model_serving_container_health_route,
|
1067
|
+
model_serving_container_command=self.model_serving_container_command,
|
1068
|
+
model_serving_container_args=self.model_serving_container_args,
|
1069
|
+
model_serving_container_environment_variables=self.model_serving_container_environment_variables,
|
1070
|
+
model_serving_container_ports=self.model_serving_container_ports,
|
1071
|
+
model_description=self.model_description,
|
1072
|
+
model_instance_schema_uri=self.model_instance_schema_uri,
|
1073
|
+
model_parameters_schema_uri=self.model_parameters_schema_uri,
|
1074
|
+
model_prediction_schema_uri=self.model_prediction_schema_uri,
|
1075
|
+
parent_model=self.parent_model,
|
1076
|
+
is_default_version=self.is_default_version,
|
1077
|
+
model_version_aliases=self.model_version_aliases,
|
1078
|
+
model_version_description=self.model_version_description,
|
1079
|
+
labels=self.labels,
|
1080
|
+
training_encryption_spec_key_name=self.training_encryption_spec_key_name,
|
1081
|
+
model_encryption_spec_key_name=self.model_encryption_spec_key_name,
|
1082
|
+
staging_bucket=self.staging_bucket,
|
1083
|
+
# RUN
|
1084
|
+
dataset=Dataset(name=self.dataset_id) if self.dataset_id else None,
|
1085
|
+
annotation_schema_uri=self.annotation_schema_uri,
|
1086
|
+
model_display_name=self.model_display_name,
|
1087
|
+
model_labels=self.model_labels,
|
1088
|
+
base_output_dir=self.base_output_dir,
|
1089
|
+
service_account=self.service_account,
|
1090
|
+
network=self.network,
|
1091
|
+
bigquery_destination=self.bigquery_destination,
|
1092
|
+
args=self.args,
|
1093
|
+
environment_variables=self.environment_variables,
|
1094
|
+
replica_count=self.replica_count,
|
1095
|
+
machine_type=self.machine_type,
|
1096
|
+
accelerator_type=self.accelerator_type,
|
1097
|
+
accelerator_count=self.accelerator_count,
|
1098
|
+
boot_disk_type=self.boot_disk_type,
|
1099
|
+
boot_disk_size_gb=self.boot_disk_size_gb,
|
1100
|
+
training_fraction_split=self.training_fraction_split,
|
1101
|
+
validation_fraction_split=self.validation_fraction_split,
|
1102
|
+
test_fraction_split=self.test_fraction_split,
|
1103
|
+
training_filter_split=self.training_filter_split,
|
1104
|
+
validation_filter_split=self.validation_filter_split,
|
1105
|
+
test_filter_split=self.test_filter_split,
|
1106
|
+
predefined_split_column_name=self.predefined_split_column_name,
|
1107
|
+
timestamp_split_column_name=self.timestamp_split_column_name,
|
1108
|
+
tensorboard=self.tensorboard,
|
1109
|
+
)
|
1110
|
+
custom_python_training_job_obj.wait_for_resource_creation()
|
1111
|
+
training_pipeline_id: str = custom_python_training_job_obj.name
|
1112
|
+
self.xcom_push(context, key="training_id", value=training_pipeline_id)
|
1113
|
+
VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
|
1114
|
+
self.defer(
|
1115
|
+
trigger=CustomPythonPackageTrainingJobTrigger(
|
1116
|
+
conn_id=self.gcp_conn_id,
|
1117
|
+
project_id=self.project_id,
|
1118
|
+
location=self.region,
|
1119
|
+
job_id=training_pipeline_id,
|
1120
|
+
poll_interval=self.poll_interval,
|
1121
|
+
impersonation_chain=self.impersonation_chain,
|
1122
|
+
),
|
1123
|
+
method_name="execute_complete",
|
1124
|
+
)
|
1125
|
+
|
1126
|
+
@cached_property
|
1127
|
+
def hook(self) -> CustomJobHook:
|
1128
|
+
return CustomJobHook(
|
1129
|
+
gcp_conn_id=self.gcp_conn_id,
|
1130
|
+
impersonation_chain=self.impersonation_chain,
|
1131
|
+
)
|
1132
|
+
|
923
1133
|
|
924
1134
|
class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
925
|
-
"""Create Custom Training
|
1135
|
+
"""Create a Custom Training Job pipeline.
|
926
1136
|
|
927
1137
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
928
1138
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -1181,9 +1391,6 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
1181
1391
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
1182
1392
|
For more information on configuring your service account please visit:
|
1183
1393
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
1184
|
-
:param sync: Whether to execute the AI Platform job synchronously. If False, this method
|
1185
|
-
will be executed in concurrent Future and any downstream object will
|
1186
|
-
be immediately returned and synced when the Future has completed.
|
1187
1394
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
1188
1395
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
1189
1396
|
credentials, or chained list of accounts required to get the access_token
|
@@ -1193,6 +1400,9 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
1193
1400
|
If set as a sequence, the identities from the list must grant
|
1194
1401
|
Service Account Token Creator IAM role to the directly preceding identity, with first
|
1195
1402
|
account from the list granting this role to the originating account (templated).
|
1403
|
+
:param deferrable: If True, run the task in the deferrable mode.
|
1404
|
+
:param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
|
1405
|
+
The default is 60 seconds.
|
1196
1406
|
"""
|
1197
1407
|
|
1198
1408
|
template_fields = (
|
@@ -1203,7 +1413,10 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
1203
1413
|
"dataset_id",
|
1204
1414
|
"impersonation_chain",
|
1205
1415
|
)
|
1206
|
-
operator_extra_links = (
|
1416
|
+
operator_extra_links = (
|
1417
|
+
VertexAIModelLink(),
|
1418
|
+
VertexAITrainingLink(),
|
1419
|
+
)
|
1207
1420
|
|
1208
1421
|
def __init__(
|
1209
1422
|
self,
|
@@ -1214,6 +1427,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
1214
1427
|
parent_model: str | None = None,
|
1215
1428
|
impersonation_chain: str | Sequence[str] | None = None,
|
1216
1429
|
dataset_id: str | None = None,
|
1430
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
1431
|
+
poll_interval: int = 60,
|
1217
1432
|
**kwargs,
|
1218
1433
|
) -> None:
|
1219
1434
|
super().__init__(
|
@@ -1225,12 +1440,15 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
1225
1440
|
)
|
1226
1441
|
self.requirements = requirements
|
1227
1442
|
self.script_path = script_path
|
1443
|
+
self.deferrable = deferrable
|
1444
|
+
self.poll_interval = poll_interval
|
1228
1445
|
|
1229
1446
|
def execute(self, context: Context):
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1447
|
+
super().execute(context)
|
1448
|
+
|
1449
|
+
if self.deferrable:
|
1450
|
+
self.invoke_defer(context=context)
|
1451
|
+
|
1234
1452
|
model, training_id, custom_job_id = self.hook.create_custom_training_job(
|
1235
1453
|
project_id=self.project_id,
|
1236
1454
|
region=self.region,
|
@@ -1303,6 +1521,95 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
|
|
1303
1521
|
if self.hook:
|
1304
1522
|
self.hook.cancel_job()
|
1305
1523
|
|
1524
|
+
def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
|
1525
|
+
if event["status"] == "error":
|
1526
|
+
raise AirflowException(event["message"])
|
1527
|
+
result = event["job"]
|
1528
|
+
model_id = self.hook.extract_model_id_from_training_pipeline(result)
|
1529
|
+
custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(result)
|
1530
|
+
self.xcom_push(context, key="model_id", value=model_id)
|
1531
|
+
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
|
1532
|
+
# push custom_job_id to xcom so it could be pulled by other tasks
|
1533
|
+
self.xcom_push(context, key="custom_job_id", value=custom_job_id)
|
1534
|
+
return result
|
1535
|
+
|
1536
|
+
def invoke_defer(self, context: Context) -> None:
|
1537
|
+
custom_training_job_obj: CustomTrainingJob = self.hook.submit_custom_training_job(
|
1538
|
+
project_id=self.project_id,
|
1539
|
+
region=self.region,
|
1540
|
+
display_name=self.display_name,
|
1541
|
+
script_path=self.script_path,
|
1542
|
+
container_uri=self.container_uri,
|
1543
|
+
requirements=self.requirements,
|
1544
|
+
model_serving_container_image_uri=self.model_serving_container_image_uri,
|
1545
|
+
model_serving_container_predict_route=self.model_serving_container_predict_route,
|
1546
|
+
model_serving_container_health_route=self.model_serving_container_health_route,
|
1547
|
+
model_serving_container_command=self.model_serving_container_command,
|
1548
|
+
model_serving_container_args=self.model_serving_container_args,
|
1549
|
+
model_serving_container_environment_variables=self.model_serving_container_environment_variables,
|
1550
|
+
model_serving_container_ports=self.model_serving_container_ports,
|
1551
|
+
model_description=self.model_description,
|
1552
|
+
model_instance_schema_uri=self.model_instance_schema_uri,
|
1553
|
+
model_parameters_schema_uri=self.model_parameters_schema_uri,
|
1554
|
+
model_prediction_schema_uri=self.model_prediction_schema_uri,
|
1555
|
+
parent_model=self.parent_model,
|
1556
|
+
is_default_version=self.is_default_version,
|
1557
|
+
model_version_aliases=self.model_version_aliases,
|
1558
|
+
model_version_description=self.model_version_description,
|
1559
|
+
labels=self.labels,
|
1560
|
+
training_encryption_spec_key_name=self.training_encryption_spec_key_name,
|
1561
|
+
model_encryption_spec_key_name=self.model_encryption_spec_key_name,
|
1562
|
+
staging_bucket=self.staging_bucket,
|
1563
|
+
# RUN
|
1564
|
+
dataset=Dataset(name=self.dataset_id) if self.dataset_id else None,
|
1565
|
+
annotation_schema_uri=self.annotation_schema_uri,
|
1566
|
+
model_display_name=self.model_display_name,
|
1567
|
+
model_labels=self.model_labels,
|
1568
|
+
base_output_dir=self.base_output_dir,
|
1569
|
+
service_account=self.service_account,
|
1570
|
+
network=self.network,
|
1571
|
+
bigquery_destination=self.bigquery_destination,
|
1572
|
+
args=self.args,
|
1573
|
+
environment_variables=self.environment_variables,
|
1574
|
+
replica_count=self.replica_count,
|
1575
|
+
machine_type=self.machine_type,
|
1576
|
+
accelerator_type=self.accelerator_type,
|
1577
|
+
accelerator_count=self.accelerator_count,
|
1578
|
+
boot_disk_type=self.boot_disk_type,
|
1579
|
+
boot_disk_size_gb=self.boot_disk_size_gb,
|
1580
|
+
training_fraction_split=self.training_fraction_split,
|
1581
|
+
validation_fraction_split=self.validation_fraction_split,
|
1582
|
+
test_fraction_split=self.test_fraction_split,
|
1583
|
+
training_filter_split=self.training_filter_split,
|
1584
|
+
validation_filter_split=self.validation_filter_split,
|
1585
|
+
test_filter_split=self.test_filter_split,
|
1586
|
+
predefined_split_column_name=self.predefined_split_column_name,
|
1587
|
+
timestamp_split_column_name=self.timestamp_split_column_name,
|
1588
|
+
tensorboard=self.tensorboard,
|
1589
|
+
)
|
1590
|
+
custom_training_job_obj.wait_for_resource_creation()
|
1591
|
+
training_pipeline_id: str = custom_training_job_obj.name
|
1592
|
+
self.xcom_push(context, key="training_id", value=training_pipeline_id)
|
1593
|
+
VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
|
1594
|
+
self.defer(
|
1595
|
+
trigger=CustomTrainingJobTrigger(
|
1596
|
+
conn_id=self.gcp_conn_id,
|
1597
|
+
project_id=self.project_id,
|
1598
|
+
location=self.region,
|
1599
|
+
job_id=training_pipeline_id,
|
1600
|
+
poll_interval=self.poll_interval,
|
1601
|
+
impersonation_chain=self.impersonation_chain,
|
1602
|
+
),
|
1603
|
+
method_name="execute_complete",
|
1604
|
+
)
|
1605
|
+
|
1606
|
+
@cached_property
|
1607
|
+
def hook(self) -> CustomJobHook:
|
1608
|
+
return CustomJobHook(
|
1609
|
+
gcp_conn_id=self.gcp_conn_id,
|
1610
|
+
impersonation_chain=self.impersonation_chain,
|
1611
|
+
)
|
1612
|
+
|
1306
1613
|
|
1307
1614
|
class DeleteCustomTrainingJobOperator(GoogleCloudBaseOperator):
|
1308
1615
|
"""
|
@@ -33,11 +33,11 @@ class PromptLanguageModelOperator(GoogleCloudBaseOperator):
|
|
33
33
|
Uses the Vertex AI PaLM API to generate natural language text.
|
34
34
|
|
35
35
|
:param project_id: Required. The ID of the Google Cloud project that the
|
36
|
-
service belongs to.
|
36
|
+
service belongs to (templated).
|
37
37
|
:param location: Required. The ID of the Google Cloud location that the
|
38
|
-
service belongs to.
|
38
|
+
service belongs to (templated).
|
39
39
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
40
|
-
to the Vertex AI PaLM API, in order to elicit a specific response.
|
40
|
+
to the Vertex AI PaLM API, in order to elicit a specific response (templated).
|
41
41
|
:param pretrained_model: By default uses the pre-trained model `text-bison`,
|
42
42
|
optimized for performing natural language tasks such as classification,
|
43
43
|
summarization, extraction, content creation, and ideation.
|
@@ -60,6 +60,8 @@ class PromptLanguageModelOperator(GoogleCloudBaseOperator):
|
|
60
60
|
account from the list granting this role to the originating account (templated).
|
61
61
|
"""
|
62
62
|
|
63
|
+
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
64
|
+
|
63
65
|
def __init__(
|
64
66
|
self,
|
65
67
|
*,
|
@@ -116,11 +118,11 @@ class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
116
118
|
Uses the Vertex AI PaLM API to generate natural language text.
|
117
119
|
|
118
120
|
:param project_id: Required. The ID of the Google Cloud project that the
|
119
|
-
service belongs to.
|
121
|
+
service belongs to (templated).
|
120
122
|
:param location: Required. The ID of the Google Cloud location that the
|
121
|
-
service belongs to.
|
123
|
+
service belongs to (templated).
|
122
124
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
123
|
-
to the Vertex AI PaLM API, in order to elicit a specific response.
|
125
|
+
to the Vertex AI PaLM API, in order to elicit a specific response (templated).
|
124
126
|
:param pretrained_model: By default uses the pre-trained model `textembedding-gecko`,
|
125
127
|
optimized for performing text embeddings.
|
126
128
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
@@ -134,6 +136,8 @@ class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
134
136
|
account from the list granting this role to the originating account (templated).
|
135
137
|
"""
|
136
138
|
|
139
|
+
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
140
|
+
|
137
141
|
def __init__(
|
138
142
|
self,
|
139
143
|
*,
|
@@ -178,11 +182,11 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
|
|
178
182
|
Use the Vertex AI Gemini Pro foundation model to generate natural language text.
|
179
183
|
|
180
184
|
:param project_id: Required. The ID of the Google Cloud project that the
|
181
|
-
service belongs to.
|
185
|
+
service belongs to (templated).
|
182
186
|
:param location: Required. The ID of the Google Cloud location that the
|
183
|
-
service belongs to.
|
187
|
+
service belongs to (templated).
|
184
188
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
185
|
-
to the Multi-modal model, in order to elicit a specific response.
|
189
|
+
to the Multi-modal model, in order to elicit a specific response (templated).
|
186
190
|
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
|
187
191
|
supporting prompts with text-only input, including natural language
|
188
192
|
tasks, multi-turn text and code chat, and code generation. It can
|
@@ -198,6 +202,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
|
|
198
202
|
account from the list granting this role to the originating account (templated).
|
199
203
|
"""
|
200
204
|
|
205
|
+
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
206
|
+
|
201
207
|
def __init__(
|
202
208
|
self,
|
203
209
|
*,
|
@@ -240,11 +246,11 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
|
|
240
246
|
Use the Vertex AI Gemini Pro foundation model to generate natural language text.
|
241
247
|
|
242
248
|
:param project_id: Required. The ID of the Google Cloud project that the
|
243
|
-
service belongs to.
|
249
|
+
service belongs to (templated).
|
244
250
|
:param location: Required. The ID of the Google Cloud location that the
|
245
|
-
service belongs to.
|
251
|
+
service belongs to (templated).
|
246
252
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
247
|
-
to the Multi-modal model, in order to elicit a specific response.
|
253
|
+
to the Multi-modal model, in order to elicit a specific response (templated).
|
248
254
|
:param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
|
249
255
|
supporting prompts with text-only input, including natural language
|
250
256
|
tasks, multi-turn text and code chat, and code generation. It can
|
@@ -263,6 +269,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
|
|
263
269
|
account from the list granting this role to the originating account (templated).
|
264
270
|
"""
|
265
271
|
|
272
|
+
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
273
|
+
|
266
274
|
def __init__(
|
267
275
|
self,
|
268
276
|
*,
|
@@ -102,7 +102,6 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
|
|
102
102
|
Service Account Token Creator IAM role to the directly preceding identity, with first
|
103
103
|
account from the list granting this role to the originating account (templated).
|
104
104
|
:param deferrable: If True, run the task in the deferrable mode.
|
105
|
-
Note that it requires calling the operator with `sync=False` parameter.
|
106
105
|
:param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
|
107
106
|
The default is 300 seconds.
|
108
107
|
"""
|