apache-airflow-providers-google 10.10.0__py3-none-any.whl → 10.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/cloud/hooks/cloud_run.py +4 -2
  3. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +131 -27
  4. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +1 -9
  5. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +121 -4
  6. airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +1 -11
  7. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -10
  8. airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +220 -6
  9. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +409 -0
  10. airflow/providers/google/cloud/links/vertex_ai.py +49 -0
  11. airflow/providers/google/cloud/operators/dataproc.py +32 -10
  12. airflow/providers/google/cloud/operators/gcs.py +1 -1
  13. airflow/providers/google/cloud/operators/mlengine.py +116 -0
  14. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +45 -0
  15. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +2 -8
  16. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +287 -201
  17. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +1 -9
  18. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +2 -9
  19. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +451 -12
  20. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +464 -0
  21. airflow/providers/google/cloud/utils/mlengine_operator_utils.py +7 -1
  22. airflow/providers/google/get_provider_info.py +5 -0
  23. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/METADATA +6 -6
  24. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/RECORD +29 -27
  25. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/LICENSE +0 -0
  26. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/NOTICE +0 -0
  27. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/WHEEL +0 -0
  28. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/entry_points.txt +0 -0
  29. {apache_airflow_providers_google-10.10.0.dist-info → apache_airflow_providers_google-10.10.1.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,10 @@ VERTEX_AI_ENDPOINT_LINK = (
49
49
  VERTEX_AI_BASE_LINK + "/locations/{region}/endpoints/{endpoint_id}?project={project_id}"
50
50
  )
51
51
  VERTEX_AI_ENDPOINT_LIST_LINK = VERTEX_AI_BASE_LINK + "/endpoints?project={project_id}"
52
+ VERTEX_AI_PIPELINE_JOB_LINK = (
53
+ VERTEX_AI_BASE_LINK + "/locations/{region}/pipelines/runs/{pipeline_id}?project={project_id}"
54
+ )
55
+ VERTEX_AI_PIPELINE_JOB_LIST_LINK = VERTEX_AI_BASE_LINK + "/pipelines/runs?project={project_id}"
52
56
 
53
57
 
54
58
  class VertexAIModelLink(BaseGoogleLink):
@@ -319,3 +323,48 @@ class VertexAIEndpointListLink(BaseGoogleLink):
319
323
  "project_id": task_instance.project_id,
320
324
  },
321
325
  )
326
+
327
+
328
+ class VertexAIPipelineJobLink(BaseGoogleLink):
329
+ """Helper class for constructing Vertex AI PipelineJob link."""
330
+
331
+ name = "Pipeline Job"
332
+ key = "pipeline_job_conf"
333
+ format_str = VERTEX_AI_PIPELINE_JOB_LINK
334
+
335
+ @staticmethod
336
+ def persist(
337
+ context: Context,
338
+ task_instance,
339
+ pipeline_id: str,
340
+ ):
341
+ task_instance.xcom_push(
342
+ context=context,
343
+ key=VertexAIPipelineJobLink.key,
344
+ value={
345
+ "pipeline_id": pipeline_id,
346
+ "region": task_instance.region,
347
+ "project_id": task_instance.project_id,
348
+ },
349
+ )
350
+
351
+
352
+ class VertexAIPipelineJobListLink(BaseGoogleLink):
353
+ """Helper class for constructing Vertex AI PipelineJobList link."""
354
+
355
+ name = "Pipeline Job List"
356
+ key = "pipeline_job_list_conf"
357
+ format_str = VERTEX_AI_PIPELINE_JOB_LIST_LINK
358
+
359
+ @staticmethod
360
+ def persist(
361
+ context: Context,
362
+ task_instance,
363
+ ):
364
+ task_instance.xcom_push(
365
+ context=context,
366
+ key=VertexAIPipelineJobListLink.key,
367
+ value={
368
+ "project_id": task_instance.project_id,
369
+ },
370
+ )
@@ -1790,6 +1790,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
1790
1790
  account from the list granting this role to the originating account (templated).
1791
1791
  :param deferrable: Run operator in the deferrable mode.
1792
1792
  :param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
1793
+ :param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called
1793
1794
  """
1794
1795
 
1795
1796
  template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters")
@@ -1812,6 +1813,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
1812
1813
  impersonation_chain: str | Sequence[str] | None = None,
1813
1814
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1814
1815
  polling_interval_seconds: int = 10,
1816
+ cancel_on_kill: bool = True,
1815
1817
  **kwargs,
1816
1818
  ) -> None:
1817
1819
  super().__init__(**kwargs)
@@ -1830,6 +1832,8 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
1830
1832
  self.impersonation_chain = impersonation_chain
1831
1833
  self.deferrable = deferrable
1832
1834
  self.polling_interval_seconds = polling_interval_seconds
1835
+ self.cancel_on_kill = cancel_on_kill
1836
+ self.operation_name: str | None = None
1833
1837
 
1834
1838
  def execute(self, context: Context):
1835
1839
  hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
@@ -1845,24 +1849,26 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
1845
1849
  timeout=self.timeout,
1846
1850
  metadata=self.metadata,
1847
1851
  )
1848
- self.workflow_id = operation.operation.name.split("/")[-1]
1852
+ operation_name = operation.operation.name
1853
+ self.operation_name = operation_name
1854
+ workflow_id = operation_name.split("/")[-1]
1849
1855
  project_id = self.project_id or hook.project_id
1850
1856
  if project_id:
1851
1857
  DataprocWorkflowLink.persist(
1852
1858
  context=context,
1853
1859
  operator=self,
1854
- workflow_id=self.workflow_id,
1860
+ workflow_id=workflow_id,
1855
1861
  region=self.region,
1856
1862
  project_id=project_id,
1857
1863
  )
1858
- self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
1864
+ self.log.info("Template instantiated. Workflow Id : %s", workflow_id)
1859
1865
  if not self.deferrable:
1860
1866
  hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
1861
- self.log.info("Workflow %s completed successfully", self.workflow_id)
1867
+ self.log.info("Workflow %s completed successfully", workflow_id)
1862
1868
  else:
1863
1869
  self.defer(
1864
1870
  trigger=DataprocWorkflowTrigger(
1865
- name=operation.operation.name,
1871
+ name=operation_name,
1866
1872
  project_id=self.project_id,
1867
1873
  region=self.region,
1868
1874
  gcp_conn_id=self.gcp_conn_id,
@@ -1884,6 +1890,11 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
1884
1890
 
1885
1891
  self.log.info("Workflow %s completed successfully", event["operation_name"])
1886
1892
 
1893
+ def on_kill(self) -> None:
1894
+ if self.cancel_on_kill and self.operation_name:
1895
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
1896
+ hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)
1897
+
1887
1898
 
1888
1899
  class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator):
1889
1900
  """Instantiate a WorkflowTemplate Inline on Google Cloud Dataproc.
@@ -1926,6 +1937,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
1926
1937
  account from the list granting this role to the originating account (templated).
1927
1938
  :param deferrable: Run operator in the deferrable mode.
1928
1939
  :param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
1940
+ :param cancel_on_kill: Flag which indicates whether cancel the workflow, when on_kill is called
1929
1941
  """
1930
1942
 
1931
1943
  template_fields: Sequence[str] = ("template", "impersonation_chain")
@@ -1946,6 +1958,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
1946
1958
  impersonation_chain: str | Sequence[str] | None = None,
1947
1959
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1948
1960
  polling_interval_seconds: int = 10,
1961
+ cancel_on_kill: bool = True,
1949
1962
  **kwargs,
1950
1963
  ) -> None:
1951
1964
  super().__init__(**kwargs)
@@ -1963,6 +1976,8 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
1963
1976
  self.impersonation_chain = impersonation_chain
1964
1977
  self.deferrable = deferrable
1965
1978
  self.polling_interval_seconds = polling_interval_seconds
1979
+ self.cancel_on_kill = cancel_on_kill
1980
+ self.operation_name: str | None = None
1966
1981
 
1967
1982
  def execute(self, context: Context):
1968
1983
  self.log.info("Instantiating Inline Template")
@@ -1977,23 +1992,25 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
1977
1992
  timeout=self.timeout,
1978
1993
  metadata=self.metadata,
1979
1994
  )
1980
- self.workflow_id = operation.operation.name.split("/")[-1]
1995
+ operation_name = operation.operation.name
1996
+ self.operation_name = operation_name
1997
+ workflow_id = operation_name.split("/")[-1]
1981
1998
  if project_id:
1982
1999
  DataprocWorkflowLink.persist(
1983
2000
  context=context,
1984
2001
  operator=self,
1985
- workflow_id=self.workflow_id,
2002
+ workflow_id=workflow_id,
1986
2003
  region=self.region,
1987
2004
  project_id=project_id,
1988
2005
  )
1989
2006
  if not self.deferrable:
1990
- self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
2007
+ self.log.info("Template instantiated. Workflow Id : %s", workflow_id)
1991
2008
  operation.result()
1992
- self.log.info("Workflow %s completed successfully", self.workflow_id)
2009
+ self.log.info("Workflow %s completed successfully", workflow_id)
1993
2010
  else:
1994
2011
  self.defer(
1995
2012
  trigger=DataprocWorkflowTrigger(
1996
- name=operation.operation.name,
2013
+ name=operation_name,
1997
2014
  project_id=self.project_id or hook.project_id,
1998
2015
  region=self.region,
1999
2016
  gcp_conn_id=self.gcp_conn_id,
@@ -2015,6 +2032,11 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
2015
2032
 
2016
2033
  self.log.info("Workflow %s completed successfully", event["operation_name"])
2017
2034
 
2035
+ def on_kill(self) -> None:
2036
+ if self.cancel_on_kill and self.operation_name:
2037
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
2038
+ hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)
2039
+
2018
2040
 
2019
2041
  class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
2020
2042
  """Submit a job to a cluster.
@@ -919,7 +919,7 @@ class GCSSynchronizeBucketsOperator(GoogleCloudBaseOperator):
919
919
 
920
920
  .. seealso::
921
921
  For more information on how to use this operator, take a look at the guide:
922
- :ref:`howto/operator:GCSSynchronizeBuckets`
922
+ :ref:`howto/operator:GCSSynchronizeBucketsOperator`
923
923
 
924
924
  :param source_bucket: The name of the bucket containing the source objects.
925
925
  :param destination_bucket: The name of the bucket containing the destination objects.
@@ -82,6 +82,10 @@ class MLEngineStartBatchPredictionJobOperator(GoogleCloudBaseOperator):
82
82
  """
83
83
  Start a Google Cloud ML Engine prediction job.
84
84
 
85
+ This operator is deprecated. Please use
86
+ :class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction.CreateBatchPredictionJobOperator`
87
+ instead.
88
+
85
89
  .. seealso::
86
90
  For more information on how to use this operator, take a look at the guide:
87
91
  :ref:`howto/operator:MLEngineStartBatchPredictionJobOperator`
@@ -210,6 +214,14 @@ class MLEngineStartBatchPredictionJobOperator(GoogleCloudBaseOperator):
210
214
  self._labels = labels
211
215
  self._impersonation_chain = impersonation_chain
212
216
 
217
+ warnings.warn(
218
+ "This operator is deprecated. All the functionality of legacy "
219
+ "MLEngine and new features are available on the Vertex AI platform. "
220
+ "Please use `CreateBatchPredictionJobOperator`",
221
+ AirflowProviderDeprecationWarning,
222
+ stacklevel=3,
223
+ )
224
+
213
225
  if not self._project_id:
214
226
  raise AirflowException("Google Cloud project id is required.")
215
227
  if not self._job_id:
@@ -364,6 +376,9 @@ class MLEngineCreateModelOperator(GoogleCloudBaseOperator):
364
376
  """
365
377
  Creates a new model.
366
378
 
379
+ This operator is deprecated. Please use appropriate VertexAI operator from
380
+ :class:`airflow.providers.google.cloud.operators.vertex_ai` instead.
381
+
367
382
  .. seealso::
368
383
  For more information on how to use this operator, take a look at the guide:
369
384
  :ref:`howto/operator:MLEngineCreateModelOperator`
@@ -407,6 +422,14 @@ class MLEngineCreateModelOperator(GoogleCloudBaseOperator):
407
422
  self._gcp_conn_id = gcp_conn_id
408
423
  self._impersonation_chain = impersonation_chain
409
424
 
425
+ warnings.warn(
426
+ "This operator is deprecated. All the functionality of legacy "
427
+ "MLEngine and new features are available on the Vertex AI platform. "
428
+ "Please use appropriate VertexAI operator.",
429
+ AirflowProviderDeprecationWarning,
430
+ stacklevel=3,
431
+ )
432
+
410
433
  def execute(self, context: Context):
411
434
  hook = MLEngineHook(
412
435
  gcp_conn_id=self._gcp_conn_id,
@@ -429,6 +452,9 @@ class MLEngineGetModelOperator(GoogleCloudBaseOperator):
429
452
  """
430
453
  Gets a particular model.
431
454
 
455
+ This operator is deprecated. Please use
456
+ :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.GetModelOperator` instead.
457
+
432
458
  .. seealso::
433
459
  For more information on how to use this operator, take a look at the guide:
434
460
  :ref:`howto/operator:MLEngineGetModelOperator`
@@ -472,6 +498,14 @@ class MLEngineGetModelOperator(GoogleCloudBaseOperator):
472
498
  self._gcp_conn_id = gcp_conn_id
473
499
  self._impersonation_chain = impersonation_chain
474
500
 
501
+ warnings.warn(
502
+ "This operator is deprecated. All the functionality of legacy "
503
+ "MLEngine and new features are available on the Vertex AI platform. "
504
+ "Please use `GetModelOperator`",
505
+ AirflowProviderDeprecationWarning,
506
+ stacklevel=3,
507
+ )
508
+
475
509
  def execute(self, context: Context):
476
510
  hook = MLEngineHook(
477
511
  gcp_conn_id=self._gcp_conn_id,
@@ -493,6 +527,10 @@ class MLEngineDeleteModelOperator(GoogleCloudBaseOperator):
493
527
  """
494
528
  Deletes a model.
495
529
 
530
+ This operator is deprecated. Please use
531
+ :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelOperator` instead.
532
+
533
+
496
534
  .. seealso::
497
535
  For more information on how to use this operator, take a look at the guide:
498
536
  :ref:`howto/operator:MLEngineDeleteModelOperator`
@@ -541,6 +579,14 @@ class MLEngineDeleteModelOperator(GoogleCloudBaseOperator):
541
579
  self._gcp_conn_id = gcp_conn_id
542
580
  self._impersonation_chain = impersonation_chain
543
581
 
582
+ warnings.warn(
583
+ "This operator is deprecated. All the functionality of legacy "
584
+ "MLEngine and new features are available on the Vertex AI platform. "
585
+ "Please use `DeleteModelOperator`",
586
+ AirflowProviderDeprecationWarning,
587
+ stacklevel=3,
588
+ )
589
+
544
590
  def execute(self, context: Context):
545
591
  hook = MLEngineHook(
546
592
  gcp_conn_id=self._gcp_conn_id,
@@ -682,6 +728,8 @@ class MLEngineCreateVersionOperator(GoogleCloudBaseOperator):
682
728
  """
683
729
  Creates a new version in the model.
684
730
 
731
+ This operator is deprecated. Please use parent_model parameter of VertexAI operators instead.
732
+
685
733
  .. seealso::
686
734
  For more information on how to use this operator, take a look at the guide:
687
735
  :ref:`howto/operator:MLEngineCreateVersionOperator`
@@ -731,6 +779,14 @@ class MLEngineCreateVersionOperator(GoogleCloudBaseOperator):
731
779
  self._impersonation_chain = impersonation_chain
732
780
  self._validate_inputs()
733
781
 
782
+ warnings.warn(
783
+ "This operator is deprecated. All the functionality of legacy "
784
+ "MLEngine and new features are available on the Vertex AI platform. "
785
+ "Please use parent_model parameter for VertexAI operators instead.",
786
+ AirflowProviderDeprecationWarning,
787
+ stacklevel=3,
788
+ )
789
+
734
790
  def _validate_inputs(self):
735
791
  if not self._model_name:
736
792
  raise AirflowException("The model_name parameter could not be empty.")
@@ -763,6 +819,10 @@ class MLEngineSetDefaultVersionOperator(GoogleCloudBaseOperator):
763
819
  """
764
820
  Sets a version in the model.
765
821
 
822
+ This operator is deprecated. Please use
823
+ :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.SetDefaultVersionOnModelOperator`
824
+ instead.
825
+
766
826
  .. seealso::
767
827
  For more information on how to use this operator, take a look at the guide:
768
828
  :ref:`howto/operator:MLEngineSetDefaultVersionOperator`
@@ -812,6 +872,14 @@ class MLEngineSetDefaultVersionOperator(GoogleCloudBaseOperator):
812
872
  self._impersonation_chain = impersonation_chain
813
873
  self._validate_inputs()
814
874
 
875
+ warnings.warn(
876
+ "This operator is deprecated. All the functionality of legacy "
877
+ "MLEngine and new features are available on the Vertex AI platform. "
878
+ "Please use `SetDefaultVersionOnModelOperator` instead.",
879
+ AirflowProviderDeprecationWarning,
880
+ stacklevel=3,
881
+ )
882
+
815
883
  def _validate_inputs(self):
816
884
  if not self._model_name:
817
885
  raise AirflowException("The model_name parameter could not be empty.")
@@ -844,6 +912,10 @@ class MLEngineListVersionsOperator(GoogleCloudBaseOperator):
844
912
  """
845
913
  Lists all available versions of the model.
846
914
 
915
+ This operator is deprecated. Please use
916
+ :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.ListModelVersionsOperator`
917
+ instead.
918
+
847
919
  .. seealso::
848
920
  For more information on how to use this operator, take a look at the guide:
849
921
  :ref:`howto/operator:MLEngineListVersionsOperator`
@@ -889,6 +961,14 @@ class MLEngineListVersionsOperator(GoogleCloudBaseOperator):
889
961
  self._impersonation_chain = impersonation_chain
890
962
  self._validate_inputs()
891
963
 
964
+ warnings.warn(
965
+ "This operator is deprecated. All the functionality of legacy "
966
+ "MLEngine and new features are available on the Vertex AI platform. "
967
+ "Please use `ListModelVersionsOperator` instead.",
968
+ AirflowProviderDeprecationWarning,
969
+ stacklevel=3,
970
+ )
971
+
892
972
  def _validate_inputs(self):
893
973
  if not self._model_name:
894
974
  raise AirflowException("The model_name parameter could not be empty.")
@@ -918,6 +998,10 @@ class MLEngineDeleteVersionOperator(GoogleCloudBaseOperator):
918
998
  """
919
999
  Deletes the version from the model.
920
1000
 
1001
+ This operator is deprecated. Please use
1002
+ :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelVersionOperator`
1003
+ instead.
1004
+
921
1005
  .. seealso::
922
1006
  For more information on how to use this operator, take a look at the guide:
923
1007
  :ref:`howto/operator:MLEngineDeleteVersionOperator`
@@ -967,6 +1051,14 @@ class MLEngineDeleteVersionOperator(GoogleCloudBaseOperator):
967
1051
  self._impersonation_chain = impersonation_chain
968
1052
  self._validate_inputs()
969
1053
 
1054
+ warnings.warn(
1055
+ "This operator is deprecated. All the functionality of legacy "
1056
+ "MLEngine and new features are available on the Vertex AI platform. "
1057
+ "Please use `DeleteModelVersionOperator` instead.",
1058
+ AirflowProviderDeprecationWarning,
1059
+ stacklevel=3,
1060
+ )
1061
+
970
1062
  def _validate_inputs(self):
971
1063
  if not self._model_name:
972
1064
  raise AirflowException("The model_name parameter could not be empty.")
@@ -998,6 +1090,10 @@ class MLEngineStartTrainingJobOperator(GoogleCloudBaseOperator):
998
1090
  """
999
1091
  Operator for launching a MLEngine training job.
1000
1092
 
1093
+ This operator is deprecated. Please use
1094
+ :class:`airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator`
1095
+ instead.
1096
+
1001
1097
  .. seealso::
1002
1098
  For more information on how to use this operator, take a look at the guide:
1003
1099
  :ref:`howto/operator:MLEngineStartTrainingJobOperator`
@@ -1124,6 +1220,14 @@ class MLEngineStartTrainingJobOperator(GoogleCloudBaseOperator):
1124
1220
  self.deferrable = deferrable
1125
1221
  self.cancel_on_kill = cancel_on_kill
1126
1222
 
1223
+ warnings.warn(
1224
+ "This operator is deprecated. All the functionality of legacy "
1225
+ "MLEngine and new features are available on the Vertex AI platform. "
1226
+ "Please use `CreateCustomPythonPackageTrainingJobOperator` instead.",
1227
+ AirflowProviderDeprecationWarning,
1228
+ stacklevel=3,
1229
+ )
1230
+
1127
1231
  custom = self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM"
1128
1232
  custom_image = (
1129
1233
  custom
@@ -1328,6 +1432,10 @@ class MLEngineTrainingCancelJobOperator(GoogleCloudBaseOperator):
1328
1432
  """
1329
1433
  Operator for cleaning up failed MLEngine training job.
1330
1434
 
1435
+ This operator is deprecated. Please use
1436
+ :class:`airflow.providers.google.cloud.operators.vertex_ai.custom_job.CancelCustomTrainingJobOperator`
1437
+ instead.
1438
+
1331
1439
  :param job_id: A unique templated id for the submitted Google MLEngine
1332
1440
  training job. (templated)
1333
1441
  :param project_id: The Google Cloud project name within which MLEngine training job should run.
@@ -1366,6 +1474,14 @@ class MLEngineTrainingCancelJobOperator(GoogleCloudBaseOperator):
1366
1474
  self._gcp_conn_id = gcp_conn_id
1367
1475
  self._impersonation_chain = impersonation_chain
1368
1476
 
1477
+ warnings.warn(
1478
+ "This operator is deprecated. All the functionality of legacy "
1479
+ "MLEngine and new features are available on the Vertex AI platform. "
1480
+ "Please use `CancelCustomTrainingJobOperator` instead.",
1481
+ AirflowProviderDeprecationWarning,
1482
+ stacklevel=3,
1483
+ )
1484
+
1369
1485
  if not self._project_id:
1370
1486
  raise AirflowException("Google Cloud project id is required.")
1371
1487
 
@@ -15,7 +15,9 @@
15
15
  # KIND, either express or implied. See the License for the
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
+
18
19
  """This module contains Google Vertex AI operators."""
20
+
19
21
  from __future__ import annotations
20
22
 
21
23
  from typing import TYPE_CHECKING, Sequence
@@ -50,6 +52,10 @@ class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
50
52
  region: str,
51
53
  display_name: str,
52
54
  labels: dict[str, str] | None = None,
55
+ parent_model: str | None = None,
56
+ is_default_version: bool | None = None,
57
+ model_version_aliases: list[str] | None = None,
58
+ model_version_description: str | None = None,
53
59
  training_encryption_spec_key_name: str | None = None,
54
60
  model_encryption_spec_key_name: str | None = None,
55
61
  # RUN
@@ -67,6 +73,10 @@ class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator):
67
73
  self.region = region
68
74
  self.display_name = display_name
69
75
  self.labels = labels
76
+ self.parent_model = parent_model
77
+ self.is_default_version = is_default_version
78
+ self.model_version_aliases = model_version_aliases
79
+ self.model_version_description = model_version_description
70
80
  self.training_encryption_spec_key_name = training_encryption_spec_key_name
71
81
  self.model_encryption_spec_key_name = model_encryption_spec_key_name
72
82
  # START Run param
@@ -90,6 +100,7 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
90
100
  """Create AutoML Forecasting Training job."""
91
101
 
92
102
  template_fields = (
103
+ "parent_model",
93
104
  "dataset_id",
94
105
  "region",
95
106
  "impersonation_chain",
@@ -158,11 +169,16 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
158
169
  gcp_conn_id=self.gcp_conn_id,
159
170
  impersonation_chain=self.impersonation_chain,
160
171
  )
172
+ self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
161
173
  model, training_id = self.hook.create_auto_ml_forecasting_training_job(
162
174
  project_id=self.project_id,
163
175
  region=self.region,
164
176
  display_name=self.display_name,
165
177
  dataset=datasets.TimeSeriesDataset(dataset_name=self.dataset_id),
178
+ parent_model=self.parent_model,
179
+ is_default_version=self.is_default_version,
180
+ model_version_aliases=self.model_version_aliases,
181
+ model_version_description=self.model_version_description,
166
182
  target_column=self.target_column,
167
183
  time_column=self.time_column,
168
184
  time_series_identifier_column=self.time_series_identifier_column,
@@ -202,6 +218,7 @@ class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
202
218
  if model:
203
219
  result = Model.to_dict(model)
204
220
  model_id = self.hook.extract_model_id(result)
221
+ self.xcom_push(context, key="model_id", value=model_id)
205
222
  VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
206
223
  else:
207
224
  result = model # type: ignore
@@ -214,6 +231,7 @@ class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
214
231
  """Create Auto ML Image Training job."""
215
232
 
216
233
  template_fields = (
234
+ "parent_model",
217
235
  "dataset_id",
218
236
  "region",
219
237
  "impersonation_chain",
@@ -254,11 +272,16 @@ class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
254
272
  gcp_conn_id=self.gcp_conn_id,
255
273
  impersonation_chain=self.impersonation_chain,
256
274
  )
275
+ self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
257
276
  model, training_id = self.hook.create_auto_ml_image_training_job(
258
277
  project_id=self.project_id,
259
278
  region=self.region,
260
279
  display_name=self.display_name,
261
280
  dataset=datasets.ImageDataset(dataset_name=self.dataset_id),
281
+ parent_model=self.parent_model,
282
+ is_default_version=self.is_default_version,
283
+ model_version_aliases=self.model_version_aliases,
284
+ model_version_description=self.model_version_description,
262
285
  prediction_type=self.prediction_type,
263
286
  multi_label=self.multi_label,
264
287
  model_type=self.model_type,
@@ -282,6 +305,7 @@ class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
282
305
  if model:
283
306
  result = Model.to_dict(model)
284
307
  model_id = self.hook.extract_model_id(result)
308
+ self.xcom_push(context, key="model_id", value=model_id)
285
309
  VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
286
310
  else:
287
311
  result = model # type: ignore
@@ -294,6 +318,7 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
294
318
  """Create Auto ML Tabular Training job."""
295
319
 
296
320
  template_fields = (
321
+ "parent_model",
297
322
  "dataset_id",
298
323
  "region",
299
324
  "impersonation_chain",
@@ -351,6 +376,7 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
351
376
  impersonation_chain=self.impersonation_chain,
352
377
  )
353
378
  credentials, _ = self.hook.get_credentials_and_project_id()
379
+ self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
354
380
  model, training_id = self.hook.create_auto_ml_tabular_training_job(
355
381
  project_id=self.project_id,
356
382
  region=self.region,
@@ -360,6 +386,10 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
360
386
  project=self.project_id,
361
387
  credentials=credentials,
362
388
  ),
389
+ parent_model=self.parent_model,
390
+ is_default_version=self.is_default_version,
391
+ model_version_aliases=self.model_version_aliases,
392
+ model_version_description=self.model_version_description,
363
393
  target_column=self.target_column,
364
394
  optimization_prediction_type=self.optimization_prediction_type,
365
395
  optimization_objective=self.optimization_objective,
@@ -393,6 +423,7 @@ class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
393
423
  if model:
394
424
  result = Model.to_dict(model)
395
425
  model_id = self.hook.extract_model_id(result)
426
+ self.xcom_push(context, key="model_id", value=model_id)
396
427
  VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
397
428
  else:
398
429
  result = model # type: ignore
@@ -405,6 +436,7 @@ class CreateAutoMLTextTrainingJobOperator(AutoMLTrainingJobBaseOperator):
405
436
  """Create Auto ML Text Training job."""
406
437
 
407
438
  template_fields = [
439
+ "parent_model",
408
440
  "dataset_id",
409
441
  "region",
410
442
  "impersonation_chain",
@@ -439,6 +471,7 @@ class CreateAutoMLTextTrainingJobOperator(AutoMLTrainingJobBaseOperator):
439
471
  gcp_conn_id=self.gcp_conn_id,
440
472
  impersonation_chain=self.impersonation_chain,
441
473
  )
474
+ self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
442
475
  model, training_id = self.hook.create_auto_ml_text_training_job(
443
476
  project_id=self.project_id,
444
477
  region=self.region,
@@ -459,11 +492,16 @@ class CreateAutoMLTextTrainingJobOperator(AutoMLTrainingJobBaseOperator):
459
492
  model_display_name=self.model_display_name,
460
493
  model_labels=self.model_labels,
461
494
  sync=self.sync,
495
+ parent_model=self.parent_model,
496
+ is_default_version=self.is_default_version,
497
+ model_version_aliases=self.model_version_aliases,
498
+ model_version_description=self.model_version_description,
462
499
  )
463
500
 
464
501
  if model:
465
502
  result = Model.to_dict(model)
466
503
  model_id = self.hook.extract_model_id(result)
504
+ self.xcom_push(context, key="model_id", value=model_id)
467
505
  VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
468
506
  else:
469
507
  result = model # type: ignore
@@ -476,6 +514,7 @@ class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
476
514
  """Create Auto ML Video Training job."""
477
515
 
478
516
  template_fields = (
517
+ "parent_model",
479
518
  "dataset_id",
480
519
  "region",
481
520
  "impersonation_chain",
@@ -504,6 +543,7 @@ class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
504
543
  gcp_conn_id=self.gcp_conn_id,
505
544
  impersonation_chain=self.impersonation_chain,
506
545
  )
546
+ self.parent_model = self.parent_model.rpartition("@")[0] if self.parent_model else None
507
547
  model, training_id = self.hook.create_auto_ml_video_training_job(
508
548
  project_id=self.project_id,
509
549
  region=self.region,
@@ -521,11 +561,16 @@ class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
521
561
  model_display_name=self.model_display_name,
522
562
  model_labels=self.model_labels,
523
563
  sync=self.sync,
564
+ parent_model=self.parent_model,
565
+ is_default_version=self.is_default_version,
566
+ model_version_aliases=self.model_version_aliases,
567
+ model_version_description=self.model_version_description,
524
568
  )
525
569
 
526
570
  if model:
527
571
  result = Model.to_dict(model)
528
572
  model_id = self.hook.extract_model_id(result)
573
+ self.xcom_push(context, key="model_id", value=model_id)
529
574
  VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
530
575
  else:
531
576
  result = model # type: ignore
@@ -15,15 +15,9 @@
15
15
  # KIND, either express or implied. See the License for the
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
- """This module contains Google Vertex AI operators.
19
18
 
20
- .. spelling:word-list::
19
+ """This module contains Google Vertex AI operators."""
21
20
 
22
- jsonl
23
- codepoints
24
- aiplatform
25
- gapic
26
- """
27
21
  from __future__ import annotations
28
22
 
29
23
  from typing import TYPE_CHECKING, Sequence
@@ -54,7 +48,7 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
54
48
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
55
49
  :param batch_prediction_job: Required. The BatchPredictionJob to create.
56
50
  :param job_display_name: Required. The user-defined name of the BatchPredictionJob. The name can be
57
- up to 128 characters long and can be consist of any UTF-8 characters.
51
+ up to 128 characters long and can consist of any UTF-8 characters.
58
52
  :param model_name: Required. A fully-qualified model resource name or model ID.
59
53
  :param instances_format: Required. The format in which instances are provided. Must be one of the
60
54
  formats listed in `Model.supported_input_storage_formats`. Default is "jsonl" when using