zenml-nightly 0.75.0.dev20250313__py3-none-any.whl → 0.75.0.dev20250315__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/analytics/context.py +4 -4
- zenml/analytics/enums.py +2 -2
- zenml/artifacts/utils.py +2 -2
- zenml/cli/__init__.py +8 -9
- zenml/cli/base.py +2 -2
- zenml/cli/code_repository.py +1 -1
- zenml/cli/login.py +21 -18
- zenml/cli/pipeline.py +3 -3
- zenml/cli/project.py +172 -0
- zenml/cli/server.py +5 -5
- zenml/cli/service_accounts.py +0 -1
- zenml/cli/service_connectors.py +15 -16
- zenml/cli/stack.py +0 -2
- zenml/cli/stack_components.py +2 -2
- zenml/cli/utils.py +3 -3
- zenml/client.py +352 -341
- zenml/config/global_config.py +41 -43
- zenml/config/server_config.py +9 -9
- zenml/constants.py +5 -3
- zenml/event_hub/event_hub.py +1 -1
- zenml/integrations/gcp/__init__.py +1 -0
- zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +5 -0
- zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +5 -28
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +125 -78
- zenml/integrations/gcp/service_connectors/gcp_service_connector.py +7 -6
- zenml/integrations/gcp/vertex_custom_job_parameters.py +50 -0
- zenml/integrations/mlflow/steps/mlflow_registry.py +3 -3
- zenml/integrations/wandb/__init__.py +1 -1
- zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py +29 -9
- zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +2 -0
- zenml/login/credentials.py +26 -27
- zenml/login/credentials_store.py +5 -5
- zenml/login/pro/client.py +9 -9
- zenml/login/pro/utils.py +8 -8
- zenml/login/pro/{tenant → workspace}/__init__.py +1 -1
- zenml/login/pro/{tenant → workspace}/client.py +25 -25
- zenml/login/pro/{tenant → workspace}/models.py +27 -28
- zenml/model/model.py +2 -2
- zenml/model_registries/base_model_registry.py +1 -1
- zenml/models/__init__.py +29 -29
- zenml/models/v2/base/filter.py +1 -1
- zenml/models/v2/base/scoped.py +49 -53
- zenml/models/v2/core/action.py +12 -12
- zenml/models/v2/core/artifact.py +15 -15
- zenml/models/v2/core/artifact_version.py +15 -15
- zenml/models/v2/core/code_repository.py +12 -12
- zenml/models/v2/core/event_source.py +12 -12
- zenml/models/v2/core/model.py +26 -18
- zenml/models/v2/core/model_version.py +15 -15
- zenml/models/v2/core/pipeline.py +15 -15
- zenml/models/v2/core/pipeline_build.py +14 -14
- zenml/models/v2/core/pipeline_deployment.py +12 -14
- zenml/models/v2/core/pipeline_run.py +16 -16
- zenml/models/v2/core/project.py +203 -0
- zenml/models/v2/core/run_metadata.py +2 -2
- zenml/models/v2/core/run_template.py +15 -15
- zenml/models/v2/core/schedule.py +12 -12
- zenml/models/v2/core/secret.py +1 -1
- zenml/models/v2/core/service.py +14 -14
- zenml/models/v2/core/step_run.py +13 -13
- zenml/models/v2/core/tag.py +96 -3
- zenml/models/v2/core/trigger.py +13 -13
- zenml/models/v2/core/trigger_execution.py +2 -2
- zenml/models/v2/core/user.py +0 -17
- zenml/models/v2/misc/server_models.py +6 -6
- zenml/models/v2/misc/statistics.py +4 -4
- zenml/orchestrators/cache_utils.py +7 -7
- zenml/orchestrators/input_utils.py +1 -1
- zenml/orchestrators/step_launcher.py +1 -1
- zenml/orchestrators/step_run_utils.py +3 -3
- zenml/orchestrators/utils.py +4 -4
- zenml/pipelines/build_utils.py +2 -2
- zenml/pipelines/pipeline_definition.py +5 -5
- zenml/pipelines/run_utils.py +1 -1
- zenml/service_connectors/service_connector.py +0 -3
- zenml/service_connectors/service_connector_utils.py +0 -1
- zenml/stack/stack.py +0 -1
- zenml/steps/base_step.py +10 -2
- zenml/utils/dashboard_utils.py +1 -1
- zenml/utils/tag_utils.py +0 -12
- zenml/zen_server/cloud_utils.py +3 -3
- zenml/zen_server/feature_gate/endpoint_utils.py +1 -1
- zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +1 -1
- zenml/zen_server/rbac/endpoint_utils.py +17 -17
- zenml/zen_server/rbac/models.py +47 -22
- zenml/zen_server/rbac/rbac_sql_zen_store.py +3 -3
- zenml/zen_server/rbac/utils.py +23 -25
- zenml/zen_server/rbac/zenml_cloud_rbac.py +7 -74
- zenml/zen_server/routers/artifact_version_endpoints.py +10 -10
- zenml/zen_server/routers/auth_endpoints.py +6 -6
- zenml/zen_server/routers/code_repositories_endpoints.py +12 -14
- zenml/zen_server/routers/model_versions_endpoints.py +13 -15
- zenml/zen_server/routers/models_endpoints.py +7 -9
- zenml/zen_server/routers/pipeline_builds_endpoints.py +14 -16
- zenml/zen_server/routers/pipeline_deployments_endpoints.py +13 -15
- zenml/zen_server/routers/pipelines_endpoints.py +16 -18
- zenml/zen_server/routers/{workspaces_endpoints.py → projects_endpoints.py} +111 -68
- zenml/zen_server/routers/run_metadata_endpoints.py +7 -9
- zenml/zen_server/routers/run_templates_endpoints.py +15 -17
- zenml/zen_server/routers/runs_endpoints.py +12 -14
- zenml/zen_server/routers/schedule_endpoints.py +12 -14
- zenml/zen_server/routers/secrets_endpoints.py +1 -3
- zenml/zen_server/routers/server_endpoints.py +7 -7
- zenml/zen_server/routers/service_connectors_endpoints.py +11 -13
- zenml/zen_server/routers/service_endpoints.py +7 -9
- zenml/zen_server/routers/stack_components_endpoints.py +9 -11
- zenml/zen_server/routers/stacks_endpoints.py +9 -11
- zenml/zen_server/routers/steps_endpoints.py +6 -6
- zenml/zen_server/routers/users_endpoints.py +5 -43
- zenml/zen_server/template_execution/utils.py +4 -4
- zenml/zen_server/utils.py +10 -10
- zenml/zen_server/zen_server_api.py +6 -5
- zenml/zen_stores/base_zen_store.py +38 -42
- zenml/zen_stores/migrations/versions/12eff0206201_rename_workspace_to_project.py +768 -0
- zenml/zen_stores/migrations/versions/41b28cae31ce_make_artifacts_workspace_scoped.py +3 -3
- zenml/zen_stores/migrations/versions/cbc6acd71f92_add_workspace_display_name.py +58 -0
- zenml/zen_stores/rest_zen_store.py +55 -63
- zenml/zen_stores/schemas/__init__.py +2 -2
- zenml/zen_stores/schemas/action_schemas.py +9 -9
- zenml/zen_stores/schemas/artifact_schemas.py +15 -17
- zenml/zen_stores/schemas/code_repository_schemas.py +16 -18
- zenml/zen_stores/schemas/event_source_schemas.py +9 -9
- zenml/zen_stores/schemas/model_schemas.py +15 -17
- zenml/zen_stores/schemas/pipeline_build_schemas.py +7 -7
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
- zenml/zen_stores/schemas/pipeline_run_schemas.py +9 -9
- zenml/zen_stores/schemas/pipeline_schemas.py +9 -9
- zenml/zen_stores/schemas/{workspace_schemas.py → project_schemas.py} +47 -41
- zenml/zen_stores/schemas/run_metadata_schemas.py +5 -5
- zenml/zen_stores/schemas/run_template_schemas.py +9 -9
- zenml/zen_stores/schemas/schedule_schema.py +9 -9
- zenml/zen_stores/schemas/service_schemas.py +7 -7
- zenml/zen_stores/schemas/step_run_schemas.py +7 -7
- zenml/zen_stores/schemas/trigger_schemas.py +9 -9
- zenml/zen_stores/schemas/user_schemas.py +0 -12
- zenml/zen_stores/sql_zen_store.py +318 -275
- zenml/zen_stores/zen_store_interface.py +56 -70
- {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/METADATA +1 -1
- {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/RECORD +143 -140
- zenml/cli/workspace.py +0 -160
- zenml/models/v2/core/workspace.py +0 -131
- {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250315.dist-info}/entry_points.txt +0 -0
zenml/config/global_config.py
CHANGED
@@ -48,7 +48,7 @@ from zenml.logger import get_logger
|
|
48
48
|
from zenml.utils import io_utils, yaml_utils
|
49
49
|
|
50
50
|
if TYPE_CHECKING:
|
51
|
-
from zenml.models import
|
51
|
+
from zenml.models import ProjectResponse, StackResponse
|
52
52
|
from zenml.zen_stores.base_zen_store import BaseZenStore
|
53
53
|
|
54
54
|
logger = get_logger(__name__)
|
@@ -113,7 +113,7 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
|
|
113
113
|
global config.
|
114
114
|
store: Store configuration.
|
115
115
|
active_stack_id: The ID of the active stack.
|
116
|
-
|
116
|
+
active_project_name: The name of the active project.
|
117
117
|
"""
|
118
118
|
|
119
119
|
user_id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
@@ -123,10 +123,10 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
|
|
123
123
|
version: Optional[str] = None
|
124
124
|
store: Optional[SerializeAsAny[StoreConfiguration]] = None
|
125
125
|
active_stack_id: Optional[uuid.UUID] = None
|
126
|
-
|
126
|
+
active_project_name: Optional[str] = None
|
127
127
|
|
128
128
|
_zen_store: Optional["BaseZenStore"] = None
|
129
|
-
|
129
|
+
_active_project: Optional["ProjectResponse"] = None
|
130
130
|
_active_stack: Optional["StackResponse"] = None
|
131
131
|
|
132
132
|
def __init__(self, **data: Any) -> None:
|
@@ -385,24 +385,24 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
|
|
385
385
|
def _sanitize_config(self) -> None:
|
386
386
|
"""Sanitize and save the global configuration.
|
387
387
|
|
388
|
-
This method is called to ensure that the active stack and
|
388
|
+
This method is called to ensure that the active stack and project
|
389
389
|
are set to their default values, if possible.
|
390
390
|
"""
|
391
391
|
# If running in a ZenML server environment, the active stack and
|
392
|
-
#
|
392
|
+
# project are not relevant
|
393
393
|
if ENV_ZENML_SERVER in os.environ:
|
394
394
|
return
|
395
|
-
|
396
|
-
self.
|
395
|
+
active_project, active_stack = self.zen_store.validate_active_config(
|
396
|
+
self.active_project_name,
|
397
397
|
self.active_stack_id,
|
398
398
|
config_name="global",
|
399
399
|
)
|
400
|
-
if
|
401
|
-
self.
|
402
|
-
self.
|
400
|
+
if active_project:
|
401
|
+
self.active_project_name = active_project.name
|
402
|
+
self._active_project = active_project
|
403
403
|
else:
|
404
|
-
self.
|
405
|
-
self.
|
404
|
+
self.active_project_name = None
|
405
|
+
self._active_project = None
|
406
406
|
|
407
407
|
self.set_active_stack(active_stack)
|
408
408
|
|
@@ -719,22 +719,22 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
|
|
719
719
|
|
720
720
|
return self._zen_store
|
721
721
|
|
722
|
-
def
|
723
|
-
self,
|
724
|
-
) -> "
|
725
|
-
"""Set the
|
722
|
+
def set_active_project(
|
723
|
+
self, project: "ProjectResponse"
|
724
|
+
) -> "ProjectResponse":
|
725
|
+
"""Set the project for the local client.
|
726
726
|
|
727
727
|
Args:
|
728
|
-
|
728
|
+
project: The project to set active.
|
729
729
|
|
730
730
|
Returns:
|
731
|
-
The
|
731
|
+
The project that was set active.
|
732
732
|
"""
|
733
|
-
self.
|
734
|
-
self.
|
735
|
-
# Sanitize the global configuration to reflect the new
|
733
|
+
self.active_project_name = project.name
|
734
|
+
self._active_project = project
|
735
|
+
# Sanitize the global configuration to reflect the new project
|
736
736
|
self._sanitize_config()
|
737
|
-
return
|
737
|
+
return project
|
738
738
|
|
739
739
|
def set_active_stack(self, stack: "StackResponse") -> None:
|
740
740
|
"""Set the active stack for the local client.
|
@@ -745,43 +745,41 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
|
|
745
745
|
self.active_stack_id = stack.id
|
746
746
|
self._active_stack = stack
|
747
747
|
|
748
|
-
def
|
749
|
-
"""Get a model of the active
|
748
|
+
def get_active_project(self) -> "ProjectResponse":
|
749
|
+
"""Get a model of the active project for the local client.
|
750
750
|
|
751
751
|
Returns:
|
752
|
-
The model of the active
|
752
|
+
The model of the active project.
|
753
753
|
"""
|
754
|
-
|
754
|
+
project_name = self.get_active_project_name()
|
755
755
|
|
756
|
-
if self.
|
757
|
-
return self.
|
756
|
+
if self._active_project is not None:
|
757
|
+
return self._active_project
|
758
758
|
|
759
|
-
|
760
|
-
|
759
|
+
project = self.zen_store.get_project(
|
760
|
+
project_name_or_id=project_name,
|
761
761
|
)
|
762
|
-
return self.
|
762
|
+
return self.set_active_project(project)
|
763
763
|
|
764
|
-
def
|
765
|
-
"""Get the name of the active
|
766
|
-
|
767
|
-
If the active workspace doesn't exist yet, the ZenStore is reinitialized.
|
764
|
+
def get_active_project_name(self) -> str:
|
765
|
+
"""Get the name of the active project.
|
768
766
|
|
769
767
|
Returns:
|
770
|
-
The name of the active
|
768
|
+
The name of the active project.
|
771
769
|
|
772
770
|
Raises:
|
773
|
-
RuntimeError: If the active
|
771
|
+
RuntimeError: If the active project is not set.
|
774
772
|
"""
|
775
|
-
if self.
|
773
|
+
if self.active_project_name is None:
|
776
774
|
_ = self.zen_store
|
777
|
-
if self.
|
775
|
+
if self.active_project_name is None:
|
778
776
|
raise RuntimeError(
|
779
|
-
"No
|
780
|
-
"active
|
777
|
+
"No project is currently set as active. Please set the "
|
778
|
+
"active project using the `zenml project set` CLI "
|
781
779
|
"command."
|
782
780
|
)
|
783
781
|
|
784
|
-
return self.
|
782
|
+
return self.active_project_name
|
785
783
|
|
786
784
|
def get_active_stack_id(self) -> UUID:
|
787
785
|
"""Get the ID of the active stack.
|
zenml/config/server_config.py
CHANGED
@@ -592,23 +592,23 @@ class ServerConfiguration(BaseModel):
|
|
592
592
|
server_config.external_user_info_url = (
|
593
593
|
f"{server_pro_config.api_url}/users/authorize_server"
|
594
594
|
)
|
595
|
-
server_config.external_server_id = server_pro_config.
|
595
|
+
server_config.external_server_id = server_pro_config.workspace_id
|
596
596
|
server_config.rbac_implementation_source = (
|
597
597
|
"zenml.zen_server.rbac.zenml_cloud_rbac.ZenMLCloudRBAC"
|
598
598
|
)
|
599
599
|
server_config.feature_gate_implementation_source = "zenml.zen_server.feature_gate.zenml_cloud_feature_gate.ZenMLCloudFeatureGateInterface"
|
600
600
|
server_config.reportable_resources = DEFAULT_REPORTABLE_RESOURCES
|
601
|
-
server_config.dashboard_url = f"{server_pro_config.dashboard_url}/
|
601
|
+
server_config.dashboard_url = f"{server_pro_config.dashboard_url}/workspaces/{server_pro_config.workspace_id}"
|
602
602
|
server_config.metadata.update(
|
603
603
|
dict(
|
604
604
|
account_id=str(server_pro_config.organization_id),
|
605
605
|
organization_id=str(server_pro_config.organization_id),
|
606
|
-
|
606
|
+
workspace_id=str(server_pro_config.workspace_id),
|
607
607
|
)
|
608
608
|
)
|
609
|
-
if server_pro_config.
|
609
|
+
if server_pro_config.workspace_name:
|
610
610
|
server_config.metadata.update(
|
611
|
-
dict(
|
611
|
+
dict(workspace_name=server_pro_config.workspace_name)
|
612
612
|
)
|
613
613
|
|
614
614
|
extra_cors_allow_origins = [
|
@@ -660,8 +660,8 @@ class ServerProConfiguration(BaseModel):
|
|
660
660
|
oauth2_audience: The OAuth2 audience.
|
661
661
|
organization_id: The ZenML Pro organization ID.
|
662
662
|
organization_name: The ZenML Pro organization name.
|
663
|
-
|
664
|
-
|
663
|
+
workspace_id: The ZenML Pro workspace ID.
|
664
|
+
workspace_name: The ZenML Pro workspace name.
|
665
665
|
"""
|
666
666
|
|
667
667
|
api_url: str
|
@@ -670,8 +670,8 @@ class ServerProConfiguration(BaseModel):
|
|
670
670
|
oauth2_audience: str
|
671
671
|
organization_id: UUID
|
672
672
|
organization_name: Optional[str] = None
|
673
|
-
|
674
|
-
|
673
|
+
workspace_id: UUID
|
674
|
+
workspace_name: Optional[str] = None
|
675
675
|
|
676
676
|
@field_validator("api_url", "dashboard_url")
|
677
677
|
@classmethod
|
zenml/constants.py
CHANGED
@@ -143,13 +143,13 @@ ENV_ZENML_REPOSITORY_PATH = "ZENML_REPOSITORY_PATH"
|
|
143
143
|
ENV_ZENML_PREVENT_PIPELINE_EXECUTION = "ZENML_PREVENT_PIPELINE_EXECUTION"
|
144
144
|
ENV_ZENML_ENABLE_RICH_TRACEBACK = "ZENML_ENABLE_RICH_TRACEBACK"
|
145
145
|
ENV_ZENML_ACTIVE_STACK_ID = "ZENML_ACTIVE_STACK_ID"
|
146
|
-
|
146
|
+
ENV_ZENML_ACTIVE_PROJECT_ID = "ZENML_ACTIVE_PROJECT_ID"
|
147
147
|
ENV_ZENML_SUPPRESS_LOGS = "ZENML_SUPPRESS_LOGS"
|
148
148
|
ENV_ZENML_ENABLE_REPO_INIT_WARNINGS = "ZENML_ENABLE_REPO_INIT_WARNINGS"
|
149
149
|
ENV_ZENML_SECRET_VALIDATION_LEVEL = "ZENML_SECRET_VALIDATION_LEVEL"
|
150
150
|
ENV_ZENML_DEFAULT_USER_NAME = "ZENML_DEFAULT_USER_NAME"
|
151
151
|
ENV_ZENML_DEFAULT_USER_PASSWORD = "ZENML_DEFAULT_USER_PASSWORD"
|
152
|
-
|
152
|
+
ENV_ZENML_DEFAULT_PROJECT_NAME = "ZENML_DEFAULT_PROJECT_NAME"
|
153
153
|
ENV_ZENML_STORE_PREFIX = "ZENML_STORE_"
|
154
154
|
ENV_ZENML_SECRETS_STORE_PREFIX = "ZENML_SECRETS_STORE_"
|
155
155
|
ENV_ZENML_BACKUP_SECRETS_STORE_PREFIX = "ZENML_BACKUP_SECRETS_STORE_"
|
@@ -239,7 +239,7 @@ SQL_STORE_BACKUP_DIRECTORY_NAME = "database_backup"
|
|
239
239
|
|
240
240
|
DEFAULT_USERNAME = "default"
|
241
241
|
DEFAULT_PASSWORD = ""
|
242
|
-
|
242
|
+
DEFAULT_PROJECT_NAME = "default"
|
243
243
|
DEFAULT_STACK_AND_COMPONENT_NAME = "default"
|
244
244
|
|
245
245
|
# Rich config
|
@@ -409,6 +409,7 @@ VERSION_1 = "/v1"
|
|
409
409
|
VISUALIZE = "/visualize"
|
410
410
|
WEBHOOKS = "/webhooks"
|
411
411
|
WORKSPACES = "/workspaces"
|
412
|
+
PROJECTS = "/projects"
|
412
413
|
|
413
414
|
# model metadata yaml file name
|
414
415
|
MODEL_METADATA_YAML_FILE_NAME = "model_metadata.yaml"
|
@@ -450,6 +451,7 @@ STACK_RECIPES_GITHUB_REPO = "https://github.com/zenml-io/mlops-stacks.git"
|
|
450
451
|
|
451
452
|
# Parameters for internal ZenML Models
|
452
453
|
TEXT_FIELD_MAX_LENGTH = 65535
|
454
|
+
STR_ID_FIELD_MAX_LENGTH = 50
|
453
455
|
STR_FIELD_MAX_LENGTH = 255
|
454
456
|
MEDIUMTEXT_MAX_LENGTH = 2**24 - 1
|
455
457
|
|
zenml/event_hub/event_hub.py
CHANGED
@@ -126,7 +126,7 @@ class InternalEventHub(BaseEventHub):
|
|
126
126
|
triggers: List[TriggerResponse] = depaginate(
|
127
127
|
self.zen_store.list_triggers,
|
128
128
|
trigger_filter_model=TriggerFilter(
|
129
|
-
|
129
|
+
project=event_source.project.id,
|
130
130
|
event_source_id=event_source.id,
|
131
131
|
is_active=True,
|
132
132
|
),
|
@@ -52,6 +52,7 @@ class GcpIntegration(Integration):
|
|
52
52
|
"google-cloud-storage>=2.9.0",
|
53
53
|
"google-cloud-aiplatform>=1.34.0", # includes shapely pin fix
|
54
54
|
"google-cloud-build>=3.11.0",
|
55
|
+
"google-cloud-pipeline-components>=2.19.0",
|
55
56
|
"kubernetes",
|
56
57
|
]
|
57
58
|
REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kubernetes","kfp"]
|
@@ -23,6 +23,9 @@ from zenml.integrations.gcp import (
|
|
23
23
|
from zenml.integrations.gcp.google_credentials_mixin import (
|
24
24
|
GoogleCredentialsConfigMixin,
|
25
25
|
)
|
26
|
+
from zenml.integrations.gcp.vertex_custom_job_parameters import (
|
27
|
+
VertexCustomJobParameters,
|
28
|
+
)
|
26
29
|
from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings
|
27
30
|
from zenml.models import ServiceConnectorRequirements
|
28
31
|
from zenml.orchestrators import BaseOrchestratorConfig, BaseOrchestratorFlavor
|
@@ -61,6 +64,8 @@ class VertexOrchestratorSettings(BaseSettings):
|
|
61
64
|
node_selector_constraint: Optional[Tuple[str, str]] = None
|
62
65
|
pod_settings: Optional[KubernetesPodSettings] = None
|
63
66
|
|
67
|
+
custom_job_parameters: Optional[VertexCustomJobParameters] = None
|
68
|
+
|
64
69
|
_node_selector_deprecation = (
|
65
70
|
deprecation_utils.deprecate_pydantic_attributes(
|
66
71
|
"node_selector_constraint"
|
@@ -23,6 +23,9 @@ from zenml.integrations.gcp import (
|
|
23
23
|
from zenml.integrations.gcp.google_credentials_mixin import (
|
24
24
|
GoogleCredentialsConfigMixin,
|
25
25
|
)
|
26
|
+
from zenml.integrations.gcp.vertex_custom_job_parameters import (
|
27
|
+
VertexCustomJobParameters,
|
28
|
+
)
|
26
29
|
from zenml.models import ServiceConnectorRequirements
|
27
30
|
from zenml.step_operators.base_step_operator import (
|
28
31
|
BaseStepOperatorConfig,
|
@@ -33,34 +36,8 @@ if TYPE_CHECKING:
|
|
33
36
|
from zenml.integrations.gcp.step_operators import VertexStepOperator
|
34
37
|
|
35
38
|
|
36
|
-
class VertexStepOperatorSettings(BaseSettings):
|
37
|
-
"""Settings for the Vertex step operator.
|
38
|
-
|
39
|
-
Attributes:
|
40
|
-
accelerator_type: Defines which accelerator (GPU, TPU) is used for the
|
41
|
-
job. Check out out this table to see which accelerator
|
42
|
-
type and count are compatible with your chosen machine type:
|
43
|
-
https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
|
44
|
-
accelerator_count: Defines number of accelerators to be used for the
|
45
|
-
job. Check out out this table to see which accelerator
|
46
|
-
type and count are compatible with your chosen machine type:
|
47
|
-
https://cloud.google.com/vertex-ai/docs/training/configure-compute#gpu-compatibility-table.
|
48
|
-
machine_type: Machine type specified here
|
49
|
-
https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types.
|
50
|
-
boot_disk_size_gb: Size of the boot disk in GB. (Default: 100)
|
51
|
-
https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
|
52
|
-
boot_disk_type: Type of the boot disk. (Default: pd-ssd)
|
53
|
-
https://cloud.google.com/vertex-ai/docs/training/configure-compute#boot_disk_options
|
54
|
-
persistent_resource_id: The ID of the persistent resource to use for the job.
|
55
|
-
https://cloud.google.com/vertex-ai/docs/training/persistent-resource-overview
|
56
|
-
"""
|
57
|
-
|
58
|
-
accelerator_type: Optional[str] = None
|
59
|
-
accelerator_count: int = 0
|
60
|
-
machine_type: str = "n1-standard-4"
|
61
|
-
boot_disk_size_gb: int = 100
|
62
|
-
boot_disk_type: str = "pd-ssd"
|
63
|
-
persistent_resource_id: Optional[str] = None
|
39
|
+
class VertexStepOperatorSettings(VertexCustomJobParameters, BaseSettings):
|
40
|
+
"""Settings for the Vertex step operator."""
|
64
41
|
|
65
42
|
|
66
43
|
class VertexStepOperatorConfig(
|
@@ -49,8 +49,12 @@ from uuid import UUID
|
|
49
49
|
from google.api_core import exceptions as google_exceptions
|
50
50
|
from google.cloud import aiplatform
|
51
51
|
from google.cloud.aiplatform_v1.types import PipelineState
|
52
|
+
from google_cloud_pipeline_components.v1.custom_job.utils import (
|
53
|
+
create_custom_training_job_from_component,
|
54
|
+
)
|
52
55
|
from kfp import dsl
|
53
56
|
from kfp.compiler import Compiler
|
57
|
+
from kfp.dsl.base_component import BaseComponent
|
54
58
|
|
55
59
|
from zenml.config.resource_settings import ResourceSettings
|
56
60
|
from zenml.constants import (
|
@@ -71,13 +75,15 @@ from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import (
|
|
71
75
|
from zenml.integrations.gcp.google_credentials_mixin import (
|
72
76
|
GoogleCredentialsMixin,
|
73
77
|
)
|
78
|
+
from zenml.integrations.gcp.vertex_custom_job_parameters import (
|
79
|
+
VertexCustomJobParameters,
|
80
|
+
)
|
74
81
|
from zenml.io import fileio
|
75
82
|
from zenml.logger import get_logger
|
76
83
|
from zenml.metadata.metadata_types import MetadataType, Uri
|
77
84
|
from zenml.orchestrators import ContainerizedOrchestrator
|
78
85
|
from zenml.orchestrators.utils import get_orchestrator_run_name
|
79
86
|
from zenml.stack.stack_validator import StackValidator
|
80
|
-
from zenml.utils import yaml_utils
|
81
87
|
from zenml.utils.io_utils import get_global_config_directory
|
82
88
|
|
83
89
|
if TYPE_CHECKING:
|
@@ -263,14 +269,14 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
263
269
|
"schedule to a Vertex orchestrator."
|
264
270
|
)
|
265
271
|
|
266
|
-
def
|
272
|
+
def _create_container_component(
|
267
273
|
self,
|
268
274
|
image: str,
|
269
275
|
command: List[str],
|
270
276
|
arguments: List[str],
|
271
277
|
component_name: str,
|
272
|
-
) ->
|
273
|
-
"""Creates a
|
278
|
+
) -> BaseComponent:
|
279
|
+
"""Creates a container component for a Vertex pipeline.
|
274
280
|
|
275
281
|
Args:
|
276
282
|
image: The image to use for the component.
|
@@ -279,7 +285,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
279
285
|
component_name: The name of the component.
|
280
286
|
|
281
287
|
Returns:
|
282
|
-
The
|
288
|
+
The container component.
|
283
289
|
"""
|
284
290
|
|
285
291
|
def dynamic_container_component() -> dsl.ContainerSpec:
|
@@ -294,7 +300,6 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
294
300
|
args=arguments,
|
295
301
|
)
|
296
302
|
|
297
|
-
# Change the name of the function
|
298
303
|
new_container_spec_func = types.FunctionType(
|
299
304
|
dynamic_container_component.__code__,
|
300
305
|
dynamic_container_component.__globals__,
|
@@ -303,9 +308,50 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
303
308
|
closure=dynamic_container_component.__closure__,
|
304
309
|
)
|
305
310
|
pipeline_task = dsl.container_component(new_container_spec_func)
|
306
|
-
|
307
311
|
return pipeline_task
|
308
312
|
|
313
|
+
def _convert_to_custom_training_job(
|
314
|
+
self,
|
315
|
+
component: BaseComponent,
|
316
|
+
settings: VertexOrchestratorSettings,
|
317
|
+
environment: Dict[str, str],
|
318
|
+
) -> BaseComponent:
|
319
|
+
"""Convert a component to a custom training job component.
|
320
|
+
|
321
|
+
Args:
|
322
|
+
component: The component to convert.
|
323
|
+
settings: The settings for the custom training job.
|
324
|
+
environment: The environment variables to set in the custom
|
325
|
+
training job.
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
The custom training job component.
|
329
|
+
"""
|
330
|
+
custom_job_parameters = (
|
331
|
+
settings.custom_job_parameters or VertexCustomJobParameters()
|
332
|
+
)
|
333
|
+
if (
|
334
|
+
custom_job_parameters.persistent_resource_id
|
335
|
+
and not custom_job_parameters.service_account
|
336
|
+
):
|
337
|
+
# Persistent resources require an explicit service account, but
|
338
|
+
# none was provided in the custom job parameters. We try to fall
|
339
|
+
# back to the workload service account.
|
340
|
+
custom_job_parameters.service_account = (
|
341
|
+
self.config.workload_service_account
|
342
|
+
)
|
343
|
+
|
344
|
+
custom_job_component = create_custom_training_job_from_component(
|
345
|
+
component_spec=component,
|
346
|
+
env=[
|
347
|
+
{"name": key, "value": value}
|
348
|
+
for key, value in environment.items()
|
349
|
+
],
|
350
|
+
**custom_job_parameters.model_dump(),
|
351
|
+
)
|
352
|
+
|
353
|
+
return custom_job_component
|
354
|
+
|
309
355
|
def prepare_or_run_pipeline(
|
310
356
|
self,
|
311
357
|
deployment: "PipelineDeploymentResponse",
|
@@ -383,7 +429,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
383
429
|
Returns:
|
384
430
|
pipeline_func
|
385
431
|
"""
|
386
|
-
step_name_to_dynamic_component: Dict[str,
|
432
|
+
step_name_to_dynamic_component: Dict[str, BaseComponent] = {}
|
387
433
|
|
388
434
|
for step_name, step in deployment.step_configurations.items():
|
389
435
|
image = self.get_image(
|
@@ -397,7 +443,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
397
443
|
deployment_id=deployment.id,
|
398
444
|
)
|
399
445
|
)
|
400
|
-
|
446
|
+
component = self._create_container_component(
|
401
447
|
image, command, arguments, step_name
|
402
448
|
)
|
403
449
|
step_settings = cast(
|
@@ -442,7 +488,11 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
442
488
|
key,
|
443
489
|
)
|
444
490
|
|
445
|
-
step_name_to_dynamic_component[step_name] =
|
491
|
+
step_name_to_dynamic_component[step_name] = component
|
492
|
+
|
493
|
+
environment[ENV_ZENML_VERTEX_RUN_ID] = (
|
494
|
+
dsl.PIPELINE_JOB_NAME_PLACEHOLDER
|
495
|
+
)
|
446
496
|
|
447
497
|
@dsl.pipeline( # type: ignore[misc]
|
448
498
|
display_name=orchestrator_run_name,
|
@@ -462,81 +512,81 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
462
512
|
step_name_to_dynamic_component[upstream_step_name]
|
463
513
|
for upstream_step_name in step.spec.upstream_steps
|
464
514
|
]
|
465
|
-
task = (
|
466
|
-
component()
|
467
|
-
.set_display_name(
|
468
|
-
name=component_name,
|
469
|
-
)
|
470
|
-
.set_caching_options(enable_caching=False)
|
471
|
-
.set_env_variable(
|
472
|
-
name=ENV_ZENML_VERTEX_RUN_ID,
|
473
|
-
value=dsl.PIPELINE_JOB_NAME_PLACEHOLDER,
|
474
|
-
)
|
475
|
-
.after(*upstream_step_components)
|
476
|
-
)
|
477
515
|
|
478
516
|
step_settings = cast(
|
479
517
|
VertexOrchestratorSettings, self.get_settings(step)
|
480
518
|
)
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
519
|
+
|
520
|
+
use_custom_training_job = (
|
521
|
+
step_settings.custom_job_parameters is not None
|
522
|
+
)
|
523
|
+
|
524
|
+
if use_custom_training_job:
|
525
|
+
if not step.config.resource_settings.empty:
|
526
|
+
logger.warning(
|
527
|
+
"Ignoring resource settings because "
|
528
|
+
"the step is running as a custom training job. "
|
529
|
+
"Use `custom_job_parameters.machine_type` "
|
530
|
+
"to configure the machine type instead."
|
531
|
+
)
|
532
|
+
if step_settings.node_selector_constraint:
|
533
|
+
logger.warning(
|
534
|
+
"Ignoring node selector constraint because "
|
535
|
+
"the step is running as a custom training job. "
|
536
|
+
"Use `custom_job_parameters.accelerator_type` "
|
537
|
+
"to configure the accelerator type instead."
|
538
|
+
)
|
539
|
+
component = self._convert_to_custom_training_job(
|
540
|
+
component,
|
541
|
+
settings=step_settings,
|
542
|
+
environment=environment,
|
543
|
+
)
|
544
|
+
task = (
|
545
|
+
component()
|
546
|
+
.set_display_name(name=component_name)
|
547
|
+
.set_caching_options(enable_caching=False)
|
548
|
+
.after(*upstream_step_components)
|
493
549
|
)
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
550
|
+
else:
|
551
|
+
task = (
|
552
|
+
component()
|
553
|
+
.set_display_name(
|
554
|
+
name=component_name,
|
555
|
+
)
|
556
|
+
.set_caching_options(enable_caching=False)
|
557
|
+
.after(*upstream_step_components)
|
498
558
|
)
|
559
|
+
for key, value in environment.items():
|
560
|
+
task = task.set_env_variable(name=key, value=value)
|
499
561
|
|
500
|
-
|
501
|
-
dynamic_component=task,
|
502
|
-
resource_settings=step.config.resource_settings,
|
503
|
-
node_selector_constraint=node_selector_constraint,
|
504
|
-
)
|
562
|
+
pod_settings = step_settings.pod_settings
|
505
563
|
|
506
|
-
|
564
|
+
node_selector_constraint: Optional[Tuple[str, str]] = (
|
565
|
+
None
|
566
|
+
)
|
567
|
+
if pod_settings and (
|
568
|
+
GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
|
569
|
+
in pod_settings.node_selectors.keys()
|
570
|
+
):
|
571
|
+
node_selector_constraint = (
|
572
|
+
GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
|
573
|
+
pod_settings.node_selectors[
|
574
|
+
GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
|
575
|
+
],
|
576
|
+
)
|
577
|
+
elif step_settings.node_selector_constraint:
|
578
|
+
node_selector_constraint = (
|
579
|
+
GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
|
580
|
+
step_settings.node_selector_constraint[1],
|
581
|
+
)
|
507
582
|
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
583
|
+
self._configure_container_resources(
|
584
|
+
dynamic_component=task,
|
585
|
+
resource_settings=step.config.resource_settings,
|
586
|
+
node_selector_constraint=node_selector_constraint,
|
587
|
+
)
|
512
588
|
|
513
|
-
|
514
|
-
yaml_file_path: The path to the YAML file to update.
|
515
|
-
environment: A dictionary of environment variables to add.
|
516
|
-
"""
|
517
|
-
pipeline_definition = yaml_utils.read_json(pipeline_file_path)
|
518
|
-
|
519
|
-
# Iterate through each component and add the environment variables
|
520
|
-
for executor in pipeline_definition["deploymentSpec"]["executors"]:
|
521
|
-
if (
|
522
|
-
"container"
|
523
|
-
in pipeline_definition["deploymentSpec"]["executors"][
|
524
|
-
executor
|
525
|
-
]
|
526
|
-
):
|
527
|
-
container = pipeline_definition["deploymentSpec"][
|
528
|
-
"executors"
|
529
|
-
][executor]["container"]
|
530
|
-
if "env" not in container:
|
531
|
-
container["env"] = []
|
532
|
-
for key, value in environment.items():
|
533
|
-
container["env"].append({"name": key, "value": value})
|
534
|
-
|
535
|
-
yaml_utils.write_json(pipeline_file_path, pipeline_definition)
|
536
|
-
|
537
|
-
print(
|
538
|
-
f"Updated YAML file with environment variables at {yaml_file_path}"
|
539
|
-
)
|
589
|
+
return dynamic_pipeline
|
540
590
|
|
541
591
|
# Save the generated pipeline to a file.
|
542
592
|
fileio.makedirs(self.pipeline_directory)
|
@@ -556,9 +606,6 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
|
|
556
606
|
),
|
557
607
|
)
|
558
608
|
|
559
|
-
# Let's update the YAML file with the environment variables
|
560
|
-
_update_json_with_environment(pipeline_file_path, environment)
|
561
|
-
|
562
609
|
logger.info(
|
563
610
|
"Writing Vertex workflow definition to `%s`.", pipeline_file_path
|
564
611
|
)
|