apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. airflow/providers/google/__init__.py +3 -3
  2. airflow/providers/google/cloud/hooks/automl.py +1 -1
  3. airflow/providers/google/cloud/hooks/bigquery.py +64 -33
  4. airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
  5. airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
  6. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
  7. airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +246 -32
  9. airflow/providers/google/cloud/hooks/dataplex.py +6 -2
  10. airflow/providers/google/cloud/hooks/dlp.py +14 -14
  11. airflow/providers/google/cloud/hooks/gcs.py +6 -2
  12. airflow/providers/google/cloud/hooks/gdm.py +2 -2
  13. airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
  14. airflow/providers/google/cloud/hooks/mlengine.py +8 -4
  15. airflow/providers/google/cloud/hooks/pubsub.py +1 -1
  16. airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
  17. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
  18. airflow/providers/google/cloud/links/vertex_ai.py +2 -1
  19. airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
  20. airflow/providers/google/cloud/operators/automl.py +13 -12
  21. airflow/providers/google/cloud/operators/bigquery.py +36 -22
  22. airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
  23. airflow/providers/google/cloud/operators/bigtable.py +7 -6
  24. airflow/providers/google/cloud/operators/cloud_build.py +12 -11
  25. airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
  26. airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
  27. airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
  28. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
  29. airflow/providers/google/cloud/operators/compute.py +12 -11
  30. airflow/providers/google/cloud/operators/datacatalog.py +21 -20
  31. airflow/providers/google/cloud/operators/dataflow.py +59 -42
  32. airflow/providers/google/cloud/operators/datafusion.py +11 -10
  33. airflow/providers/google/cloud/operators/datapipeline.py +3 -2
  34. airflow/providers/google/cloud/operators/dataprep.py +5 -4
  35. airflow/providers/google/cloud/operators/dataproc.py +19 -16
  36. airflow/providers/google/cloud/operators/datastore.py +8 -7
  37. airflow/providers/google/cloud/operators/dlp.py +31 -30
  38. airflow/providers/google/cloud/operators/functions.py +4 -3
  39. airflow/providers/google/cloud/operators/gcs.py +66 -41
  40. airflow/providers/google/cloud/operators/kubernetes_engine.py +232 -12
  41. airflow/providers/google/cloud/operators/life_sciences.py +2 -1
  42. airflow/providers/google/cloud/operators/mlengine.py +11 -10
  43. airflow/providers/google/cloud/operators/pubsub.py +6 -5
  44. airflow/providers/google/cloud/operators/spanner.py +7 -6
  45. airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
  46. airflow/providers/google/cloud/operators/stackdriver.py +11 -10
  47. airflow/providers/google/cloud/operators/tasks.py +14 -13
  48. airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
  49. airflow/providers/google/cloud/operators/translate_speech.py +2 -1
  50. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
  51. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
  52. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
  53. airflow/providers/google/cloud/operators/vision.py +13 -12
  54. airflow/providers/google/cloud/operators/workflows.py +10 -9
  55. airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
  56. airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
  57. airflow/providers/google/cloud/sensors/bigtable.py +2 -1
  58. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
  59. airflow/providers/google/cloud/sensors/dataflow.py +239 -52
  60. airflow/providers/google/cloud/sensors/datafusion.py +2 -1
  61. airflow/providers/google/cloud/sensors/dataproc.py +3 -2
  62. airflow/providers/google/cloud/sensors/gcs.py +14 -12
  63. airflow/providers/google/cloud/sensors/tasks.py +2 -1
  64. airflow/providers/google/cloud/sensors/workflows.py +2 -1
  65. airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
  66. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
  67. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
  68. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
  69. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
  70. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
  71. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
  72. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
  73. airflow/providers/google/cloud/triggers/bigquery.py +14 -3
  74. airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
  75. airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
  76. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
  77. airflow/providers/google/cloud/triggers/dataflow.py +504 -4
  78. airflow/providers/google/cloud/triggers/dataproc.py +110 -26
  79. airflow/providers/google/cloud/triggers/mlengine.py +2 -1
  80. airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
  81. airflow/providers/google/common/hooks/base_google.py +45 -7
  82. airflow/providers/google/firebase/hooks/firestore.py +2 -2
  83. airflow/providers/google/firebase/operators/firestore.py +2 -1
  84. airflow/providers/google/get_provider_info.py +3 -2
  85. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/METADATA +8 -8
  86. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/RECORD +88 -89
  87. airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
  88. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/WHEEL +0 -0
  89. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,7 +19,9 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from typing import TYPE_CHECKING, Sequence
22
+ import warnings
23
+ from functools import cached_property
24
+ from typing import TYPE_CHECKING, Any, Sequence
23
25
 
24
26
  from deprecated import deprecated
25
27
  from google.api_core.exceptions import NotFound
@@ -28,7 +30,8 @@ from google.cloud.aiplatform.models import Model
28
30
  from google.cloud.aiplatform_v1.types.dataset import Dataset
29
31
  from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
30
32
 
31
- from airflow.exceptions import AirflowProviderDeprecationWarning
33
+ from airflow.configuration import conf
34
+ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
32
35
  from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
33
36
  from airflow.providers.google.cloud.links.vertex_ai import (
34
37
  VertexAIModelLink,
@@ -36,9 +39,19 @@ from airflow.providers.google.cloud.links.vertex_ai import (
36
39
  VertexAITrainingPipelinesLink,
37
40
  )
38
41
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
42
+ from airflow.providers.google.cloud.triggers.vertex_ai import (
43
+ CustomContainerTrainingJobTrigger,
44
+ CustomPythonPackageTrainingJobTrigger,
45
+ CustomTrainingJobTrigger,
46
+ )
39
47
 
40
48
  if TYPE_CHECKING:
41
49
  from google.api_core.retry import Retry
50
+ from google.cloud.aiplatform import (
51
+ CustomContainerTrainingJob,
52
+ CustomPythonPackageTrainingJob,
53
+ CustomTrainingJob,
54
+ )
42
55
 
43
56
  from airflow.utils.context import Context
44
57
 
@@ -160,6 +173,13 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
160
173
  self.gcp_conn_id = gcp_conn_id
161
174
  self.impersonation_chain = impersonation_chain
162
175
 
176
+ def execute(self, context: Context) -> None:
177
+ warnings.warn(
178
+ "The 'sync' parameter is deprecated and will be removed after 01.10.2024.",
179
+ AirflowProviderDeprecationWarning,
180
+ stacklevel=2,
181
+ )
182
+
163
183
 
164
184
  class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
165
185
  """Create Custom Container Training job.
@@ -421,9 +441,6 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
421
441
  ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
422
442
  For more information on configuring your service account please visit:
423
443
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
424
- :param sync: Whether to execute the AI Platform job synchronously. If False, this method
425
- will be executed in concurrent Future and any downstream object will
426
- be immediately returned and synced when the Future has completed.
427
444
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
428
445
  :param impersonation_chain: Optional service account to impersonate using short-term
429
446
  credentials, or chained list of accounts required to get the access_token
@@ -433,6 +450,9 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
433
450
  If set as a sequence, the identities from the list must grant
434
451
  Service Account Token Creator IAM role to the directly preceding identity, with first
435
452
  account from the list granting this role to the originating account (templated).
453
+ :param deferrable: If True, run the task in the deferrable mode.
454
+ :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
455
+ The default is 60 seconds.
436
456
  """
437
457
 
438
458
  template_fields = (
@@ -442,7 +462,10 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
442
462
  "dataset_id",
443
463
  "impersonation_chain",
444
464
  )
445
- operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
465
+ operator_extra_links = (
466
+ VertexAIModelLink(),
467
+ VertexAITrainingLink(),
468
+ )
446
469
 
447
470
  def __init__(
448
471
  self,
@@ -452,6 +475,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
452
475
  parent_model: str | None = None,
453
476
  impersonation_chain: str | Sequence[str] | None = None,
454
477
  dataset_id: str | None = None,
478
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
479
+ poll_interval: int = 60,
455
480
  **kwargs,
456
481
  ) -> None:
457
482
  super().__init__(
@@ -462,12 +487,15 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
462
487
  **kwargs,
463
488
  )
464
489
  self.command = command
490
+ self.deferrable = deferrable
491
+ self.poll_interval = poll_interval
465
492
 
466
493
  def execute(self, context: Context):
467
- self.hook = CustomJobHook(
468
- gcp_conn_id=self.gcp_conn_id,
469
- impersonation_chain=self.impersonation_chain,
470
- )
494
+ super().execute(context)
495
+
496
+ if self.deferrable:
497
+ self.invoke_defer(context=context)
498
+
471
499
  model, training_id, custom_job_id = self.hook.create_custom_container_training_job(
472
500
  project_id=self.project_id,
473
501
  region=self.region,
@@ -539,6 +567,94 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
539
567
  if self.hook:
540
568
  self.hook.cancel_job()
541
569
 
570
+ def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
571
+ if event["status"] == "error":
572
+ raise AirflowException(event["message"])
573
+ result = event["job"]
574
+ model_id = self.hook.extract_model_id_from_training_pipeline(result)
575
+ custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(result)
576
+ self.xcom_push(context, key="model_id", value=model_id)
577
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
578
+ # push custom_job_id to xcom so it could be pulled by other tasks
579
+ self.xcom_push(context, key="custom_job_id", value=custom_job_id)
580
+ return result
581
+
582
+ def invoke_defer(self, context: Context) -> None:
583
+ custom_container_training_job_obj: CustomContainerTrainingJob = self.hook.submit_custom_container_training_job(
584
+ project_id=self.project_id,
585
+ region=self.region,
586
+ display_name=self.display_name,
587
+ command=self.command,
588
+ container_uri=self.container_uri,
589
+ model_serving_container_image_uri=self.model_serving_container_image_uri,
590
+ model_serving_container_predict_route=self.model_serving_container_predict_route,
591
+ model_serving_container_health_route=self.model_serving_container_health_route,
592
+ model_serving_container_command=self.model_serving_container_command,
593
+ model_serving_container_args=self.model_serving_container_args,
594
+ model_serving_container_environment_variables=self.model_serving_container_environment_variables,
595
+ model_serving_container_ports=self.model_serving_container_ports,
596
+ model_description=self.model_description,
597
+ model_instance_schema_uri=self.model_instance_schema_uri,
598
+ model_parameters_schema_uri=self.model_parameters_schema_uri,
599
+ model_prediction_schema_uri=self.model_prediction_schema_uri,
600
+ parent_model=self.parent_model,
601
+ is_default_version=self.is_default_version,
602
+ model_version_aliases=self.model_version_aliases,
603
+ model_version_description=self.model_version_description,
604
+ labels=self.labels,
605
+ training_encryption_spec_key_name=self.training_encryption_spec_key_name,
606
+ model_encryption_spec_key_name=self.model_encryption_spec_key_name,
607
+ staging_bucket=self.staging_bucket,
608
+ # RUN
609
+ dataset=Dataset(name=self.dataset_id) if self.dataset_id else None,
610
+ annotation_schema_uri=self.annotation_schema_uri,
611
+ model_display_name=self.model_display_name,
612
+ model_labels=self.model_labels,
613
+ base_output_dir=self.base_output_dir,
614
+ service_account=self.service_account,
615
+ network=self.network,
616
+ bigquery_destination=self.bigquery_destination,
617
+ args=self.args,
618
+ environment_variables=self.environment_variables,
619
+ replica_count=self.replica_count,
620
+ machine_type=self.machine_type,
621
+ accelerator_type=self.accelerator_type,
622
+ accelerator_count=self.accelerator_count,
623
+ boot_disk_type=self.boot_disk_type,
624
+ boot_disk_size_gb=self.boot_disk_size_gb,
625
+ training_fraction_split=self.training_fraction_split,
626
+ validation_fraction_split=self.validation_fraction_split,
627
+ test_fraction_split=self.test_fraction_split,
628
+ training_filter_split=self.training_filter_split,
629
+ validation_filter_split=self.validation_filter_split,
630
+ test_filter_split=self.test_filter_split,
631
+ predefined_split_column_name=self.predefined_split_column_name,
632
+ timestamp_split_column_name=self.timestamp_split_column_name,
633
+ tensorboard=self.tensorboard,
634
+ )
635
+ custom_container_training_job_obj.wait_for_resource_creation()
636
+ training_pipeline_id: str = custom_container_training_job_obj.name
637
+ self.xcom_push(context, key="training_id", value=training_pipeline_id)
638
+ VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
639
+ self.defer(
640
+ trigger=CustomContainerTrainingJobTrigger(
641
+ conn_id=self.gcp_conn_id,
642
+ project_id=self.project_id,
643
+ location=self.region,
644
+ job_id=training_pipeline_id,
645
+ poll_interval=self.poll_interval,
646
+ impersonation_chain=self.impersonation_chain,
647
+ ),
648
+ method_name="execute_complete",
649
+ )
650
+
651
+ @cached_property
652
+ def hook(self) -> CustomJobHook:
653
+ return CustomJobHook(
654
+ gcp_conn_id=self.gcp_conn_id,
655
+ impersonation_chain=self.impersonation_chain,
656
+ )
657
+
542
658
 
543
659
  class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator):
544
660
  """Create Custom Python Package Training job.
@@ -800,9 +916,6 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
800
916
  ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
801
917
  For more information on configuring your service account please visit:
802
918
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
803
- :param sync: Whether to execute the AI Platform job synchronously. If False, this method
804
- will be executed in concurrent Future and any downstream object will
805
- be immediately returned and synced when the Future has completed.
806
919
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
807
920
  :param impersonation_chain: Optional service account to impersonate using short-term
808
921
  credentials, or chained list of accounts required to get the access_token
@@ -812,6 +925,9 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
812
925
  If set as a sequence, the identities from the list must grant
813
926
  Service Account Token Creator IAM role to the directly preceding identity, with first
814
927
  account from the list granting this role to the originating account (templated).
928
+ :param deferrable: If True, run the task in the deferrable mode.
929
+ :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
930
+ The default is 60 seconds.
815
931
  """
816
932
 
817
933
  template_fields = (
@@ -831,6 +947,8 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
831
947
  parent_model: str | None = None,
832
948
  impersonation_chain: str | Sequence[str] | None = None,
833
949
  dataset_id: str | None = None,
950
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
951
+ poll_interval: int = 60,
834
952
  **kwargs,
835
953
  ) -> None:
836
954
  super().__init__(
@@ -842,12 +960,15 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
842
960
  )
843
961
  self.python_package_gcs_uri = python_package_gcs_uri
844
962
  self.python_module_name = python_module_name
963
+ self.deferrable = deferrable
964
+ self.poll_interval = poll_interval
845
965
 
846
966
  def execute(self, context: Context):
847
- self.hook = CustomJobHook(
848
- gcp_conn_id=self.gcp_conn_id,
849
- impersonation_chain=self.impersonation_chain,
850
- )
967
+ super().execute(context)
968
+
969
+ if self.deferrable:
970
+ self.invoke_defer(context=context)
971
+
851
972
  model, training_id, custom_job_id = self.hook.create_custom_python_package_training_job(
852
973
  project_id=self.project_id,
853
974
  region=self.region,
@@ -920,9 +1041,98 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
920
1041
  if self.hook:
921
1042
  self.hook.cancel_job()
922
1043
 
1044
+ def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
1045
+ if event["status"] == "error":
1046
+ raise AirflowException(event["message"])
1047
+ result = event["job"]
1048
+ model_id = self.hook.extract_model_id_from_training_pipeline(result)
1049
+ custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(result)
1050
+ self.xcom_push(context, key="model_id", value=model_id)
1051
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1052
+ # push custom_job_id to xcom so it could be pulled by other tasks
1053
+ self.xcom_push(context, key="custom_job_id", value=custom_job_id)
1054
+ return result
1055
+
1056
+ def invoke_defer(self, context: Context) -> None:
1057
+ custom_python_training_job_obj: CustomPythonPackageTrainingJob = self.hook.submit_custom_python_package_training_job(
1058
+ project_id=self.project_id,
1059
+ region=self.region,
1060
+ display_name=self.display_name,
1061
+ python_package_gcs_uri=self.python_package_gcs_uri,
1062
+ python_module_name=self.python_module_name,
1063
+ container_uri=self.container_uri,
1064
+ model_serving_container_image_uri=self.model_serving_container_image_uri,
1065
+ model_serving_container_predict_route=self.model_serving_container_predict_route,
1066
+ model_serving_container_health_route=self.model_serving_container_health_route,
1067
+ model_serving_container_command=self.model_serving_container_command,
1068
+ model_serving_container_args=self.model_serving_container_args,
1069
+ model_serving_container_environment_variables=self.model_serving_container_environment_variables,
1070
+ model_serving_container_ports=self.model_serving_container_ports,
1071
+ model_description=self.model_description,
1072
+ model_instance_schema_uri=self.model_instance_schema_uri,
1073
+ model_parameters_schema_uri=self.model_parameters_schema_uri,
1074
+ model_prediction_schema_uri=self.model_prediction_schema_uri,
1075
+ parent_model=self.parent_model,
1076
+ is_default_version=self.is_default_version,
1077
+ model_version_aliases=self.model_version_aliases,
1078
+ model_version_description=self.model_version_description,
1079
+ labels=self.labels,
1080
+ training_encryption_spec_key_name=self.training_encryption_spec_key_name,
1081
+ model_encryption_spec_key_name=self.model_encryption_spec_key_name,
1082
+ staging_bucket=self.staging_bucket,
1083
+ # RUN
1084
+ dataset=Dataset(name=self.dataset_id) if self.dataset_id else None,
1085
+ annotation_schema_uri=self.annotation_schema_uri,
1086
+ model_display_name=self.model_display_name,
1087
+ model_labels=self.model_labels,
1088
+ base_output_dir=self.base_output_dir,
1089
+ service_account=self.service_account,
1090
+ network=self.network,
1091
+ bigquery_destination=self.bigquery_destination,
1092
+ args=self.args,
1093
+ environment_variables=self.environment_variables,
1094
+ replica_count=self.replica_count,
1095
+ machine_type=self.machine_type,
1096
+ accelerator_type=self.accelerator_type,
1097
+ accelerator_count=self.accelerator_count,
1098
+ boot_disk_type=self.boot_disk_type,
1099
+ boot_disk_size_gb=self.boot_disk_size_gb,
1100
+ training_fraction_split=self.training_fraction_split,
1101
+ validation_fraction_split=self.validation_fraction_split,
1102
+ test_fraction_split=self.test_fraction_split,
1103
+ training_filter_split=self.training_filter_split,
1104
+ validation_filter_split=self.validation_filter_split,
1105
+ test_filter_split=self.test_filter_split,
1106
+ predefined_split_column_name=self.predefined_split_column_name,
1107
+ timestamp_split_column_name=self.timestamp_split_column_name,
1108
+ tensorboard=self.tensorboard,
1109
+ )
1110
+ custom_python_training_job_obj.wait_for_resource_creation()
1111
+ training_pipeline_id: str = custom_python_training_job_obj.name
1112
+ self.xcom_push(context, key="training_id", value=training_pipeline_id)
1113
+ VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
1114
+ self.defer(
1115
+ trigger=CustomPythonPackageTrainingJobTrigger(
1116
+ conn_id=self.gcp_conn_id,
1117
+ project_id=self.project_id,
1118
+ location=self.region,
1119
+ job_id=training_pipeline_id,
1120
+ poll_interval=self.poll_interval,
1121
+ impersonation_chain=self.impersonation_chain,
1122
+ ),
1123
+ method_name="execute_complete",
1124
+ )
1125
+
1126
+ @cached_property
1127
+ def hook(self) -> CustomJobHook:
1128
+ return CustomJobHook(
1129
+ gcp_conn_id=self.gcp_conn_id,
1130
+ impersonation_chain=self.impersonation_chain,
1131
+ )
1132
+
923
1133
 
924
1134
  class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
925
- """Create Custom Training job.
1135
+ """Create a Custom Training Job pipeline.
926
1136
 
927
1137
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
928
1138
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -1181,9 +1391,6 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1181
1391
  ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
1182
1392
  For more information on configuring your service account please visit:
1183
1393
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
1184
- :param sync: Whether to execute the AI Platform job synchronously. If False, this method
1185
- will be executed in concurrent Future and any downstream object will
1186
- be immediately returned and synced when the Future has completed.
1187
1394
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
1188
1395
  :param impersonation_chain: Optional service account to impersonate using short-term
1189
1396
  credentials, or chained list of accounts required to get the access_token
@@ -1193,6 +1400,9 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1193
1400
  If set as a sequence, the identities from the list must grant
1194
1401
  Service Account Token Creator IAM role to the directly preceding identity, with first
1195
1402
  account from the list granting this role to the originating account (templated).
1403
+ :param deferrable: If True, run the task in the deferrable mode.
1404
+ :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
1405
+ The default is 60 seconds.
1196
1406
  """
1197
1407
 
1198
1408
  template_fields = (
@@ -1203,7 +1413,10 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1203
1413
  "dataset_id",
1204
1414
  "impersonation_chain",
1205
1415
  )
1206
- operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
1416
+ operator_extra_links = (
1417
+ VertexAIModelLink(),
1418
+ VertexAITrainingLink(),
1419
+ )
1207
1420
 
1208
1421
  def __init__(
1209
1422
  self,
@@ -1214,6 +1427,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1214
1427
  parent_model: str | None = None,
1215
1428
  impersonation_chain: str | Sequence[str] | None = None,
1216
1429
  dataset_id: str | None = None,
1430
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1431
+ poll_interval: int = 60,
1217
1432
  **kwargs,
1218
1433
  ) -> None:
1219
1434
  super().__init__(
@@ -1225,12 +1440,15 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1225
1440
  )
1226
1441
  self.requirements = requirements
1227
1442
  self.script_path = script_path
1443
+ self.deferrable = deferrable
1444
+ self.poll_interval = poll_interval
1228
1445
 
1229
1446
  def execute(self, context: Context):
1230
- self.hook = CustomJobHook(
1231
- gcp_conn_id=self.gcp_conn_id,
1232
- impersonation_chain=self.impersonation_chain,
1233
- )
1447
+ super().execute(context)
1448
+
1449
+ if self.deferrable:
1450
+ self.invoke_defer(context=context)
1451
+
1234
1452
  model, training_id, custom_job_id = self.hook.create_custom_training_job(
1235
1453
  project_id=self.project_id,
1236
1454
  region=self.region,
@@ -1303,6 +1521,95 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1303
1521
  if self.hook:
1304
1522
  self.hook.cancel_job()
1305
1523
 
1524
+ def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None:
1525
+ if event["status"] == "error":
1526
+ raise AirflowException(event["message"])
1527
+ result = event["job"]
1528
+ model_id = self.hook.extract_model_id_from_training_pipeline(result)
1529
+ custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(result)
1530
+ self.xcom_push(context, key="model_id", value=model_id)
1531
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1532
+ # push custom_job_id to xcom so it could be pulled by other tasks
1533
+ self.xcom_push(context, key="custom_job_id", value=custom_job_id)
1534
+ return result
1535
+
1536
+ def invoke_defer(self, context: Context) -> None:
1537
+ custom_training_job_obj: CustomTrainingJob = self.hook.submit_custom_training_job(
1538
+ project_id=self.project_id,
1539
+ region=self.region,
1540
+ display_name=self.display_name,
1541
+ script_path=self.script_path,
1542
+ container_uri=self.container_uri,
1543
+ requirements=self.requirements,
1544
+ model_serving_container_image_uri=self.model_serving_container_image_uri,
1545
+ model_serving_container_predict_route=self.model_serving_container_predict_route,
1546
+ model_serving_container_health_route=self.model_serving_container_health_route,
1547
+ model_serving_container_command=self.model_serving_container_command,
1548
+ model_serving_container_args=self.model_serving_container_args,
1549
+ model_serving_container_environment_variables=self.model_serving_container_environment_variables,
1550
+ model_serving_container_ports=self.model_serving_container_ports,
1551
+ model_description=self.model_description,
1552
+ model_instance_schema_uri=self.model_instance_schema_uri,
1553
+ model_parameters_schema_uri=self.model_parameters_schema_uri,
1554
+ model_prediction_schema_uri=self.model_prediction_schema_uri,
1555
+ parent_model=self.parent_model,
1556
+ is_default_version=self.is_default_version,
1557
+ model_version_aliases=self.model_version_aliases,
1558
+ model_version_description=self.model_version_description,
1559
+ labels=self.labels,
1560
+ training_encryption_spec_key_name=self.training_encryption_spec_key_name,
1561
+ model_encryption_spec_key_name=self.model_encryption_spec_key_name,
1562
+ staging_bucket=self.staging_bucket,
1563
+ # RUN
1564
+ dataset=Dataset(name=self.dataset_id) if self.dataset_id else None,
1565
+ annotation_schema_uri=self.annotation_schema_uri,
1566
+ model_display_name=self.model_display_name,
1567
+ model_labels=self.model_labels,
1568
+ base_output_dir=self.base_output_dir,
1569
+ service_account=self.service_account,
1570
+ network=self.network,
1571
+ bigquery_destination=self.bigquery_destination,
1572
+ args=self.args,
1573
+ environment_variables=self.environment_variables,
1574
+ replica_count=self.replica_count,
1575
+ machine_type=self.machine_type,
1576
+ accelerator_type=self.accelerator_type,
1577
+ accelerator_count=self.accelerator_count,
1578
+ boot_disk_type=self.boot_disk_type,
1579
+ boot_disk_size_gb=self.boot_disk_size_gb,
1580
+ training_fraction_split=self.training_fraction_split,
1581
+ validation_fraction_split=self.validation_fraction_split,
1582
+ test_fraction_split=self.test_fraction_split,
1583
+ training_filter_split=self.training_filter_split,
1584
+ validation_filter_split=self.validation_filter_split,
1585
+ test_filter_split=self.test_filter_split,
1586
+ predefined_split_column_name=self.predefined_split_column_name,
1587
+ timestamp_split_column_name=self.timestamp_split_column_name,
1588
+ tensorboard=self.tensorboard,
1589
+ )
1590
+ custom_training_job_obj.wait_for_resource_creation()
1591
+ training_pipeline_id: str = custom_training_job_obj.name
1592
+ self.xcom_push(context, key="training_id", value=training_pipeline_id)
1593
+ VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id)
1594
+ self.defer(
1595
+ trigger=CustomTrainingJobTrigger(
1596
+ conn_id=self.gcp_conn_id,
1597
+ project_id=self.project_id,
1598
+ location=self.region,
1599
+ job_id=training_pipeline_id,
1600
+ poll_interval=self.poll_interval,
1601
+ impersonation_chain=self.impersonation_chain,
1602
+ ),
1603
+ method_name="execute_complete",
1604
+ )
1605
+
1606
+ @cached_property
1607
+ def hook(self) -> CustomJobHook:
1608
+ return CustomJobHook(
1609
+ gcp_conn_id=self.gcp_conn_id,
1610
+ impersonation_chain=self.impersonation_chain,
1611
+ )
1612
+
1306
1613
 
1307
1614
  class DeleteCustomTrainingJobOperator(GoogleCloudBaseOperator):
1308
1615
  """
@@ -33,11 +33,11 @@ class PromptLanguageModelOperator(GoogleCloudBaseOperator):
33
33
  Uses the Vertex AI PaLM API to generate natural language text.
34
34
 
35
35
  :param project_id: Required. The ID of the Google Cloud project that the
36
- service belongs to.
36
+ service belongs to (templated).
37
37
  :param location: Required. The ID of the Google Cloud location that the
38
- service belongs to.
38
+ service belongs to (templated).
39
39
  :param prompt: Required. Inputs or queries that a user or a program gives
40
- to the Vertex AI PaLM API, in order to elicit a specific response.
40
+ to the Vertex AI PaLM API, in order to elicit a specific response (templated).
41
41
  :param pretrained_model: By default uses the pre-trained model `text-bison`,
42
42
  optimized for performing natural language tasks such as classification,
43
43
  summarization, extraction, content creation, and ideation.
@@ -60,6 +60,8 @@ class PromptLanguageModelOperator(GoogleCloudBaseOperator):
60
60
  account from the list granting this role to the originating account (templated).
61
61
  """
62
62
 
63
+ template_fields = ("location", "project_id", "impersonation_chain", "prompt")
64
+
63
65
  def __init__(
64
66
  self,
65
67
  *,
@@ -116,11 +118,11 @@ class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator):
116
118
  Uses the Vertex AI PaLM API to generate natural language text.
117
119
 
118
120
  :param project_id: Required. The ID of the Google Cloud project that the
119
- service belongs to.
121
+ service belongs to (templated).
120
122
  :param location: Required. The ID of the Google Cloud location that the
121
- service belongs to.
123
+ service belongs to (templated).
122
124
  :param prompt: Required. Inputs or queries that a user or a program gives
123
- to the Vertex AI PaLM API, in order to elicit a specific response.
125
+ to the Vertex AI PaLM API, in order to elicit a specific response (templated).
124
126
  :param pretrained_model: By default uses the pre-trained model `textembedding-gecko`,
125
127
  optimized for performing text embeddings.
126
128
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
@@ -134,6 +136,8 @@ class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator):
134
136
  account from the list granting this role to the originating account (templated).
135
137
  """
136
138
 
139
+ template_fields = ("location", "project_id", "impersonation_chain", "prompt")
140
+
137
141
  def __init__(
138
142
  self,
139
143
  *,
@@ -178,11 +182,11 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
178
182
  Use the Vertex AI Gemini Pro foundation model to generate natural language text.
179
183
 
180
184
  :param project_id: Required. The ID of the Google Cloud project that the
181
- service belongs to.
185
+ service belongs to (templated).
182
186
  :param location: Required. The ID of the Google Cloud location that the
183
- service belongs to.
187
+ service belongs to (templated).
184
188
  :param prompt: Required. Inputs or queries that a user or a program gives
185
- to the Multi-modal model, in order to elicit a specific response.
189
+ to the Multi-modal model, in order to elicit a specific response (templated).
186
190
  :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
187
191
  supporting prompts with text-only input, including natural language
188
192
  tasks, multi-turn text and code chat, and code generation. It can
@@ -198,6 +202,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
198
202
  account from the list granting this role to the originating account (templated).
199
203
  """
200
204
 
205
+ template_fields = ("location", "project_id", "impersonation_chain", "prompt")
206
+
201
207
  def __init__(
202
208
  self,
203
209
  *,
@@ -240,11 +246,11 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
240
246
  Use the Vertex AI Gemini Pro foundation model to generate natural language text.
241
247
 
242
248
  :param project_id: Required. The ID of the Google Cloud project that the
243
- service belongs to.
249
+ service belongs to (templated).
244
250
  :param location: Required. The ID of the Google Cloud location that the
245
- service belongs to.
251
+ service belongs to (templated).
246
252
  :param prompt: Required. Inputs or queries that a user or a program gives
247
- to the Multi-modal model, in order to elicit a specific response.
253
+ to the Multi-modal model, in order to elicit a specific response (templated).
248
254
  :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
249
255
  supporting prompts with text-only input, including natural language
250
256
  tasks, multi-turn text and code chat, and code generation. It can
@@ -263,6 +269,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
263
269
  account from the list granting this role to the originating account (templated).
264
270
  """
265
271
 
272
+ template_fields = ("location", "project_id", "impersonation_chain", "prompt")
273
+
266
274
  def __init__(
267
275
  self,
268
276
  *,
@@ -102,7 +102,6 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
102
102
  Service Account Token Creator IAM role to the directly preceding identity, with first
103
103
  account from the list granting this role to the originating account (templated).
104
104
  :param deferrable: If True, run the task in the deferrable mode.
105
- Note that it requires calling the operator with `sync=False` parameter.
106
105
  :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
107
106
  The default is 300 seconds.
108
107
  """