zenml-nightly 0.75.0.dev20250313__py3-none-any.whl → 0.75.0.dev20250314__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/analytics/context.py +4 -4
- zenml/analytics/enums.py +2 -2
- zenml/artifacts/utils.py +2 -2
- zenml/cli/__init__.py +8 -9
- zenml/cli/base.py +2 -2
- zenml/cli/code_repository.py +1 -1
- zenml/cli/pipeline.py +3 -3
- zenml/cli/project.py +172 -0
- zenml/cli/service_accounts.py +0 -1
- zenml/cli/service_connectors.py +15 -16
- zenml/cli/stack.py +0 -2
- zenml/cli/stack_components.py +2 -2
- zenml/cli/utils.py +3 -3
- zenml/client.py +347 -340
- zenml/config/global_config.py +41 -43
- zenml/constants.py +5 -3
- zenml/event_hub/event_hub.py +1 -1
- zenml/integrations/gcp/service_connectors/gcp_service_connector.py +7 -6
- zenml/integrations/mlflow/steps/mlflow_registry.py +3 -3
- zenml/integrations/wandb/__init__.py +1 -1
- zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py +29 -9
- zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +2 -0
- zenml/model/model.py +2 -2
- zenml/model_registries/base_model_registry.py +1 -1
- zenml/models/__init__.py +29 -29
- zenml/models/v2/base/filter.py +1 -1
- zenml/models/v2/base/scoped.py +49 -53
- zenml/models/v2/core/action.py +12 -12
- zenml/models/v2/core/artifact.py +15 -15
- zenml/models/v2/core/artifact_version.py +15 -15
- zenml/models/v2/core/code_repository.py +12 -12
- zenml/models/v2/core/event_source.py +12 -12
- zenml/models/v2/core/model.py +17 -17
- zenml/models/v2/core/model_version.py +15 -15
- zenml/models/v2/core/pipeline.py +15 -15
- zenml/models/v2/core/pipeline_build.py +14 -14
- zenml/models/v2/core/pipeline_deployment.py +12 -14
- zenml/models/v2/core/pipeline_run.py +16 -16
- zenml/models/v2/core/project.py +203 -0
- zenml/models/v2/core/run_metadata.py +2 -2
- zenml/models/v2/core/run_template.py +15 -15
- zenml/models/v2/core/schedule.py +12 -12
- zenml/models/v2/core/secret.py +1 -1
- zenml/models/v2/core/service.py +14 -14
- zenml/models/v2/core/step_run.py +13 -13
- zenml/models/v2/core/trigger.py +13 -13
- zenml/models/v2/core/trigger_execution.py +2 -2
- zenml/models/v2/core/user.py +0 -17
- zenml/models/v2/misc/statistics.py +4 -4
- zenml/orchestrators/cache_utils.py +7 -7
- zenml/orchestrators/input_utils.py +1 -1
- zenml/orchestrators/step_launcher.py +1 -1
- zenml/orchestrators/step_run_utils.py +2 -2
- zenml/orchestrators/utils.py +4 -4
- zenml/pipelines/build_utils.py +2 -2
- zenml/pipelines/pipeline_definition.py +5 -5
- zenml/pipelines/run_utils.py +1 -1
- zenml/service_connectors/service_connector.py +0 -3
- zenml/service_connectors/service_connector_utils.py +0 -1
- zenml/stack/stack.py +0 -1
- zenml/steps/base_step.py +10 -2
- zenml/zen_server/rbac/endpoint_utils.py +17 -17
- zenml/zen_server/rbac/models.py +20 -20
- zenml/zen_server/rbac/rbac_sql_zen_store.py +3 -3
- zenml/zen_server/rbac/utils.py +23 -25
- zenml/zen_server/rbac/zenml_cloud_rbac.py +12 -16
- zenml/zen_server/routers/artifact_version_endpoints.py +10 -10
- zenml/zen_server/routers/auth_endpoints.py +6 -6
- zenml/zen_server/routers/code_repositories_endpoints.py +12 -14
- zenml/zen_server/routers/model_versions_endpoints.py +13 -15
- zenml/zen_server/routers/models_endpoints.py +7 -9
- zenml/zen_server/routers/pipeline_builds_endpoints.py +14 -16
- zenml/zen_server/routers/pipeline_deployments_endpoints.py +13 -15
- zenml/zen_server/routers/pipelines_endpoints.py +16 -18
- zenml/zen_server/routers/{workspaces_endpoints.py → projects_endpoints.py} +111 -68
- zenml/zen_server/routers/run_metadata_endpoints.py +7 -9
- zenml/zen_server/routers/run_templates_endpoints.py +15 -17
- zenml/zen_server/routers/runs_endpoints.py +12 -14
- zenml/zen_server/routers/schedule_endpoints.py +12 -14
- zenml/zen_server/routers/secrets_endpoints.py +1 -3
- zenml/zen_server/routers/server_endpoints.py +5 -5
- zenml/zen_server/routers/service_connectors_endpoints.py +11 -13
- zenml/zen_server/routers/service_endpoints.py +7 -9
- zenml/zen_server/routers/stack_components_endpoints.py +9 -11
- zenml/zen_server/routers/stacks_endpoints.py +9 -11
- zenml/zen_server/routers/steps_endpoints.py +6 -6
- zenml/zen_server/routers/users_endpoints.py +5 -43
- zenml/zen_server/template_execution/utils.py +4 -4
- zenml/zen_server/utils.py +10 -10
- zenml/zen_server/zen_server_api.py +3 -2
- zenml/zen_stores/base_zen_store.py +35 -39
- zenml/zen_stores/migrations/versions/12eff0206201_rename_workspace_to_project.py +768 -0
- zenml/zen_stores/migrations/versions/41b28cae31ce_make_artifacts_workspace_scoped.py +3 -3
- zenml/zen_stores/migrations/versions/cbc6acd71f92_add_workspace_display_name.py +58 -0
- zenml/zen_stores/rest_zen_store.py +54 -62
- zenml/zen_stores/schemas/__init__.py +2 -2
- zenml/zen_stores/schemas/action_schemas.py +9 -9
- zenml/zen_stores/schemas/artifact_schemas.py +15 -17
- zenml/zen_stores/schemas/code_repository_schemas.py +16 -18
- zenml/zen_stores/schemas/event_source_schemas.py +9 -9
- zenml/zen_stores/schemas/model_schemas.py +15 -17
- zenml/zen_stores/schemas/pipeline_build_schemas.py +7 -7
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
- zenml/zen_stores/schemas/pipeline_run_schemas.py +9 -9
- zenml/zen_stores/schemas/pipeline_schemas.py +9 -9
- zenml/zen_stores/schemas/{workspace_schemas.py → project_schemas.py} +47 -41
- zenml/zen_stores/schemas/run_metadata_schemas.py +5 -5
- zenml/zen_stores/schemas/run_template_schemas.py +9 -9
- zenml/zen_stores/schemas/schedule_schema.py +9 -9
- zenml/zen_stores/schemas/service_schemas.py +7 -7
- zenml/zen_stores/schemas/step_run_schemas.py +7 -7
- zenml/zen_stores/schemas/trigger_schemas.py +9 -9
- zenml/zen_stores/schemas/user_schemas.py +0 -12
- zenml/zen_stores/sql_zen_store.py +258 -268
- zenml/zen_stores/zen_store_interface.py +56 -70
- {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250314.dist-info}/METADATA +1 -1
- {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250314.dist-info}/RECORD +121 -119
- zenml/cli/workspace.py +0 -160
- zenml/models/v2/core/workspace.py +0 -131
- {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250314.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250314.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.75.0.dev20250313.dist-info → zenml_nightly-0.75.0.dev20250314.dist-info}/entry_points.txt +0 -0
zenml/config/global_config.py
CHANGED
@@ -48,7 +48,7 @@ from zenml.logger import get_logger
|
|
48
48
|
from zenml.utils import io_utils, yaml_utils
|
49
49
|
|
50
50
|
if TYPE_CHECKING:
|
51
|
-
from zenml.models import
|
51
|
+
from zenml.models import ProjectResponse, StackResponse
|
52
52
|
from zenml.zen_stores.base_zen_store import BaseZenStore
|
53
53
|
|
54
54
|
logger = get_logger(__name__)
|
@@ -113,7 +113,7 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
|
|
113
113
|
global config.
|
114
114
|
store: Store configuration.
|
115
115
|
active_stack_id: The ID of the active stack.
|
116
|
-
|
116
|
+
active_project_name: The name of the active project.
|
117
117
|
"""
|
118
118
|
|
119
119
|
user_id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
@@ -123,10 +123,10 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
|
|
123
123
|
version: Optional[str] = None
|
124
124
|
store: Optional[SerializeAsAny[StoreConfiguration]] = None
|
125
125
|
active_stack_id: Optional[uuid.UUID] = None
|
126
|
-
|
126
|
+
active_project_name: Optional[str] = None
|
127
127
|
|
128
128
|
_zen_store: Optional["BaseZenStore"] = None
|
129
|
-
|
129
|
+
_active_project: Optional["ProjectResponse"] = None
|
130
130
|
_active_stack: Optional["StackResponse"] = None
|
131
131
|
|
132
132
|
def __init__(self, **data: Any) -> None:
|
@@ -385,24 +385,24 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
|
|
385
385
|
def _sanitize_config(self) -> None:
|
386
386
|
"""Sanitize and save the global configuration.
|
387
387
|
|
388
|
-
This method is called to ensure that the active stack and
|
388
|
+
This method is called to ensure that the active stack and project
|
389
389
|
are set to their default values, if possible.
|
390
390
|
"""
|
391
391
|
# If running in a ZenML server environment, the active stack and
|
392
|
-
#
|
392
|
+
# project are not relevant
|
393
393
|
if ENV_ZENML_SERVER in os.environ:
|
394
394
|
return
|
395
|
-
|
396
|
-
self.
|
395
|
+
active_project, active_stack = self.zen_store.validate_active_config(
|
396
|
+
self.active_project_name,
|
397
397
|
self.active_stack_id,
|
398
398
|
config_name="global",
|
399
399
|
)
|
400
|
-
if
|
401
|
-
self.
|
402
|
-
self.
|
400
|
+
if active_project:
|
401
|
+
self.active_project_name = active_project.name
|
402
|
+
self._active_project = active_project
|
403
403
|
else:
|
404
|
-
self.
|
405
|
-
self.
|
404
|
+
self.active_project_name = None
|
405
|
+
self._active_project = None
|
406
406
|
|
407
407
|
self.set_active_stack(active_stack)
|
408
408
|
|
@@ -719,22 +719,22 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
|
|
719
719
|
|
720
720
|
return self._zen_store
|
721
721
|
|
722
|
-
def
|
723
|
-
self,
|
724
|
-
) -> "
|
725
|
-
"""Set the
|
722
|
+
def set_active_project(
|
723
|
+
self, project: "ProjectResponse"
|
724
|
+
) -> "ProjectResponse":
|
725
|
+
"""Set the project for the local client.
|
726
726
|
|
727
727
|
Args:
|
728
|
-
|
728
|
+
project: The project to set active.
|
729
729
|
|
730
730
|
Returns:
|
731
|
-
The
|
731
|
+
The project that was set active.
|
732
732
|
"""
|
733
|
-
self.
|
734
|
-
self.
|
735
|
-
# Sanitize the global configuration to reflect the new
|
733
|
+
self.active_project_name = project.name
|
734
|
+
self._active_project = project
|
735
|
+
# Sanitize the global configuration to reflect the new project
|
736
736
|
self._sanitize_config()
|
737
|
-
return
|
737
|
+
return project
|
738
738
|
|
739
739
|
def set_active_stack(self, stack: "StackResponse") -> None:
|
740
740
|
"""Set the active stack for the local client.
|
@@ -745,43 +745,41 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
|
|
745
745
|
self.active_stack_id = stack.id
|
746
746
|
self._active_stack = stack
|
747
747
|
|
748
|
-
def
|
749
|
-
"""Get a model of the active
|
748
|
+
def get_active_project(self) -> "ProjectResponse":
|
749
|
+
"""Get a model of the active project for the local client.
|
750
750
|
|
751
751
|
Returns:
|
752
|
-
The model of the active
|
752
|
+
The model of the active project.
|
753
753
|
"""
|
754
|
-
|
754
|
+
project_name = self.get_active_project_name()
|
755
755
|
|
756
|
-
if self.
|
757
|
-
return self.
|
756
|
+
if self._active_project is not None:
|
757
|
+
return self._active_project
|
758
758
|
|
759
|
-
|
760
|
-
|
759
|
+
project = self.zen_store.get_project(
|
760
|
+
project_name_or_id=project_name,
|
761
761
|
)
|
762
|
-
return self.
|
762
|
+
return self.set_active_project(project)
|
763
763
|
|
764
|
-
def
|
765
|
-
"""Get the name of the active
|
766
|
-
|
767
|
-
If the active workspace doesn't exist yet, the ZenStore is reinitialized.
|
764
|
+
def get_active_project_name(self) -> str:
|
765
|
+
"""Get the name of the active project.
|
768
766
|
|
769
767
|
Returns:
|
770
|
-
The name of the active
|
768
|
+
The name of the active project.
|
771
769
|
|
772
770
|
Raises:
|
773
|
-
RuntimeError: If the active
|
771
|
+
RuntimeError: If the active project is not set.
|
774
772
|
"""
|
775
|
-
if self.
|
773
|
+
if self.active_project_name is None:
|
776
774
|
_ = self.zen_store
|
777
|
-
if self.
|
775
|
+
if self.active_project_name is None:
|
778
776
|
raise RuntimeError(
|
779
|
-
"No
|
780
|
-
"active
|
777
|
+
"No project is currently set as active. Please set the "
|
778
|
+
"active project using the `zenml project set` CLI "
|
781
779
|
"command."
|
782
780
|
)
|
783
781
|
|
784
|
-
return self.
|
782
|
+
return self.active_project_name
|
785
783
|
|
786
784
|
def get_active_stack_id(self) -> UUID:
|
787
785
|
"""Get the ID of the active stack.
|
zenml/constants.py
CHANGED
@@ -143,13 +143,13 @@ ENV_ZENML_REPOSITORY_PATH = "ZENML_REPOSITORY_PATH"
|
|
143
143
|
ENV_ZENML_PREVENT_PIPELINE_EXECUTION = "ZENML_PREVENT_PIPELINE_EXECUTION"
|
144
144
|
ENV_ZENML_ENABLE_RICH_TRACEBACK = "ZENML_ENABLE_RICH_TRACEBACK"
|
145
145
|
ENV_ZENML_ACTIVE_STACK_ID = "ZENML_ACTIVE_STACK_ID"
|
146
|
-
|
146
|
+
ENV_ZENML_ACTIVE_PROJECT_ID = "ZENML_ACTIVE_PROJECT_ID"
|
147
147
|
ENV_ZENML_SUPPRESS_LOGS = "ZENML_SUPPRESS_LOGS"
|
148
148
|
ENV_ZENML_ENABLE_REPO_INIT_WARNINGS = "ZENML_ENABLE_REPO_INIT_WARNINGS"
|
149
149
|
ENV_ZENML_SECRET_VALIDATION_LEVEL = "ZENML_SECRET_VALIDATION_LEVEL"
|
150
150
|
ENV_ZENML_DEFAULT_USER_NAME = "ZENML_DEFAULT_USER_NAME"
|
151
151
|
ENV_ZENML_DEFAULT_USER_PASSWORD = "ZENML_DEFAULT_USER_PASSWORD"
|
152
|
-
|
152
|
+
ENV_ZENML_DEFAULT_PROJECT_NAME = "ZENML_DEFAULT_PROJECT_NAME"
|
153
153
|
ENV_ZENML_STORE_PREFIX = "ZENML_STORE_"
|
154
154
|
ENV_ZENML_SECRETS_STORE_PREFIX = "ZENML_SECRETS_STORE_"
|
155
155
|
ENV_ZENML_BACKUP_SECRETS_STORE_PREFIX = "ZENML_BACKUP_SECRETS_STORE_"
|
@@ -239,7 +239,7 @@ SQL_STORE_BACKUP_DIRECTORY_NAME = "database_backup"
|
|
239
239
|
|
240
240
|
DEFAULT_USERNAME = "default"
|
241
241
|
DEFAULT_PASSWORD = ""
|
242
|
-
|
242
|
+
DEFAULT_PROJECT_NAME = "default"
|
243
243
|
DEFAULT_STACK_AND_COMPONENT_NAME = "default"
|
244
244
|
|
245
245
|
# Rich config
|
@@ -409,6 +409,7 @@ VERSION_1 = "/v1"
|
|
409
409
|
VISUALIZE = "/visualize"
|
410
410
|
WEBHOOKS = "/webhooks"
|
411
411
|
WORKSPACES = "/workspaces"
|
412
|
+
PROJECTS = "/projects"
|
412
413
|
|
413
414
|
# model metadata yaml file name
|
414
415
|
MODEL_METADATA_YAML_FILE_NAME = "model_metadata.yaml"
|
@@ -450,6 +451,7 @@ STACK_RECIPES_GITHUB_REPO = "https://github.com/zenml-io/mlops-stacks.git"
|
|
450
451
|
|
451
452
|
# Parameters for internal ZenML Models
|
452
453
|
TEXT_FIELD_MAX_LENGTH = 65535
|
454
|
+
STR_ID_FIELD_MAX_LENGTH = 50
|
453
455
|
STR_FIELD_MAX_LENGTH = 255
|
454
456
|
MEDIUMTEXT_MAX_LENGTH = 2**24 - 1
|
455
457
|
|
zenml/event_hub/event_hub.py
CHANGED
@@ -126,7 +126,7 @@ class InternalEventHub(BaseEventHub):
|
|
126
126
|
triggers: List[TriggerResponse] = depaginate(
|
127
127
|
self.zen_store.list_triggers,
|
128
128
|
trigger_filter_model=TriggerFilter(
|
129
|
-
|
129
|
+
project=event_source.project.id,
|
130
130
|
event_source_id=event_source.id,
|
131
131
|
is_active=True,
|
132
132
|
),
|
@@ -438,7 +438,7 @@ class GCPUserAccountConfig(GCPBaseProjectIDConfig, GCPUserAccountCredentials):
|
|
438
438
|
class GCPServiceAccountConfig(GCPBaseConfig, GCPServiceAccountCredentials):
|
439
439
|
"""GCP service account configuration."""
|
440
440
|
|
441
|
-
|
441
|
+
project_id: Optional[str] = None
|
442
442
|
|
443
443
|
@property
|
444
444
|
def gcp_project_id(self) -> str:
|
@@ -450,14 +450,14 @@ class GCPServiceAccountConfig(GCPBaseConfig, GCPServiceAccountCredentials):
|
|
450
450
|
Returns:
|
451
451
|
The GCP project ID.
|
452
452
|
"""
|
453
|
-
if self.
|
454
|
-
self.
|
453
|
+
if self.project_id is None:
|
454
|
+
self.project_id = json.loads(
|
455
455
|
self.service_account_json.get_secret_value()
|
456
456
|
)["project_id"]
|
457
457
|
# Guaranteed by the field validator
|
458
|
-
assert self.
|
458
|
+
assert self.project_id is not None
|
459
459
|
|
460
|
-
return self.
|
460
|
+
return self.project_id
|
461
461
|
|
462
462
|
|
463
463
|
class GCPExternalAccountConfig(
|
@@ -798,7 +798,8 @@ connector will distribute the service account credentials JSON to clients
|
|
798
798
|
instead (not recommended).
|
799
799
|
|
800
800
|
A GCP project is required and the connector may only be used to access GCP
|
801
|
-
resources in the specified project.
|
801
|
+
resources in the specified project. If the `project_id` is not provided, the
|
802
|
+
connector will use the one extracted from the service account key JSON.
|
802
803
|
|
803
804
|
If you already have the GOOGLE_APPLICATION_CREDENTIALS environment variable
|
804
805
|
configured to point to a service account key JSON file, it will be automatically
|
@@ -92,7 +92,7 @@ def mlflow_register_model_step(
|
|
92
92
|
pipeline_name = step_context.pipeline.name
|
93
93
|
current_run_name = step_context.pipeline_run.name
|
94
94
|
pipeline_run_uuid = str(step_context.pipeline_run.id)
|
95
|
-
|
95
|
+
zenml_project = str(step_context.pipeline.project.name)
|
96
96
|
|
97
97
|
# Get MLflow run ID either from params or from experiment tracker using
|
98
98
|
# pipeline name and run name
|
@@ -144,8 +144,8 @@ def mlflow_register_model_step(
|
|
144
144
|
metadata.zenml_run_name = run_name
|
145
145
|
if metadata.zenml_pipeline_run_uuid is None:
|
146
146
|
metadata.zenml_pipeline_run_uuid = pipeline_run_uuid
|
147
|
-
if metadata.
|
148
|
-
metadata.
|
147
|
+
if metadata.zenml_project is None:
|
148
|
+
metadata.zenml_project = zenml_project
|
149
149
|
if getattr(metadata, "mlflow_run_id", None) is None:
|
150
150
|
setattr(metadata, "mlflow_run_id", mlflow_run_id)
|
151
151
|
|
@@ -30,7 +30,7 @@ class WandbIntegration(Integration):
|
|
30
30
|
"""Definition of Plotly integration for ZenML."""
|
31
31
|
|
32
32
|
NAME = WANDB
|
33
|
-
REQUIREMENTS = ["wandb>=0.12.12", "Pillow>=9.1.0"]
|
33
|
+
REQUIREMENTS = ["wandb>=0.12.12,<1.0.0", "Pillow>=9.1.0", "weave>=0.51.33,<1.0.0"]
|
34
34
|
REQUIREMENTS_IGNORED_ON_UNINSTALL = ["Pillow"]
|
35
35
|
|
36
36
|
@classmethod
|
@@ -14,7 +14,7 @@
|
|
14
14
|
"""Implementation for the wandb experiment tracker."""
|
15
15
|
|
16
16
|
import os
|
17
|
-
from typing import TYPE_CHECKING,
|
17
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Type, cast
|
18
18
|
|
19
19
|
import wandb
|
20
20
|
|
@@ -30,8 +30,6 @@ from zenml.logger import get_logger
|
|
30
30
|
from zenml.metadata.metadata_types import Uri
|
31
31
|
|
32
32
|
if TYPE_CHECKING:
|
33
|
-
from wandb import Settings
|
34
|
-
|
35
33
|
from zenml.config.step_run_info import StepRunInfo
|
36
34
|
from zenml.metadata.metadata_types import MetadataType
|
37
35
|
|
@@ -76,9 +74,7 @@ class WandbExperimentTracker(BaseExperimentTracker):
|
|
76
74
|
wandb_run_name = (
|
77
75
|
settings.run_name or f"{info.run_name}_{info.pipeline_step_name}"
|
78
76
|
)
|
79
|
-
self._initialize_wandb(
|
80
|
-
run_name=wandb_run_name, tags=tags, settings=settings.settings
|
81
|
-
)
|
77
|
+
self._initialize_wandb(run_name=wandb_run_name, tags=tags, info=info)
|
82
78
|
|
83
79
|
def get_step_run_metadata(
|
84
80
|
self, info: "StepRunInfo"
|
@@ -131,25 +127,49 @@ class WandbExperimentTracker(BaseExperimentTracker):
|
|
131
127
|
|
132
128
|
def _initialize_wandb(
|
133
129
|
self,
|
130
|
+
info: "StepRunInfo",
|
134
131
|
run_name: str,
|
135
132
|
tags: List[str],
|
136
|
-
settings: Union["Settings", Dict[str, Any], None] = None,
|
137
133
|
) -> None:
|
138
134
|
"""Initializes a wandb run.
|
139
135
|
|
140
136
|
Args:
|
137
|
+
info: Step run information.
|
141
138
|
run_name: Name of the wandb run to create.
|
142
139
|
tags: Tags to attach to the wandb run.
|
143
|
-
settings: Additional settings for the wandb run.
|
144
140
|
"""
|
145
141
|
logger.info(
|
146
142
|
f"Initializing wandb with entity {self.config.entity}, project "
|
147
143
|
f"name: {self.config.project_name}, run_name: {run_name}."
|
148
144
|
)
|
145
|
+
settings = cast(
|
146
|
+
WandbExperimentTrackerSettings, self.get_settings(info)
|
147
|
+
)
|
149
148
|
wandb.init(
|
150
149
|
entity=self.config.entity,
|
151
150
|
project=self.config.project_name,
|
152
151
|
name=run_name,
|
153
152
|
tags=tags,
|
154
|
-
settings=settings,
|
153
|
+
settings=settings.settings,
|
155
154
|
)
|
155
|
+
|
156
|
+
if settings.enable_weave:
|
157
|
+
import weave
|
158
|
+
|
159
|
+
if self.config.project_name:
|
160
|
+
logger.info("Initializing weave")
|
161
|
+
weave.init(project_name=self.config.project_name)
|
162
|
+
else:
|
163
|
+
logger.info(
|
164
|
+
"Weave enabled but no project_name specified. "
|
165
|
+
"Skipping weave initialization."
|
166
|
+
)
|
167
|
+
else:
|
168
|
+
import weave
|
169
|
+
|
170
|
+
if self.config.project_name:
|
171
|
+
logger.info("Disabling weave")
|
172
|
+
weave.init(
|
173
|
+
project_name=self.config.project_name,
|
174
|
+
settings={"disabled": True},
|
175
|
+
)
|
@@ -46,11 +46,13 @@ class WandbExperimentTrackerSettings(BaseSettings):
|
|
46
46
|
run_name: The Wandb run name.
|
47
47
|
tags: Tags for the Wandb run.
|
48
48
|
settings: Settings for the Wandb run.
|
49
|
+
enable_weave: Whether to enable Weave integration.
|
49
50
|
"""
|
50
51
|
|
51
52
|
run_name: Optional[str] = None
|
52
53
|
tags: List[str] = []
|
53
54
|
settings: Dict[str, Any] = {}
|
55
|
+
enable_weave: bool = False
|
54
56
|
|
55
57
|
@field_validator("settings", mode="before")
|
56
58
|
@classmethod
|
zenml/model/model.py
CHANGED
@@ -548,7 +548,7 @@ class Model(BaseModel):
|
|
548
548
|
limitations=self.limitations,
|
549
549
|
trade_offs=self.trade_offs,
|
550
550
|
ethics=self.ethics,
|
551
|
-
|
551
|
+
project=zenml_client.active_project.id,
|
552
552
|
save_models_to_registry=self.save_models_to_registry,
|
553
553
|
)
|
554
554
|
model_request = ModelRequest.model_validate(model_request)
|
@@ -702,7 +702,7 @@ class Model(BaseModel):
|
|
702
702
|
|
703
703
|
client = Client()
|
704
704
|
model_version_request = ModelVersionRequest(
|
705
|
-
|
705
|
+
project=client.active_project.id,
|
706
706
|
name=str(self.version) if self.version else None,
|
707
707
|
description=self.description,
|
708
708
|
model=model.id,
|
@@ -68,7 +68,7 @@ class ModelRegistryModelMetadata(BaseModel):
|
|
68
68
|
zenml_pipeline_uuid: Optional[str] = None
|
69
69
|
zenml_pipeline_run_uuid: Optional[str] = None
|
70
70
|
zenml_step_name: Optional[str] = None
|
71
|
-
|
71
|
+
zenml_project: Optional[str] = None
|
72
72
|
|
73
73
|
@property
|
74
74
|
def custom_attributes(self) -> Dict[str, str]:
|
zenml/models/__init__.py
CHANGED
@@ -34,13 +34,13 @@ from zenml.models.v2.base.scoped import (
|
|
34
34
|
UserScopedResponse,
|
35
35
|
UserScopedResponseBody,
|
36
36
|
UserScopedResponseMetadata,
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
37
|
+
ProjectScopedRequest,
|
38
|
+
ProjectScopedFilter,
|
39
|
+
ProjectScopedResponse,
|
40
|
+
ProjectScopedResponseBody,
|
41
|
+
ProjectScopedResponseMetadata,
|
42
|
+
ProjectScopedResponseResources,
|
43
|
+
ProjectScopedFilter,
|
44
44
|
)
|
45
45
|
from zenml.models.v2.base.filter import (
|
46
46
|
BaseFilter,
|
@@ -285,7 +285,7 @@ from zenml.models.v2.core.stack import (
|
|
285
285
|
StackResponseMetadata,
|
286
286
|
)
|
287
287
|
from zenml.models.v2.misc.statistics import (
|
288
|
-
|
288
|
+
ProjectStatistics,
|
289
289
|
ServerStatistics,
|
290
290
|
)
|
291
291
|
from zenml.models.v2.core.step_run import (
|
@@ -318,13 +318,13 @@ from zenml.models.v2.core.user import (
|
|
318
318
|
UserResponseBody,
|
319
319
|
UserResponseMetadata,
|
320
320
|
)
|
321
|
-
from zenml.models.v2.core.
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
321
|
+
from zenml.models.v2.core.project import (
|
322
|
+
ProjectRequest,
|
323
|
+
ProjectUpdate,
|
324
|
+
ProjectFilter,
|
325
|
+
ProjectResponse,
|
326
|
+
ProjectResponseBody,
|
327
|
+
ProjectResponseMetadata,
|
328
328
|
)
|
329
329
|
|
330
330
|
# V2 Misc
|
@@ -503,13 +503,13 @@ __all__ = [
|
|
503
503
|
"UserScopedResponse",
|
504
504
|
"UserScopedResponseBody",
|
505
505
|
"UserScopedResponseMetadata",
|
506
|
-
"
|
507
|
-
"
|
508
|
-
"
|
509
|
-
"
|
510
|
-
"
|
511
|
-
"
|
512
|
-
"
|
506
|
+
"ProjectScopedRequest",
|
507
|
+
"ProjectScopedFilter",
|
508
|
+
"ProjectScopedResponse",
|
509
|
+
"ProjectScopedResponseBody",
|
510
|
+
"ProjectScopedResponseMetadata",
|
511
|
+
"ProjectScopedResponseResources",
|
512
|
+
"ProjectScopedFilter",
|
513
513
|
"BaseFilter",
|
514
514
|
"StrFilter",
|
515
515
|
"BoolFilter",
|
@@ -737,12 +737,12 @@ __all__ = [
|
|
737
737
|
"UserResponse",
|
738
738
|
"UserResponseBody",
|
739
739
|
"UserResponseMetadata",
|
740
|
-
"
|
741
|
-
"
|
742
|
-
"
|
743
|
-
"
|
744
|
-
"
|
745
|
-
"
|
740
|
+
"ProjectRequest",
|
741
|
+
"ProjectUpdate",
|
742
|
+
"ProjectFilter",
|
743
|
+
"ProjectResponse",
|
744
|
+
"ProjectResponseBody",
|
745
|
+
"ProjectResponseMetadata",
|
746
746
|
# V2 Misc
|
747
747
|
"AuthenticationMethodModel",
|
748
748
|
"DeployedStack",
|
@@ -776,5 +776,5 @@ __all__ = [
|
|
776
776
|
"ResourcesInfo",
|
777
777
|
"RunMetadataEntry",
|
778
778
|
"RunMetadataResource",
|
779
|
-
"
|
779
|
+
"ProjectStatistics",
|
780
780
|
]
|