apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/cloud/hooks/automl.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +64 -33
- airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
- airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
- airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
- airflow/providers/google/cloud/hooks/dataflow.py +246 -32
- airflow/providers/google/cloud/hooks/dataplex.py +6 -2
- airflow/providers/google/cloud/hooks/dlp.py +14 -14
- airflow/providers/google/cloud/hooks/gcs.py +6 -2
- airflow/providers/google/cloud/hooks/gdm.py +2 -2
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/mlengine.py +8 -4
- airflow/providers/google/cloud/hooks/pubsub.py +1 -1
- airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
- airflow/providers/google/cloud/links/vertex_ai.py +2 -1
- airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
- airflow/providers/google/cloud/operators/automl.py +13 -12
- airflow/providers/google/cloud/operators/bigquery.py +36 -22
- airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
- airflow/providers/google/cloud/operators/bigtable.py +7 -6
- airflow/providers/google/cloud/operators/cloud_build.py +12 -11
- airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
- airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
- airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
- airflow/providers/google/cloud/operators/compute.py +12 -11
- airflow/providers/google/cloud/operators/datacatalog.py +21 -20
- airflow/providers/google/cloud/operators/dataflow.py +59 -42
- airflow/providers/google/cloud/operators/datafusion.py +11 -10
- airflow/providers/google/cloud/operators/datapipeline.py +3 -2
- airflow/providers/google/cloud/operators/dataprep.py +5 -4
- airflow/providers/google/cloud/operators/dataproc.py +19 -16
- airflow/providers/google/cloud/operators/datastore.py +8 -7
- airflow/providers/google/cloud/operators/dlp.py +31 -30
- airflow/providers/google/cloud/operators/functions.py +4 -3
- airflow/providers/google/cloud/operators/gcs.py +66 -41
- airflow/providers/google/cloud/operators/kubernetes_engine.py +232 -12
- airflow/providers/google/cloud/operators/life_sciences.py +2 -1
- airflow/providers/google/cloud/operators/mlengine.py +11 -10
- airflow/providers/google/cloud/operators/pubsub.py +6 -5
- airflow/providers/google/cloud/operators/spanner.py +7 -6
- airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
- airflow/providers/google/cloud/operators/stackdriver.py +11 -10
- airflow/providers/google/cloud/operators/tasks.py +14 -13
- airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
- airflow/providers/google/cloud/operators/translate_speech.py +2 -1
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
- airflow/providers/google/cloud/operators/vision.py +13 -12
- airflow/providers/google/cloud/operators/workflows.py +10 -9
- airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
- airflow/providers/google/cloud/sensors/bigtable.py +2 -1
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
- airflow/providers/google/cloud/sensors/dataflow.py +239 -52
- airflow/providers/google/cloud/sensors/datafusion.py +2 -1
- airflow/providers/google/cloud/sensors/dataproc.py +3 -2
- airflow/providers/google/cloud/sensors/gcs.py +14 -12
- airflow/providers/google/cloud/sensors/tasks.py +2 -1
- airflow/providers/google/cloud/sensors/workflows.py +2 -1
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
- airflow/providers/google/cloud/triggers/bigquery.py +14 -3
- airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
- airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
- airflow/providers/google/cloud/triggers/dataflow.py +504 -4
- airflow/providers/google/cloud/triggers/dataproc.py +110 -26
- airflow/providers/google/cloud/triggers/mlengine.py +2 -1
- airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
- airflow/providers/google/common/hooks/base_google.py +45 -7
- airflow/providers/google/firebase/hooks/firestore.py +2 -2
- airflow/providers/google/firebase/operators/firestore.py +2 -1
- airflow/providers/google/get_provider_info.py +3 -2
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/METADATA +8 -8
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/RECORD +88 -89
- airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,7 +19,8 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
-
|
22
|
+
import asyncio
|
23
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
23
24
|
|
24
25
|
from deprecated import deprecated
|
25
26
|
from google.api_core.client_options import ClientOptions
|
@@ -31,15 +32,24 @@ from google.cloud.aiplatform import (
|
|
31
32
|
datasets,
|
32
33
|
models,
|
33
34
|
)
|
34
|
-
from google.cloud.aiplatform_v1 import
|
35
|
+
from google.cloud.aiplatform_v1 import (
|
36
|
+
JobServiceAsyncClient,
|
37
|
+
JobServiceClient,
|
38
|
+
JobState,
|
39
|
+
PipelineServiceAsyncClient,
|
40
|
+
PipelineServiceClient,
|
41
|
+
PipelineState,
|
42
|
+
types,
|
43
|
+
)
|
35
44
|
|
36
45
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
37
46
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
38
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
47
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
39
48
|
|
40
49
|
if TYPE_CHECKING:
|
41
50
|
from google.api_core.operation import Operation
|
42
|
-
from google.api_core.retry import Retry
|
51
|
+
from google.api_core.retry import AsyncRetry, Retry
|
52
|
+
from google.auth.credentials import Credentials
|
43
53
|
from google.cloud.aiplatform_v1.services.job_service.pagers import ListCustomJobsPager
|
44
54
|
from google.cloud.aiplatform_v1.services.pipeline_service.pagers import (
|
45
55
|
ListPipelineJobsPager,
|
@@ -101,7 +111,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
101
111
|
self,
|
102
112
|
display_name: str,
|
103
113
|
container_uri: str,
|
104
|
-
command: Sequence[str] =
|
114
|
+
command: Sequence[str] = (),
|
105
115
|
model_serving_container_image_uri: str | None = None,
|
106
116
|
model_serving_container_predict_route: str | None = None,
|
107
117
|
model_serving_container_health_route: str | None = None,
|
@@ -168,7 +178,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
168
178
|
training_encryption_spec_key_name: str | None = None,
|
169
179
|
model_encryption_spec_key_name: str | None = None,
|
170
180
|
staging_bucket: str | None = None,
|
171
|
-
):
|
181
|
+
) -> CustomPythonPackageTrainingJob:
|
172
182
|
"""Return CustomPythonPackageTrainingJob object."""
|
173
183
|
return CustomPythonPackageTrainingJob(
|
174
184
|
display_name=display_name,
|
@@ -218,7 +228,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
218
228
|
training_encryption_spec_key_name: str | None = None,
|
219
229
|
model_encryption_spec_key_name: str | None = None,
|
220
230
|
staging_bucket: str | None = None,
|
221
|
-
):
|
231
|
+
) -> CustomTrainingJob:
|
222
232
|
"""Return CustomTrainingJob object."""
|
223
233
|
return CustomTrainingJob(
|
224
234
|
display_name=display_name,
|
@@ -246,10 +256,15 @@ class CustomJobHook(GoogleBaseHook):
|
|
246
256
|
)
|
247
257
|
|
248
258
|
@staticmethod
|
249
|
-
def extract_model_id(obj: dict) -> str:
|
259
|
+
def extract_model_id(obj: dict[str, Any]) -> str:
|
250
260
|
"""Return unique id of the Model."""
|
251
261
|
return obj["name"].rpartition("/")[-1]
|
252
262
|
|
263
|
+
@staticmethod
|
264
|
+
def extract_model_id_from_training_pipeline(training_pipeline: dict[str, Any]) -> str:
|
265
|
+
"""Return a unique Model id from a serialized TrainingPipeline proto."""
|
266
|
+
return training_pipeline["model_to_upload"]["name"].rpartition("/")[-1]
|
267
|
+
|
253
268
|
@staticmethod
|
254
269
|
def extract_training_id(resource_name: str) -> str:
|
255
270
|
"""Return unique id of the Training pipeline."""
|
@@ -260,6 +275,11 @@ class CustomJobHook(GoogleBaseHook):
|
|
260
275
|
"""Return unique id of the Custom Job pipeline."""
|
261
276
|
return custom_job_name.rpartition("/")[-1]
|
262
277
|
|
278
|
+
@staticmethod
|
279
|
+
def extract_custom_job_id_from_training_pipeline(training_pipeline: dict[str, Any]) -> str:
|
280
|
+
"""Return a unique Custom Job id from a serialized TrainingPipeline proto."""
|
281
|
+
return training_pipeline["training_task_metadata"]["backingCustomJob"].rpartition("/")[-1]
|
282
|
+
|
263
283
|
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
264
284
|
"""Wait for long-lasting operation to complete."""
|
265
285
|
try:
|
@@ -310,7 +330,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
310
330
|
model_version_aliases: list[str] | None = None,
|
311
331
|
model_version_description: str | None = None,
|
312
332
|
) -> tuple[models.Model | None, str, str]:
|
313
|
-
"""Run
|
333
|
+
"""Run a training pipeline job and wait until its completion."""
|
314
334
|
model = job.run(
|
315
335
|
dataset=dataset,
|
316
336
|
annotation_schema_uri=annotation_schema_uri,
|
@@ -609,7 +629,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
609
629
|
region: str,
|
610
630
|
display_name: str,
|
611
631
|
container_uri: str,
|
612
|
-
command: Sequence[str] =
|
632
|
+
command: Sequence[str] = (),
|
613
633
|
model_serving_container_image_uri: str | None = None,
|
614
634
|
model_serving_container_predict_route: str | None = None,
|
615
635
|
model_serving_container_health_route: str | None = None,
|
@@ -1754,80 +1774,1180 @@ class CustomJobHook(GoogleBaseHook):
|
|
1754
1774
|
return model, training_id, custom_job_id
|
1755
1775
|
|
1756
1776
|
@GoogleBaseHook.fallback_to_default_project_id
|
1757
|
-
|
1758
|
-
reason="Please use `PipelineJobHook.delete_pipeline_job`",
|
1759
|
-
category=AirflowProviderDeprecationWarning,
|
1760
|
-
)
|
1761
|
-
def delete_pipeline_job(
|
1777
|
+
def submit_custom_container_training_job(
|
1762
1778
|
self,
|
1779
|
+
*,
|
1763
1780
|
project_id: str,
|
1764
1781
|
region: str,
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
|
1782
|
+
display_name: str,
|
1783
|
+
container_uri: str,
|
1784
|
+
command: Sequence[str] = (),
|
1785
|
+
model_serving_container_image_uri: str | None = None,
|
1786
|
+
model_serving_container_predict_route: str | None = None,
|
1787
|
+
model_serving_container_health_route: str | None = None,
|
1788
|
+
model_serving_container_command: Sequence[str] | None = None,
|
1789
|
+
model_serving_container_args: Sequence[str] | None = None,
|
1790
|
+
model_serving_container_environment_variables: dict[str, str] | None = None,
|
1791
|
+
model_serving_container_ports: Sequence[int] | None = None,
|
1792
|
+
model_description: str | None = None,
|
1793
|
+
model_instance_schema_uri: str | None = None,
|
1794
|
+
model_parameters_schema_uri: str | None = None,
|
1795
|
+
model_prediction_schema_uri: str | None = None,
|
1796
|
+
parent_model: str | None = None,
|
1797
|
+
is_default_version: bool | None = None,
|
1798
|
+
model_version_aliases: list[str] | None = None,
|
1799
|
+
model_version_description: str | None = None,
|
1800
|
+
labels: dict[str, str] | None = None,
|
1801
|
+
training_encryption_spec_key_name: str | None = None,
|
1802
|
+
model_encryption_spec_key_name: str | None = None,
|
1803
|
+
staging_bucket: str | None = None,
|
1804
|
+
# RUN
|
1805
|
+
dataset: None
|
1806
|
+
| (
|
1807
|
+
datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset
|
1808
|
+
) = None,
|
1809
|
+
annotation_schema_uri: str | None = None,
|
1810
|
+
model_display_name: str | None = None,
|
1811
|
+
model_labels: dict[str, str] | None = None,
|
1812
|
+
base_output_dir: str | None = None,
|
1813
|
+
service_account: str | None = None,
|
1814
|
+
network: str | None = None,
|
1815
|
+
bigquery_destination: str | None = None,
|
1816
|
+
args: list[str | float | int] | None = None,
|
1817
|
+
environment_variables: dict[str, str] | None = None,
|
1818
|
+
replica_count: int = 1,
|
1819
|
+
machine_type: str = "n1-standard-4",
|
1820
|
+
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
|
1821
|
+
accelerator_count: int = 0,
|
1822
|
+
boot_disk_type: str = "pd-ssd",
|
1823
|
+
boot_disk_size_gb: int = 100,
|
1824
|
+
training_fraction_split: float | None = None,
|
1825
|
+
validation_fraction_split: float | None = None,
|
1826
|
+
test_fraction_split: float | None = None,
|
1827
|
+
training_filter_split: str | None = None,
|
1828
|
+
validation_filter_split: str | None = None,
|
1829
|
+
test_filter_split: str | None = None,
|
1830
|
+
predefined_split_column_name: str | None = None,
|
1831
|
+
timestamp_split_column_name: str | None = None,
|
1832
|
+
tensorboard: str | None = None,
|
1833
|
+
) -> CustomContainerTrainingJob:
|
1770
1834
|
"""
|
1771
|
-
|
1772
|
-
|
1773
|
-
This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method.
|
1835
|
+
Create and submit a Custom Container Training Job pipeline, then exit without waiting for it to complete.
|
1774
1836
|
|
1775
|
-
:param
|
1776
|
-
:param
|
1777
|
-
|
1778
|
-
:param
|
1779
|
-
:param
|
1780
|
-
|
1781
|
-
|
1782
|
-
|
1783
|
-
|
1837
|
+
:param display_name: Required. The user-defined name of this TrainingPipeline.
|
1838
|
+
:param command: The command to be invoked when the container is started.
|
1839
|
+
It overrides the entrypoint instruction in Dockerfile when provided
|
1840
|
+
:param container_uri: Required: Uri of the training container image in the GCR.
|
1841
|
+
:param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
|
1842
|
+
of the Model serving container suitable for serving the model produced by the
|
1843
|
+
training script.
|
1844
|
+
:param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
|
1845
|
+
HTTP path to send prediction requests to the container, and which must be supported
|
1846
|
+
by it. If not specified a default HTTP path will be used by Vertex AI.
|
1847
|
+
:param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
|
1848
|
+
HTTP path to send health check requests to the container, and which must be supported
|
1849
|
+
by it. If not specified a standard HTTP path will be used by AI Platform.
|
1850
|
+
:param model_serving_container_command: The command with which the container is run. Not executed
|
1851
|
+
within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
|
1852
|
+
Variable references $(VAR_NAME) are expanded using the container's
|
1853
|
+
environment. If a variable cannot be resolved, the reference in the
|
1854
|
+
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
|
1855
|
+
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
|
1856
|
+
expanded, regardless of whether the variable exists or not.
|
1857
|
+
:param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
|
1858
|
+
this is not provided. Variable references $(VAR_NAME) are expanded using the
|
1859
|
+
container's environment. If a variable cannot be resolved, the reference
|
1860
|
+
in the input string will be unchanged. The $(VAR_NAME) syntax can be
|
1861
|
+
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
|
1862
|
+
never be expanded, regardless of whether the variable exists or not.
|
1863
|
+
:param model_serving_container_environment_variables: The environment variables that are to be
|
1864
|
+
present in the container. Should be a dictionary where keys are environment variable names
|
1865
|
+
and values are environment variable values for those names.
|
1866
|
+
:param model_serving_container_ports: Declaration of ports that are exposed by the container. This
|
1867
|
+
field is primarily informational, it gives Vertex AI information about the
|
1868
|
+
network connections the container uses. Listing or not a port here has
|
1869
|
+
no impact on whether the port is actually exposed, any port listening on
|
1870
|
+
the default "0.0.0.0" address inside a container will be accessible from
|
1871
|
+
the network.
|
1872
|
+
:param model_description: The description of the Model.
|
1873
|
+
:param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
|
1874
|
+
Storage describing the format of a single instance, which
|
1875
|
+
are used in
|
1876
|
+
``PredictRequest.instances``,
|
1877
|
+
``ExplainRequest.instances``
|
1878
|
+
and
|
1879
|
+
``BatchPredictionJob.input_config``.
|
1880
|
+
The schema is defined as an OpenAPI 3.0.2 `Schema
|
1881
|
+
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
1882
|
+
AutoML Models always have this field populated by AI
|
1883
|
+
Platform. Note: The URI given on output will be immutable
|
1884
|
+
and probably different, including the URI scheme, than the
|
1885
|
+
one given on input. The output URI will point to a location
|
1886
|
+
where the user only has a read access.
|
1887
|
+
:param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
|
1888
|
+
Storage describing the parameters of prediction and
|
1889
|
+
explanation via
|
1890
|
+
``PredictRequest.parameters``,
|
1891
|
+
``ExplainRequest.parameters``
|
1892
|
+
and
|
1893
|
+
``BatchPredictionJob.model_parameters``.
|
1894
|
+
The schema is defined as an OpenAPI 3.0.2 `Schema
|
1895
|
+
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
1896
|
+
AutoML Models always have this field populated by AI
|
1897
|
+
Platform, if no parameters are supported it is set to an
|
1898
|
+
empty string. Note: The URI given on output will be
|
1899
|
+
immutable and probably different, including the URI scheme,
|
1900
|
+
than the one given on input. The output URI will point to a
|
1901
|
+
location where the user only has a read access.
|
1902
|
+
:param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
|
1903
|
+
Storage describing the format of a single prediction
|
1904
|
+
produced by this Model, which are returned via
|
1905
|
+
``PredictResponse.predictions``,
|
1906
|
+
``ExplainResponse.explanations``,
|
1907
|
+
and
|
1908
|
+
``BatchPredictionJob.output_config``.
|
1909
|
+
The schema is defined as an OpenAPI 3.0.2 `Schema
|
1910
|
+
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
1911
|
+
AutoML Models always have this field populated by AI
|
1912
|
+
Platform. Note: The URI given on output will be immutable
|
1913
|
+
and probably different, including the URI scheme, than the
|
1914
|
+
one given on input. The output URI will point to a location
|
1915
|
+
where the user only has a read access.
|
1916
|
+
:param parent_model: Optional. The resource name or model ID of an existing model.
|
1917
|
+
The new model uploaded by this job will be a version of `parent_model`.
|
1918
|
+
Only set this field when training a new version of an existing model.
|
1919
|
+
:param is_default_version: Optional. When set to True, the newly uploaded model version will
|
1920
|
+
automatically have alias "default" included. Subsequent uses of
|
1921
|
+
the model produced by this job without a version specified will
|
1922
|
+
use this "default" version.
|
1923
|
+
When set to False, the "default" alias will not be moved.
|
1924
|
+
Actions targeting the model version produced by this job will need
|
1925
|
+
to specifically reference this version by ID or alias.
|
1926
|
+
New model uploads, i.e. version 1, will always be "default" aliased.
|
1927
|
+
:param model_version_aliases: Optional. User provided version aliases so that the model version
|
1928
|
+
uploaded by this job can be referenced via alias instead of
|
1929
|
+
auto-generated version ID. A default version alias will be created
|
1930
|
+
for the first version of the model.
|
1931
|
+
The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
|
1932
|
+
:param model_version_description: Optional. The description of the model version
|
1933
|
+
being uploaded by this job.
|
1934
|
+
:param project_id: Project to run training in.
|
1935
|
+
:param region: Location to run training in.
|
1936
|
+
:param labels: Optional. The labels with user-defined metadata to
|
1937
|
+
organize TrainingPipelines.
|
1938
|
+
Label keys and values can be no longer than 64
|
1939
|
+
characters, can only
|
1940
|
+
contain lowercase letters, numeric characters,
|
1941
|
+
underscores and dashes. International characters
|
1942
|
+
are allowed.
|
1943
|
+
See https://goo.gl/xmQnxf for more information
|
1944
|
+
and examples of labels.
|
1945
|
+
:param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
1946
|
+
managed encryption key used to protect the training pipeline. Has the
|
1947
|
+
form:
|
1948
|
+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
1949
|
+
The key needs to be in the same region as where the compute
|
1950
|
+
resource is created.
|
1784
1951
|
|
1785
|
-
|
1786
|
-
request={
|
1787
|
-
"name": name,
|
1788
|
-
},
|
1789
|
-
retry=retry,
|
1790
|
-
timeout=timeout,
|
1791
|
-
metadata=metadata,
|
1792
|
-
)
|
1793
|
-
return result
|
1952
|
+
If set, this TrainingPipeline will be secured by this key.
|
1794
1953
|
|
1795
|
-
|
1796
|
-
|
1797
|
-
|
1798
|
-
|
1799
|
-
|
1800
|
-
|
1801
|
-
|
1802
|
-
|
1803
|
-
metadata: Sequence[tuple[str, str]] = (),
|
1804
|
-
) -> Operation:
|
1805
|
-
"""
|
1806
|
-
Delete a TrainingPipeline.
|
1954
|
+
Note: Model trained by this TrainingPipeline is also secured
|
1955
|
+
by this key if ``model_to_upload`` is not set separately.
|
1956
|
+
:param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
1957
|
+
managed encryption key used to protect the model. Has the
|
1958
|
+
form:
|
1959
|
+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
1960
|
+
The key needs to be in the same region as where the compute
|
1961
|
+
resource is created.
|
1807
1962
|
|
1808
|
-
|
1809
|
-
:param
|
1810
|
-
:param
|
1811
|
-
:param
|
1812
|
-
|
1813
|
-
|
1814
|
-
|
1815
|
-
client = self.get_pipeline_service_client(region)
|
1816
|
-
name = client.training_pipeline_path(project_id, region, training_pipeline)
|
1963
|
+
If set, the trained Model will be secured by this key.
|
1964
|
+
:param staging_bucket: Bucket used to stage source and training artifacts.
|
1965
|
+
:param dataset: Vertex AI to fit this training against.
|
1966
|
+
:param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
|
1967
|
+
annotation schema. The schema is defined as an OpenAPI 3.0.2
|
1968
|
+
[Schema Object]
|
1969
|
+
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
|
1817
1970
|
|
1818
|
-
|
1819
|
-
|
1820
|
-
|
1821
|
-
|
1822
|
-
retry=retry,
|
1823
|
-
timeout=timeout,
|
1824
|
-
metadata=metadata,
|
1825
|
-
)
|
1826
|
-
return result
|
1971
|
+
Only Annotations that both match this schema and belong to
|
1972
|
+
DataItems not ignored by the split method are used in
|
1973
|
+
respectively training, validation or test role, depending on
|
1974
|
+
the role of the DataItem they are on.
|
1827
1975
|
|
1828
|
-
|
1829
|
-
|
1830
|
-
|
1976
|
+
When used in conjunction with
|
1977
|
+
``annotations_filter``,
|
1978
|
+
the Annotations used for training are filtered by both
|
1979
|
+
``annotations_filter``
|
1980
|
+
and
|
1981
|
+
``annotation_schema_uri``.
|
1982
|
+
:param model_display_name: If the script produces a managed Vertex AI Model. The display name of
|
1983
|
+
the Model. The name can be up to 128 characters long and can be consist
|
1984
|
+
of any UTF-8 characters.
|
1985
|
+
|
1986
|
+
If not provided upon creation, the job's display_name is used.
|
1987
|
+
:param model_labels: Optional. The labels with user-defined metadata to
|
1988
|
+
organize your Models.
|
1989
|
+
Label keys and values can be no longer than 64
|
1990
|
+
characters, can only
|
1991
|
+
contain lowercase letters, numeric characters,
|
1992
|
+
underscores and dashes. International characters
|
1993
|
+
are allowed.
|
1994
|
+
See https://goo.gl/xmQnxf for more information
|
1995
|
+
and examples of labels.
|
1996
|
+
:param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
|
1997
|
+
staging directory will be used.
|
1998
|
+
|
1999
|
+
Vertex AI sets the following environment variables when it runs your training code:
|
2000
|
+
|
2001
|
+
- AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
|
2002
|
+
i.e. <base_output_dir>/model/
|
2003
|
+
- AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
|
2004
|
+
i.e. <base_output_dir>/checkpoints/
|
2005
|
+
- AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
|
2006
|
+
logs, i.e. <base_output_dir>/logs/
|
2007
|
+
|
2008
|
+
:param service_account: Specifies the service account for workload run-as account.
|
2009
|
+
Users submitting jobs must have act-as permission on this run-as account.
|
2010
|
+
:param network: The full name of the Compute Engine network to which the job
|
2011
|
+
should be peered.
|
2012
|
+
Private services access must already be configured for the network.
|
2013
|
+
If left unspecified, the job is not peered with any network.
|
2014
|
+
:param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
|
2015
|
+
The BigQuery project location where the training data is to
|
2016
|
+
be written to. In the given project a new dataset is created
|
2017
|
+
with name
|
2018
|
+
``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
|
2019
|
+
where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
|
2020
|
+
training input data will be written into that dataset. In
|
2021
|
+
the dataset three tables will be created, ``training``,
|
2022
|
+
``validation`` and ``test``.
|
2023
|
+
|
2024
|
+
- AIP_DATA_FORMAT = "bigquery".
|
2025
|
+
- AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
|
2026
|
+
- AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
|
2027
|
+
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
|
2028
|
+
:param args: Command line arguments to be passed to the Python script.
|
2029
|
+
:param environment_variables: Environment variables to be passed to the container.
|
2030
|
+
Should be a dictionary where keys are environment variable names
|
2031
|
+
and values are environment variable values for those names.
|
2032
|
+
At most 10 environment variables can be specified.
|
2033
|
+
The Name of the environment variable must be unique.
|
2034
|
+
:param replica_count: The number of worker replicas. If replica count = 1 then one chief
|
2035
|
+
replica will be provisioned. If replica_count > 1 the remainder will be
|
2036
|
+
provisioned as a worker replica pool.
|
2037
|
+
:param machine_type: The type of machine to use for training.
|
2038
|
+
:param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
|
2039
|
+
NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
|
2040
|
+
NVIDIA_TESLA_T4
|
2041
|
+
:param accelerator_count: The number of accelerators to attach to a worker replica.
|
2042
|
+
:param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
|
2043
|
+
Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
|
2044
|
+
`pd-standard` (Persistent Disk Hard Disk Drive).
|
2045
|
+
:param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
|
2046
|
+
boot disk size must be within the range of [100, 64000].
|
2047
|
+
:param training_fraction_split: Optional. The fraction of the input data that is to be used to train
|
2048
|
+
the Model. This is ignored if Dataset is not provided.
|
2049
|
+
:param validation_fraction_split: Optional. The fraction of the input data that is to be used to
|
2050
|
+
validate the Model. This is ignored if Dataset is not provided.
|
2051
|
+
:param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
|
2052
|
+
the Model. This is ignored if Dataset is not provided.
|
2053
|
+
:param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
2054
|
+
this filter are used to train the Model. A filter with same syntax
|
2055
|
+
as the one used in DatasetService.ListDataItems may be used. If a
|
2056
|
+
single DataItem is matched by more than one of the FilterSplit filters,
|
2057
|
+
then it is assigned to the first set that applies to it in the training,
|
2058
|
+
validation, test order. This is ignored if Dataset is not provided.
|
2059
|
+
:param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
2060
|
+
this filter are used to validate the Model. A filter with same syntax
|
2061
|
+
as the one used in DatasetService.ListDataItems may be used. If a
|
2062
|
+
single DataItem is matched by more than one of the FilterSplit filters,
|
2063
|
+
then it is assigned to the first set that applies to it in the training,
|
2064
|
+
validation, test order. This is ignored if Dataset is not provided.
|
2065
|
+
:param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
2066
|
+
this filter are used to test the Model. A filter with same syntax
|
2067
|
+
as the one used in DatasetService.ListDataItems may be used. If a
|
2068
|
+
single DataItem is matched by more than one of the FilterSplit filters,
|
2069
|
+
then it is assigned to the first set that applies to it in the training,
|
2070
|
+
validation, test order. This is ignored if Dataset is not provided.
|
2071
|
+
:param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
|
2072
|
+
columns. The value of the key (either the label's value or
|
2073
|
+
value in the column) must be one of {``training``,
|
2074
|
+
``validation``, ``test``}, and it defines to which set the
|
2075
|
+
given piece of data is assigned. If for a piece of data the
|
2076
|
+
key is not present or has an invalid value, that piece is
|
2077
|
+
ignored by the pipeline.
|
2078
|
+
|
2079
|
+
Supported only for tabular and time series Datasets.
|
2080
|
+
:param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
|
2081
|
+
columns. The value of the key values of the key (the values in
|
2082
|
+
the column) must be in RFC 3339 `date-time` format, where
|
2083
|
+
`time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
|
2084
|
+
piece of data the key is not present or has an invalid value,
|
2085
|
+
that piece is ignored by the pipeline.
|
2086
|
+
|
2087
|
+
Supported only for tabular and time series Datasets.
|
2088
|
+
:param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
|
2089
|
+
logs. Format:
|
2090
|
+
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
2091
|
+
For more information on configuring your service account please visit:
|
2092
|
+
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
2093
|
+
"""
|
2094
|
+
self._job = self.get_custom_container_training_job(
|
2095
|
+
project=project_id,
|
2096
|
+
location=region,
|
2097
|
+
display_name=display_name,
|
2098
|
+
container_uri=container_uri,
|
2099
|
+
command=command,
|
2100
|
+
model_serving_container_image_uri=model_serving_container_image_uri,
|
2101
|
+
model_serving_container_predict_route=model_serving_container_predict_route,
|
2102
|
+
model_serving_container_health_route=model_serving_container_health_route,
|
2103
|
+
model_serving_container_command=model_serving_container_command,
|
2104
|
+
model_serving_container_args=model_serving_container_args,
|
2105
|
+
model_serving_container_environment_variables=model_serving_container_environment_variables,
|
2106
|
+
model_serving_container_ports=model_serving_container_ports,
|
2107
|
+
model_description=model_description,
|
2108
|
+
model_instance_schema_uri=model_instance_schema_uri,
|
2109
|
+
model_parameters_schema_uri=model_parameters_schema_uri,
|
2110
|
+
model_prediction_schema_uri=model_prediction_schema_uri,
|
2111
|
+
labels=labels,
|
2112
|
+
training_encryption_spec_key_name=training_encryption_spec_key_name,
|
2113
|
+
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
2114
|
+
staging_bucket=staging_bucket,
|
2115
|
+
)
|
2116
|
+
|
2117
|
+
if not self._job:
|
2118
|
+
raise AirflowException("CustomContainerTrainingJob instance creation failed.")
|
2119
|
+
|
2120
|
+
self._job.submit(
|
2121
|
+
dataset=dataset,
|
2122
|
+
annotation_schema_uri=annotation_schema_uri,
|
2123
|
+
model_display_name=model_display_name,
|
2124
|
+
model_labels=model_labels,
|
2125
|
+
base_output_dir=base_output_dir,
|
2126
|
+
service_account=service_account,
|
2127
|
+
network=network,
|
2128
|
+
bigquery_destination=bigquery_destination,
|
2129
|
+
args=args,
|
2130
|
+
environment_variables=environment_variables,
|
2131
|
+
replica_count=replica_count,
|
2132
|
+
machine_type=machine_type,
|
2133
|
+
accelerator_type=accelerator_type,
|
2134
|
+
accelerator_count=accelerator_count,
|
2135
|
+
boot_disk_type=boot_disk_type,
|
2136
|
+
boot_disk_size_gb=boot_disk_size_gb,
|
2137
|
+
training_fraction_split=training_fraction_split,
|
2138
|
+
validation_fraction_split=validation_fraction_split,
|
2139
|
+
test_fraction_split=test_fraction_split,
|
2140
|
+
training_filter_split=training_filter_split,
|
2141
|
+
validation_filter_split=validation_filter_split,
|
2142
|
+
test_filter_split=test_filter_split,
|
2143
|
+
predefined_split_column_name=predefined_split_column_name,
|
2144
|
+
timestamp_split_column_name=timestamp_split_column_name,
|
2145
|
+
tensorboard=tensorboard,
|
2146
|
+
parent_model=parent_model,
|
2147
|
+
is_default_version=is_default_version,
|
2148
|
+
model_version_aliases=model_version_aliases,
|
2149
|
+
model_version_description=model_version_description,
|
2150
|
+
sync=False,
|
2151
|
+
)
|
2152
|
+
return self._job
|
2153
|
+
|
2154
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
2155
|
+
def submit_custom_python_package_training_job(
|
2156
|
+
self,
|
2157
|
+
*,
|
2158
|
+
project_id: str,
|
2159
|
+
region: str,
|
2160
|
+
display_name: str,
|
2161
|
+
python_package_gcs_uri: str,
|
2162
|
+
python_module_name: str,
|
2163
|
+
container_uri: str,
|
2164
|
+
model_serving_container_image_uri: str | None = None,
|
2165
|
+
model_serving_container_predict_route: str | None = None,
|
2166
|
+
model_serving_container_health_route: str | None = None,
|
2167
|
+
model_serving_container_command: Sequence[str] | None = None,
|
2168
|
+
model_serving_container_args: Sequence[str] | None = None,
|
2169
|
+
model_serving_container_environment_variables: dict[str, str] | None = None,
|
2170
|
+
model_serving_container_ports: Sequence[int] | None = None,
|
2171
|
+
model_description: str | None = None,
|
2172
|
+
model_instance_schema_uri: str | None = None,
|
2173
|
+
model_parameters_schema_uri: str | None = None,
|
2174
|
+
model_prediction_schema_uri: str | None = None,
|
2175
|
+
labels: dict[str, str] | None = None,
|
2176
|
+
training_encryption_spec_key_name: str | None = None,
|
2177
|
+
model_encryption_spec_key_name: str | None = None,
|
2178
|
+
staging_bucket: str | None = None,
|
2179
|
+
# RUN
|
2180
|
+
dataset: None
|
2181
|
+
| (
|
2182
|
+
datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset
|
2183
|
+
) = None,
|
2184
|
+
annotation_schema_uri: str | None = None,
|
2185
|
+
model_display_name: str | None = None,
|
2186
|
+
model_labels: dict[str, str] | None = None,
|
2187
|
+
base_output_dir: str | None = None,
|
2188
|
+
service_account: str | None = None,
|
2189
|
+
network: str | None = None,
|
2190
|
+
bigquery_destination: str | None = None,
|
2191
|
+
args: list[str | float | int] | None = None,
|
2192
|
+
environment_variables: dict[str, str] | None = None,
|
2193
|
+
replica_count: int = 1,
|
2194
|
+
machine_type: str = "n1-standard-4",
|
2195
|
+
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
|
2196
|
+
accelerator_count: int = 0,
|
2197
|
+
boot_disk_type: str = "pd-ssd",
|
2198
|
+
boot_disk_size_gb: int = 100,
|
2199
|
+
training_fraction_split: float | None = None,
|
2200
|
+
validation_fraction_split: float | None = None,
|
2201
|
+
test_fraction_split: float | None = None,
|
2202
|
+
training_filter_split: str | None = None,
|
2203
|
+
validation_filter_split: str | None = None,
|
2204
|
+
test_filter_split: str | None = None,
|
2205
|
+
predefined_split_column_name: str | None = None,
|
2206
|
+
timestamp_split_column_name: str | None = None,
|
2207
|
+
tensorboard: str | None = None,
|
2208
|
+
parent_model: str | None = None,
|
2209
|
+
is_default_version: bool | None = None,
|
2210
|
+
model_version_aliases: list[str] | None = None,
|
2211
|
+
model_version_description: str | None = None,
|
2212
|
+
) -> CustomPythonPackageTrainingJob:
|
2213
|
+
"""
|
2214
|
+
Create and submit a Custom Python Package Training Job pipeline, then exit without waiting for it to complete.
|
2215
|
+
|
2216
|
+
:param display_name: Required. The user-defined name of this TrainingPipeline.
|
2217
|
+
:param python_package_gcs_uri: Required: GCS location of the training python package.
|
2218
|
+
:param python_module_name: Required: The module name of the training python package.
|
2219
|
+
:param container_uri: Required: Uri of the training container image in the GCR.
|
2220
|
+
:param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
|
2221
|
+
of the Model serving container suitable for serving the model produced by the
|
2222
|
+
training script.
|
2223
|
+
:param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
|
2224
|
+
HTTP path to send prediction requests to the container, and which must be supported
|
2225
|
+
by it. If not specified a default HTTP path will be used by Vertex AI.
|
2226
|
+
:param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
|
2227
|
+
HTTP path to send health check requests to the container, and which must be supported
|
2228
|
+
by it. If not specified a standard HTTP path will be used by AI Platform.
|
2229
|
+
:param model_serving_container_command: The command with which the container is run. Not executed
|
2230
|
+
within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
|
2231
|
+
Variable references $(VAR_NAME) are expanded using the container's
|
2232
|
+
environment. If a variable cannot be resolved, the reference in the
|
2233
|
+
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
|
2234
|
+
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
|
2235
|
+
expanded, regardless of whether the variable exists or not.
|
2236
|
+
:param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
|
2237
|
+
this is not provided. Variable references $(VAR_NAME) are expanded using the
|
2238
|
+
container's environment. If a variable cannot be resolved, the reference
|
2239
|
+
in the input string will be unchanged. The $(VAR_NAME) syntax can be
|
2240
|
+
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
|
2241
|
+
never be expanded, regardless of whether the variable exists or not.
|
2242
|
+
:param model_serving_container_environment_variables: The environment variables that are to be
|
2243
|
+
present in the container. Should be a dictionary where keys are environment variable names
|
2244
|
+
and values are environment variable values for those names.
|
2245
|
+
:param model_serving_container_ports: Declaration of ports that are exposed by the container. This
|
2246
|
+
field is primarily informational, it gives Vertex AI information about the
|
2247
|
+
network connections the container uses. Listing or not a port here has
|
2248
|
+
no impact on whether the port is actually exposed, any port listening on
|
2249
|
+
the default "0.0.0.0" address inside a container will be accessible from
|
2250
|
+
the network.
|
2251
|
+
:param model_description: The description of the Model.
|
2252
|
+
:param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
|
2253
|
+
Storage describing the format of a single instance, which
|
2254
|
+
are used in
|
2255
|
+
``PredictRequest.instances``,
|
2256
|
+
``ExplainRequest.instances``
|
2257
|
+
and
|
2258
|
+
``BatchPredictionJob.input_config``.
|
2259
|
+
The schema is defined as an OpenAPI 3.0.2 `Schema
|
2260
|
+
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
2261
|
+
AutoML Models always have this field populated by AI
|
2262
|
+
Platform. Note: The URI given on output will be immutable
|
2263
|
+
and probably different, including the URI scheme, than the
|
2264
|
+
one given on input. The output URI will point to a location
|
2265
|
+
where the user only has a read access.
|
2266
|
+
:param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
|
2267
|
+
Storage describing the parameters of prediction and
|
2268
|
+
explanation via
|
2269
|
+
``PredictRequest.parameters``,
|
2270
|
+
``ExplainRequest.parameters``
|
2271
|
+
and
|
2272
|
+
``BatchPredictionJob.model_parameters``.
|
2273
|
+
The schema is defined as an OpenAPI 3.0.2 `Schema
|
2274
|
+
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
2275
|
+
AutoML Models always have this field populated by AI
|
2276
|
+
Platform, if no parameters are supported it is set to an
|
2277
|
+
empty string. Note: The URI given on output will be
|
2278
|
+
immutable and probably different, including the URI scheme,
|
2279
|
+
than the one given on input. The output URI will point to a
|
2280
|
+
location where the user only has a read access.
|
2281
|
+
:param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
|
2282
|
+
Storage describing the format of a single prediction
|
2283
|
+
produced by this Model, which are returned via
|
2284
|
+
``PredictResponse.predictions``,
|
2285
|
+
``ExplainResponse.explanations``,
|
2286
|
+
and
|
2287
|
+
``BatchPredictionJob.output_config``.
|
2288
|
+
The schema is defined as an OpenAPI 3.0.2 `Schema
|
2289
|
+
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
2290
|
+
AutoML Models always have this field populated by AI
|
2291
|
+
Platform. Note: The URI given on output will be immutable
|
2292
|
+
and probably different, including the URI scheme, than the
|
2293
|
+
one given on input. The output URI will point to a location
|
2294
|
+
where the user only has a read access.
|
2295
|
+
:param parent_model: Optional. The resource name or model ID of an existing model.
|
2296
|
+
The new model uploaded by this job will be a version of `parent_model`.
|
2297
|
+
Only set this field when training a new version of an existing model.
|
2298
|
+
:param is_default_version: Optional. When set to True, the newly uploaded model version will
|
2299
|
+
automatically have alias "default" included. Subsequent uses of
|
2300
|
+
the model produced by this job without a version specified will
|
2301
|
+
use this "default" version.
|
2302
|
+
When set to False, the "default" alias will not be moved.
|
2303
|
+
Actions targeting the model version produced by this job will need
|
2304
|
+
to specifically reference this version by ID or alias.
|
2305
|
+
New model uploads, i.e. version 1, will always be "default" aliased.
|
2306
|
+
:param model_version_aliases: Optional. User provided version aliases so that the model version
|
2307
|
+
uploaded by this job can be referenced via alias instead of
|
2308
|
+
auto-generated version ID. A default version alias will be created
|
2309
|
+
for the first version of the model.
|
2310
|
+
The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
|
2311
|
+
:param model_version_description: Optional. The description of the model version
|
2312
|
+
being uploaded by this job.
|
2313
|
+
:param project_id: Project to run training in.
|
2314
|
+
:param region: Location to run training in.
|
2315
|
+
:param labels: Optional. The labels with user-defined metadata to
|
2316
|
+
organize TrainingPipelines.
|
2317
|
+
Label keys and values can be no longer than 64
|
2318
|
+
characters, can only
|
2319
|
+
contain lowercase letters, numeric characters,
|
2320
|
+
underscores and dashes. International characters
|
2321
|
+
are allowed.
|
2322
|
+
See https://goo.gl/xmQnxf for more information
|
2323
|
+
and examples of labels.
|
2324
|
+
:param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
2325
|
+
managed encryption key used to protect the training pipeline. Has the
|
2326
|
+
form:
|
2327
|
+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
2328
|
+
The key needs to be in the same region as where the compute
|
2329
|
+
resource is created.
|
2330
|
+
|
2331
|
+
If set, this TrainingPipeline will be secured by this key.
|
2332
|
+
|
2333
|
+
Note: Model trained by this TrainingPipeline is also secured
|
2334
|
+
by this key if ``model_to_upload`` is not set separately.
|
2335
|
+
:param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
2336
|
+
managed encryption key used to protect the model. Has the
|
2337
|
+
form:
|
2338
|
+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
2339
|
+
The key needs to be in the same region as where the compute
|
2340
|
+
resource is created.
|
2341
|
+
|
2342
|
+
If set, the trained Model will be secured by this key.
|
2343
|
+
:param staging_bucket: Bucket used to stage source and training artifacts.
|
2344
|
+
:param dataset: Vertex AI to fit this training against.
|
2345
|
+
:param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
|
2346
|
+
annotation schema. The schema is defined as an OpenAPI 3.0.2
|
2347
|
+
[Schema Object]
|
2348
|
+
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
|
2349
|
+
|
2350
|
+
Only Annotations that both match this schema and belong to
|
2351
|
+
DataItems not ignored by the split method are used in
|
2352
|
+
respectively training, validation or test role, depending on
|
2353
|
+
the role of the DataItem they are on.
|
2354
|
+
|
2355
|
+
When used in conjunction with
|
2356
|
+
``annotations_filter``,
|
2357
|
+
the Annotations used for training are filtered by both
|
2358
|
+
``annotations_filter``
|
2359
|
+
and
|
2360
|
+
``annotation_schema_uri``.
|
2361
|
+
:param model_display_name: If the script produces a managed Vertex AI Model. The display name of
|
2362
|
+
the Model. The name can be up to 128 characters long and can be consist
|
2363
|
+
of any UTF-8 characters.
|
2364
|
+
|
2365
|
+
If not provided upon creation, the job's display_name is used.
|
2366
|
+
:param model_labels: Optional. The labels with user-defined metadata to
|
2367
|
+
organize your Models.
|
2368
|
+
Label keys and values can be no longer than 64
|
2369
|
+
characters, can only
|
2370
|
+
contain lowercase letters, numeric characters,
|
2371
|
+
underscores and dashes. International characters
|
2372
|
+
are allowed.
|
2373
|
+
See https://goo.gl/xmQnxf for more information
|
2374
|
+
and examples of labels.
|
2375
|
+
:param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
|
2376
|
+
staging directory will be used.
|
2377
|
+
|
2378
|
+
Vertex AI sets the following environment variables when it runs your training code:
|
2379
|
+
|
2380
|
+
- AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
|
2381
|
+
i.e. <base_output_dir>/model/
|
2382
|
+
- AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
|
2383
|
+
i.e. <base_output_dir>/checkpoints/
|
2384
|
+
- AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
|
2385
|
+
logs, i.e. <base_output_dir>/logs/
|
2386
|
+
:param service_account: Specifies the service account for workload run-as account.
|
2387
|
+
Users submitting jobs must have act-as permission on this run-as account.
|
2388
|
+
:param network: The full name of the Compute Engine network to which the job
|
2389
|
+
should be peered.
|
2390
|
+
Private services access must already be configured for the network.
|
2391
|
+
If left unspecified, the job is not peered with any network.
|
2392
|
+
:param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
|
2393
|
+
The BigQuery project location where the training data is to
|
2394
|
+
be written to. In the given project a new dataset is created
|
2395
|
+
with name
|
2396
|
+
``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
|
2397
|
+
where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
|
2398
|
+
training input data will be written into that dataset. In
|
2399
|
+
the dataset three tables will be created, ``training``,
|
2400
|
+
``validation`` and ``test``.
|
2401
|
+
|
2402
|
+
- AIP_DATA_FORMAT = "bigquery".
|
2403
|
+
- AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
|
2404
|
+
- AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
|
2405
|
+
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
|
2406
|
+
:param args: Command line arguments to be passed to the Python script.
|
2407
|
+
:param environment_variables: Environment variables to be passed to the container.
|
2408
|
+
Should be a dictionary where keys are environment variable names
|
2409
|
+
and values are environment variable values for those names.
|
2410
|
+
At most 10 environment variables can be specified.
|
2411
|
+
The Name of the environment variable must be unique.
|
2412
|
+
:param replica_count: The number of worker replicas. If replica count = 1 then one chief
|
2413
|
+
replica will be provisioned. If replica_count > 1 the remainder will be
|
2414
|
+
provisioned as a worker replica pool.
|
2415
|
+
:param machine_type: The type of machine to use for training.
|
2416
|
+
:param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
|
2417
|
+
NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
|
2418
|
+
NVIDIA_TESLA_T4
|
2419
|
+
:param accelerator_count: The number of accelerators to attach to a worker replica.
|
2420
|
+
:param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
|
2421
|
+
Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
|
2422
|
+
`pd-standard` (Persistent Disk Hard Disk Drive).
|
2423
|
+
:param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
|
2424
|
+
boot disk size must be within the range of [100, 64000].
|
2425
|
+
:param training_fraction_split: Optional. The fraction of the input data that is to be used to train
|
2426
|
+
the Model. This is ignored if Dataset is not provided.
|
2427
|
+
:param validation_fraction_split: Optional. The fraction of the input data that is to be used to
|
2428
|
+
validate the Model. This is ignored if Dataset is not provided.
|
2429
|
+
:param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
|
2430
|
+
the Model. This is ignored if Dataset is not provided.
|
2431
|
+
:param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
2432
|
+
this filter are used to train the Model. A filter with same syntax
|
2433
|
+
as the one used in DatasetService.ListDataItems may be used. If a
|
2434
|
+
single DataItem is matched by more than one of the FilterSplit filters,
|
2435
|
+
then it is assigned to the first set that applies to it in the training,
|
2436
|
+
validation, test order. This is ignored if Dataset is not provided.
|
2437
|
+
:param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
2438
|
+
this filter are used to validate the Model. A filter with same syntax
|
2439
|
+
as the one used in DatasetService.ListDataItems may be used. If a
|
2440
|
+
single DataItem is matched by more than one of the FilterSplit filters,
|
2441
|
+
then it is assigned to the first set that applies to it in the training,
|
2442
|
+
validation, test order. This is ignored if Dataset is not provided.
|
2443
|
+
:param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
2444
|
+
this filter are used to test the Model. A filter with same syntax
|
2445
|
+
as the one used in DatasetService.ListDataItems may be used. If a
|
2446
|
+
single DataItem is matched by more than one of the FilterSplit filters,
|
2447
|
+
then it is assigned to the first set that applies to it in the training,
|
2448
|
+
validation, test order. This is ignored if Dataset is not provided.
|
2449
|
+
:param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
|
2450
|
+
columns. The value of the key (either the label's value or
|
2451
|
+
value in the column) must be one of {``training``,
|
2452
|
+
``validation``, ``test``}, and it defines to which set the
|
2453
|
+
given piece of data is assigned. If for a piece of data the
|
2454
|
+
key is not present or has an invalid value, that piece is
|
2455
|
+
ignored by the pipeline.
|
2456
|
+
|
2457
|
+
Supported only for tabular and time series Datasets.
|
2458
|
+
:param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
|
2459
|
+
columns. The value of the key values of the key (the values in
|
2460
|
+
the column) must be in RFC 3339 `date-time` format, where
|
2461
|
+
`time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
|
2462
|
+
piece of data the key is not present or has an invalid value,
|
2463
|
+
that piece is ignored by the pipeline.
|
2464
|
+
|
2465
|
+
Supported only for tabular and time series Datasets.
|
2466
|
+
:param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
|
2467
|
+
logs. Format:
|
2468
|
+
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
2469
|
+
For more information on configuring your service account please visit:
|
2470
|
+
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
2471
|
+
"""
|
2472
|
+
self._job = self.get_custom_python_package_training_job(
|
2473
|
+
project=project_id,
|
2474
|
+
location=region,
|
2475
|
+
display_name=display_name,
|
2476
|
+
python_package_gcs_uri=python_package_gcs_uri,
|
2477
|
+
python_module_name=python_module_name,
|
2478
|
+
container_uri=container_uri,
|
2479
|
+
model_serving_container_image_uri=model_serving_container_image_uri,
|
2480
|
+
model_serving_container_predict_route=model_serving_container_predict_route,
|
2481
|
+
model_serving_container_health_route=model_serving_container_health_route,
|
2482
|
+
model_serving_container_command=model_serving_container_command,
|
2483
|
+
model_serving_container_args=model_serving_container_args,
|
2484
|
+
model_serving_container_environment_variables=model_serving_container_environment_variables,
|
2485
|
+
model_serving_container_ports=model_serving_container_ports,
|
2486
|
+
model_description=model_description,
|
2487
|
+
model_instance_schema_uri=model_instance_schema_uri,
|
2488
|
+
model_parameters_schema_uri=model_parameters_schema_uri,
|
2489
|
+
model_prediction_schema_uri=model_prediction_schema_uri,
|
2490
|
+
labels=labels,
|
2491
|
+
training_encryption_spec_key_name=training_encryption_spec_key_name,
|
2492
|
+
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
2493
|
+
staging_bucket=staging_bucket,
|
2494
|
+
)
|
2495
|
+
|
2496
|
+
if not self._job:
|
2497
|
+
raise AirflowException("CustomPythonPackageTrainingJob instance creation failed.")
|
2498
|
+
|
2499
|
+
self._job.run(
|
2500
|
+
dataset=dataset,
|
2501
|
+
annotation_schema_uri=annotation_schema_uri,
|
2502
|
+
model_display_name=model_display_name,
|
2503
|
+
model_labels=model_labels,
|
2504
|
+
base_output_dir=base_output_dir,
|
2505
|
+
service_account=service_account,
|
2506
|
+
network=network,
|
2507
|
+
bigquery_destination=bigquery_destination,
|
2508
|
+
args=args,
|
2509
|
+
environment_variables=environment_variables,
|
2510
|
+
replica_count=replica_count,
|
2511
|
+
machine_type=machine_type,
|
2512
|
+
accelerator_type=accelerator_type,
|
2513
|
+
accelerator_count=accelerator_count,
|
2514
|
+
boot_disk_type=boot_disk_type,
|
2515
|
+
boot_disk_size_gb=boot_disk_size_gb,
|
2516
|
+
training_fraction_split=training_fraction_split,
|
2517
|
+
validation_fraction_split=validation_fraction_split,
|
2518
|
+
test_fraction_split=test_fraction_split,
|
2519
|
+
training_filter_split=training_filter_split,
|
2520
|
+
validation_filter_split=validation_filter_split,
|
2521
|
+
test_filter_split=test_filter_split,
|
2522
|
+
predefined_split_column_name=predefined_split_column_name,
|
2523
|
+
timestamp_split_column_name=timestamp_split_column_name,
|
2524
|
+
tensorboard=tensorboard,
|
2525
|
+
parent_model=parent_model,
|
2526
|
+
is_default_version=is_default_version,
|
2527
|
+
model_version_aliases=model_version_aliases,
|
2528
|
+
model_version_description=model_version_description,
|
2529
|
+
sync=False,
|
2530
|
+
)
|
2531
|
+
|
2532
|
+
return self._job
|
2533
|
+
|
2534
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
2535
|
+
def submit_custom_training_job(
|
2536
|
+
self,
|
2537
|
+
*,
|
2538
|
+
project_id: str,
|
2539
|
+
region: str,
|
2540
|
+
display_name: str,
|
2541
|
+
script_path: str,
|
2542
|
+
container_uri: str,
|
2543
|
+
requirements: Sequence[str] | None = None,
|
2544
|
+
model_serving_container_image_uri: str | None = None,
|
2545
|
+
model_serving_container_predict_route: str | None = None,
|
2546
|
+
model_serving_container_health_route: str | None = None,
|
2547
|
+
model_serving_container_command: Sequence[str] | None = None,
|
2548
|
+
model_serving_container_args: Sequence[str] | None = None,
|
2549
|
+
model_serving_container_environment_variables: dict[str, str] | None = None,
|
2550
|
+
model_serving_container_ports: Sequence[int] | None = None,
|
2551
|
+
model_description: str | None = None,
|
2552
|
+
model_instance_schema_uri: str | None = None,
|
2553
|
+
model_parameters_schema_uri: str | None = None,
|
2554
|
+
model_prediction_schema_uri: str | None = None,
|
2555
|
+
parent_model: str | None = None,
|
2556
|
+
is_default_version: bool | None = None,
|
2557
|
+
model_version_aliases: list[str] | None = None,
|
2558
|
+
model_version_description: str | None = None,
|
2559
|
+
labels: dict[str, str] | None = None,
|
2560
|
+
training_encryption_spec_key_name: str | None = None,
|
2561
|
+
model_encryption_spec_key_name: str | None = None,
|
2562
|
+
staging_bucket: str | None = None,
|
2563
|
+
# RUN
|
2564
|
+
dataset: None
|
2565
|
+
| (
|
2566
|
+
datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset
|
2567
|
+
) = None,
|
2568
|
+
annotation_schema_uri: str | None = None,
|
2569
|
+
model_display_name: str | None = None,
|
2570
|
+
model_labels: dict[str, str] | None = None,
|
2571
|
+
base_output_dir: str | None = None,
|
2572
|
+
service_account: str | None = None,
|
2573
|
+
network: str | None = None,
|
2574
|
+
bigquery_destination: str | None = None,
|
2575
|
+
args: list[str | float | int] | None = None,
|
2576
|
+
environment_variables: dict[str, str] | None = None,
|
2577
|
+
replica_count: int = 1,
|
2578
|
+
machine_type: str = "n1-standard-4",
|
2579
|
+
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
|
2580
|
+
accelerator_count: int = 0,
|
2581
|
+
boot_disk_type: str = "pd-ssd",
|
2582
|
+
boot_disk_size_gb: int = 100,
|
2583
|
+
training_fraction_split: float | None = None,
|
2584
|
+
validation_fraction_split: float | None = None,
|
2585
|
+
test_fraction_split: float | None = None,
|
2586
|
+
training_filter_split: str | None = None,
|
2587
|
+
validation_filter_split: str | None = None,
|
2588
|
+
test_filter_split: str | None = None,
|
2589
|
+
predefined_split_column_name: str | None = None,
|
2590
|
+
timestamp_split_column_name: str | None = None,
|
2591
|
+
tensorboard: str | None = None,
|
2592
|
+
) -> CustomTrainingJob:
|
2593
|
+
"""
|
2594
|
+
Create and submit a Custom Training Job pipeline, then exit without waiting for it to complete.
|
2595
|
+
|
2596
|
+
Neither the training model nor backing custom job are available at the moment when the training
|
2597
|
+
pipeline is submitted, both are created only after a period of time. Therefore, it is not possible
|
2598
|
+
to extract and return them in this method, this should be done with a separate client request.
|
2599
|
+
|
2600
|
+
:param display_name: Required. The user-defined name of this TrainingPipeline.
|
2601
|
+
:param script_path: Required. Local path to training script.
|
2602
|
+
:param container_uri: Required: Uri of the training container image in the GCR.
|
2603
|
+
:param requirements: List of python packages dependencies of script.
|
2604
|
+
:param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
|
2605
|
+
of the Model serving container suitable for serving the model produced by the
|
2606
|
+
training script.
|
2607
|
+
:param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
|
2608
|
+
HTTP path to send prediction requests to the container, and which must be supported
|
2609
|
+
by it. If not specified a default HTTP path will be used by Vertex AI.
|
2610
|
+
:param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
|
2611
|
+
HTTP path to send health check requests to the container, and which must be supported
|
2612
|
+
by it. If not specified a standard HTTP path will be used by AI Platform.
|
2613
|
+
:param model_serving_container_command: The command with which the container is run. Not executed
|
2614
|
+
within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
|
2615
|
+
Variable references $(VAR_NAME) are expanded using the container's
|
2616
|
+
environment. If a variable cannot be resolved, the reference in the
|
2617
|
+
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
|
2618
|
+
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
|
2619
|
+
expanded, regardless of whether the variable exists or not.
|
2620
|
+
:param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
|
2621
|
+
this is not provided. Variable references $(VAR_NAME) are expanded using the
|
2622
|
+
container's environment. If a variable cannot be resolved, the reference
|
2623
|
+
in the input string will be unchanged. The $(VAR_NAME) syntax can be
|
2624
|
+
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
|
2625
|
+
never be expanded, regardless of whether the variable exists or not.
|
2626
|
+
:param model_serving_container_environment_variables: The environment variables that are to be
|
2627
|
+
present in the container. Should be a dictionary where keys are environment variable names
|
2628
|
+
and values are environment variable values for those names.
|
2629
|
+
:param model_serving_container_ports: Declaration of ports that are exposed by the container. This
|
2630
|
+
field is primarily informational, it gives Vertex AI information about the
|
2631
|
+
network connections the container uses. Listing or not a port here has
|
2632
|
+
no impact on whether the port is actually exposed, any port listening on
|
2633
|
+
the default "0.0.0.0" address inside a container will be accessible from
|
2634
|
+
the network.
|
2635
|
+
:param model_description: The description of the Model.
|
2636
|
+
:param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
|
2637
|
+
Storage describing the format of a single instance, which
|
2638
|
+
are used in
|
2639
|
+
``PredictRequest.instances``,
|
2640
|
+
``ExplainRequest.instances``
|
2641
|
+
and
|
2642
|
+
``BatchPredictionJob.input_config``.
|
2643
|
+
The schema is defined as an OpenAPI 3.0.2 `Schema
|
2644
|
+
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
2645
|
+
AutoML Models always have this field populated by AI
|
2646
|
+
Platform. Note: The URI given on output will be immutable
|
2647
|
+
and probably different, including the URI scheme, than the
|
2648
|
+
one given on input. The output URI will point to a location
|
2649
|
+
where the user only has a read access.
|
2650
|
+
:param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
|
2651
|
+
Storage describing the parameters of prediction and
|
2652
|
+
explanation via
|
2653
|
+
``PredictRequest.parameters``,
|
2654
|
+
``ExplainRequest.parameters``
|
2655
|
+
and
|
2656
|
+
``BatchPredictionJob.model_parameters``.
|
2657
|
+
The schema is defined as an OpenAPI 3.0.2 `Schema
|
2658
|
+
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
2659
|
+
AutoML Models always have this field populated by AI
|
2660
|
+
Platform, if no parameters are supported it is set to an
|
2661
|
+
empty string. Note: The URI given on output will be
|
2662
|
+
immutable and probably different, including the URI scheme,
|
2663
|
+
than the one given on input. The output URI will point to a
|
2664
|
+
location where the user only has a read access.
|
2665
|
+
:param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
|
2666
|
+
Storage describing the format of a single prediction
|
2667
|
+
produced by this Model, which are returned via
|
2668
|
+
``PredictResponse.predictions``,
|
2669
|
+
``ExplainResponse.explanations``,
|
2670
|
+
and
|
2671
|
+
``BatchPredictionJob.output_config``.
|
2672
|
+
The schema is defined as an OpenAPI 3.0.2 `Schema
|
2673
|
+
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
2674
|
+
AutoML Models always have this field populated by AI
|
2675
|
+
Platform. Note: The URI given on output will be immutable
|
2676
|
+
and probably different, including the URI scheme, than the
|
2677
|
+
one given on input. The output URI will point to a location
|
2678
|
+
where the user only has a read access.
|
2679
|
+
:param parent_model: Optional. The resource name or model ID of an existing model.
|
2680
|
+
The new model uploaded by this job will be a version of `parent_model`.
|
2681
|
+
Only set this field when training a new version of an existing model.
|
2682
|
+
:param is_default_version: Optional. When set to True, the newly uploaded model version will
|
2683
|
+
automatically have alias "default" included. Subsequent uses of
|
2684
|
+
the model produced by this job without a version specified will
|
2685
|
+
use this "default" version.
|
2686
|
+
When set to False, the "default" alias will not be moved.
|
2687
|
+
Actions targeting the model version produced by this job will need
|
2688
|
+
to specifically reference this version by ID or alias.
|
2689
|
+
New model uploads, i.e. version 1, will always be "default" aliased.
|
2690
|
+
:param model_version_aliases: Optional. User provided version aliases so that the model version
|
2691
|
+
uploaded by this job can be referenced via alias instead of
|
2692
|
+
auto-generated version ID. A default version alias will be created
|
2693
|
+
for the first version of the model.
|
2694
|
+
The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
|
2695
|
+
:param model_version_description: Optional. The description of the model version
|
2696
|
+
being uploaded by this job.
|
2697
|
+
:param project_id: Project to run training in.
|
2698
|
+
:param region: Location to run training in.
|
2699
|
+
:param labels: Optional. The labels with user-defined metadata to
|
2700
|
+
organize TrainingPipelines.
|
2701
|
+
Label keys and values can be no longer than 64
|
2702
|
+
characters, can only
|
2703
|
+
contain lowercase letters, numeric characters,
|
2704
|
+
underscores and dashes. International characters
|
2705
|
+
are allowed.
|
2706
|
+
See https://goo.gl/xmQnxf for more information
|
2707
|
+
and examples of labels.
|
2708
|
+
:param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
2709
|
+
managed encryption key used to protect the training pipeline. Has the
|
2710
|
+
form:
|
2711
|
+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
2712
|
+
The key needs to be in the same region as where the compute
|
2713
|
+
resource is created.
|
2714
|
+
|
2715
|
+
If set, this TrainingPipeline will be secured by this key.
|
2716
|
+
|
2717
|
+
Note: Model trained by this TrainingPipeline is also secured
|
2718
|
+
by this key if ``model_to_upload`` is not set separately.
|
2719
|
+
:param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
2720
|
+
managed encryption key used to protect the model. Has the
|
2721
|
+
form:
|
2722
|
+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
2723
|
+
The key needs to be in the same region as where the compute
|
2724
|
+
resource is created.
|
2725
|
+
|
2726
|
+
If set, the trained Model will be secured by this key.
|
2727
|
+
:param staging_bucket: Bucket used to stage source and training artifacts.
|
2728
|
+
:param dataset: Vertex AI to fit this training against.
|
2729
|
+
:param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
|
2730
|
+
annotation schema. The schema is defined as an OpenAPI 3.0.2
|
2731
|
+
[Schema Object]
|
2732
|
+
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
|
2733
|
+
|
2734
|
+
Only Annotations that both match this schema and belong to
|
2735
|
+
DataItems not ignored by the split method are used in
|
2736
|
+
respectively training, validation or test role, depending on
|
2737
|
+
the role of the DataItem they are on.
|
2738
|
+
|
2739
|
+
When used in conjunction with
|
2740
|
+
``annotations_filter``,
|
2741
|
+
the Annotations used for training are filtered by both
|
2742
|
+
``annotations_filter``
|
2743
|
+
and
|
2744
|
+
``annotation_schema_uri``.
|
2745
|
+
:param model_display_name: If the script produces a managed Vertex AI Model. The display name of
|
2746
|
+
the Model. The name can be up to 128 characters long and can be consist
|
2747
|
+
of any UTF-8 characters.
|
2748
|
+
|
2749
|
+
If not provided upon creation, the job's display_name is used.
|
2750
|
+
:param model_labels: Optional. The labels with user-defined metadata to
|
2751
|
+
organize your Models.
|
2752
|
+
Label keys and values can be no longer than 64
|
2753
|
+
characters, can only
|
2754
|
+
contain lowercase letters, numeric characters,
|
2755
|
+
underscores and dashes. International characters
|
2756
|
+
are allowed.
|
2757
|
+
See https://goo.gl/xmQnxf for more information
|
2758
|
+
and examples of labels.
|
2759
|
+
:param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
|
2760
|
+
staging directory will be used.
|
2761
|
+
|
2762
|
+
Vertex AI sets the following environment variables when it runs your training code:
|
2763
|
+
|
2764
|
+
- AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
|
2765
|
+
i.e. <base_output_dir>/model/
|
2766
|
+
- AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
|
2767
|
+
i.e. <base_output_dir>/checkpoints/
|
2768
|
+
- AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
|
2769
|
+
logs, i.e. <base_output_dir>/logs/
|
2770
|
+
:param service_account: Specifies the service account for workload run-as account.
|
2771
|
+
Users submitting jobs must have act-as permission on this run-as account.
|
2772
|
+
:param network: The full name of the Compute Engine network to which the job
|
2773
|
+
should be peered.
|
2774
|
+
Private services access must already be configured for the network.
|
2775
|
+
If left unspecified, the job is not peered with any network.
|
2776
|
+
:param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
|
2777
|
+
The BigQuery project location where the training data is to
|
2778
|
+
be written to. In the given project a new dataset is created
|
2779
|
+
with name
|
2780
|
+
``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
|
2781
|
+
where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
|
2782
|
+
training input data will be written into that dataset. In
|
2783
|
+
the dataset three tables will be created, ``training``,
|
2784
|
+
``validation`` and ``test``.
|
2785
|
+
|
2786
|
+
- AIP_DATA_FORMAT = "bigquery".
|
2787
|
+
- AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
|
2788
|
+
- AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
|
2789
|
+
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
|
2790
|
+
:param args: Command line arguments to be passed to the Python script.
|
2791
|
+
:param environment_variables: Environment variables to be passed to the container.
|
2792
|
+
Should be a dictionary where keys are environment variable names
|
2793
|
+
and values are environment variable values for those names.
|
2794
|
+
At most 10 environment variables can be specified.
|
2795
|
+
The Name of the environment variable must be unique.
|
2796
|
+
:param replica_count: The number of worker replicas. If replica count = 1 then one chief
|
2797
|
+
replica will be provisioned. If replica_count > 1 the remainder will be
|
2798
|
+
provisioned as a worker replica pool.
|
2799
|
+
:param machine_type: The type of machine to use for training.
|
2800
|
+
:param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
|
2801
|
+
NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
|
2802
|
+
NVIDIA_TESLA_T4
|
2803
|
+
:param accelerator_count: The number of accelerators to attach to a worker replica.
|
2804
|
+
:param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
|
2805
|
+
Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
|
2806
|
+
`pd-standard` (Persistent Disk Hard Disk Drive).
|
2807
|
+
:param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
|
2808
|
+
boot disk size must be within the range of [100, 64000].
|
2809
|
+
:param training_fraction_split: Optional. The fraction of the input data that is to be used to train
|
2810
|
+
the Model. This is ignored if Dataset is not provided.
|
2811
|
+
:param validation_fraction_split: Optional. The fraction of the input data that is to be used to
|
2812
|
+
validate the Model. This is ignored if Dataset is not provided.
|
2813
|
+
:param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
|
2814
|
+
the Model. This is ignored if Dataset is not provided.
|
2815
|
+
:param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
2816
|
+
this filter are used to train the Model. A filter with same syntax
|
2817
|
+
as the one used in DatasetService.ListDataItems may be used. If a
|
2818
|
+
single DataItem is matched by more than one of the FilterSplit filters,
|
2819
|
+
then it is assigned to the first set that applies to it in the training,
|
2820
|
+
validation, test order. This is ignored if Dataset is not provided.
|
2821
|
+
:param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
2822
|
+
this filter are used to validate the Model. A filter with same syntax
|
2823
|
+
as the one used in DatasetService.ListDataItems may be used. If a
|
2824
|
+
single DataItem is matched by more than one of the FilterSplit filters,
|
2825
|
+
then it is assigned to the first set that applies to it in the training,
|
2826
|
+
validation, test order. This is ignored if Dataset is not provided.
|
2827
|
+
:param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
2828
|
+
this filter are used to test the Model. A filter with same syntax
|
2829
|
+
as the one used in DatasetService.ListDataItems may be used. If a
|
2830
|
+
single DataItem is matched by more than one of the FilterSplit filters,
|
2831
|
+
then it is assigned to the first set that applies to it in the training,
|
2832
|
+
validation, test order. This is ignored if Dataset is not provided.
|
2833
|
+
:param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
|
2834
|
+
columns. The value of the key (either the label's value or
|
2835
|
+
value in the column) must be one of {``training``,
|
2836
|
+
``validation``, ``test``}, and it defines to which set the
|
2837
|
+
given piece of data is assigned. If for a piece of data the
|
2838
|
+
key is not present or has an invalid value, that piece is
|
2839
|
+
ignored by the pipeline.
|
2840
|
+
|
2841
|
+
Supported only for tabular and time series Datasets.
|
2842
|
+
:param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
|
2843
|
+
columns. The value of the key values of the key (the values in
|
2844
|
+
the column) must be in RFC 3339 `date-time` format, where
|
2845
|
+
`time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
|
2846
|
+
piece of data the key is not present or has an invalid value,
|
2847
|
+
that piece is ignored by the pipeline.
|
2848
|
+
|
2849
|
+
Supported only for tabular and time series Datasets.
|
2850
|
+
:param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
|
2851
|
+
logs. Format:
|
2852
|
+
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
2853
|
+
For more information on configuring your service account please visit:
|
2854
|
+
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
2855
|
+
"""
|
2856
|
+
self._job = self.get_custom_training_job(
|
2857
|
+
project=project_id,
|
2858
|
+
location=region,
|
2859
|
+
display_name=display_name,
|
2860
|
+
script_path=script_path,
|
2861
|
+
container_uri=container_uri,
|
2862
|
+
requirements=requirements,
|
2863
|
+
model_serving_container_image_uri=model_serving_container_image_uri,
|
2864
|
+
model_serving_container_predict_route=model_serving_container_predict_route,
|
2865
|
+
model_serving_container_health_route=model_serving_container_health_route,
|
2866
|
+
model_serving_container_command=model_serving_container_command,
|
2867
|
+
model_serving_container_args=model_serving_container_args,
|
2868
|
+
model_serving_container_environment_variables=model_serving_container_environment_variables,
|
2869
|
+
model_serving_container_ports=model_serving_container_ports,
|
2870
|
+
model_description=model_description,
|
2871
|
+
model_instance_schema_uri=model_instance_schema_uri,
|
2872
|
+
model_parameters_schema_uri=model_parameters_schema_uri,
|
2873
|
+
model_prediction_schema_uri=model_prediction_schema_uri,
|
2874
|
+
labels=labels,
|
2875
|
+
training_encryption_spec_key_name=training_encryption_spec_key_name,
|
2876
|
+
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
2877
|
+
staging_bucket=staging_bucket,
|
2878
|
+
)
|
2879
|
+
|
2880
|
+
if not self._job:
|
2881
|
+
raise AirflowException("CustomTrainingJob instance creation failed.")
|
2882
|
+
|
2883
|
+
self._job.submit(
|
2884
|
+
dataset=dataset,
|
2885
|
+
annotation_schema_uri=annotation_schema_uri,
|
2886
|
+
model_display_name=model_display_name,
|
2887
|
+
model_labels=model_labels,
|
2888
|
+
base_output_dir=base_output_dir,
|
2889
|
+
service_account=service_account,
|
2890
|
+
network=network,
|
2891
|
+
bigquery_destination=bigquery_destination,
|
2892
|
+
args=args,
|
2893
|
+
environment_variables=environment_variables,
|
2894
|
+
replica_count=replica_count,
|
2895
|
+
machine_type=machine_type,
|
2896
|
+
accelerator_type=accelerator_type,
|
2897
|
+
accelerator_count=accelerator_count,
|
2898
|
+
boot_disk_type=boot_disk_type,
|
2899
|
+
boot_disk_size_gb=boot_disk_size_gb,
|
2900
|
+
training_fraction_split=training_fraction_split,
|
2901
|
+
validation_fraction_split=validation_fraction_split,
|
2902
|
+
test_fraction_split=test_fraction_split,
|
2903
|
+
training_filter_split=training_filter_split,
|
2904
|
+
validation_filter_split=validation_filter_split,
|
2905
|
+
test_filter_split=test_filter_split,
|
2906
|
+
predefined_split_column_name=predefined_split_column_name,
|
2907
|
+
timestamp_split_column_name=timestamp_split_column_name,
|
2908
|
+
tensorboard=tensorboard,
|
2909
|
+
parent_model=parent_model,
|
2910
|
+
is_default_version=is_default_version,
|
2911
|
+
model_version_aliases=model_version_aliases,
|
2912
|
+
model_version_description=model_version_description,
|
2913
|
+
sync=False,
|
2914
|
+
)
|
2915
|
+
return self._job
|
2916
|
+
|
2917
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
2918
|
+
def delete_training_pipeline(
|
2919
|
+
self,
|
2920
|
+
project_id: str,
|
2921
|
+
region: str,
|
2922
|
+
training_pipeline: str,
|
2923
|
+
retry: Retry | _MethodDefault = DEFAULT,
|
2924
|
+
timeout: float | None = None,
|
2925
|
+
metadata: Sequence[tuple[str, str]] = (),
|
2926
|
+
) -> Operation:
|
2927
|
+
"""
|
2928
|
+
Delete a TrainingPipeline.
|
2929
|
+
|
2930
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
2931
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
2932
|
+
:param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted.
|
2933
|
+
:param retry: Designation of what errors, if any, should be retried.
|
2934
|
+
:param timeout: The timeout for this request.
|
2935
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
2936
|
+
"""
|
2937
|
+
client = self.get_pipeline_service_client(region)
|
2938
|
+
name = client.training_pipeline_path(project_id, region, training_pipeline)
|
2939
|
+
|
2940
|
+
result = client.delete_training_pipeline(
|
2941
|
+
request={"name": name},
|
2942
|
+
retry=retry,
|
2943
|
+
timeout=timeout,
|
2944
|
+
metadata=metadata,
|
2945
|
+
)
|
2946
|
+
return result
|
2947
|
+
|
2948
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
2949
|
+
def delete_custom_job(
|
2950
|
+
self,
|
1831
2951
|
project_id: str,
|
1832
2952
|
region: str,
|
1833
2953
|
custom_job: str,
|
@@ -2178,3 +3298,240 @@ class CustomJobHook(GoogleBaseHook):
|
|
2178
3298
|
metadata=metadata,
|
2179
3299
|
)
|
2180
3300
|
return result
|
3301
|
+
|
3302
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
3303
|
+
@deprecated(
|
3304
|
+
reason="Please use `PipelineJobHook.delete_pipeline_job`",
|
3305
|
+
category=AirflowProviderDeprecationWarning,
|
3306
|
+
)
|
3307
|
+
def delete_pipeline_job(
|
3308
|
+
self,
|
3309
|
+
project_id: str,
|
3310
|
+
region: str,
|
3311
|
+
pipeline_job: str,
|
3312
|
+
retry: Retry | _MethodDefault = DEFAULT,
|
3313
|
+
timeout: float | None = None,
|
3314
|
+
metadata: Sequence[tuple[str, str]] = (),
|
3315
|
+
) -> Operation:
|
3316
|
+
"""
|
3317
|
+
Delete a PipelineJob.
|
3318
|
+
|
3319
|
+
This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method.
|
3320
|
+
|
3321
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
3322
|
+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
3323
|
+
:param pipeline_job: Required. The name of the PipelineJob resource to be deleted.
|
3324
|
+
:param retry: Designation of what errors, if any, should be retried.
|
3325
|
+
:param timeout: The timeout for this request.
|
3326
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
3327
|
+
"""
|
3328
|
+
client = self.get_pipeline_service_client(region)
|
3329
|
+
name = client.pipeline_job_path(project_id, region, pipeline_job)
|
3330
|
+
|
3331
|
+
result = client.delete_pipeline_job(
|
3332
|
+
request={"name": name},
|
3333
|
+
retry=retry,
|
3334
|
+
timeout=timeout,
|
3335
|
+
metadata=metadata,
|
3336
|
+
)
|
3337
|
+
return result
|
3338
|
+
|
3339
|
+
|
3340
|
+
class CustomJobAsyncHook(GoogleBaseAsyncHook):
|
3341
|
+
"""Async hook for Custom Job Service Client."""
|
3342
|
+
|
3343
|
+
sync_hook_class = CustomJobHook
|
3344
|
+
JOB_COMPLETE_STATES = {
|
3345
|
+
JobState.JOB_STATE_CANCELLED,
|
3346
|
+
JobState.JOB_STATE_FAILED,
|
3347
|
+
JobState.JOB_STATE_PAUSED,
|
3348
|
+
JobState.JOB_STATE_SUCCEEDED,
|
3349
|
+
}
|
3350
|
+
PIPELINE_COMPLETE_STATES = (
|
3351
|
+
PipelineState.PIPELINE_STATE_CANCELLED,
|
3352
|
+
PipelineState.PIPELINE_STATE_FAILED,
|
3353
|
+
PipelineState.PIPELINE_STATE_PAUSED,
|
3354
|
+
PipelineState.PIPELINE_STATE_SUCCEEDED,
|
3355
|
+
)
|
3356
|
+
|
3357
|
+
def __init__(
|
3358
|
+
self,
|
3359
|
+
gcp_conn_id: str = "google_cloud_default",
|
3360
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
3361
|
+
**kwargs,
|
3362
|
+
):
|
3363
|
+
super().__init__(
|
3364
|
+
gcp_conn_id=gcp_conn_id,
|
3365
|
+
impersonation_chain=impersonation_chain,
|
3366
|
+
**kwargs,
|
3367
|
+
)
|
3368
|
+
self._job: None | (
|
3369
|
+
CustomContainerTrainingJob | CustomPythonPackageTrainingJob | CustomTrainingJob
|
3370
|
+
) = None
|
3371
|
+
|
3372
|
+
async def get_credentials(self) -> Credentials:
|
3373
|
+
return (await self.get_sync_hook()).get_credentials()
|
3374
|
+
|
3375
|
+
async def get_job_service_client(
|
3376
|
+
self,
|
3377
|
+
region: str | None = None,
|
3378
|
+
) -> JobServiceAsyncClient:
|
3379
|
+
"""Retrieve Vertex AI JobServiceAsyncClient object."""
|
3380
|
+
if region and region != "global":
|
3381
|
+
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
|
3382
|
+
else:
|
3383
|
+
client_options = ClientOptions()
|
3384
|
+
return JobServiceAsyncClient(
|
3385
|
+
credentials=(await self.get_credentials()),
|
3386
|
+
client_info=CLIENT_INFO,
|
3387
|
+
client_options=client_options,
|
3388
|
+
)
|
3389
|
+
|
3390
|
+
async def get_pipeline_service_client(
|
3391
|
+
self,
|
3392
|
+
region: str | None = None,
|
3393
|
+
) -> PipelineServiceAsyncClient:
|
3394
|
+
"""Retrieve Vertex AI PipelineServiceAsyncClient object."""
|
3395
|
+
if region and region != "global":
|
3396
|
+
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
|
3397
|
+
else:
|
3398
|
+
client_options = ClientOptions()
|
3399
|
+
return PipelineServiceAsyncClient(
|
3400
|
+
credentials=(await self.get_credentials()),
|
3401
|
+
client_info=CLIENT_INFO,
|
3402
|
+
client_options=client_options,
|
3403
|
+
)
|
3404
|
+
|
3405
|
+
async def get_custom_job(
|
3406
|
+
self,
|
3407
|
+
project_id: str,
|
3408
|
+
location: str,
|
3409
|
+
job_id: str,
|
3410
|
+
retry: AsyncRetry | _MethodDefault = DEFAULT,
|
3411
|
+
timeout: float | _MethodDefault | None = DEFAULT,
|
3412
|
+
metadata: Sequence[tuple[str, str]] = (),
|
3413
|
+
client: JobServiceAsyncClient | None = None,
|
3414
|
+
) -> types.CustomJob:
|
3415
|
+
"""
|
3416
|
+
Get a CustomJob proto message from JobServiceAsyncClient.
|
3417
|
+
|
3418
|
+
:param project_id: Required. The ID of the Google Cloud project that the job belongs to.
|
3419
|
+
:param location: Required. The ID of the Google Cloud region that the job belongs to.
|
3420
|
+
:param job_id: Required. The custom job id.
|
3421
|
+
:param retry: Designation of what errors, if any, should be retried.
|
3422
|
+
:param timeout: The timeout for this request.
|
3423
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
3424
|
+
:param client: The async job service client.
|
3425
|
+
"""
|
3426
|
+
if not client:
|
3427
|
+
client = await self.get_job_service_client(region=location)
|
3428
|
+
job_name = client.custom_job_path(project_id, location, job_id)
|
3429
|
+
result: types.CustomJob = await client.get_custom_job(
|
3430
|
+
request={"name": job_name},
|
3431
|
+
retry=retry,
|
3432
|
+
timeout=timeout,
|
3433
|
+
metadata=metadata,
|
3434
|
+
)
|
3435
|
+
return result
|
3436
|
+
|
3437
|
+
async def get_training_pipeline(
|
3438
|
+
self,
|
3439
|
+
project_id: str,
|
3440
|
+
location: str,
|
3441
|
+
pipeline_id: str,
|
3442
|
+
retry: AsyncRetry | _MethodDefault = DEFAULT,
|
3443
|
+
timeout: float | _MethodDefault | None = DEFAULT,
|
3444
|
+
metadata: Sequence[tuple[str, str]] = (),
|
3445
|
+
client: PipelineServiceAsyncClient | None = None,
|
3446
|
+
) -> types.TrainingPipeline:
|
3447
|
+
"""
|
3448
|
+
Get a TrainingPipeline proto message from PipelineServiceAsyncClient.
|
3449
|
+
|
3450
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
3451
|
+
:param location: Required. The ID of the Google Cloud region that the service belongs to.
|
3452
|
+
:param pipeline_id: Required. The ID of the PipelineJob resource.
|
3453
|
+
:param retry: Designation of what errors, if any, should be retried.
|
3454
|
+
:param timeout: The timeout for this request.
|
3455
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
3456
|
+
:param client: The async pipeline service client.
|
3457
|
+
"""
|
3458
|
+
if not client:
|
3459
|
+
client = await self.get_pipeline_service_client(region=location)
|
3460
|
+
pipeline_name = client.training_pipeline_path(
|
3461
|
+
project=project_id,
|
3462
|
+
location=location,
|
3463
|
+
training_pipeline=pipeline_id,
|
3464
|
+
)
|
3465
|
+
response: types.TrainingPipeline = await client.get_training_pipeline(
|
3466
|
+
request={"name": pipeline_name},
|
3467
|
+
retry=retry,
|
3468
|
+
timeout=timeout,
|
3469
|
+
metadata=metadata,
|
3470
|
+
)
|
3471
|
+
return response
|
3472
|
+
|
3473
|
+
async def wait_for_custom_job(
|
3474
|
+
self,
|
3475
|
+
project_id: str,
|
3476
|
+
location: str,
|
3477
|
+
job_id: str,
|
3478
|
+
retry: AsyncRetry | _MethodDefault = DEFAULT,
|
3479
|
+
timeout: float | None = None,
|
3480
|
+
metadata: Sequence[tuple[str, str]] = (),
|
3481
|
+
poll_interval: int = 10,
|
3482
|
+
) -> types.CustomJob:
|
3483
|
+
"""Make async calls to Vertex AI to check the custom job state until it is complete."""
|
3484
|
+
client = await self.get_job_service_client(region=location)
|
3485
|
+
while True:
|
3486
|
+
try:
|
3487
|
+
self.log.info("Requesting a custom job with id %s", job_id)
|
3488
|
+
job: types.CustomJob = await self.get_custom_job(
|
3489
|
+
project_id=project_id,
|
3490
|
+
location=location,
|
3491
|
+
job_id=job_id,
|
3492
|
+
retry=retry,
|
3493
|
+
timeout=timeout,
|
3494
|
+
metadata=metadata,
|
3495
|
+
client=client,
|
3496
|
+
)
|
3497
|
+
except Exception as ex:
|
3498
|
+
self.log.exception("Exception occurred while requesting job %s", job_id)
|
3499
|
+
raise AirflowException(ex)
|
3500
|
+
self.log.info("Status of the custom job %s is %s", job.name, job.state.name)
|
3501
|
+
if job.state in self.JOB_COMPLETE_STATES:
|
3502
|
+
return job
|
3503
|
+
self.log.info("Sleeping for %s seconds.", poll_interval)
|
3504
|
+
await asyncio.sleep(poll_interval)
|
3505
|
+
|
3506
|
+
async def wait_for_training_pipeline(
|
3507
|
+
self,
|
3508
|
+
project_id: str,
|
3509
|
+
location: str,
|
3510
|
+
pipeline_id: str,
|
3511
|
+
retry: AsyncRetry | _MethodDefault = DEFAULT,
|
3512
|
+
timeout: float | None = None,
|
3513
|
+
metadata: Sequence[tuple[str, str]] = (),
|
3514
|
+
poll_interval: int = 10,
|
3515
|
+
) -> types.TrainingPipeline:
|
3516
|
+
"""Make async calls to Vertex AI to check the training pipeline state until it is complete."""
|
3517
|
+
client = await self.get_pipeline_service_client(region=location)
|
3518
|
+
while True:
|
3519
|
+
try:
|
3520
|
+
self.log.info("Requesting a training pipeline with id %s", pipeline_id)
|
3521
|
+
pipeline: types.TrainingPipeline = await self.get_training_pipeline(
|
3522
|
+
project_id=project_id,
|
3523
|
+
location=location,
|
3524
|
+
pipeline_id=pipeline_id,
|
3525
|
+
retry=retry,
|
3526
|
+
timeout=timeout,
|
3527
|
+
metadata=metadata,
|
3528
|
+
client=client,
|
3529
|
+
)
|
3530
|
+
except Exception as ex:
|
3531
|
+
self.log.exception("Exception occurred while requesting training pipeline %s", pipeline_id)
|
3532
|
+
raise AirflowException(ex)
|
3533
|
+
self.log.info("Status of the training pipeline %s is %s", pipeline.name, pipeline.state.name)
|
3534
|
+
if pipeline.state in self.PIPELINE_COMPLETE_STATES:
|
3535
|
+
return pipeline
|
3536
|
+
self.log.info("Sleeping for %s seconds.", poll_interval)
|
3537
|
+
await asyncio.sleep(poll_interval)
|