apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. airflow/providers/google/__init__.py +3 -3
  2. airflow/providers/google/cloud/hooks/automl.py +1 -1
  3. airflow/providers/google/cloud/hooks/bigquery.py +64 -33
  4. airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
  5. airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
  6. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
  7. airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +246 -32
  9. airflow/providers/google/cloud/hooks/dataplex.py +6 -2
  10. airflow/providers/google/cloud/hooks/dlp.py +14 -14
  11. airflow/providers/google/cloud/hooks/gcs.py +6 -2
  12. airflow/providers/google/cloud/hooks/gdm.py +2 -2
  13. airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
  14. airflow/providers/google/cloud/hooks/mlengine.py +8 -4
  15. airflow/providers/google/cloud/hooks/pubsub.py +1 -1
  16. airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
  17. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
  18. airflow/providers/google/cloud/links/vertex_ai.py +2 -1
  19. airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
  20. airflow/providers/google/cloud/operators/automl.py +13 -12
  21. airflow/providers/google/cloud/operators/bigquery.py +36 -22
  22. airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
  23. airflow/providers/google/cloud/operators/bigtable.py +7 -6
  24. airflow/providers/google/cloud/operators/cloud_build.py +12 -11
  25. airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
  26. airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
  27. airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
  28. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
  29. airflow/providers/google/cloud/operators/compute.py +12 -11
  30. airflow/providers/google/cloud/operators/datacatalog.py +21 -20
  31. airflow/providers/google/cloud/operators/dataflow.py +59 -42
  32. airflow/providers/google/cloud/operators/datafusion.py +11 -10
  33. airflow/providers/google/cloud/operators/datapipeline.py +3 -2
  34. airflow/providers/google/cloud/operators/dataprep.py +5 -4
  35. airflow/providers/google/cloud/operators/dataproc.py +19 -16
  36. airflow/providers/google/cloud/operators/datastore.py +8 -7
  37. airflow/providers/google/cloud/operators/dlp.py +31 -30
  38. airflow/providers/google/cloud/operators/functions.py +4 -3
  39. airflow/providers/google/cloud/operators/gcs.py +66 -41
  40. airflow/providers/google/cloud/operators/kubernetes_engine.py +232 -12
  41. airflow/providers/google/cloud/operators/life_sciences.py +2 -1
  42. airflow/providers/google/cloud/operators/mlengine.py +11 -10
  43. airflow/providers/google/cloud/operators/pubsub.py +6 -5
  44. airflow/providers/google/cloud/operators/spanner.py +7 -6
  45. airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
  46. airflow/providers/google/cloud/operators/stackdriver.py +11 -10
  47. airflow/providers/google/cloud/operators/tasks.py +14 -13
  48. airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
  49. airflow/providers/google/cloud/operators/translate_speech.py +2 -1
  50. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
  51. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
  52. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
  53. airflow/providers/google/cloud/operators/vision.py +13 -12
  54. airflow/providers/google/cloud/operators/workflows.py +10 -9
  55. airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
  56. airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
  57. airflow/providers/google/cloud/sensors/bigtable.py +2 -1
  58. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
  59. airflow/providers/google/cloud/sensors/dataflow.py +239 -52
  60. airflow/providers/google/cloud/sensors/datafusion.py +2 -1
  61. airflow/providers/google/cloud/sensors/dataproc.py +3 -2
  62. airflow/providers/google/cloud/sensors/gcs.py +14 -12
  63. airflow/providers/google/cloud/sensors/tasks.py +2 -1
  64. airflow/providers/google/cloud/sensors/workflows.py +2 -1
  65. airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
  66. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
  67. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
  68. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
  69. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
  70. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
  71. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
  72. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
  73. airflow/providers/google/cloud/triggers/bigquery.py +14 -3
  74. airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
  75. airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
  76. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
  77. airflow/providers/google/cloud/triggers/dataflow.py +504 -4
  78. airflow/providers/google/cloud/triggers/dataproc.py +110 -26
  79. airflow/providers/google/cloud/triggers/mlengine.py +2 -1
  80. airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
  81. airflow/providers/google/common/hooks/base_google.py +45 -7
  82. airflow/providers/google/firebase/hooks/firestore.py +2 -2
  83. airflow/providers/google/firebase/operators/firestore.py +2 -1
  84. airflow/providers/google/get_provider_info.py +3 -2
  85. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/METADATA +8 -8
  86. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/RECORD +88 -89
  87. airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
  88. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/WHEEL +0 -0
  89. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,7 +19,8 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from typing import TYPE_CHECKING, Sequence
22
+ import asyncio
23
+ from typing import TYPE_CHECKING, Any, Sequence
23
24
 
24
25
  from deprecated import deprecated
25
26
  from google.api_core.client_options import ClientOptions
@@ -31,15 +32,24 @@ from google.cloud.aiplatform import (
31
32
  datasets,
32
33
  models,
33
34
  )
34
- from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient
35
+ from google.cloud.aiplatform_v1 import (
36
+ JobServiceAsyncClient,
37
+ JobServiceClient,
38
+ JobState,
39
+ PipelineServiceAsyncClient,
40
+ PipelineServiceClient,
41
+ PipelineState,
42
+ types,
43
+ )
35
44
 
36
45
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
37
46
  from airflow.providers.google.common.consts import CLIENT_INFO
38
- from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
47
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
39
48
 
40
49
  if TYPE_CHECKING:
41
50
  from google.api_core.operation import Operation
42
- from google.api_core.retry import Retry
51
+ from google.api_core.retry import AsyncRetry, Retry
52
+ from google.auth.credentials import Credentials
43
53
  from google.cloud.aiplatform_v1.services.job_service.pagers import ListCustomJobsPager
44
54
  from google.cloud.aiplatform_v1.services.pipeline_service.pagers import (
45
55
  ListPipelineJobsPager,
@@ -101,7 +111,7 @@ class CustomJobHook(GoogleBaseHook):
101
111
  self,
102
112
  display_name: str,
103
113
  container_uri: str,
104
- command: Sequence[str] = [],
114
+ command: Sequence[str] = (),
105
115
  model_serving_container_image_uri: str | None = None,
106
116
  model_serving_container_predict_route: str | None = None,
107
117
  model_serving_container_health_route: str | None = None,
@@ -168,7 +178,7 @@ class CustomJobHook(GoogleBaseHook):
168
178
  training_encryption_spec_key_name: str | None = None,
169
179
  model_encryption_spec_key_name: str | None = None,
170
180
  staging_bucket: str | None = None,
171
- ):
181
+ ) -> CustomPythonPackageTrainingJob:
172
182
  """Return CustomPythonPackageTrainingJob object."""
173
183
  return CustomPythonPackageTrainingJob(
174
184
  display_name=display_name,
@@ -218,7 +228,7 @@ class CustomJobHook(GoogleBaseHook):
218
228
  training_encryption_spec_key_name: str | None = None,
219
229
  model_encryption_spec_key_name: str | None = None,
220
230
  staging_bucket: str | None = None,
221
- ):
231
+ ) -> CustomTrainingJob:
222
232
  """Return CustomTrainingJob object."""
223
233
  return CustomTrainingJob(
224
234
  display_name=display_name,
@@ -246,10 +256,15 @@ class CustomJobHook(GoogleBaseHook):
246
256
  )
247
257
 
248
258
  @staticmethod
249
- def extract_model_id(obj: dict) -> str:
259
+ def extract_model_id(obj: dict[str, Any]) -> str:
250
260
  """Return unique id of the Model."""
251
261
  return obj["name"].rpartition("/")[-1]
252
262
 
263
+ @staticmethod
264
+ def extract_model_id_from_training_pipeline(training_pipeline: dict[str, Any]) -> str:
265
+ """Return a unique Model id from a serialized TrainingPipeline proto."""
266
+ return training_pipeline["model_to_upload"]["name"].rpartition("/")[-1]
267
+
253
268
  @staticmethod
254
269
  def extract_training_id(resource_name: str) -> str:
255
270
  """Return unique id of the Training pipeline."""
@@ -260,6 +275,11 @@ class CustomJobHook(GoogleBaseHook):
260
275
  """Return unique id of the Custom Job pipeline."""
261
276
  return custom_job_name.rpartition("/")[-1]
262
277
 
278
+ @staticmethod
279
+ def extract_custom_job_id_from_training_pipeline(training_pipeline: dict[str, Any]) -> str:
280
+ """Return a unique Custom Job id from a serialized TrainingPipeline proto."""
281
+ return training_pipeline["training_task_metadata"]["backingCustomJob"].rpartition("/")[-1]
282
+
263
283
  def wait_for_operation(self, operation: Operation, timeout: float | None = None):
264
284
  """Wait for long-lasting operation to complete."""
265
285
  try:
@@ -310,7 +330,7 @@ class CustomJobHook(GoogleBaseHook):
310
330
  model_version_aliases: list[str] | None = None,
311
331
  model_version_description: str | None = None,
312
332
  ) -> tuple[models.Model | None, str, str]:
313
- """Run Job for training pipeline."""
333
+ """Run a training pipeline job and wait until its completion."""
314
334
  model = job.run(
315
335
  dataset=dataset,
316
336
  annotation_schema_uri=annotation_schema_uri,
@@ -609,7 +629,7 @@ class CustomJobHook(GoogleBaseHook):
609
629
  region: str,
610
630
  display_name: str,
611
631
  container_uri: str,
612
- command: Sequence[str] = [],
632
+ command: Sequence[str] = (),
613
633
  model_serving_container_image_uri: str | None = None,
614
634
  model_serving_container_predict_route: str | None = None,
615
635
  model_serving_container_health_route: str | None = None,
@@ -1754,80 +1774,1180 @@ class CustomJobHook(GoogleBaseHook):
1754
1774
  return model, training_id, custom_job_id
1755
1775
 
1756
1776
  @GoogleBaseHook.fallback_to_default_project_id
1757
- @deprecated(
1758
- reason="Please use `PipelineJobHook.delete_pipeline_job`",
1759
- category=AirflowProviderDeprecationWarning,
1760
- )
1761
- def delete_pipeline_job(
1777
+ def submit_custom_container_training_job(
1762
1778
  self,
1779
+ *,
1763
1780
  project_id: str,
1764
1781
  region: str,
1765
- pipeline_job: str,
1766
- retry: Retry | _MethodDefault = DEFAULT,
1767
- timeout: float | None = None,
1768
- metadata: Sequence[tuple[str, str]] = (),
1769
- ) -> Operation:
1782
+ display_name: str,
1783
+ container_uri: str,
1784
+ command: Sequence[str] = (),
1785
+ model_serving_container_image_uri: str | None = None,
1786
+ model_serving_container_predict_route: str | None = None,
1787
+ model_serving_container_health_route: str | None = None,
1788
+ model_serving_container_command: Sequence[str] | None = None,
1789
+ model_serving_container_args: Sequence[str] | None = None,
1790
+ model_serving_container_environment_variables: dict[str, str] | None = None,
1791
+ model_serving_container_ports: Sequence[int] | None = None,
1792
+ model_description: str | None = None,
1793
+ model_instance_schema_uri: str | None = None,
1794
+ model_parameters_schema_uri: str | None = None,
1795
+ model_prediction_schema_uri: str | None = None,
1796
+ parent_model: str | None = None,
1797
+ is_default_version: bool | None = None,
1798
+ model_version_aliases: list[str] | None = None,
1799
+ model_version_description: str | None = None,
1800
+ labels: dict[str, str] | None = None,
1801
+ training_encryption_spec_key_name: str | None = None,
1802
+ model_encryption_spec_key_name: str | None = None,
1803
+ staging_bucket: str | None = None,
1804
+ # RUN
1805
+ dataset: None
1806
+ | (
1807
+ datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset
1808
+ ) = None,
1809
+ annotation_schema_uri: str | None = None,
1810
+ model_display_name: str | None = None,
1811
+ model_labels: dict[str, str] | None = None,
1812
+ base_output_dir: str | None = None,
1813
+ service_account: str | None = None,
1814
+ network: str | None = None,
1815
+ bigquery_destination: str | None = None,
1816
+ args: list[str | float | int] | None = None,
1817
+ environment_variables: dict[str, str] | None = None,
1818
+ replica_count: int = 1,
1819
+ machine_type: str = "n1-standard-4",
1820
+ accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
1821
+ accelerator_count: int = 0,
1822
+ boot_disk_type: str = "pd-ssd",
1823
+ boot_disk_size_gb: int = 100,
1824
+ training_fraction_split: float | None = None,
1825
+ validation_fraction_split: float | None = None,
1826
+ test_fraction_split: float | None = None,
1827
+ training_filter_split: str | None = None,
1828
+ validation_filter_split: str | None = None,
1829
+ test_filter_split: str | None = None,
1830
+ predefined_split_column_name: str | None = None,
1831
+ timestamp_split_column_name: str | None = None,
1832
+ tensorboard: str | None = None,
1833
+ ) -> CustomContainerTrainingJob:
1770
1834
  """
1771
- Delete a PipelineJob.
1772
-
1773
- This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method.
1835
+ Create and submit a Custom Container Training Job pipeline, then exit without waiting for it to complete.
1774
1836
 
1775
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
1776
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
1777
- :param pipeline_job: Required. The name of the PipelineJob resource to be deleted.
1778
- :param retry: Designation of what errors, if any, should be retried.
1779
- :param timeout: The timeout for this request.
1780
- :param metadata: Strings which should be sent along with the request as metadata.
1781
- """
1782
- client = self.get_pipeline_service_client(region)
1783
- name = client.pipeline_job_path(project_id, region, pipeline_job)
1837
+ :param display_name: Required. The user-defined name of this TrainingPipeline.
1838
+ :param command: The command to be invoked when the container is started.
1839
+ It overrides the entrypoint instruction in Dockerfile when provided
1840
+ :param container_uri: Required: Uri of the training container image in the GCR.
1841
+ :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
1842
+ of the Model serving container suitable for serving the model produced by the
1843
+ training script.
1844
+ :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
1845
+ HTTP path to send prediction requests to the container, and which must be supported
1846
+ by it. If not specified a default HTTP path will be used by Vertex AI.
1847
+ :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
1848
+ HTTP path to send health check requests to the container, and which must be supported
1849
+ by it. If not specified a standard HTTP path will be used by AI Platform.
1850
+ :param model_serving_container_command: The command with which the container is run. Not executed
1851
+ within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
1852
+ Variable references $(VAR_NAME) are expanded using the container's
1853
+ environment. If a variable cannot be resolved, the reference in the
1854
+ input string will be unchanged. The $(VAR_NAME) syntax can be escaped
1855
+ with a double $$, ie: $$(VAR_NAME). Escaped references will never be
1856
+ expanded, regardless of whether the variable exists or not.
1857
+ :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
1858
+ this is not provided. Variable references $(VAR_NAME) are expanded using the
1859
+ container's environment. If a variable cannot be resolved, the reference
1860
+ in the input string will be unchanged. The $(VAR_NAME) syntax can be
1861
+ escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
1862
+ never be expanded, regardless of whether the variable exists or not.
1863
+ :param model_serving_container_environment_variables: The environment variables that are to be
1864
+ present in the container. Should be a dictionary where keys are environment variable names
1865
+ and values are environment variable values for those names.
1866
+ :param model_serving_container_ports: Declaration of ports that are exposed by the container. This
1867
+ field is primarily informational, it gives Vertex AI information about the
1868
+ network connections the container uses. Listing or not a port here has
1869
+ no impact on whether the port is actually exposed, any port listening on
1870
+ the default "0.0.0.0" address inside a container will be accessible from
1871
+ the network.
1872
+ :param model_description: The description of the Model.
1873
+ :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
1874
+ Storage describing the format of a single instance, which
1875
+ are used in
1876
+ ``PredictRequest.instances``,
1877
+ ``ExplainRequest.instances``
1878
+ and
1879
+ ``BatchPredictionJob.input_config``.
1880
+ The schema is defined as an OpenAPI 3.0.2 `Schema
1881
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
1882
+ AutoML Models always have this field populated by AI
1883
+ Platform. Note: The URI given on output will be immutable
1884
+ and probably different, including the URI scheme, than the
1885
+ one given on input. The output URI will point to a location
1886
+ where the user only has a read access.
1887
+ :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
1888
+ Storage describing the parameters of prediction and
1889
+ explanation via
1890
+ ``PredictRequest.parameters``,
1891
+ ``ExplainRequest.parameters``
1892
+ and
1893
+ ``BatchPredictionJob.model_parameters``.
1894
+ The schema is defined as an OpenAPI 3.0.2 `Schema
1895
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
1896
+ AutoML Models always have this field populated by AI
1897
+ Platform, if no parameters are supported it is set to an
1898
+ empty string. Note: The URI given on output will be
1899
+ immutable and probably different, including the URI scheme,
1900
+ than the one given on input. The output URI will point to a
1901
+ location where the user only has a read access.
1902
+ :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
1903
+ Storage describing the format of a single prediction
1904
+ produced by this Model, which are returned via
1905
+ ``PredictResponse.predictions``,
1906
+ ``ExplainResponse.explanations``,
1907
+ and
1908
+ ``BatchPredictionJob.output_config``.
1909
+ The schema is defined as an OpenAPI 3.0.2 `Schema
1910
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
1911
+ AutoML Models always have this field populated by AI
1912
+ Platform. Note: The URI given on output will be immutable
1913
+ and probably different, including the URI scheme, than the
1914
+ one given on input. The output URI will point to a location
1915
+ where the user only has a read access.
1916
+ :param parent_model: Optional. The resource name or model ID of an existing model.
1917
+ The new model uploaded by this job will be a version of `parent_model`.
1918
+ Only set this field when training a new version of an existing model.
1919
+ :param is_default_version: Optional. When set to True, the newly uploaded model version will
1920
+ automatically have alias "default" included. Subsequent uses of
1921
+ the model produced by this job without a version specified will
1922
+ use this "default" version.
1923
+ When set to False, the "default" alias will not be moved.
1924
+ Actions targeting the model version produced by this job will need
1925
+ to specifically reference this version by ID or alias.
1926
+ New model uploads, i.e. version 1, will always be "default" aliased.
1927
+ :param model_version_aliases: Optional. User provided version aliases so that the model version
1928
+ uploaded by this job can be referenced via alias instead of
1929
+ auto-generated version ID. A default version alias will be created
1930
+ for the first version of the model.
1931
+ The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
1932
+ :param model_version_description: Optional. The description of the model version
1933
+ being uploaded by this job.
1934
+ :param project_id: Project to run training in.
1935
+ :param region: Location to run training in.
1936
+ :param labels: Optional. The labels with user-defined metadata to
1937
+ organize TrainingPipelines.
1938
+ Label keys and values can be no longer than 64
1939
+ characters, can only
1940
+ contain lowercase letters, numeric characters,
1941
+ underscores and dashes. International characters
1942
+ are allowed.
1943
+ See https://goo.gl/xmQnxf for more information
1944
+ and examples of labels.
1945
+ :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
1946
+ managed encryption key used to protect the training pipeline. Has the
1947
+ form:
1948
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
1949
+ The key needs to be in the same region as where the compute
1950
+ resource is created.
1784
1951
 
1785
- result = client.delete_pipeline_job(
1786
- request={
1787
- "name": name,
1788
- },
1789
- retry=retry,
1790
- timeout=timeout,
1791
- metadata=metadata,
1792
- )
1793
- return result
1952
+ If set, this TrainingPipeline will be secured by this key.
1794
1953
 
1795
- @GoogleBaseHook.fallback_to_default_project_id
1796
- def delete_training_pipeline(
1797
- self,
1798
- project_id: str,
1799
- region: str,
1800
- training_pipeline: str,
1801
- retry: Retry | _MethodDefault = DEFAULT,
1802
- timeout: float | None = None,
1803
- metadata: Sequence[tuple[str, str]] = (),
1804
- ) -> Operation:
1805
- """
1806
- Delete a TrainingPipeline.
1954
+ Note: Model trained by this TrainingPipeline is also secured
1955
+ by this key if ``model_to_upload`` is not set separately.
1956
+ :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
1957
+ managed encryption key used to protect the model. Has the
1958
+ form:
1959
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
1960
+ The key needs to be in the same region as where the compute
1961
+ resource is created.
1807
1962
 
1808
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
1809
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
1810
- :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted.
1811
- :param retry: Designation of what errors, if any, should be retried.
1812
- :param timeout: The timeout for this request.
1813
- :param metadata: Strings which should be sent along with the request as metadata.
1814
- """
1815
- client = self.get_pipeline_service_client(region)
1816
- name = client.training_pipeline_path(project_id, region, training_pipeline)
1963
+ If set, the trained Model will be secured by this key.
1964
+ :param staging_bucket: Bucket used to stage source and training artifacts.
1965
+ :param dataset: Vertex AI to fit this training against.
1966
+ :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
1967
+ annotation schema. The schema is defined as an OpenAPI 3.0.2
1968
+ [Schema Object]
1969
+ (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
1817
1970
 
1818
- result = client.delete_training_pipeline(
1819
- request={
1820
- "name": name,
1821
- },
1822
- retry=retry,
1823
- timeout=timeout,
1824
- metadata=metadata,
1825
- )
1826
- return result
1971
+ Only Annotations that both match this schema and belong to
1972
+ DataItems not ignored by the split method are used in
1973
+ respectively training, validation or test role, depending on
1974
+ the role of the DataItem they are on.
1827
1975
 
1828
- @GoogleBaseHook.fallback_to_default_project_id
1829
- def delete_custom_job(
1830
- self,
1976
+ When used in conjunction with
1977
+ ``annotations_filter``,
1978
+ the Annotations used for training are filtered by both
1979
+ ``annotations_filter``
1980
+ and
1981
+ ``annotation_schema_uri``.
1982
+ :param model_display_name: If the script produces a managed Vertex AI Model. The display name of
1983
+ the Model. The name can be up to 128 characters long and can be consist
1984
+ of any UTF-8 characters.
1985
+
1986
+ If not provided upon creation, the job's display_name is used.
1987
+ :param model_labels: Optional. The labels with user-defined metadata to
1988
+ organize your Models.
1989
+ Label keys and values can be no longer than 64
1990
+ characters, can only
1991
+ contain lowercase letters, numeric characters,
1992
+ underscores and dashes. International characters
1993
+ are allowed.
1994
+ See https://goo.gl/xmQnxf for more information
1995
+ and examples of labels.
1996
+ :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
1997
+ staging directory will be used.
1998
+
1999
+ Vertex AI sets the following environment variables when it runs your training code:
2000
+
2001
+ - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
2002
+ i.e. <base_output_dir>/model/
2003
+ - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
2004
+ i.e. <base_output_dir>/checkpoints/
2005
+ - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
2006
+ logs, i.e. <base_output_dir>/logs/
2007
+
2008
+ :param service_account: Specifies the service account for workload run-as account.
2009
+ Users submitting jobs must have act-as permission on this run-as account.
2010
+ :param network: The full name of the Compute Engine network to which the job
2011
+ should be peered.
2012
+ Private services access must already be configured for the network.
2013
+ If left unspecified, the job is not peered with any network.
2014
+ :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
2015
+ The BigQuery project location where the training data is to
2016
+ be written to. In the given project a new dataset is created
2017
+ with name
2018
+ ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
2019
+ where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
2020
+ training input data will be written into that dataset. In
2021
+ the dataset three tables will be created, ``training``,
2022
+ ``validation`` and ``test``.
2023
+
2024
+ - AIP_DATA_FORMAT = "bigquery".
2025
+ - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
2026
+ - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
2027
+ - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
2028
+ :param args: Command line arguments to be passed to the Python script.
2029
+ :param environment_variables: Environment variables to be passed to the container.
2030
+ Should be a dictionary where keys are environment variable names
2031
+ and values are environment variable values for those names.
2032
+ At most 10 environment variables can be specified.
2033
+ The Name of the environment variable must be unique.
2034
+ :param replica_count: The number of worker replicas. If replica count = 1 then one chief
2035
+ replica will be provisioned. If replica_count > 1 the remainder will be
2036
+ provisioned as a worker replica pool.
2037
+ :param machine_type: The type of machine to use for training.
2038
+ :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
2039
+ NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
2040
+ NVIDIA_TESLA_T4
2041
+ :param accelerator_count: The number of accelerators to attach to a worker replica.
2042
+ :param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
2043
+ Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
2044
+ `pd-standard` (Persistent Disk Hard Disk Drive).
2045
+ :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
2046
+ boot disk size must be within the range of [100, 64000].
2047
+ :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
2048
+ the Model. This is ignored if Dataset is not provided.
2049
+ :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
2050
+ validate the Model. This is ignored if Dataset is not provided.
2051
+ :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
2052
+ the Model. This is ignored if Dataset is not provided.
2053
+ :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
2054
+ this filter are used to train the Model. A filter with same syntax
2055
+ as the one used in DatasetService.ListDataItems may be used. If a
2056
+ single DataItem is matched by more than one of the FilterSplit filters,
2057
+ then it is assigned to the first set that applies to it in the training,
2058
+ validation, test order. This is ignored if Dataset is not provided.
2059
+ :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
2060
+ this filter are used to validate the Model. A filter with same syntax
2061
+ as the one used in DatasetService.ListDataItems may be used. If a
2062
+ single DataItem is matched by more than one of the FilterSplit filters,
2063
+ then it is assigned to the first set that applies to it in the training,
2064
+ validation, test order. This is ignored if Dataset is not provided.
2065
+ :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
2066
+ this filter are used to test the Model. A filter with same syntax
2067
+ as the one used in DatasetService.ListDataItems may be used. If a
2068
+ single DataItem is matched by more than one of the FilterSplit filters,
2069
+ then it is assigned to the first set that applies to it in the training,
2070
+ validation, test order. This is ignored if Dataset is not provided.
2071
+ :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
2072
+ columns. The value of the key (either the label's value or
2073
+ value in the column) must be one of {``training``,
2074
+ ``validation``, ``test``}, and it defines to which set the
2075
+ given piece of data is assigned. If for a piece of data the
2076
+ key is not present or has an invalid value, that piece is
2077
+ ignored by the pipeline.
2078
+
2079
+ Supported only for tabular and time series Datasets.
2080
+ :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
2081
+ columns. The value of the key values of the key (the values in
2082
+ the column) must be in RFC 3339 `date-time` format, where
2083
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
2084
+ piece of data the key is not present or has an invalid value,
2085
+ that piece is ignored by the pipeline.
2086
+
2087
+ Supported only for tabular and time series Datasets.
2088
+ :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
2089
+ logs. Format:
2090
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
2091
+ For more information on configuring your service account please visit:
2092
+ https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
2093
+ """
2094
+ self._job = self.get_custom_container_training_job(
2095
+ project=project_id,
2096
+ location=region,
2097
+ display_name=display_name,
2098
+ container_uri=container_uri,
2099
+ command=command,
2100
+ model_serving_container_image_uri=model_serving_container_image_uri,
2101
+ model_serving_container_predict_route=model_serving_container_predict_route,
2102
+ model_serving_container_health_route=model_serving_container_health_route,
2103
+ model_serving_container_command=model_serving_container_command,
2104
+ model_serving_container_args=model_serving_container_args,
2105
+ model_serving_container_environment_variables=model_serving_container_environment_variables,
2106
+ model_serving_container_ports=model_serving_container_ports,
2107
+ model_description=model_description,
2108
+ model_instance_schema_uri=model_instance_schema_uri,
2109
+ model_parameters_schema_uri=model_parameters_schema_uri,
2110
+ model_prediction_schema_uri=model_prediction_schema_uri,
2111
+ labels=labels,
2112
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
2113
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
2114
+ staging_bucket=staging_bucket,
2115
+ )
2116
+
2117
+ if not self._job:
2118
+ raise AirflowException("CustomContainerTrainingJob instance creation failed.")
2119
+
2120
+ self._job.submit(
2121
+ dataset=dataset,
2122
+ annotation_schema_uri=annotation_schema_uri,
2123
+ model_display_name=model_display_name,
2124
+ model_labels=model_labels,
2125
+ base_output_dir=base_output_dir,
2126
+ service_account=service_account,
2127
+ network=network,
2128
+ bigquery_destination=bigquery_destination,
2129
+ args=args,
2130
+ environment_variables=environment_variables,
2131
+ replica_count=replica_count,
2132
+ machine_type=machine_type,
2133
+ accelerator_type=accelerator_type,
2134
+ accelerator_count=accelerator_count,
2135
+ boot_disk_type=boot_disk_type,
2136
+ boot_disk_size_gb=boot_disk_size_gb,
2137
+ training_fraction_split=training_fraction_split,
2138
+ validation_fraction_split=validation_fraction_split,
2139
+ test_fraction_split=test_fraction_split,
2140
+ training_filter_split=training_filter_split,
2141
+ validation_filter_split=validation_filter_split,
2142
+ test_filter_split=test_filter_split,
2143
+ predefined_split_column_name=predefined_split_column_name,
2144
+ timestamp_split_column_name=timestamp_split_column_name,
2145
+ tensorboard=tensorboard,
2146
+ parent_model=parent_model,
2147
+ is_default_version=is_default_version,
2148
+ model_version_aliases=model_version_aliases,
2149
+ model_version_description=model_version_description,
2150
+ sync=False,
2151
+ )
2152
+ return self._job
2153
+
2154
+ @GoogleBaseHook.fallback_to_default_project_id
2155
+ def submit_custom_python_package_training_job(
2156
+ self,
2157
+ *,
2158
+ project_id: str,
2159
+ region: str,
2160
+ display_name: str,
2161
+ python_package_gcs_uri: str,
2162
+ python_module_name: str,
2163
+ container_uri: str,
2164
+ model_serving_container_image_uri: str | None = None,
2165
+ model_serving_container_predict_route: str | None = None,
2166
+ model_serving_container_health_route: str | None = None,
2167
+ model_serving_container_command: Sequence[str] | None = None,
2168
+ model_serving_container_args: Sequence[str] | None = None,
2169
+ model_serving_container_environment_variables: dict[str, str] | None = None,
2170
+ model_serving_container_ports: Sequence[int] | None = None,
2171
+ model_description: str | None = None,
2172
+ model_instance_schema_uri: str | None = None,
2173
+ model_parameters_schema_uri: str | None = None,
2174
+ model_prediction_schema_uri: str | None = None,
2175
+ labels: dict[str, str] | None = None,
2176
+ training_encryption_spec_key_name: str | None = None,
2177
+ model_encryption_spec_key_name: str | None = None,
2178
+ staging_bucket: str | None = None,
2179
+ # RUN
2180
+ dataset: None
2181
+ | (
2182
+ datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset
2183
+ ) = None,
2184
+ annotation_schema_uri: str | None = None,
2185
+ model_display_name: str | None = None,
2186
+ model_labels: dict[str, str] | None = None,
2187
+ base_output_dir: str | None = None,
2188
+ service_account: str | None = None,
2189
+ network: str | None = None,
2190
+ bigquery_destination: str | None = None,
2191
+ args: list[str | float | int] | None = None,
2192
+ environment_variables: dict[str, str] | None = None,
2193
+ replica_count: int = 1,
2194
+ machine_type: str = "n1-standard-4",
2195
+ accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
2196
+ accelerator_count: int = 0,
2197
+ boot_disk_type: str = "pd-ssd",
2198
+ boot_disk_size_gb: int = 100,
2199
+ training_fraction_split: float | None = None,
2200
+ validation_fraction_split: float | None = None,
2201
+ test_fraction_split: float | None = None,
2202
+ training_filter_split: str | None = None,
2203
+ validation_filter_split: str | None = None,
2204
+ test_filter_split: str | None = None,
2205
+ predefined_split_column_name: str | None = None,
2206
+ timestamp_split_column_name: str | None = None,
2207
+ tensorboard: str | None = None,
2208
+ parent_model: str | None = None,
2209
+ is_default_version: bool | None = None,
2210
+ model_version_aliases: list[str] | None = None,
2211
+ model_version_description: str | None = None,
2212
+ ) -> CustomPythonPackageTrainingJob:
2213
+ """
2214
+ Create and submit a Custom Python Package Training Job pipeline, then exit without waiting for it to complete.
2215
+
2216
+ :param display_name: Required. The user-defined name of this TrainingPipeline.
2217
+ :param python_package_gcs_uri: Required: GCS location of the training python package.
2218
+ :param python_module_name: Required: The module name of the training python package.
2219
+ :param container_uri: Required: Uri of the training container image in the GCR.
2220
+ :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
2221
+ of the Model serving container suitable for serving the model produced by the
2222
+ training script.
2223
+ :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
2224
+ HTTP path to send prediction requests to the container, and which must be supported
2225
+ by it. If not specified a default HTTP path will be used by Vertex AI.
2226
+ :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
2227
+ HTTP path to send health check requests to the container, and which must be supported
2228
+ by it. If not specified a standard HTTP path will be used by AI Platform.
2229
+ :param model_serving_container_command: The command with which the container is run. Not executed
2230
+ within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
2231
+ Variable references $(VAR_NAME) are expanded using the container's
2232
+ environment. If a variable cannot be resolved, the reference in the
2233
+ input string will be unchanged. The $(VAR_NAME) syntax can be escaped
2234
+ with a double $$, ie: $$(VAR_NAME). Escaped references will never be
2235
+ expanded, regardless of whether the variable exists or not.
2236
+ :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
2237
+ this is not provided. Variable references $(VAR_NAME) are expanded using the
2238
+ container's environment. If a variable cannot be resolved, the reference
2239
+ in the input string will be unchanged. The $(VAR_NAME) syntax can be
2240
+ escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
2241
+ never be expanded, regardless of whether the variable exists or not.
2242
+ :param model_serving_container_environment_variables: The environment variables that are to be
2243
+ present in the container. Should be a dictionary where keys are environment variable names
2244
+ and values are environment variable values for those names.
2245
+ :param model_serving_container_ports: Declaration of ports that are exposed by the container. This
2246
+ field is primarily informational, it gives Vertex AI information about the
2247
+ network connections the container uses. Listing or not a port here has
2248
+ no impact on whether the port is actually exposed, any port listening on
2249
+ the default "0.0.0.0" address inside a container will be accessible from
2250
+ the network.
2251
+ :param model_description: The description of the Model.
2252
+ :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
2253
+ Storage describing the format of a single instance, which
2254
+ are used in
2255
+ ``PredictRequest.instances``,
2256
+ ``ExplainRequest.instances``
2257
+ and
2258
+ ``BatchPredictionJob.input_config``.
2259
+ The schema is defined as an OpenAPI 3.0.2 `Schema
2260
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
2261
+ AutoML Models always have this field populated by AI
2262
+ Platform. Note: The URI given on output will be immutable
2263
+ and probably different, including the URI scheme, than the
2264
+ one given on input. The output URI will point to a location
2265
+ where the user only has a read access.
2266
+ :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
2267
+ Storage describing the parameters of prediction and
2268
+ explanation via
2269
+ ``PredictRequest.parameters``,
2270
+ ``ExplainRequest.parameters``
2271
+ and
2272
+ ``BatchPredictionJob.model_parameters``.
2273
+ The schema is defined as an OpenAPI 3.0.2 `Schema
2274
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
2275
+ AutoML Models always have this field populated by AI
2276
+ Platform, if no parameters are supported it is set to an
2277
+ empty string. Note: The URI given on output will be
2278
+ immutable and probably different, including the URI scheme,
2279
+ than the one given on input. The output URI will point to a
2280
+ location where the user only has a read access.
2281
+ :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
2282
+ Storage describing the format of a single prediction
2283
+ produced by this Model, which are returned via
2284
+ ``PredictResponse.predictions``,
2285
+ ``ExplainResponse.explanations``,
2286
+ and
2287
+ ``BatchPredictionJob.output_config``.
2288
+ The schema is defined as an OpenAPI 3.0.2 `Schema
2289
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
2290
+ AutoML Models always have this field populated by AI
2291
+ Platform. Note: The URI given on output will be immutable
2292
+ and probably different, including the URI scheme, than the
2293
+ one given on input. The output URI will point to a location
2294
+ where the user only has a read access.
2295
+ :param parent_model: Optional. The resource name or model ID of an existing model.
2296
+ The new model uploaded by this job will be a version of `parent_model`.
2297
+ Only set this field when training a new version of an existing model.
2298
+ :param is_default_version: Optional. When set to True, the newly uploaded model version will
2299
+ automatically have alias "default" included. Subsequent uses of
2300
+ the model produced by this job without a version specified will
2301
+ use this "default" version.
2302
+ When set to False, the "default" alias will not be moved.
2303
+ Actions targeting the model version produced by this job will need
2304
+ to specifically reference this version by ID or alias.
2305
+ New model uploads, i.e. version 1, will always be "default" aliased.
2306
+ :param model_version_aliases: Optional. User provided version aliases so that the model version
2307
+ uploaded by this job can be referenced via alias instead of
2308
+ auto-generated version ID. A default version alias will be created
2309
+ for the first version of the model.
2310
+ The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
2311
+ :param model_version_description: Optional. The description of the model version
2312
+ being uploaded by this job.
2313
+ :param project_id: Project to run training in.
2314
+ :param region: Location to run training in.
2315
+ :param labels: Optional. The labels with user-defined metadata to
2316
+ organize TrainingPipelines.
2317
+ Label keys and values can be no longer than 64
2318
+ characters, can only
2319
+ contain lowercase letters, numeric characters,
2320
+ underscores and dashes. International characters
2321
+ are allowed.
2322
+ See https://goo.gl/xmQnxf for more information
2323
+ and examples of labels.
2324
+ :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
2325
+ managed encryption key used to protect the training pipeline. Has the
2326
+ form:
2327
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
2328
+ The key needs to be in the same region as where the compute
2329
+ resource is created.
2330
+
2331
+ If set, this TrainingPipeline will be secured by this key.
2332
+
2333
+ Note: Model trained by this TrainingPipeline is also secured
2334
+ by this key if ``model_to_upload`` is not set separately.
2335
+ :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
2336
+ managed encryption key used to protect the model. Has the
2337
+ form:
2338
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
2339
+ The key needs to be in the same region as where the compute
2340
+ resource is created.
2341
+
2342
+ If set, the trained Model will be secured by this key.
2343
+ :param staging_bucket: Bucket used to stage source and training artifacts.
2344
+ :param dataset: Vertex AI to fit this training against.
2345
+ :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
2346
+ annotation schema. The schema is defined as an OpenAPI 3.0.2
2347
+ [Schema Object]
2348
+ (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
2349
+
2350
+ Only Annotations that both match this schema and belong to
2351
+ DataItems not ignored by the split method are used in
2352
+ respectively training, validation or test role, depending on
2353
+ the role of the DataItem they are on.
2354
+
2355
+ When used in conjunction with
2356
+ ``annotations_filter``,
2357
+ the Annotations used for training are filtered by both
2358
+ ``annotations_filter``
2359
+ and
2360
+ ``annotation_schema_uri``.
2361
+ :param model_display_name: If the script produces a managed Vertex AI Model. The display name of
2362
+ the Model. The name can be up to 128 characters long and can be consist
2363
+ of any UTF-8 characters.
2364
+
2365
+ If not provided upon creation, the job's display_name is used.
2366
+ :param model_labels: Optional. The labels with user-defined metadata to
2367
+ organize your Models.
2368
+ Label keys and values can be no longer than 64
2369
+ characters, can only
2370
+ contain lowercase letters, numeric characters,
2371
+ underscores and dashes. International characters
2372
+ are allowed.
2373
+ See https://goo.gl/xmQnxf for more information
2374
+ and examples of labels.
2375
+ :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
2376
+ staging directory will be used.
2377
+
2378
+ Vertex AI sets the following environment variables when it runs your training code:
2379
+
2380
+ - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
2381
+ i.e. <base_output_dir>/model/
2382
+ - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
2383
+ i.e. <base_output_dir>/checkpoints/
2384
+ - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
2385
+ logs, i.e. <base_output_dir>/logs/
2386
+ :param service_account: Specifies the service account for workload run-as account.
2387
+ Users submitting jobs must have act-as permission on this run-as account.
2388
+ :param network: The full name of the Compute Engine network to which the job
2389
+ should be peered.
2390
+ Private services access must already be configured for the network.
2391
+ If left unspecified, the job is not peered with any network.
2392
+ :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
2393
+ The BigQuery project location where the training data is to
2394
+ be written to. In the given project a new dataset is created
2395
+ with name
2396
+ ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
2397
+ where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
2398
+ training input data will be written into that dataset. In
2399
+ the dataset three tables will be created, ``training``,
2400
+ ``validation`` and ``test``.
2401
+
2402
+ - AIP_DATA_FORMAT = "bigquery".
2403
+ - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
2404
+ - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
2405
+ - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
2406
+ :param args: Command line arguments to be passed to the Python script.
2407
+ :param environment_variables: Environment variables to be passed to the container.
2408
+ Should be a dictionary where keys are environment variable names
2409
+ and values are environment variable values for those names.
2410
+ At most 10 environment variables can be specified.
2411
+ The Name of the environment variable must be unique.
2412
+ :param replica_count: The number of worker replicas. If replica count = 1 then one chief
2413
+ replica will be provisioned. If replica_count > 1 the remainder will be
2414
+ provisioned as a worker replica pool.
2415
+ :param machine_type: The type of machine to use for training.
2416
+ :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
2417
+ NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
2418
+ NVIDIA_TESLA_T4
2419
+ :param accelerator_count: The number of accelerators to attach to a worker replica.
2420
+ :param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
2421
+ Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
2422
+ `pd-standard` (Persistent Disk Hard Disk Drive).
2423
+ :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
2424
+ boot disk size must be within the range of [100, 64000].
2425
+ :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
2426
+ the Model. This is ignored if Dataset is not provided.
2427
+ :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
2428
+ validate the Model. This is ignored if Dataset is not provided.
2429
+ :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
2430
+ the Model. This is ignored if Dataset is not provided.
2431
+ :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
2432
+ this filter are used to train the Model. A filter with same syntax
2433
+ as the one used in DatasetService.ListDataItems may be used. If a
2434
+ single DataItem is matched by more than one of the FilterSplit filters,
2435
+ then it is assigned to the first set that applies to it in the training,
2436
+ validation, test order. This is ignored if Dataset is not provided.
2437
+ :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
2438
+ this filter are used to validate the Model. A filter with same syntax
2439
+ as the one used in DatasetService.ListDataItems may be used. If a
2440
+ single DataItem is matched by more than one of the FilterSplit filters,
2441
+ then it is assigned to the first set that applies to it in the training,
2442
+ validation, test order. This is ignored if Dataset is not provided.
2443
+ :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
2444
+ this filter are used to test the Model. A filter with same syntax
2445
+ as the one used in DatasetService.ListDataItems may be used. If a
2446
+ single DataItem is matched by more than one of the FilterSplit filters,
2447
+ then it is assigned to the first set that applies to it in the training,
2448
+ validation, test order. This is ignored if Dataset is not provided.
2449
+ :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
2450
+ columns. The value of the key (either the label's value or
2451
+ value in the column) must be one of {``training``,
2452
+ ``validation``, ``test``}, and it defines to which set the
2453
+ given piece of data is assigned. If for a piece of data the
2454
+ key is not present or has an invalid value, that piece is
2455
+ ignored by the pipeline.
2456
+
2457
+ Supported only for tabular and time series Datasets.
2458
+ :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
2459
+ columns. The value of the key values of the key (the values in
2460
+ the column) must be in RFC 3339 `date-time` format, where
2461
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
2462
+ piece of data the key is not present or has an invalid value,
2463
+ that piece is ignored by the pipeline.
2464
+
2465
+ Supported only for tabular and time series Datasets.
2466
+ :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
2467
+ logs. Format:
2468
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
2469
+ For more information on configuring your service account please visit:
2470
+ https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
2471
+ """
2472
+ self._job = self.get_custom_python_package_training_job(
2473
+ project=project_id,
2474
+ location=region,
2475
+ display_name=display_name,
2476
+ python_package_gcs_uri=python_package_gcs_uri,
2477
+ python_module_name=python_module_name,
2478
+ container_uri=container_uri,
2479
+ model_serving_container_image_uri=model_serving_container_image_uri,
2480
+ model_serving_container_predict_route=model_serving_container_predict_route,
2481
+ model_serving_container_health_route=model_serving_container_health_route,
2482
+ model_serving_container_command=model_serving_container_command,
2483
+ model_serving_container_args=model_serving_container_args,
2484
+ model_serving_container_environment_variables=model_serving_container_environment_variables,
2485
+ model_serving_container_ports=model_serving_container_ports,
2486
+ model_description=model_description,
2487
+ model_instance_schema_uri=model_instance_schema_uri,
2488
+ model_parameters_schema_uri=model_parameters_schema_uri,
2489
+ model_prediction_schema_uri=model_prediction_schema_uri,
2490
+ labels=labels,
2491
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
2492
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
2493
+ staging_bucket=staging_bucket,
2494
+ )
2495
+
2496
+ if not self._job:
2497
+ raise AirflowException("CustomPythonPackageTrainingJob instance creation failed.")
2498
+
2499
+ self._job.run(
2500
+ dataset=dataset,
2501
+ annotation_schema_uri=annotation_schema_uri,
2502
+ model_display_name=model_display_name,
2503
+ model_labels=model_labels,
2504
+ base_output_dir=base_output_dir,
2505
+ service_account=service_account,
2506
+ network=network,
2507
+ bigquery_destination=bigquery_destination,
2508
+ args=args,
2509
+ environment_variables=environment_variables,
2510
+ replica_count=replica_count,
2511
+ machine_type=machine_type,
2512
+ accelerator_type=accelerator_type,
2513
+ accelerator_count=accelerator_count,
2514
+ boot_disk_type=boot_disk_type,
2515
+ boot_disk_size_gb=boot_disk_size_gb,
2516
+ training_fraction_split=training_fraction_split,
2517
+ validation_fraction_split=validation_fraction_split,
2518
+ test_fraction_split=test_fraction_split,
2519
+ training_filter_split=training_filter_split,
2520
+ validation_filter_split=validation_filter_split,
2521
+ test_filter_split=test_filter_split,
2522
+ predefined_split_column_name=predefined_split_column_name,
2523
+ timestamp_split_column_name=timestamp_split_column_name,
2524
+ tensorboard=tensorboard,
2525
+ parent_model=parent_model,
2526
+ is_default_version=is_default_version,
2527
+ model_version_aliases=model_version_aliases,
2528
+ model_version_description=model_version_description,
2529
+ sync=False,
2530
+ )
2531
+
2532
+ return self._job
2533
+
2534
+ @GoogleBaseHook.fallback_to_default_project_id
2535
+ def submit_custom_training_job(
2536
+ self,
2537
+ *,
2538
+ project_id: str,
2539
+ region: str,
2540
+ display_name: str,
2541
+ script_path: str,
2542
+ container_uri: str,
2543
+ requirements: Sequence[str] | None = None,
2544
+ model_serving_container_image_uri: str | None = None,
2545
+ model_serving_container_predict_route: str | None = None,
2546
+ model_serving_container_health_route: str | None = None,
2547
+ model_serving_container_command: Sequence[str] | None = None,
2548
+ model_serving_container_args: Sequence[str] | None = None,
2549
+ model_serving_container_environment_variables: dict[str, str] | None = None,
2550
+ model_serving_container_ports: Sequence[int] | None = None,
2551
+ model_description: str | None = None,
2552
+ model_instance_schema_uri: str | None = None,
2553
+ model_parameters_schema_uri: str | None = None,
2554
+ model_prediction_schema_uri: str | None = None,
2555
+ parent_model: str | None = None,
2556
+ is_default_version: bool | None = None,
2557
+ model_version_aliases: list[str] | None = None,
2558
+ model_version_description: str | None = None,
2559
+ labels: dict[str, str] | None = None,
2560
+ training_encryption_spec_key_name: str | None = None,
2561
+ model_encryption_spec_key_name: str | None = None,
2562
+ staging_bucket: str | None = None,
2563
+ # RUN
2564
+ dataset: None
2565
+ | (
2566
+ datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset
2567
+ ) = None,
2568
+ annotation_schema_uri: str | None = None,
2569
+ model_display_name: str | None = None,
2570
+ model_labels: dict[str, str] | None = None,
2571
+ base_output_dir: str | None = None,
2572
+ service_account: str | None = None,
2573
+ network: str | None = None,
2574
+ bigquery_destination: str | None = None,
2575
+ args: list[str | float | int] | None = None,
2576
+ environment_variables: dict[str, str] | None = None,
2577
+ replica_count: int = 1,
2578
+ machine_type: str = "n1-standard-4",
2579
+ accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
2580
+ accelerator_count: int = 0,
2581
+ boot_disk_type: str = "pd-ssd",
2582
+ boot_disk_size_gb: int = 100,
2583
+ training_fraction_split: float | None = None,
2584
+ validation_fraction_split: float | None = None,
2585
+ test_fraction_split: float | None = None,
2586
+ training_filter_split: str | None = None,
2587
+ validation_filter_split: str | None = None,
2588
+ test_filter_split: str | None = None,
2589
+ predefined_split_column_name: str | None = None,
2590
+ timestamp_split_column_name: str | None = None,
2591
+ tensorboard: str | None = None,
2592
+ ) -> CustomTrainingJob:
2593
+ """
2594
+ Create and submit a Custom Training Job pipeline, then exit without waiting for it to complete.
2595
+
2596
+ Neither the training model nor backing custom job are available at the moment when the training
2597
+ pipeline is submitted, both are created only after a period of time. Therefore, it is not possible
2598
+ to extract and return them in this method, this should be done with a separate client request.
2599
+
2600
+ :param display_name: Required. The user-defined name of this TrainingPipeline.
2601
+ :param script_path: Required. Local path to training script.
2602
+ :param container_uri: Required: Uri of the training container image in the GCR.
2603
+ :param requirements: List of python packages dependencies of script.
2604
+ :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
2605
+ of the Model serving container suitable for serving the model produced by the
2606
+ training script.
2607
+ :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
2608
+ HTTP path to send prediction requests to the container, and which must be supported
2609
+ by it. If not specified a default HTTP path will be used by Vertex AI.
2610
+ :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
2611
+ HTTP path to send health check requests to the container, and which must be supported
2612
+ by it. If not specified a standard HTTP path will be used by AI Platform.
2613
+ :param model_serving_container_command: The command with which the container is run. Not executed
2614
+ within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
2615
+ Variable references $(VAR_NAME) are expanded using the container's
2616
+ environment. If a variable cannot be resolved, the reference in the
2617
+ input string will be unchanged. The $(VAR_NAME) syntax can be escaped
2618
+ with a double $$, ie: $$(VAR_NAME). Escaped references will never be
2619
+ expanded, regardless of whether the variable exists or not.
2620
+ :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
2621
+ this is not provided. Variable references $(VAR_NAME) are expanded using the
2622
+ container's environment. If a variable cannot be resolved, the reference
2623
+ in the input string will be unchanged. The $(VAR_NAME) syntax can be
2624
+ escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
2625
+ never be expanded, regardless of whether the variable exists or not.
2626
+ :param model_serving_container_environment_variables: The environment variables that are to be
2627
+ present in the container. Should be a dictionary where keys are environment variable names
2628
+ and values are environment variable values for those names.
2629
+ :param model_serving_container_ports: Declaration of ports that are exposed by the container. This
2630
+ field is primarily informational, it gives Vertex AI information about the
2631
+ network connections the container uses. Listing or not a port here has
2632
+ no impact on whether the port is actually exposed, any port listening on
2633
+ the default "0.0.0.0" address inside a container will be accessible from
2634
+ the network.
2635
+ :param model_description: The description of the Model.
2636
+ :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
2637
+ Storage describing the format of a single instance, which
2638
+ are used in
2639
+ ``PredictRequest.instances``,
2640
+ ``ExplainRequest.instances``
2641
+ and
2642
+ ``BatchPredictionJob.input_config``.
2643
+ The schema is defined as an OpenAPI 3.0.2 `Schema
2644
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
2645
+ AutoML Models always have this field populated by AI
2646
+ Platform. Note: The URI given on output will be immutable
2647
+ and probably different, including the URI scheme, than the
2648
+ one given on input. The output URI will point to a location
2649
+ where the user only has a read access.
2650
+ :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
2651
+ Storage describing the parameters of prediction and
2652
+ explanation via
2653
+ ``PredictRequest.parameters``,
2654
+ ``ExplainRequest.parameters``
2655
+ and
2656
+ ``BatchPredictionJob.model_parameters``.
2657
+ The schema is defined as an OpenAPI 3.0.2 `Schema
2658
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
2659
+ AutoML Models always have this field populated by AI
2660
+ Platform, if no parameters are supported it is set to an
2661
+ empty string. Note: The URI given on output will be
2662
+ immutable and probably different, including the URI scheme,
2663
+ than the one given on input. The output URI will point to a
2664
+ location where the user only has a read access.
2665
+ :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
2666
+ Storage describing the format of a single prediction
2667
+ produced by this Model, which are returned via
2668
+ ``PredictResponse.predictions``,
2669
+ ``ExplainResponse.explanations``,
2670
+ and
2671
+ ``BatchPredictionJob.output_config``.
2672
+ The schema is defined as an OpenAPI 3.0.2 `Schema
2673
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
2674
+ AutoML Models always have this field populated by AI
2675
+ Platform. Note: The URI given on output will be immutable
2676
+ and probably different, including the URI scheme, than the
2677
+ one given on input. The output URI will point to a location
2678
+ where the user only has a read access.
2679
+ :param parent_model: Optional. The resource name or model ID of an existing model.
2680
+ The new model uploaded by this job will be a version of `parent_model`.
2681
+ Only set this field when training a new version of an existing model.
2682
+ :param is_default_version: Optional. When set to True, the newly uploaded model version will
2683
+ automatically have alias "default" included. Subsequent uses of
2684
+ the model produced by this job without a version specified will
2685
+ use this "default" version.
2686
+ When set to False, the "default" alias will not be moved.
2687
+ Actions targeting the model version produced by this job will need
2688
+ to specifically reference this version by ID or alias.
2689
+ New model uploads, i.e. version 1, will always be "default" aliased.
2690
+ :param model_version_aliases: Optional. User provided version aliases so that the model version
2691
+ uploaded by this job can be referenced via alias instead of
2692
+ auto-generated version ID. A default version alias will be created
2693
+ for the first version of the model.
2694
+ The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
2695
+ :param model_version_description: Optional. The description of the model version
2696
+ being uploaded by this job.
2697
+ :param project_id: Project to run training in.
2698
+ :param region: Location to run training in.
2699
+ :param labels: Optional. The labels with user-defined metadata to
2700
+ organize TrainingPipelines.
2701
+ Label keys and values can be no longer than 64
2702
+ characters, can only
2703
+ contain lowercase letters, numeric characters,
2704
+ underscores and dashes. International characters
2705
+ are allowed.
2706
+ See https://goo.gl/xmQnxf for more information
2707
+ and examples of labels.
2708
+ :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
2709
+ managed encryption key used to protect the training pipeline. Has the
2710
+ form:
2711
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
2712
+ The key needs to be in the same region as where the compute
2713
+ resource is created.
2714
+
2715
+ If set, this TrainingPipeline will be secured by this key.
2716
+
2717
+ Note: Model trained by this TrainingPipeline is also secured
2718
+ by this key if ``model_to_upload`` is not set separately.
2719
+ :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
2720
+ managed encryption key used to protect the model. Has the
2721
+ form:
2722
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
2723
+ The key needs to be in the same region as where the compute
2724
+ resource is created.
2725
+
2726
+ If set, the trained Model will be secured by this key.
2727
+ :param staging_bucket: Bucket used to stage source and training artifacts.
2728
+ :param dataset: Vertex AI to fit this training against.
2729
+ :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
2730
+ annotation schema. The schema is defined as an OpenAPI 3.0.2
2731
+ [Schema Object]
2732
+ (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
2733
+
2734
+ Only Annotations that both match this schema and belong to
2735
+ DataItems not ignored by the split method are used in
2736
+ respectively training, validation or test role, depending on
2737
+ the role of the DataItem they are on.
2738
+
2739
+ When used in conjunction with
2740
+ ``annotations_filter``,
2741
+ the Annotations used for training are filtered by both
2742
+ ``annotations_filter``
2743
+ and
2744
+ ``annotation_schema_uri``.
2745
+ :param model_display_name: If the script produces a managed Vertex AI Model. The display name of
2746
+ the Model. The name can be up to 128 characters long and can be consist
2747
+ of any UTF-8 characters.
2748
+
2749
+ If not provided upon creation, the job's display_name is used.
2750
+ :param model_labels: Optional. The labels with user-defined metadata to
2751
+ organize your Models.
2752
+ Label keys and values can be no longer than 64
2753
+ characters, can only
2754
+ contain lowercase letters, numeric characters,
2755
+ underscores and dashes. International characters
2756
+ are allowed.
2757
+ See https://goo.gl/xmQnxf for more information
2758
+ and examples of labels.
2759
+ :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
2760
+ staging directory will be used.
2761
+
2762
+ Vertex AI sets the following environment variables when it runs your training code:
2763
+
2764
+ - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
2765
+ i.e. <base_output_dir>/model/
2766
+ - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
2767
+ i.e. <base_output_dir>/checkpoints/
2768
+ - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
2769
+ logs, i.e. <base_output_dir>/logs/
2770
+ :param service_account: Specifies the service account for workload run-as account.
2771
+ Users submitting jobs must have act-as permission on this run-as account.
2772
+ :param network: The full name of the Compute Engine network to which the job
2773
+ should be peered.
2774
+ Private services access must already be configured for the network.
2775
+ If left unspecified, the job is not peered with any network.
2776
+ :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
2777
+ The BigQuery project location where the training data is to
2778
+ be written to. In the given project a new dataset is created
2779
+ with name
2780
+ ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
2781
+ where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
2782
+ training input data will be written into that dataset. In
2783
+ the dataset three tables will be created, ``training``,
2784
+ ``validation`` and ``test``.
2785
+
2786
+ - AIP_DATA_FORMAT = "bigquery".
2787
+ - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
2788
+ - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
2789
+ - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
2790
+ :param args: Command line arguments to be passed to the Python script.
2791
+ :param environment_variables: Environment variables to be passed to the container.
2792
+ Should be a dictionary where keys are environment variable names
2793
+ and values are environment variable values for those names.
2794
+ At most 10 environment variables can be specified.
2795
+ The Name of the environment variable must be unique.
2796
+ :param replica_count: The number of worker replicas. If replica count = 1 then one chief
2797
+ replica will be provisioned. If replica_count > 1 the remainder will be
2798
+ provisioned as a worker replica pool.
2799
+ :param machine_type: The type of machine to use for training.
2800
+ :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
2801
+ NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
2802
+ NVIDIA_TESLA_T4
2803
+ :param accelerator_count: The number of accelerators to attach to a worker replica.
2804
+ :param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
2805
+ Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
2806
+ `pd-standard` (Persistent Disk Hard Disk Drive).
2807
+ :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
2808
+ boot disk size must be within the range of [100, 64000].
2809
+ :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
2810
+ the Model. This is ignored if Dataset is not provided.
2811
+ :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
2812
+ validate the Model. This is ignored if Dataset is not provided.
2813
+ :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
2814
+ the Model. This is ignored if Dataset is not provided.
2815
+ :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
2816
+ this filter are used to train the Model. A filter with same syntax
2817
+ as the one used in DatasetService.ListDataItems may be used. If a
2818
+ single DataItem is matched by more than one of the FilterSplit filters,
2819
+ then it is assigned to the first set that applies to it in the training,
2820
+ validation, test order. This is ignored if Dataset is not provided.
2821
+ :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
2822
+ this filter are used to validate the Model. A filter with same syntax
2823
+ as the one used in DatasetService.ListDataItems may be used. If a
2824
+ single DataItem is matched by more than one of the FilterSplit filters,
2825
+ then it is assigned to the first set that applies to it in the training,
2826
+ validation, test order. This is ignored if Dataset is not provided.
2827
+ :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
2828
+ this filter are used to test the Model. A filter with same syntax
2829
+ as the one used in DatasetService.ListDataItems may be used. If a
2830
+ single DataItem is matched by more than one of the FilterSplit filters,
2831
+ then it is assigned to the first set that applies to it in the training,
2832
+ validation, test order. This is ignored if Dataset is not provided.
2833
+ :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
2834
+ columns. The value of the key (either the label's value or
2835
+ value in the column) must be one of {``training``,
2836
+ ``validation``, ``test``}, and it defines to which set the
2837
+ given piece of data is assigned. If for a piece of data the
2838
+ key is not present or has an invalid value, that piece is
2839
+ ignored by the pipeline.
2840
+
2841
+ Supported only for tabular and time series Datasets.
2842
+ :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
2843
+ columns. The value of the key values of the key (the values in
2844
+ the column) must be in RFC 3339 `date-time` format, where
2845
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
2846
+ piece of data the key is not present or has an invalid value,
2847
+ that piece is ignored by the pipeline.
2848
+
2849
+ Supported only for tabular and time series Datasets.
2850
+ :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
2851
+ logs. Format:
2852
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
2853
+ For more information on configuring your service account please visit:
2854
+ https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
2855
+ """
2856
+ self._job = self.get_custom_training_job(
2857
+ project=project_id,
2858
+ location=region,
2859
+ display_name=display_name,
2860
+ script_path=script_path,
2861
+ container_uri=container_uri,
2862
+ requirements=requirements,
2863
+ model_serving_container_image_uri=model_serving_container_image_uri,
2864
+ model_serving_container_predict_route=model_serving_container_predict_route,
2865
+ model_serving_container_health_route=model_serving_container_health_route,
2866
+ model_serving_container_command=model_serving_container_command,
2867
+ model_serving_container_args=model_serving_container_args,
2868
+ model_serving_container_environment_variables=model_serving_container_environment_variables,
2869
+ model_serving_container_ports=model_serving_container_ports,
2870
+ model_description=model_description,
2871
+ model_instance_schema_uri=model_instance_schema_uri,
2872
+ model_parameters_schema_uri=model_parameters_schema_uri,
2873
+ model_prediction_schema_uri=model_prediction_schema_uri,
2874
+ labels=labels,
2875
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
2876
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
2877
+ staging_bucket=staging_bucket,
2878
+ )
2879
+
2880
+ if not self._job:
2881
+ raise AirflowException("CustomTrainingJob instance creation failed.")
2882
+
2883
+ self._job.submit(
2884
+ dataset=dataset,
2885
+ annotation_schema_uri=annotation_schema_uri,
2886
+ model_display_name=model_display_name,
2887
+ model_labels=model_labels,
2888
+ base_output_dir=base_output_dir,
2889
+ service_account=service_account,
2890
+ network=network,
2891
+ bigquery_destination=bigquery_destination,
2892
+ args=args,
2893
+ environment_variables=environment_variables,
2894
+ replica_count=replica_count,
2895
+ machine_type=machine_type,
2896
+ accelerator_type=accelerator_type,
2897
+ accelerator_count=accelerator_count,
2898
+ boot_disk_type=boot_disk_type,
2899
+ boot_disk_size_gb=boot_disk_size_gb,
2900
+ training_fraction_split=training_fraction_split,
2901
+ validation_fraction_split=validation_fraction_split,
2902
+ test_fraction_split=test_fraction_split,
2903
+ training_filter_split=training_filter_split,
2904
+ validation_filter_split=validation_filter_split,
2905
+ test_filter_split=test_filter_split,
2906
+ predefined_split_column_name=predefined_split_column_name,
2907
+ timestamp_split_column_name=timestamp_split_column_name,
2908
+ tensorboard=tensorboard,
2909
+ parent_model=parent_model,
2910
+ is_default_version=is_default_version,
2911
+ model_version_aliases=model_version_aliases,
2912
+ model_version_description=model_version_description,
2913
+ sync=False,
2914
+ )
2915
+ return self._job
2916
+
2917
+ @GoogleBaseHook.fallback_to_default_project_id
2918
+ def delete_training_pipeline(
2919
+ self,
2920
+ project_id: str,
2921
+ region: str,
2922
+ training_pipeline: str,
2923
+ retry: Retry | _MethodDefault = DEFAULT,
2924
+ timeout: float | None = None,
2925
+ metadata: Sequence[tuple[str, str]] = (),
2926
+ ) -> Operation:
2927
+ """
2928
+ Delete a TrainingPipeline.
2929
+
2930
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
2931
+ :param region: Required. The ID of the Google Cloud region that the service belongs to.
2932
+ :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted.
2933
+ :param retry: Designation of what errors, if any, should be retried.
2934
+ :param timeout: The timeout for this request.
2935
+ :param metadata: Strings which should be sent along with the request as metadata.
2936
+ """
2937
+ client = self.get_pipeline_service_client(region)
2938
+ name = client.training_pipeline_path(project_id, region, training_pipeline)
2939
+
2940
+ result = client.delete_training_pipeline(
2941
+ request={"name": name},
2942
+ retry=retry,
2943
+ timeout=timeout,
2944
+ metadata=metadata,
2945
+ )
2946
+ return result
2947
+
2948
+ @GoogleBaseHook.fallback_to_default_project_id
2949
+ def delete_custom_job(
2950
+ self,
1831
2951
  project_id: str,
1832
2952
  region: str,
1833
2953
  custom_job: str,
@@ -2178,3 +3298,240 @@ class CustomJobHook(GoogleBaseHook):
2178
3298
  metadata=metadata,
2179
3299
  )
2180
3300
  return result
3301
+
3302
+ @GoogleBaseHook.fallback_to_default_project_id
3303
+ @deprecated(
3304
+ reason="Please use `PipelineJobHook.delete_pipeline_job`",
3305
+ category=AirflowProviderDeprecationWarning,
3306
+ )
3307
+ def delete_pipeline_job(
3308
+ self,
3309
+ project_id: str,
3310
+ region: str,
3311
+ pipeline_job: str,
3312
+ retry: Retry | _MethodDefault = DEFAULT,
3313
+ timeout: float | None = None,
3314
+ metadata: Sequence[tuple[str, str]] = (),
3315
+ ) -> Operation:
3316
+ """
3317
+ Delete a PipelineJob.
3318
+
3319
+ This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method.
3320
+
3321
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
3322
+ :param region: Required. The ID of the Google Cloud region that the service belongs to.
3323
+ :param pipeline_job: Required. The name of the PipelineJob resource to be deleted.
3324
+ :param retry: Designation of what errors, if any, should be retried.
3325
+ :param timeout: The timeout for this request.
3326
+ :param metadata: Strings which should be sent along with the request as metadata.
3327
+ """
3328
+ client = self.get_pipeline_service_client(region)
3329
+ name = client.pipeline_job_path(project_id, region, pipeline_job)
3330
+
3331
+ result = client.delete_pipeline_job(
3332
+ request={"name": name},
3333
+ retry=retry,
3334
+ timeout=timeout,
3335
+ metadata=metadata,
3336
+ )
3337
+ return result
3338
+
3339
+
3340
+ class CustomJobAsyncHook(GoogleBaseAsyncHook):
3341
+ """Async hook for Custom Job Service Client."""
3342
+
3343
+ sync_hook_class = CustomJobHook
3344
+ JOB_COMPLETE_STATES = {
3345
+ JobState.JOB_STATE_CANCELLED,
3346
+ JobState.JOB_STATE_FAILED,
3347
+ JobState.JOB_STATE_PAUSED,
3348
+ JobState.JOB_STATE_SUCCEEDED,
3349
+ }
3350
+ PIPELINE_COMPLETE_STATES = (
3351
+ PipelineState.PIPELINE_STATE_CANCELLED,
3352
+ PipelineState.PIPELINE_STATE_FAILED,
3353
+ PipelineState.PIPELINE_STATE_PAUSED,
3354
+ PipelineState.PIPELINE_STATE_SUCCEEDED,
3355
+ )
3356
+
3357
+ def __init__(
3358
+ self,
3359
+ gcp_conn_id: str = "google_cloud_default",
3360
+ impersonation_chain: str | Sequence[str] | None = None,
3361
+ **kwargs,
3362
+ ):
3363
+ super().__init__(
3364
+ gcp_conn_id=gcp_conn_id,
3365
+ impersonation_chain=impersonation_chain,
3366
+ **kwargs,
3367
+ )
3368
+ self._job: None | (
3369
+ CustomContainerTrainingJob | CustomPythonPackageTrainingJob | CustomTrainingJob
3370
+ ) = None
3371
+
3372
+ async def get_credentials(self) -> Credentials:
3373
+ return (await self.get_sync_hook()).get_credentials()
3374
+
3375
+ async def get_job_service_client(
3376
+ self,
3377
+ region: str | None = None,
3378
+ ) -> JobServiceAsyncClient:
3379
+ """Retrieve Vertex AI JobServiceAsyncClient object."""
3380
+ if region and region != "global":
3381
+ client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
3382
+ else:
3383
+ client_options = ClientOptions()
3384
+ return JobServiceAsyncClient(
3385
+ credentials=(await self.get_credentials()),
3386
+ client_info=CLIENT_INFO,
3387
+ client_options=client_options,
3388
+ )
3389
+
3390
+ async def get_pipeline_service_client(
3391
+ self,
3392
+ region: str | None = None,
3393
+ ) -> PipelineServiceAsyncClient:
3394
+ """Retrieve Vertex AI PipelineServiceAsyncClient object."""
3395
+ if region and region != "global":
3396
+ client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
3397
+ else:
3398
+ client_options = ClientOptions()
3399
+ return PipelineServiceAsyncClient(
3400
+ credentials=(await self.get_credentials()),
3401
+ client_info=CLIENT_INFO,
3402
+ client_options=client_options,
3403
+ )
3404
+
3405
+ async def get_custom_job(
3406
+ self,
3407
+ project_id: str,
3408
+ location: str,
3409
+ job_id: str,
3410
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
3411
+ timeout: float | _MethodDefault | None = DEFAULT,
3412
+ metadata: Sequence[tuple[str, str]] = (),
3413
+ client: JobServiceAsyncClient | None = None,
3414
+ ) -> types.CustomJob:
3415
+ """
3416
+ Get a CustomJob proto message from JobServiceAsyncClient.
3417
+
3418
+ :param project_id: Required. The ID of the Google Cloud project that the job belongs to.
3419
+ :param location: Required. The ID of the Google Cloud region that the job belongs to.
3420
+ :param job_id: Required. The custom job id.
3421
+ :param retry: Designation of what errors, if any, should be retried.
3422
+ :param timeout: The timeout for this request.
3423
+ :param metadata: Strings which should be sent along with the request as metadata.
3424
+ :param client: The async job service client.
3425
+ """
3426
+ if not client:
3427
+ client = await self.get_job_service_client(region=location)
3428
+ job_name = client.custom_job_path(project_id, location, job_id)
3429
+ result: types.CustomJob = await client.get_custom_job(
3430
+ request={"name": job_name},
3431
+ retry=retry,
3432
+ timeout=timeout,
3433
+ metadata=metadata,
3434
+ )
3435
+ return result
3436
+
3437
+ async def get_training_pipeline(
3438
+ self,
3439
+ project_id: str,
3440
+ location: str,
3441
+ pipeline_id: str,
3442
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
3443
+ timeout: float | _MethodDefault | None = DEFAULT,
3444
+ metadata: Sequence[tuple[str, str]] = (),
3445
+ client: PipelineServiceAsyncClient | None = None,
3446
+ ) -> types.TrainingPipeline:
3447
+ """
3448
+ Get a TrainingPipeline proto message from PipelineServiceAsyncClient.
3449
+
3450
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
3451
+ :param location: Required. The ID of the Google Cloud region that the service belongs to.
3452
+ :param pipeline_id: Required. The ID of the PipelineJob resource.
3453
+ :param retry: Designation of what errors, if any, should be retried.
3454
+ :param timeout: The timeout for this request.
3455
+ :param metadata: Strings which should be sent along with the request as metadata.
3456
+ :param client: The async pipeline service client.
3457
+ """
3458
+ if not client:
3459
+ client = await self.get_pipeline_service_client(region=location)
3460
+ pipeline_name = client.training_pipeline_path(
3461
+ project=project_id,
3462
+ location=location,
3463
+ training_pipeline=pipeline_id,
3464
+ )
3465
+ response: types.TrainingPipeline = await client.get_training_pipeline(
3466
+ request={"name": pipeline_name},
3467
+ retry=retry,
3468
+ timeout=timeout,
3469
+ metadata=metadata,
3470
+ )
3471
+ return response
3472
+
3473
+ async def wait_for_custom_job(
3474
+ self,
3475
+ project_id: str,
3476
+ location: str,
3477
+ job_id: str,
3478
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
3479
+ timeout: float | None = None,
3480
+ metadata: Sequence[tuple[str, str]] = (),
3481
+ poll_interval: int = 10,
3482
+ ) -> types.CustomJob:
3483
+ """Make async calls to Vertex AI to check the custom job state until it is complete."""
3484
+ client = await self.get_job_service_client(region=location)
3485
+ while True:
3486
+ try:
3487
+ self.log.info("Requesting a custom job with id %s", job_id)
3488
+ job: types.CustomJob = await self.get_custom_job(
3489
+ project_id=project_id,
3490
+ location=location,
3491
+ job_id=job_id,
3492
+ retry=retry,
3493
+ timeout=timeout,
3494
+ metadata=metadata,
3495
+ client=client,
3496
+ )
3497
+ except Exception as ex:
3498
+ self.log.exception("Exception occurred while requesting job %s", job_id)
3499
+ raise AirflowException(ex)
3500
+ self.log.info("Status of the custom job %s is %s", job.name, job.state.name)
3501
+ if job.state in self.JOB_COMPLETE_STATES:
3502
+ return job
3503
+ self.log.info("Sleeping for %s seconds.", poll_interval)
3504
+ await asyncio.sleep(poll_interval)
3505
+
3506
+ async def wait_for_training_pipeline(
3507
+ self,
3508
+ project_id: str,
3509
+ location: str,
3510
+ pipeline_id: str,
3511
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
3512
+ timeout: float | None = None,
3513
+ metadata: Sequence[tuple[str, str]] = (),
3514
+ poll_interval: int = 10,
3515
+ ) -> types.TrainingPipeline:
3516
+ """Make async calls to Vertex AI to check the training pipeline state until it is complete."""
3517
+ client = await self.get_pipeline_service_client(region=location)
3518
+ while True:
3519
+ try:
3520
+ self.log.info("Requesting a training pipeline with id %s", pipeline_id)
3521
+ pipeline: types.TrainingPipeline = await self.get_training_pipeline(
3522
+ project_id=project_id,
3523
+ location=location,
3524
+ pipeline_id=pipeline_id,
3525
+ retry=retry,
3526
+ timeout=timeout,
3527
+ metadata=metadata,
3528
+ client=client,
3529
+ )
3530
+ except Exception as ex:
3531
+ self.log.exception("Exception occurred while requesting training pipeline %s", pipeline_id)
3532
+ raise AirflowException(ex)
3533
+ self.log.info("Status of the training pipeline %s is %s", pipeline.name, pipeline.state.name)
3534
+ if pipeline.state in self.PIPELINE_COMPLETE_STATES:
3535
+ return pipeline
3536
+ self.log.info("Sleeping for %s seconds.", poll_interval)
3537
+ await asyncio.sleep(poll_interval)