apache-airflow-providers-google 10.10.0__py3-none-any.whl → 10.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/cloud/hooks/cloud_run.py +4 -2
  3. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +131 -27
  4. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +1 -9
  5. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +121 -4
  6. airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +1 -11
  7. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -10
  8. airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +220 -6
  9. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +409 -0
  10. airflow/providers/google/cloud/links/vertex_ai.py +49 -0
  11. airflow/providers/google/cloud/operators/dataproc.py +32 -10
  12. airflow/providers/google/cloud/operators/gcs.py +1 -1
  13. airflow/providers/google/cloud/operators/mlengine.py +116 -0
  14. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +45 -0
  15. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +2 -8
  16. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +287 -201
  17. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +1 -9
  18. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +2 -9
  19. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +451 -12
  20. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +464 -0
  21. airflow/providers/google/cloud/utils/mlengine_operator_utils.py +7 -1
  22. airflow/providers/google/get_provider_info.py +5 -0
  23. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/METADATA +6 -6
  24. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/RECORD +29 -27
  25. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/LICENSE +0 -0
  26. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/NOTICE +0 -0
  27. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/WHEEL +0 -0
  28. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/entry_points.txt +0 -0
  29. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/top_level.txt +0 -0
@@ -61,6 +61,10 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
61
61
  model_instance_schema_uri: str | None = None,
62
62
  model_parameters_schema_uri: str | None = None,
63
63
  model_prediction_schema_uri: str | None = None,
64
+ parent_model: str | None = None,
65
+ is_default_version: bool | None = None,
66
+ model_version_aliases: list[str] | None = None,
67
+ model_version_description: str | None = None,
64
68
  labels: dict[str, str] | None = None,
65
69
  training_encryption_spec_key_name: str | None = None,
66
70
  model_encryption_spec_key_name: str | None = None,
@@ -114,6 +118,10 @@ class CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
114
118
  self.model_parameters_schema_uri = model_parameters_schema_uri
115
119
  self.model_prediction_schema_uri = model_prediction_schema_uri
116
120
  self.labels = labels
121
+ self.parent_model = parent_model
122
+ self.is_default_version = is_default_version
123
+ self.model_version_aliases = model_version_aliases
124
+ self.model_version_description = model_version_description
117
125
  self.training_encryption_spec_key_name = training_encryption_spec_key_name
118
126
  self.model_encryption_spec_key_name = model_encryption_spec_key_name
119
127
  self.staging_bucket = staging_bucket
@@ -192,48 +200,66 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
192
200
  the network.
193
201
  :param model_description: The description of the Model.
194
202
  :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
195
- Storage describing the format of a single instance, which
196
- are used in
197
- ``PredictRequest.instances``,
198
- ``ExplainRequest.instances``
199
- and
200
- ``BatchPredictionJob.input_config``.
201
- The schema is defined as an OpenAPI 3.0.2 `Schema
202
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
203
- AutoML Models always have this field populated by AI
204
- Platform. Note: The URI given on output will be immutable
205
- and probably different, including the URI scheme, than the
206
- one given on input. The output URI will point to a location
207
- where the user only has a read access.
203
+ Storage describing the format of a single instance, which
204
+ are used in
205
+ ``PredictRequest.instances``,
206
+ ``ExplainRequest.instances``
207
+ and
208
+ ``BatchPredictionJob.input_config``.
209
+ The schema is defined as an OpenAPI 3.0.2 `Schema
210
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
211
+ AutoML Models always have this field populated by AI
212
+ Platform. Note: The URI given on output will be immutable
213
+ and probably different, including the URI scheme, than the
214
+ one given on input. The output URI will point to a location
215
+ where the user only has a read access.
208
216
  :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
209
- Storage describing the parameters of prediction and
210
- explanation via
211
- ``PredictRequest.parameters``,
212
- ``ExplainRequest.parameters``
213
- and
214
- ``BatchPredictionJob.model_parameters``.
215
- The schema is defined as an OpenAPI 3.0.2 `Schema
216
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
217
- AutoML Models always have this field populated by AI
218
- Platform, if no parameters are supported it is set to an
219
- empty string. Note: The URI given on output will be
220
- immutable and probably different, including the URI scheme,
221
- than the one given on input. The output URI will point to a
222
- location where the user only has a read access.
217
+ Storage describing the parameters of prediction and
218
+ explanation via
219
+ ``PredictRequest.parameters``,
220
+ ``ExplainRequest.parameters``
221
+ and
222
+ ``BatchPredictionJob.model_parameters``.
223
+ The schema is defined as an OpenAPI 3.0.2 `Schema
224
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
225
+ AutoML Models always have this field populated by AI
226
+ Platform, if no parameters are supported it is set to an
227
+ empty string. Note: The URI given on output will be
228
+ immutable and probably different, including the URI scheme,
229
+ than the one given on input. The output URI will point to a
230
+ location where the user only has a read access.
223
231
  :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
224
- Storage describing the format of a single prediction
225
- produced by this Model, which are returned via
226
- ``PredictResponse.predictions``,
227
- ``ExplainResponse.explanations``,
228
- and
229
- ``BatchPredictionJob.output_config``.
230
- The schema is defined as an OpenAPI 3.0.2 `Schema
231
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
232
- AutoML Models always have this field populated by AI
233
- Platform. Note: The URI given on output will be immutable
234
- and probably different, including the URI scheme, than the
235
- one given on input. The output URI will point to a location
236
- where the user only has a read access.
232
+ Storage describing the format of a single prediction
233
+ produced by this Model, which are returned via
234
+ ``PredictResponse.predictions``,
235
+ ``ExplainResponse.explanations``,
236
+ and
237
+ ``BatchPredictionJob.output_config``.
238
+ The schema is defined as an OpenAPI 3.0.2 `Schema
239
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
240
+ AutoML Models always have this field populated by AI
241
+ Platform. Note: The URI given on output will be immutable
242
+ and probably different, including the URI scheme, than the
243
+ one given on input. The output URI will point to a location
244
+ where the user only has a read access.
245
+ :param parent_model: Optional. The resource name or model ID of an existing model.
246
+ The new model uploaded by this job will be a version of `parent_model`.
247
+ Only set this field when training a new version of an existing model.
248
+ :param is_default_version: Optional. When set to True, the newly uploaded model version will
249
+ automatically have alias "default" included. Subsequent uses of
250
+ the model produced by this job without a version specified will
251
+ use this "default" version.
252
+ When set to False, the "default" alias will not be moved.
253
+ Actions targeting the model version produced by this job will need
254
+ to specifically reference this version by ID or alias.
255
+ New model uploads, i.e. version 1, will always be "default" aliased.
256
+ :param model_version_aliases: Optional. User provided version aliases so that the model version
257
+ uploaded by this job can be referenced via alias instead of
258
+ auto-generated version ID. A default version alias will be created
259
+ for the first version of the model.
260
+ The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
261
+ :param model_version_description: Optional. The description of the model version
262
+ being uploaded by this job.
237
263
  :param project_id: Project to run training in.
238
264
  :param region: Location to run training in.
239
265
  :param labels: Optional. The labels with user-defined metadata to
@@ -409,6 +435,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
409
435
  template_fields = (
410
436
  "region",
411
437
  "command",
438
+ "parent_model",
412
439
  "dataset_id",
413
440
  "impersonation_chain",
414
441
  )
@@ -428,6 +455,8 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
428
455
  gcp_conn_id=self.gcp_conn_id,
429
456
  impersonation_chain=self.impersonation_chain,
430
457
  )
458
+ self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
459
+
431
460
  model, training_id, custom_job_id = self.hook.create_custom_container_training_job(
432
461
  project_id=self.project_id,
433
462
  region=self.region,
@@ -445,6 +474,10 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
445
474
  model_instance_schema_uri=self.model_instance_schema_uri,
446
475
  model_parameters_schema_uri=self.model_parameters_schema_uri,
447
476
  model_prediction_schema_uri=self.model_prediction_schema_uri,
477
+ parent_model=self.parent_model,
478
+ is_default_version=self.is_default_version,
479
+ model_version_aliases=self.model_version_aliases,
480
+ model_version_description=self.model_version_description,
448
481
  labels=self.labels,
449
482
  training_encryption_spec_key_name=self.training_encryption_spec_key_name,
450
483
  model_encryption_spec_key_name=self.model_encryption_spec_key_name,
@@ -481,6 +514,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
481
514
  if model:
482
515
  result = Model.to_dict(model)
483
516
  model_id = self.hook.extract_model_id(result)
517
+ self.xcom_push(context, key="model_id", value=model_id)
484
518
  VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
485
519
  else:
486
520
  result = model # type: ignore
@@ -537,78 +571,96 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
537
571
  the network.
538
572
  :param model_description: The description of the Model.
539
573
  :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
540
- Storage describing the format of a single instance, which
541
- are used in
542
- ``PredictRequest.instances``,
543
- ``ExplainRequest.instances``
544
- and
545
- ``BatchPredictionJob.input_config``.
546
- The schema is defined as an OpenAPI 3.0.2 `Schema
547
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
548
- AutoML Models always have this field populated by AI
549
- Platform. Note: The URI given on output will be immutable
550
- and probably different, including the URI scheme, than the
551
- one given on input. The output URI will point to a location
552
- where the user only has a read access.
574
+ Storage describing the format of a single instance, which
575
+ are used in
576
+ ``PredictRequest.instances``,
577
+ ``ExplainRequest.instances``
578
+ and
579
+ ``BatchPredictionJob.input_config``.
580
+ The schema is defined as an OpenAPI 3.0.2 `Schema
581
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
582
+ AutoML Models always have this field populated by AI
583
+ Platform. Note: The URI given on output will be immutable
584
+ and probably different, including the URI scheme, than the
585
+ one given on input. The output URI will point to a location
586
+ where the user only has a read access.
553
587
  :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
554
- Storage describing the parameters of prediction and
555
- explanation via
556
- ``PredictRequest.parameters``,
557
- ``ExplainRequest.parameters``
558
- and
559
- ``BatchPredictionJob.model_parameters``.
560
- The schema is defined as an OpenAPI 3.0.2 `Schema
561
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
562
- AutoML Models always have this field populated by AI
563
- Platform, if no parameters are supported it is set to an
564
- empty string. Note: The URI given on output will be
565
- immutable and probably different, including the URI scheme,
566
- than the one given on input. The output URI will point to a
567
- location where the user only has a read access.
588
+ Storage describing the parameters of prediction and
589
+ explanation via
590
+ ``PredictRequest.parameters``,
591
+ ``ExplainRequest.parameters``
592
+ and
593
+ ``BatchPredictionJob.model_parameters``.
594
+ The schema is defined as an OpenAPI 3.0.2 `Schema
595
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
596
+ AutoML Models always have this field populated by AI
597
+ Platform, if no parameters are supported it is set to an
598
+ empty string. Note: The URI given on output will be
599
+ immutable and probably different, including the URI scheme,
600
+ than the one given on input. The output URI will point to a
601
+ location where the user only has a read access.
568
602
  :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
569
- Storage describing the format of a single prediction
570
- produced by this Model, which are returned via
571
- ``PredictResponse.predictions``,
572
- ``ExplainResponse.explanations``,
573
- and
574
- ``BatchPredictionJob.output_config``.
575
- The schema is defined as an OpenAPI 3.0.2 `Schema
576
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
577
- AutoML Models always have this field populated by AI
578
- Platform. Note: The URI given on output will be immutable
579
- and probably different, including the URI scheme, than the
580
- one given on input. The output URI will point to a location
581
- where the user only has a read access.
603
+ Storage describing the format of a single prediction
604
+ produced by this Model, which are returned via
605
+ ``PredictResponse.predictions``,
606
+ ``ExplainResponse.explanations``,
607
+ and
608
+ ``BatchPredictionJob.output_config``.
609
+ The schema is defined as an OpenAPI 3.0.2 `Schema
610
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
611
+ AutoML Models always have this field populated by AI
612
+ Platform. Note: The URI given on output will be immutable
613
+ and probably different, including the URI scheme, than the
614
+ one given on input. The output URI will point to a location
615
+ where the user only has a read access.
616
+ :param parent_model: Optional. The resource name or model ID of an existing model.
617
+ The new model uploaded by this job will be a version of `parent_model`.
618
+ Only set this field when training a new version of an existing model.
619
+ :param is_default_version: Optional. When set to True, the newly uploaded model version will
620
+ automatically have alias "default" included. Subsequent uses of
621
+ the model produced by this job without a version specified will
622
+ use this "default" version.
623
+ When set to False, the "default" alias will not be moved.
624
+ Actions targeting the model version produced by this job will need
625
+ to specifically reference this version by ID or alias.
626
+ New model uploads, i.e. version 1, will always be "default" aliased.
627
+ :param model_version_aliases: Optional. User provided version aliases so that the model version
628
+ uploaded by this job can be referenced via alias instead of
629
+ auto-generated version ID. A default version alias will be created
630
+ for the first version of the model.
631
+ The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
632
+ :param model_version_description: Optional. The description of the model version
633
+ being uploaded by this job.
582
634
  :param project_id: Project to run training in.
583
635
  :param region: Location to run training in.
584
636
  :param labels: Optional. The labels with user-defined metadata to
585
- organize TrainingPipelines.
586
- Label keys and values can be no longer than 64
587
- characters, can only
588
- contain lowercase letters, numeric characters,
589
- underscores and dashes. International characters
590
- are allowed.
591
- See https://goo.gl/xmQnxf for more information
592
- and examples of labels.
637
+ organize TrainingPipelines.
638
+ Label keys and values can be no longer than 64
639
+ characters, can only
640
+ contain lowercase letters, numeric characters,
641
+ underscores and dashes. International characters
642
+ are allowed.
643
+ See https://goo.gl/xmQnxf for more information
644
+ and examples of labels.
593
645
  :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
594
- managed encryption key used to protect the training pipeline. Has the
595
- form:
596
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
597
- The key needs to be in the same region as where the compute
598
- resource is created.
646
+ managed encryption key used to protect the training pipeline. Has the
647
+ form:
648
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
649
+ The key needs to be in the same region as where the compute
650
+ resource is created.
599
651
 
600
- If set, this TrainingPipeline will be secured by this key.
652
+ If set, this TrainingPipeline will be secured by this key.
601
653
 
602
- Note: Model trained by this TrainingPipeline is also secured
603
- by this key if ``model_to_upload`` is not set separately.
654
+ Note: Model trained by this TrainingPipeline is also secured
655
+ by this key if ``model_to_upload`` is not set separately.
604
656
  :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
605
- managed encryption key used to protect the model. Has the
606
- form:
607
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
608
- The key needs to be in the same region as where the compute
609
- resource is created.
657
+ managed encryption key used to protect the model. Has the
658
+ form:
659
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
660
+ The key needs to be in the same region as where the compute
661
+ resource is created.
610
662
 
611
- If set, the trained Model will be secured by this key.
663
+ If set, the trained Model will be secured by this key.
612
664
  :param staging_bucket: Bucket used to stage source and training artifacts.
613
665
  :param dataset: Vertex AI to fit this training against.
614
666
  :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
@@ -628,19 +680,19 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
628
680
  and
629
681
  ``annotation_schema_uri``.
630
682
  :param model_display_name: If the script produces a managed Vertex AI Model. The display name of
631
- the Model. The name can be up to 128 characters long and can be consist
632
- of any UTF-8 characters.
683
+ the Model. The name can be up to 128 characters long and can be consist
684
+ of any UTF-8 characters.
633
685
 
634
- If not provided upon creation, the job's display_name is used.
686
+ If not provided upon creation, the job's display_name is used.
635
687
  :param model_labels: Optional. The labels with user-defined metadata to
636
- organize your Models.
637
- Label keys and values can be no longer than 64
638
- characters, can only
639
- contain lowercase letters, numeric characters,
640
- underscores and dashes. International characters
641
- are allowed.
642
- See https://goo.gl/xmQnxf for more information
643
- and examples of labels.
688
+ organize your Models.
689
+ Label keys and values can be no longer than 64
690
+ characters, can only
691
+ contain lowercase letters, numeric characters,
692
+ underscores and dashes. International characters
693
+ are allowed.
694
+ See https://goo.gl/xmQnxf for more information
695
+ and examples of labels.
644
696
  :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
645
697
  staging directory will be used.
646
698
 
@@ -653,38 +705,38 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
653
705
  - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
654
706
  logs, i.e. <base_output_dir>/logs/
655
707
  :param service_account: Specifies the service account for workload run-as account.
656
- Users submitting jobs must have act-as permission on this run-as account.
708
+ Users submitting jobs must have act-as permission on this run-as account.
657
709
  :param network: The full name of the Compute Engine network to which the job
658
- should be peered.
659
- Private services access must already be configured for the network.
660
- If left unspecified, the job is not peered with any network.
710
+ should be peered.
711
+ Private services access must already be configured for the network.
712
+ If left unspecified, the job is not peered with any network.
661
713
  :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
662
- The BigQuery project location where the training data is to
663
- be written to. In the given project a new dataset is created
664
- with name
665
- ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
666
- where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
667
- training input data will be written into that dataset. In
668
- the dataset three tables will be created, ``training``,
669
- ``validation`` and ``test``.
670
-
671
- - AIP_DATA_FORMAT = "bigquery".
672
- - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
673
- - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
674
- - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
714
+ The BigQuery project location where the training data is to
715
+ be written to. In the given project a new dataset is created
716
+ with name
717
+ ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
718
+ where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
719
+ training input data will be written into that dataset. In
720
+ the dataset three tables will be created, ``training``,
721
+ ``validation`` and ``test``.
722
+
723
+ - AIP_DATA_FORMAT = "bigquery".
724
+ - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
725
+ - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
726
+ - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
675
727
  :param args: Command line arguments to be passed to the Python script.
676
728
  :param environment_variables: Environment variables to be passed to the container.
677
- Should be a dictionary where keys are environment variable names
678
- and values are environment variable values for those names.
679
- At most 10 environment variables can be specified.
680
- The Name of the environment variable must be unique.
729
+ Should be a dictionary where keys are environment variable names
730
+ and values are environment variable values for those names.
731
+ At most 10 environment variables can be specified.
732
+ The Name of the environment variable must be unique.
681
733
  :param replica_count: The number of worker replicas. If replica count = 1 then one chief
682
- replica will be provisioned. If replica_count > 1 the remainder will be
683
- provisioned as a worker replica pool.
734
+ replica will be provisioned. If replica_count > 1 the remainder will be
735
+ provisioned as a worker replica pool.
684
736
  :param machine_type: The type of machine to use for training.
685
737
  :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
686
- NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
687
- NVIDIA_TESLA_T4
738
+ NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
739
+ NVIDIA_TESLA_T4
688
740
  :param accelerator_count: The number of accelerators to attach to a worker replica.
689
741
  :param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
690
742
  Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
@@ -752,6 +804,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
752
804
  """
753
805
 
754
806
  template_fields = (
807
+ "parent_model",
755
808
  "region",
756
809
  "dataset_id",
757
810
  "impersonation_chain",
@@ -774,6 +827,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
774
827
  gcp_conn_id=self.gcp_conn_id,
775
828
  impersonation_chain=self.impersonation_chain,
776
829
  )
830
+ self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
777
831
  model, training_id, custom_job_id = self.hook.create_custom_python_package_training_job(
778
832
  project_id=self.project_id,
779
833
  region=self.region,
@@ -792,6 +846,10 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
792
846
  model_instance_schema_uri=self.model_instance_schema_uri,
793
847
  model_parameters_schema_uri=self.model_parameters_schema_uri,
794
848
  model_prediction_schema_uri=self.model_prediction_schema_uri,
849
+ parent_model=self.parent_model,
850
+ is_default_version=self.is_default_version,
851
+ model_version_aliases=self.model_version_aliases,
852
+ model_version_description=self.model_version_description,
795
853
  labels=self.labels,
796
854
  training_encryption_spec_key_name=self.training_encryption_spec_key_name,
797
855
  model_encryption_spec_key_name=self.model_encryption_spec_key_name,
@@ -828,6 +886,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
828
886
  if model:
829
887
  result = Model.to_dict(model)
830
888
  model_id = self.hook.extract_model_id(result)
889
+ self.xcom_push(context, key="model_id", value=model_id)
831
890
  VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
832
891
  else:
833
892
  result = model # type: ignore
@@ -884,78 +943,96 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
884
943
  the network.
885
944
  :param model_description: The description of the Model.
886
945
  :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
887
- Storage describing the format of a single instance, which
888
- are used in
889
- ``PredictRequest.instances``,
890
- ``ExplainRequest.instances``
891
- and
892
- ``BatchPredictionJob.input_config``.
893
- The schema is defined as an OpenAPI 3.0.2 `Schema
894
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
895
- AutoML Models always have this field populated by AI
896
- Platform. Note: The URI given on output will be immutable
897
- and probably different, including the URI scheme, than the
898
- one given on input. The output URI will point to a location
899
- where the user only has a read access.
946
+ Storage describing the format of a single instance, which
947
+ are used in
948
+ ``PredictRequest.instances``,
949
+ ``ExplainRequest.instances``
950
+ and
951
+ ``BatchPredictionJob.input_config``.
952
+ The schema is defined as an OpenAPI 3.0.2 `Schema
953
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
954
+ AutoML Models always have this field populated by AI
955
+ Platform. Note: The URI given on output will be immutable
956
+ and probably different, including the URI scheme, than the
957
+ one given on input. The output URI will point to a location
958
+ where the user only has a read access.
900
959
  :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
901
- Storage describing the parameters of prediction and
902
- explanation via
903
- ``PredictRequest.parameters``,
904
- ``ExplainRequest.parameters``
905
- and
906
- ``BatchPredictionJob.model_parameters``.
907
- The schema is defined as an OpenAPI 3.0.2 `Schema
908
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
909
- AutoML Models always have this field populated by AI
910
- Platform, if no parameters are supported it is set to an
911
- empty string. Note: The URI given on output will be
912
- immutable and probably different, including the URI scheme,
913
- than the one given on input. The output URI will point to a
914
- location where the user only has a read access.
960
+ Storage describing the parameters of prediction and
961
+ explanation via
962
+ ``PredictRequest.parameters``,
963
+ ``ExplainRequest.parameters``
964
+ and
965
+ ``BatchPredictionJob.model_parameters``.
966
+ The schema is defined as an OpenAPI 3.0.2 `Schema
967
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
968
+ AutoML Models always have this field populated by AI
969
+ Platform, if no parameters are supported it is set to an
970
+ empty string. Note: The URI given on output will be
971
+ immutable and probably different, including the URI scheme,
972
+ than the one given on input. The output URI will point to a
973
+ location where the user only has a read access.
915
974
  :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
916
- Storage describing the format of a single prediction
917
- produced by this Model, which are returned via
918
- ``PredictResponse.predictions``,
919
- ``ExplainResponse.explanations``,
920
- and
921
- ``BatchPredictionJob.output_config``.
922
- The schema is defined as an OpenAPI 3.0.2 `Schema
923
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
924
- AutoML Models always have this field populated by AI
925
- Platform. Note: The URI given on output will be immutable
926
- and probably different, including the URI scheme, than the
927
- one given on input. The output URI will point to a location
928
- where the user only has a read access.
975
+ Storage describing the format of a single prediction
976
+ produced by this Model, which are returned via
977
+ ``PredictResponse.predictions``,
978
+ ``ExplainResponse.explanations``,
979
+ and
980
+ ``BatchPredictionJob.output_config``.
981
+ The schema is defined as an OpenAPI 3.0.2 `Schema
982
+ Object <https://tinyurl.com/y538mdwt#schema-object>`__.
983
+ AutoML Models always have this field populated by AI
984
+ Platform. Note: The URI given on output will be immutable
985
+ and probably different, including the URI scheme, than the
986
+ one given on input. The output URI will point to a location
987
+ where the user only has a read access.
988
+ :param parent_model: Optional. The resource name or model ID of an existing model.
989
+ The new model uploaded by this job will be a version of `parent_model`.
990
+ Only set this field when training a new version of an existing model.
991
+ :param is_default_version: Optional. When set to True, the newly uploaded model version will
992
+ automatically have alias "default" included. Subsequent uses of
993
+ the model produced by this job without a version specified will
994
+ use this "default" version.
995
+ When set to False, the "default" alias will not be moved.
996
+ Actions targeting the model version produced by this job will need
997
+ to specifically reference this version by ID or alias.
998
+ New model uploads, i.e. version 1, will always be "default" aliased.
999
+ :param model_version_aliases: Optional. User provided version aliases so that the model version
1000
+ uploaded by this job can be referenced via alias instead of
1001
+ auto-generated version ID. A default version alias will be created
1002
+ for the first version of the model.
1003
+ The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
1004
+ :param model_version_description: Optional. The description of the model version
1005
+ being uploaded by this job.
929
1006
  :param project_id: Project to run training in.
930
1007
  :param region: Location to run training in.
931
1008
  :param labels: Optional. The labels with user-defined metadata to
932
- organize TrainingPipelines.
933
- Label keys and values can be no longer than 64
934
- characters, can only
935
- contain lowercase letters, numeric characters,
936
- underscores and dashes. International characters
937
- are allowed.
938
- See https://goo.gl/xmQnxf for more information
939
- and examples of labels.
1009
+ organize TrainingPipelines.
1010
+ Label keys and values can be no longer than 64
1011
+ characters, can only
1012
+ contain lowercase letters, numeric characters,
1013
+ underscores and dashes. International characters
1014
+ are allowed.
1015
+ See https://goo.gl/xmQnxf for more information
1016
+ and examples of labels.
940
1017
  :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
941
- managed encryption key used to protect the training pipeline. Has the
942
- form:
943
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
944
- The key needs to be in the same region as where the compute
945
- resource is created.
1018
+ managed encryption key used to protect the training pipeline. Has the
1019
+ form:
1020
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
1021
+ The key needs to be in the same region as where the compute
1022
+ resource is created.
946
1023
 
947
- If set, this TrainingPipeline will be secured by this key.
1024
+ If set, this TrainingPipeline will be secured by this key.
948
1025
 
949
- Note: Model trained by this TrainingPipeline is also secured
950
- by this key if ``model_to_upload`` is not set separately.
1026
+ Note: Model trained by this TrainingPipeline is also secured
1027
+ by this key if ``model_to_upload`` is not set separately.
951
1028
  :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
952
- managed encryption key used to protect the model. Has the
953
- form:
954
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
955
- The key needs to be in the same region as where the compute
956
- resource is created.
1029
+ managed encryption key used to protect the model. Has the
1030
+ form:
1031
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
1032
+ The key needs to be in the same region as where the compute
1033
+ resource is created.
957
1034
 
958
- If set, the trained Model will be secured by this key.
1035
+ If set, the trained Model will be secured by this key.
959
1036
  :param staging_bucket: Bucket used to stage source and training artifacts.
960
1037
  :param dataset: Vertex AI to fit this training against.
961
1038
  :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
@@ -1101,6 +1178,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1101
1178
  template_fields = (
1102
1179
  "region",
1103
1180
  "script_path",
1181
+ "parent_model",
1104
1182
  "requirements",
1105
1183
  "dataset_id",
1106
1184
  "impersonation_chain",
@@ -1123,6 +1201,8 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1123
1201
  gcp_conn_id=self.gcp_conn_id,
1124
1202
  impersonation_chain=self.impersonation_chain,
1125
1203
  )
1204
+ self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
1205
+
1126
1206
  model, training_id, custom_job_id = self.hook.create_custom_training_job(
1127
1207
  project_id=self.project_id,
1128
1208
  region=self.region,
@@ -1141,6 +1221,10 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1141
1221
  model_instance_schema_uri=self.model_instance_schema_uri,
1142
1222
  model_parameters_schema_uri=self.model_parameters_schema_uri,
1143
1223
  model_prediction_schema_uri=self.model_prediction_schema_uri,
1224
+ parent_model=self.parent_model,
1225
+ is_default_version=self.is_default_version,
1226
+ model_version_aliases=self.model_version_aliases,
1227
+ model_version_description=self.model_version_description,
1144
1228
  labels=self.labels,
1145
1229
  training_encryption_spec_key_name=self.training_encryption_spec_key_name,
1146
1230
  model_encryption_spec_key_name=self.model_encryption_spec_key_name,
@@ -1177,6 +1261,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
1177
1261
  if model:
1178
1262
  result = Model.to_dict(model)
1179
1263
  model_id = self.hook.extract_model_id(result)
1264
+ self.xcom_push(context, key="model_id", value=model_id)
1180
1265
  VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1181
1266
  else:
1182
1267
  result = model # type: ignore
@@ -1276,7 +1361,8 @@ class DeleteCustomTrainingJobOperator(GoogleCloudBaseOperator):
1276
1361
 
1277
1362
 
1278
1363
  class ListCustomTrainingJobOperator(GoogleCloudBaseOperator):
1279
- """Lists CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob in a Location.
1364
+ """
1365
+ Lists CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob in a Location.
1280
1366
 
1281
1367
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
1282
1368
  :param region: Required. The ID of the Google Cloud region that the service belongs to.