zenml-nightly 0.75.0.dev20250313__py3-none-any.whl → 0.75.0.dev20250315__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. zenml/VERSION +1 -1
  2. zenml/analytics/context.py +4 -4
  3. zenml/analytics/enums.py +2 -2
  4. zenml/artifacts/utils.py +2 -2
  5. zenml/cli/__init__.py +8 -9
  6. zenml/cli/base.py +2 -2
  7. zenml/cli/code_repository.py +1 -1
  8. zenml/cli/login.py +21 -18
  9. zenml/cli/pipeline.py +3 -3
  10. zenml/cli/project.py +172 -0
  11. zenml/cli/server.py +5 -5
  12. zenml/cli/service_accounts.py +0 -1
  13. zenml/cli/service_connectors.py +15 -16
  14. zenml/cli/stack.py +0 -2
  15. zenml/cli/stack_components.py +2 -2
  16. zenml/cli/utils.py +3 -3
  17. zenml/client.py +352 -341
  18. zenml/config/global_config.py +41 -43
  19. zenml/config/server_config.py +9 -9
  20. zenml/constants.py +5 -3
  21. zenml/event_hub/event_hub.py +1 -1
  22. zenml/integrations/gcp/__init__.py +1 -0
  23. zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +5 -0
  24. zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +5 -28
  25. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +125 -78
  26. zenml/integrations/gcp/service_connectors/gcp_service_connector.py +7 -6
  27. zenml/integrations/gcp/vertex_custom_job_parameters.py +50 -0
  28. zenml/integrations/mlflow/steps/mlflow_registry.py +3 -3
  29. zenml/integrations/wandb/__init__.py +1 -1
  30. zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py +29 -9
  31. zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +2 -0
  32. zenml/login/credentials.py +26 -27
  33. zenml/login/credentials_store.py +5 -5
  34. zenml/login/pro/client.py +9 -9
  35. zenml/login/pro/utils.py +8 -8
  36. zenml/login/pro/{tenant → workspace}/__init__.py +1 -1
  37. zenml/login/pro/{tenant → workspace}/client.py +25 -25
  38. zenml/login/pro/{tenant → workspace}/models.py +27 -28
  39. zenml/model/model.py +2 -2
  40. zenml/model_registries/base_model_registry.py +1 -1
  41. zenml/models/__init__.py +29 -29
  42. zenml/models/v2/base/filter.py +1 -1
  43. zenml/models/v2/base/scoped.py +49 -53
  44. zenml/models/v2/core/action.py +12 -12
  45. zenml/models/v2/core/artifact.py +15 -15
  46. zenml/models/v2/core/artifact_version.py +15 -15
  47. zenml/models/v2/core/code_repository.py +12 -12
  48. zenml/models/v2/core/event_source.py +12 -12
  49. zenml/models/v2/core/model.py +26 -18
  50. zenml/models/v2/core/model_version.py +15 -15
  51. zenml/models/v2/core/pipeline.py +15 -15
  52. zenml/models/v2/core/pipeline_build.py +14 -14
  53. zenml/models/v2/core/pipeline_deployment.py +12 -14
  54. zenml/models/v2/core/pipeline_run.py +16 -16
  55. zenml/models/v2/core/project.py +203 -0
  56. zenml/models/v2/core/run_metadata.py +2 -2
  57. zenml/models/v2/core/run_template.py +15 -15
  58. zenml/models/v2/core/schedule.py +12 -12
  59. zenml/models/v2/core/secret.py +1 -1
  60. zenml/models/v2/core/service.py +14 -14
  61. zenml/models/v2/core/step_run.py +13 -13
  62. zenml/models/v2/core/tag.py +96 -3
  63. zenml/models/v2/core/trigger.py +13 -13
  64. zenml/models/v2/core/trigger_execution.py +2 -2
  65. zenml/models/v2/core/user.py +0 -17
  66. zenml/models/v2/misc/server_models.py +6 -6
  67. zenml/models/v2/misc/statistics.py +4 -4
  68. zenml/orchestrators/cache_utils.py +7 -7
  69. zenml/orchestrators/input_utils.py +1 -1
  70. zenml/orchestrators/step_launcher.py +1 -1
  71. zenml/orchestrators/step_run_utils.py +3 -3
  72. zenml/orchestrators/utils.py +4 -4
  73. zenml/pipelines/build_utils.py +2 -2
  74. zenml/pipelines/pipeline_definition.py +5 -5
  75. zenml/pipelines/run_utils.py +1 -1
  76. zenml/service_connectors/service_connector.py +0 -3
  77. zenml/service_connectors/service_connector_utils.py +0 -1
  78. zenml/stack/stack.py +0 -1
  79. zenml/steps/base_step.py +10 -2
  80. zenml/utils/dashboard_utils.py +1 -1
  81. zenml/utils/tag_utils.py +0 -12
  82. zenml/zen_server/cloud_utils.py +3 -3
  83. zenml/zen_server/feature_gate/endpoint_utils.py +1 -1
  84. zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +1 -1
  85. zenml/zen_server/rbac/endpoint_utils.py +17 -17
  86. zenml/zen_server/rbac/models.py +47 -22
  87. zenml/zen_server/rbac/rbac_sql_zen_store.py +3 -3
  88. zenml/zen_server/rbac/utils.py +23 -25
  89. zenml/zen_server/rbac/zenml_cloud_rbac.py +7 -74
  90. zenml/zen_server/routers/artifact_version_endpoints.py +10 -10
  91. zenml/zen_server/routers/auth_endpoints.py +6 -6
  92. zenml/zen_server/routers/code_repositories_endpoints.py +12 -14
  93. zenml/zen_server/routers/model_versions_endpoints.py +13 -15
  94. zenml/zen_server/routers/models_endpoints.py +7 -9
  95. zenml/zen_server/routers/pipeline_builds_endpoints.py +14 -16
  96. zenml/zen_server/routers/pipeline_deployments_endpoints.py +13 -15
  97. zenml/zen_server/routers/pipelines_endpoints.py +16 -18
  98. zenml/zen_server/routers/{workspaces_endpoints.py → projects_endpoints.py} +111 -68
  99. zenml/zen_server/routers/run_metadata_endpoints.py +7 -9
  100. zenml/zen_server/routers/run_templates_endpoints.py +15 -17
  101. zenml/zen_server/routers/runs_endpoints.py +12 -14
  102. zenml/zen_server/routers/schedule_endpoints.py +12 -14
  103. zenml/zen_server/routers/secrets_endpoints.py +1 -3
  104. zenml/zen_server/routers/server_endpoints.py +7 -7
  105. zenml/zen_server/routers/service_connectors_endpoints.py +11 -13
  106. zenml/zen_server/routers/service_endpoints.py +7 -9
  107. zenml/zen_server/routers/stack_components_endpoints.py +9 -11
  108. zenml/zen_server/routers/stacks_endpoints.py +9 -11
  109. zenml/zen_server/routers/steps_endpoints.py +6 -6
  110. zenml/zen_server/routers/users_endpoints.py +5 -43
  111. zenml/zen_server/template_execution/utils.py +4 -4
  112. zenml/zen_server/utils.py +10 -10
  113. zenml/zen_server/zen_server_api.py +6 -5
  114. zenml/zen_stores/base_zen_store.py +38 -42
  115. zenml/zen_stores/migrations/versions/12eff0206201_rename_workspace_to_project.py +768 -0
  116. zenml/zen_stores/migrations/versions/41b28cae31ce_make_artifacts_workspace_scoped.py +3 -3
  117. zenml/zen_stores/migrations/versions/cbc6acd71f92_add_workspace_display_name.py +58 -0
  118. zenml/zen_stores/rest_zen_store.py +55 -63
  119. zenml/zen_stores/schemas/__init__.py +2 -2
  120. zenml/zen_stores/schemas/action_schemas.py +9 -9
  121. zenml/zen_stores/schemas/artifact_schemas.py +15 -17
  122. zenml/zen_stores/schemas/code_repository_schemas.py +16 -18
  123. zenml/zen_stores/schemas/event_source_schemas.py +9 -9
  124. zenml/zen_stores/schemas/model_schemas.py +15 -17
  125. zenml/zen_stores/schemas/pipeline_build_schemas.py +7 -7
  126. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
  127. zenml/zen_stores/schemas/pipeline_run_schemas.py +9 -9
  128. zenml/zen_stores/schemas/pipeline_schemas.py +9 -9
  129. zenml/zen_stores/schemas/{workspace_schemas.py → project_schemas.py} +47 -41
  130. zenml/zen_stores/schemas/run_metadata_schemas.py +5 -5
  131. zenml/zen_stores/schemas/run_template_schemas.py +9 -9
  132. zenml/zen_stores/schemas/schedule_schema.py +9 -9
  133. zenml/zen_stores/schemas/service_schemas.py +7 -7
  134. zenml/zen_stores/schemas/step_run_schemas.py +7 -7
  135. zenml/zen_stores/schemas/trigger_schemas.py +9 -9
  136. zenml/zen_stores/schemas/user_schemas.py +0 -12
  137. zenml/zen_stores/sql_zen_store.py +318 -275
  138. zenml/zen_stores/zen_store_interface.py +56 -70
  139. {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/METADATA +1 -1
  140. {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/RECORD +143 -140
  141. zenml/cli/workspace.py +0 -160
  142. zenml/models/v2/core/workspace.py +0 -131
  143. {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/LICENSE +0 -0
  144. {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/WHEEL +0 -0
  145. {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/entry_points.txt +0 -0
@@ -48,7 +48,7 @@ from zenml.logger import get_logger
48
48
  from zenml.utils import io_utils, yaml_utils
49
49
 
50
50
  if TYPE_CHECKING:
51
- from zenml.models import StackResponse, WorkspaceResponse
51
+ from zenml.models import ProjectResponse, StackResponse
52
52
  from zenml.zen_stores.base_zen_store import BaseZenStore
53
53
 
54
54
  logger = get_logger(__name__)
@@ -113,7 +113,7 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
113
113
  global config.
114
114
  store: Store configuration.
115
115
  active_stack_id: The ID of the active stack.
116
- active_workspace_name: The name of the active workspace.
116
+ active_project_name: The name of the active project.
117
117
  """
118
118
 
119
119
  user_id: uuid.UUID = Field(default_factory=uuid.uuid4)
@@ -123,10 +123,10 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
123
123
  version: Optional[str] = None
124
124
  store: Optional[SerializeAsAny[StoreConfiguration]] = None
125
125
  active_stack_id: Optional[uuid.UUID] = None
126
- active_workspace_name: Optional[str] = None
126
+ active_project_name: Optional[str] = None
127
127
 
128
128
  _zen_store: Optional["BaseZenStore"] = None
129
- _active_workspace: Optional["WorkspaceResponse"] = None
129
+ _active_project: Optional["ProjectResponse"] = None
130
130
  _active_stack: Optional["StackResponse"] = None
131
131
 
132
132
  def __init__(self, **data: Any) -> None:
@@ -385,24 +385,24 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
385
385
  def _sanitize_config(self) -> None:
386
386
  """Sanitize and save the global configuration.
387
387
 
388
- This method is called to ensure that the active stack and workspace
388
+ This method is called to ensure that the active stack and project
389
389
  are set to their default values, if possible.
390
390
  """
391
391
  # If running in a ZenML server environment, the active stack and
392
- # workspace are not relevant
392
+ # project are not relevant
393
393
  if ENV_ZENML_SERVER in os.environ:
394
394
  return
395
- active_workspace, active_stack = self.zen_store.validate_active_config(
396
- self.active_workspace_name,
395
+ active_project, active_stack = self.zen_store.validate_active_config(
396
+ self.active_project_name,
397
397
  self.active_stack_id,
398
398
  config_name="global",
399
399
  )
400
- if active_workspace:
401
- self.active_workspace_name = active_workspace.name
402
- self._active_workspace = active_workspace
400
+ if active_project:
401
+ self.active_project_name = active_project.name
402
+ self._active_project = active_project
403
403
  else:
404
- self.active_workspace_name = None
405
- self._active_workspace = None
404
+ self.active_project_name = None
405
+ self._active_project = None
406
406
 
407
407
  self.set_active_stack(active_stack)
408
408
 
@@ -719,22 +719,22 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
719
719
 
720
720
  return self._zen_store
721
721
 
722
- def set_active_workspace(
723
- self, workspace: "WorkspaceResponse"
724
- ) -> "WorkspaceResponse":
725
- """Set the workspace for the local client.
722
+ def set_active_project(
723
+ self, project: "ProjectResponse"
724
+ ) -> "ProjectResponse":
725
+ """Set the project for the local client.
726
726
 
727
727
  Args:
728
- workspace: The workspace to set active.
728
+ project: The project to set active.
729
729
 
730
730
  Returns:
731
- The workspace that was set active.
731
+ The project that was set active.
732
732
  """
733
- self.active_workspace_name = workspace.name
734
- self._active_workspace = workspace
735
- # Sanitize the global configuration to reflect the new workspace
733
+ self.active_project_name = project.name
734
+ self._active_project = project
735
+ # Sanitize the global configuration to reflect the new project
736
736
  self._sanitize_config()
737
- return workspace
737
+ return project
738
738
 
739
739
  def set_active_stack(self, stack: "StackResponse") -> None:
740
740
  """Set the active stack for the local client.
@@ -745,43 +745,41 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
745
745
  self.active_stack_id = stack.id
746
746
  self._active_stack = stack
747
747
 
748
- def get_active_workspace(self) -> "WorkspaceResponse":
749
- """Get a model of the active workspace for the local client.
748
+ def get_active_project(self) -> "ProjectResponse":
749
+ """Get a model of the active project for the local client.
750
750
 
751
751
  Returns:
752
- The model of the active workspace.
752
+ The model of the active project.
753
753
  """
754
- workspace_name = self.get_active_workspace_name()
754
+ project_name = self.get_active_project_name()
755
755
 
756
- if self._active_workspace is not None:
757
- return self._active_workspace
756
+ if self._active_project is not None:
757
+ return self._active_project
758
758
 
759
- workspace = self.zen_store.get_workspace(
760
- workspace_name_or_id=workspace_name,
759
+ project = self.zen_store.get_project(
760
+ project_name_or_id=project_name,
761
761
  )
762
- return self.set_active_workspace(workspace)
762
+ return self.set_active_project(project)
763
763
 
764
- def get_active_workspace_name(self) -> str:
765
- """Get the name of the active workspace.
766
-
767
- If the active workspace doesn't exist yet, the ZenStore is reinitialized.
764
+ def get_active_project_name(self) -> str:
765
+ """Get the name of the active project.
768
766
 
769
767
  Returns:
770
- The name of the active workspace.
768
+ The name of the active project.
771
769
 
772
770
  Raises:
773
- RuntimeError: If the active workspace is not set.
771
+ RuntimeError: If the active project is not set.
774
772
  """
775
- if self.active_workspace_name is None:
773
+ if self.active_project_name is None:
776
774
  _ = self.zen_store
777
- if self.active_workspace_name is None:
775
+ if self.active_project_name is None:
778
776
  raise RuntimeError(
779
- "No workspace is currently set as active. Please set the "
780
- "active workspace using the `zenml workspace set` CLI "
777
+ "No project is currently set as active. Please set the "
778
+ "active project using the `zenml project set` CLI "
781
779
  "command."
782
780
  )
783
781
 
784
- return self.active_workspace_name
782
+ return self.active_project_name
785
783
 
786
784
  def get_active_stack_id(self) -> UUID:
787
785
  """Get the ID of the active stack.
@@ -592,23 +592,23 @@ class ServerConfiguration(BaseModel):
592
592
  server_config.external_user_info_url = (
593
593
  f"{server_pro_config.api_url}/users/authorize_server"
594
594
  )
595
- server_config.external_server_id = server_pro_config.tenant_id
595
+ server_config.external_server_id = server_pro_config.workspace_id
596
596
  server_config.rbac_implementation_source = (
597
597
  "zenml.zen_server.rbac.zenml_cloud_rbac.ZenMLCloudRBAC"
598
598
  )
599
599
  server_config.feature_gate_implementation_source = "zenml.zen_server.feature_gate.zenml_cloud_feature_gate.ZenMLCloudFeatureGateInterface"
600
600
  server_config.reportable_resources = DEFAULT_REPORTABLE_RESOURCES
601
- server_config.dashboard_url = f"{server_pro_config.dashboard_url}/organizations/{server_pro_config.organization_id}/tenants/{server_pro_config.tenant_id}"
601
+ server_config.dashboard_url = f"{server_pro_config.dashboard_url}/workspaces/{server_pro_config.workspace_id}"
602
602
  server_config.metadata.update(
603
603
  dict(
604
604
  account_id=str(server_pro_config.organization_id),
605
605
  organization_id=str(server_pro_config.organization_id),
606
- tenant_id=str(server_pro_config.tenant_id),
606
+ workspace_id=str(server_pro_config.workspace_id),
607
607
  )
608
608
  )
609
- if server_pro_config.tenant_name:
609
+ if server_pro_config.workspace_name:
610
610
  server_config.metadata.update(
611
- dict(tenant_name=server_pro_config.tenant_name)
611
+ dict(workspace_name=server_pro_config.workspace_name)
612
612
  )
613
613
 
614
614
  extra_cors_allow_origins = [
@@ -660,8 +660,8 @@ class ServerProConfiguration(BaseModel):
660
660
  oauth2_audience: The OAuth2 audience.
661
661
  organization_id: The ZenML Pro organization ID.
662
662
  organization_name: The ZenML Pro organization name.
663
- tenant_id: The ZenML Pro tenant ID.
664
- tenant_name: The ZenML Pro tenant name.
663
+ workspace_id: The ZenML Pro workspace ID.
664
+ workspace_name: The ZenML Pro workspace name.
665
665
  """
666
666
 
667
667
  api_url: str
@@ -670,8 +670,8 @@ class ServerProConfiguration(BaseModel):
670
670
  oauth2_audience: str
671
671
  organization_id: UUID
672
672
  organization_name: Optional[str] = None
673
- tenant_id: UUID
674
- tenant_name: Optional[str] = None
673
+ workspace_id: UUID
674
+ workspace_name: Optional[str] = None
675
675
 
676
676
  @field_validator("api_url", "dashboard_url")
677
677
  @classmethod
zenml/constants.py CHANGED
@@ -143,13 +143,13 @@ ENV_ZENML_REPOSITORY_PATH = "ZENML_REPOSITORY_PATH"
143
143
  ENV_ZENML_PREVENT_PIPELINE_EXECUTION = "ZENML_PREVENT_PIPELINE_EXECUTION"
144
144
  ENV_ZENML_ENABLE_RICH_TRACEBACK = "ZENML_ENABLE_RICH_TRACEBACK"
145
145
  ENV_ZENML_ACTIVE_STACK_ID = "ZENML_ACTIVE_STACK_ID"
146
- ENV_ZENML_ACTIVE_WORKSPACE_ID = "ZENML_ACTIVE_WORKSPACE_ID"
146
+ ENV_ZENML_ACTIVE_PROJECT_ID = "ZENML_ACTIVE_PROJECT_ID"
147
147
  ENV_ZENML_SUPPRESS_LOGS = "ZENML_SUPPRESS_LOGS"
148
148
  ENV_ZENML_ENABLE_REPO_INIT_WARNINGS = "ZENML_ENABLE_REPO_INIT_WARNINGS"
149
149
  ENV_ZENML_SECRET_VALIDATION_LEVEL = "ZENML_SECRET_VALIDATION_LEVEL"
150
150
  ENV_ZENML_DEFAULT_USER_NAME = "ZENML_DEFAULT_USER_NAME"
151
151
  ENV_ZENML_DEFAULT_USER_PASSWORD = "ZENML_DEFAULT_USER_PASSWORD"
152
- ENV_ZENML_DEFAULT_WORKSPACE_NAME = "ZENML_DEFAULT_WORKSPACE_NAME"
152
+ ENV_ZENML_DEFAULT_PROJECT_NAME = "ZENML_DEFAULT_PROJECT_NAME"
153
153
  ENV_ZENML_STORE_PREFIX = "ZENML_STORE_"
154
154
  ENV_ZENML_SECRETS_STORE_PREFIX = "ZENML_SECRETS_STORE_"
155
155
  ENV_ZENML_BACKUP_SECRETS_STORE_PREFIX = "ZENML_BACKUP_SECRETS_STORE_"
@@ -239,7 +239,7 @@ SQL_STORE_BACKUP_DIRECTORY_NAME = "database_backup"
239
239
 
240
240
  DEFAULT_USERNAME = "default"
241
241
  DEFAULT_PASSWORD = ""
242
- DEFAULT_WORKSPACE_NAME = "default"
242
+ DEFAULT_PROJECT_NAME = "default"
243
243
  DEFAULT_STACK_AND_COMPONENT_NAME = "default"
244
244
 
245
245
  # Rich config
@@ -409,6 +409,7 @@ VERSION_1 = "/v1"
409
409
  VISUALIZE = "/visualize"
410
410
  WEBHOOKS = "/webhooks"
411
411
  WORKSPACES = "/workspaces"
412
+ PROJECTS = "/projects"
412
413
 
413
414
  # model metadata yaml file name
414
415
  MODEL_METADATA_YAML_FILE_NAME = "model_metadata.yaml"
@@ -450,6 +451,7 @@ STACK_RECIPES_GITHUB_REPO = "https://github.com/zenml-io/mlops-stacks.git"
450
451
 
451
452
  # Parameters for internal ZenML Models
452
453
  TEXT_FIELD_MAX_LENGTH = 65535
454
+ STR_ID_FIELD_MAX_LENGTH = 50
453
455
  STR_FIELD_MAX_LENGTH = 255
454
456
  MEDIUMTEXT_MAX_LENGTH = 2**24 - 1
455
457
 
@@ -126,7 +126,7 @@ class InternalEventHub(BaseEventHub):
126
126
  triggers: List[TriggerResponse] = depaginate(
127
127
  self.zen_store.list_triggers,
128
128
  trigger_filter_model=TriggerFilter(
129
- workspace=event_source.workspace.id,
129
+ project=event_source.project.id,
130
130
  event_source_id=event_source.id,
131
131
  is_active=True,
132
132
  ),
@@ -52,6 +52,7 @@ class GcpIntegration(Integration):
52
52
  "google-cloud-storage>=2.9.0",
53
53
  "google-cloud-aiplatform>=1.34.0", # includes shapely pin fix
54
54
  "google-cloud-build>=3.11.0",
55
+ "google-cloud-pipeline-components>=2.19.0",
55
56
  "kubernetes",
56
57
  ]
57
58
  REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kubernetes","kfp"]
@@ -23,6 +23,9 @@ from zenml.integrations.gcp import (
23
23
  from zenml.integrations.gcp.google_credentials_mixin import (
24
24
  GoogleCredentialsConfigMixin,
25
25
  )
26
+ from zenml.integrations.gcp.vertex_custom_job_parameters import (
27
+ VertexCustomJobParameters,
28
+ )
26
29
  from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings
27
30
  from zenml.models import ServiceConnectorRequirements
28
31
  from zenml.orchestrators import BaseOrchestratorConfig, BaseOrchestratorFlavor
@@ -61,6 +64,8 @@ class VertexOrchestratorSettings(BaseSettings):
61
64
  node_selector_constraint: Optional[Tuple[str, str]] = None
62
65
  pod_settings: Optional[KubernetesPodSettings] = None
63
66
 
67
+ custom_job_parameters: Optional[VertexCustomJobParameters] = None
68
+
64
69
  _node_selector_deprecation = (
65
70
  deprecation_utils.deprecate_pydantic_attributes(
66
71
  "node_selector_constraint"
@@ -23,6 +23,9 @@ from zenml.integrations.gcp import (
23
23
  from zenml.integrations.gcp.google_credentials_mixin import (
24
24
  GoogleCredentialsConfigMixin,
25
25
  )
26
+ from zenml.integrations.gcp.vertex_custom_job_parameters import (
27
+ VertexCustomJobParameters,
28
+ )
26
29
  from zenml.models import ServiceConnectorRequirements
27
30
  from zenml.step_operators.base_step_operator import (
28
31
  BaseStepOperatorConfig,
@@ -33,34 +36,8 @@ if TYPE_CHECKING:
33
36
  from zenml.integrations.gcp.step_operators import VertexStepOperator
34
37
 
35
38
 
36
- class VertexStepOperatorSettings(BaseSettings):
37
- """Settings for the Vertex step operator.
38
-
39
- Attributes:
40
- accelerator_type: Defines which accelerator (GPU, TPU) is used for the
41
- job. Check out out this table to see which accelerator
42
- type and count are compatible with your chosen machine type:
43
- https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
44
- accelerator_count: Defines number of accelerators to be used for the
45
- job. Check out out this table to see which accelerator
46
- type and count are compatible with your chosen machine type:
47
- https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
48
- machine_type: Machine type specified here
49
- https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types.
50
- boot_disk_size_gb: Size of the boot disk in GB. (Default: 100)
51
- https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
52
- boot_disk_type: Type of the boot disk. (Default: pd-ssd)
53
- https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
54
- persistent_resource_id: The ID of the persistent resource to use for the job.
55
- https://cloud.google.com/vertex-ai/docs/training/persistent-resource-overview
56
- """
57
-
58
- accelerator_type: Optional[str] = None
59
- accelerator_count: int = 0
60
- machine_type: str = "n1-standard-4"
61
- boot_disk_size_gb: int = 100
62
- boot_disk_type: str = "pd-ssd"
63
- persistent_resource_id: Optional[str] = None
39
+ class VertexStepOperatorSettings(VertexCustomJobParameters, BaseSettings):
40
+ """Settings for the Vertex step operator."""
64
41
 
65
42
 
66
43
  class VertexStepOperatorConfig(
@@ -49,8 +49,12 @@ from uuid import UUID
49
49
  from google.api_core import exceptions as google_exceptions
50
50
  from google.cloud import aiplatform
51
51
  from google.cloud.aiplatform_v1.types import PipelineState
52
+ from google_cloud_pipeline_components.v1.custom_job.utils import (
53
+ create_custom_training_job_from_component,
54
+ )
52
55
  from kfp import dsl
53
56
  from kfp.compiler import Compiler
57
+ from kfp.dsl.base_component import BaseComponent
54
58
 
55
59
  from zenml.config.resource_settings import ResourceSettings
56
60
  from zenml.constants import (
@@ -71,13 +75,15 @@ from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import (
71
75
  from zenml.integrations.gcp.google_credentials_mixin import (
72
76
  GoogleCredentialsMixin,
73
77
  )
78
+ from zenml.integrations.gcp.vertex_custom_job_parameters import (
79
+ VertexCustomJobParameters,
80
+ )
74
81
  from zenml.io import fileio
75
82
  from zenml.logger import get_logger
76
83
  from zenml.metadata.metadata_types import MetadataType, Uri
77
84
  from zenml.orchestrators import ContainerizedOrchestrator
78
85
  from zenml.orchestrators.utils import get_orchestrator_run_name
79
86
  from zenml.stack.stack_validator import StackValidator
80
- from zenml.utils import yaml_utils
81
87
  from zenml.utils.io_utils import get_global_config_directory
82
88
 
83
89
  if TYPE_CHECKING:
@@ -263,14 +269,14 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
263
269
  "schedule to a Vertex orchestrator."
264
270
  )
265
271
 
266
- def _create_dynamic_component(
272
+ def _create_container_component(
267
273
  self,
268
274
  image: str,
269
275
  command: List[str],
270
276
  arguments: List[str],
271
277
  component_name: str,
272
- ) -> dsl.PipelineTask:
273
- """Creates a dynamic container component for a Vertex pipeline.
278
+ ) -> BaseComponent:
279
+ """Creates a container component for a Vertex pipeline.
274
280
 
275
281
  Args:
276
282
  image: The image to use for the component.
@@ -279,7 +285,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
279
285
  component_name: The name of the component.
280
286
 
281
287
  Returns:
282
- The dynamic container component.
288
+ The container component.
283
289
  """
284
290
 
285
291
  def dynamic_container_component() -> dsl.ContainerSpec:
@@ -294,7 +300,6 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
294
300
  args=arguments,
295
301
  )
296
302
 
297
- # Change the name of the function
298
303
  new_container_spec_func = types.FunctionType(
299
304
  dynamic_container_component.__code__,
300
305
  dynamic_container_component.__globals__,
@@ -303,9 +308,50 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
303
308
  closure=dynamic_container_component.__closure__,
304
309
  )
305
310
  pipeline_task = dsl.container_component(new_container_spec_func)
306
-
307
311
  return pipeline_task
308
312
 
313
+ def _convert_to_custom_training_job(
314
+ self,
315
+ component: BaseComponent,
316
+ settings: VertexOrchestratorSettings,
317
+ environment: Dict[str, str],
318
+ ) -> BaseComponent:
319
+ """Convert a component to a custom training job component.
320
+
321
+ Args:
322
+ component: The component to convert.
323
+ settings: The settings for the custom training job.
324
+ environment: The environment variables to set in the custom
325
+ training job.
326
+
327
+ Returns:
328
+ The custom training job component.
329
+ """
330
+ custom_job_parameters = (
331
+ settings.custom_job_parameters or VertexCustomJobParameters()
332
+ )
333
+ if (
334
+ custom_job_parameters.persistent_resource_id
335
+ and not custom_job_parameters.service_account
336
+ ):
337
+ # Persistent resources require an explicit service account, but
338
+ # none was provided in the custom job parameters. We try to fall
339
+ # back to the workload service account.
340
+ custom_job_parameters.service_account = (
341
+ self.config.workload_service_account
342
+ )
343
+
344
+ custom_job_component = create_custom_training_job_from_component(
345
+ component_spec=component,
346
+ env=[
347
+ {"name": key, "value": value}
348
+ for key, value in environment.items()
349
+ ],
350
+ **custom_job_parameters.model_dump(),
351
+ )
352
+
353
+ return custom_job_component
354
+
309
355
  def prepare_or_run_pipeline(
310
356
  self,
311
357
  deployment: "PipelineDeploymentResponse",
@@ -383,7 +429,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
383
429
  Returns:
384
430
  pipeline_func
385
431
  """
386
- step_name_to_dynamic_component: Dict[str, Any] = {}
432
+ step_name_to_dynamic_component: Dict[str, BaseComponent] = {}
387
433
 
388
434
  for step_name, step in deployment.step_configurations.items():
389
435
  image = self.get_image(
@@ -397,7 +443,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
397
443
  deployment_id=deployment.id,
398
444
  )
399
445
  )
400
- dynamic_component = self._create_dynamic_component(
446
+ component = self._create_container_component(
401
447
  image, command, arguments, step_name
402
448
  )
403
449
  step_settings = cast(
@@ -442,7 +488,11 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
442
488
  key,
443
489
  )
444
490
 
445
- step_name_to_dynamic_component[step_name] = dynamic_component
491
+ step_name_to_dynamic_component[step_name] = component
492
+
493
+ environment[ENV_ZENML_VERTEX_RUN_ID] = (
494
+ dsl.PIPELINE_JOB_NAME_PLACEHOLDER
495
+ )
446
496
 
447
497
  @dsl.pipeline( # type: ignore[misc]
448
498
  display_name=orchestrator_run_name,
@@ -462,81 +512,81 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
462
512
  step_name_to_dynamic_component[upstream_step_name]
463
513
  for upstream_step_name in step.spec.upstream_steps
464
514
  ]
465
- task = (
466
- component()
467
- .set_display_name(
468
- name=component_name,
469
- )
470
- .set_caching_options(enable_caching=False)
471
- .set_env_variable(
472
- name=ENV_ZENML_VERTEX_RUN_ID,
473
- value=dsl.PIPELINE_JOB_NAME_PLACEHOLDER,
474
- )
475
- .after(*upstream_step_components)
476
- )
477
515
 
478
516
  step_settings = cast(
479
517
  VertexOrchestratorSettings, self.get_settings(step)
480
518
  )
481
- pod_settings = step_settings.pod_settings
482
-
483
- node_selector_constraint: Optional[Tuple[str, str]] = None
484
- if pod_settings and (
485
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
486
- in pod_settings.node_selectors.keys()
487
- ):
488
- node_selector_constraint = (
489
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
490
- pod_settings.node_selectors[
491
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
492
- ],
519
+
520
+ use_custom_training_job = (
521
+ step_settings.custom_job_parameters is not None
522
+ )
523
+
524
+ if use_custom_training_job:
525
+ if not step.config.resource_settings.empty:
526
+ logger.warning(
527
+ "Ignoring resource settings because "
528
+ "the step is running as a custom training job. "
529
+ "Use `custom_job_parameters.machine_type` "
530
+ "to configure the machine type instead."
531
+ )
532
+ if step_settings.node_selector_constraint:
533
+ logger.warning(
534
+ "Ignoring node selector constraint because "
535
+ "the step is running as a custom training job. "
536
+ "Use `custom_job_parameters.accelerator_type` "
537
+ "to configure the accelerator type instead."
538
+ )
539
+ component = self._convert_to_custom_training_job(
540
+ component,
541
+ settings=step_settings,
542
+ environment=environment,
543
+ )
544
+ task = (
545
+ component()
546
+ .set_display_name(name=component_name)
547
+ .set_caching_options(enable_caching=False)
548
+ .after(*upstream_step_components)
493
549
  )
494
- elif step_settings.node_selector_constraint:
495
- node_selector_constraint = (
496
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
497
- step_settings.node_selector_constraint[1],
550
+ else:
551
+ task = (
552
+ component()
553
+ .set_display_name(
554
+ name=component_name,
555
+ )
556
+ .set_caching_options(enable_caching=False)
557
+ .after(*upstream_step_components)
498
558
  )
559
+ for key, value in environment.items():
560
+ task = task.set_env_variable(name=key, value=value)
499
561
 
500
- self._configure_container_resources(
501
- dynamic_component=task,
502
- resource_settings=step.config.resource_settings,
503
- node_selector_constraint=node_selector_constraint,
504
- )
562
+ pod_settings = step_settings.pod_settings
505
563
 
506
- return dynamic_pipeline
564
+ node_selector_constraint: Optional[Tuple[str, str]] = (
565
+ None
566
+ )
567
+ if pod_settings and (
568
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
569
+ in pod_settings.node_selectors.keys()
570
+ ):
571
+ node_selector_constraint = (
572
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
573
+ pod_settings.node_selectors[
574
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
575
+ ],
576
+ )
577
+ elif step_settings.node_selector_constraint:
578
+ node_selector_constraint = (
579
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
580
+ step_settings.node_selector_constraint[1],
581
+ )
507
582
 
508
- def _update_json_with_environment(
509
- yaml_file_path: str, environment: Dict[str, str]
510
- ) -> None:
511
- """Updates the env section of the steps in the YAML file with the given environment variables.
583
+ self._configure_container_resources(
584
+ dynamic_component=task,
585
+ resource_settings=step.config.resource_settings,
586
+ node_selector_constraint=node_selector_constraint,
587
+ )
512
588
 
513
- Args:
514
- yaml_file_path: The path to the YAML file to update.
515
- environment: A dictionary of environment variables to add.
516
- """
517
- pipeline_definition = yaml_utils.read_json(pipeline_file_path)
518
-
519
- # Iterate through each component and add the environment variables
520
- for executor in pipeline_definition["deploymentSpec"]["executors"]:
521
- if (
522
- "container"
523
- in pipeline_definition["deploymentSpec"]["executors"][
524
- executor
525
- ]
526
- ):
527
- container = pipeline_definition["deploymentSpec"][
528
- "executors"
529
- ][executor]["container"]
530
- if "env" not in container:
531
- container["env"] = []
532
- for key, value in environment.items():
533
- container["env"].append({"name": key, "value": value})
534
-
535
- yaml_utils.write_json(pipeline_file_path, pipeline_definition)
536
-
537
- print(
538
- f"Updated YAML file with environment variables at {yaml_file_path}"
539
- )
589
+ return dynamic_pipeline
540
590
 
541
591
  # Save the generated pipeline to a file.
542
592
  fileio.makedirs(self.pipeline_directory)
@@ -556,9 +606,6 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
556
606
  ),
557
607
  )
558
608
 
559
- # Let's update the YAML file with the environment variables
560
- _update_json_with_environment(pipeline_file_path, environment)
561
-
562
609
  logger.info(
563
610
  "Writing Vertex workflow definition to `%s`.", pipeline_file_path
564
611
  )