zenml-nightly 0.54.1.dev20240118__py3-none-any.whl → 0.54.1.dev20240120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/__init__.py +10 -4
- zenml/artifacts/artifact_config.py +12 -12
- zenml/artifacts/external_artifact.py +3 -3
- zenml/artifacts/external_artifact_config.py +8 -8
- zenml/artifacts/utils.py +4 -4
- zenml/cli/__init__.py +4 -4
- zenml/cli/artifact.py +38 -18
- zenml/cli/base.py +3 -3
- zenml/cli/model.py +24 -16
- zenml/cli/server.py +9 -0
- zenml/cli/utils.py +3 -3
- zenml/client.py +13 -2
- zenml/config/compiler.py +1 -1
- zenml/config/pipeline_configurations.py +2 -2
- zenml/config/pipeline_run_configuration.py +2 -2
- zenml/config/step_configurations.py +2 -2
- zenml/integrations/__init__.py +2 -4
- zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +11 -9
- zenml/metadata/lazy_load.py +5 -5
- zenml/model/lazy_load.py +2 -2
- zenml/model/{model_version.py → model.py} +47 -38
- zenml/model/utils.py +33 -33
- zenml/model_registries/base_model_registry.py +10 -8
- zenml/models/__init__.py +2 -0
- zenml/models/v2/base/filter.py +3 -0
- zenml/models/v2/base/scoped.py +59 -0
- zenml/models/v2/core/artifact.py +2 -2
- zenml/models/v2/core/artifact_version.py +6 -6
- zenml/models/v2/core/model.py +6 -6
- zenml/models/v2/core/model_version.py +9 -9
- zenml/models/v2/core/run_metadata.py +2 -2
- zenml/new/pipelines/model_utils.py +20 -20
- zenml/new/pipelines/pipeline.py +47 -54
- zenml/new/pipelines/pipeline_context.py +1 -1
- zenml/new/pipelines/pipeline_decorator.py +4 -4
- zenml/new/steps/step_context.py +15 -15
- zenml/new/steps/step_decorator.py +5 -5
- zenml/orchestrators/input_utils.py +5 -7
- zenml/orchestrators/step_launcher.py +12 -19
- zenml/orchestrators/step_runner.py +8 -10
- zenml/pipelines/base_pipeline.py +1 -1
- zenml/pipelines/pipeline_decorator.py +6 -6
- zenml/steps/base_step.py +15 -15
- zenml/steps/step_decorator.py +6 -6
- zenml/steps/utils.py +68 -0
- zenml/zen_server/deploy/helm/templates/server-db-job.yaml +1 -1
- zenml/zen_server/deploy/helm/templates/server-secret.yaml +1 -1
- zenml/zen_server/deploy/helm/templates/serviceaccount.yaml +1 -1
- zenml/zen_server/utils.py +19 -1
- zenml/zen_stores/migrations/versions/4d688d8f7aff_rename_model_version_to_model.py +94 -0
- zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +16 -4
- zenml/zen_stores/rest_zen_store.py +2 -2
- zenml/zen_stores/sql_zen_store.py +4 -1
- {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/METADATA +1 -1
- {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/RECORD +58 -57
- {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/entry_points.txt +0 -0
@@ -19,7 +19,9 @@ from typing import Any, Dict, List, Optional, Tuple, cast
|
|
19
19
|
|
20
20
|
import mlflow
|
21
21
|
from mlflow import MlflowClient
|
22
|
-
from mlflow.entities.model_registry import
|
22
|
+
from mlflow.entities.model_registry import (
|
23
|
+
ModelVersion as MLflowModelVersion,
|
24
|
+
)
|
23
25
|
from mlflow.exceptions import MlflowException
|
24
26
|
from mlflow.pyfunc import load_model
|
25
27
|
|
@@ -36,9 +38,9 @@ from zenml.logger import get_logger
|
|
36
38
|
from zenml.model_registries.base_model_registry import (
|
37
39
|
BaseModelRegistry,
|
38
40
|
ModelRegistryModelMetadata,
|
39
|
-
ModelVersion,
|
40
41
|
ModelVersionStage,
|
41
42
|
RegisteredModel,
|
43
|
+
RegistryModelVersion,
|
42
44
|
)
|
43
45
|
from zenml.stack.stack import Stack
|
44
46
|
from zenml.stack.stack_validator import StackValidator
|
@@ -348,7 +350,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
348
350
|
description: Optional[str] = None,
|
349
351
|
metadata: Optional[ModelRegistryModelMetadata] = None,
|
350
352
|
**kwargs: Any,
|
351
|
-
) ->
|
353
|
+
) -> RegistryModelVersion:
|
352
354
|
"""Register a model version to the MLflow model registry.
|
353
355
|
|
354
356
|
Args:
|
@@ -442,7 +444,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
442
444
|
metadata: Optional[ModelRegistryModelMetadata] = None,
|
443
445
|
remove_metadata: Optional[List[str]] = None,
|
444
446
|
stage: Optional[ModelVersionStage] = None,
|
445
|
-
) ->
|
447
|
+
) -> RegistryModelVersion:
|
446
448
|
"""Update a model version in the MLflow model registry.
|
447
449
|
|
448
450
|
Args:
|
@@ -521,7 +523,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
521
523
|
self,
|
522
524
|
name: str,
|
523
525
|
version: str,
|
524
|
-
) ->
|
526
|
+
) -> RegistryModelVersion:
|
525
527
|
"""Get a model version from the MLflow model registry.
|
526
528
|
|
527
529
|
Args:
|
@@ -561,7 +563,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
561
563
|
created_before: Optional[datetime] = None,
|
562
564
|
order_by_date: Optional[str] = None,
|
563
565
|
**kwargs: Any,
|
564
|
-
) -> List[
|
566
|
+
) -> List[RegistryModelVersion]:
|
565
567
|
"""List model versions from the MLflow model registry.
|
566
568
|
|
567
569
|
Args:
|
@@ -704,7 +706,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
704
706
|
|
705
707
|
def get_model_uri_artifact_store(
|
706
708
|
self,
|
707
|
-
model_version:
|
709
|
+
model_version: RegistryModelVersion,
|
708
710
|
) -> str:
|
709
711
|
"""Get the model URI artifact store.
|
710
712
|
|
@@ -723,7 +725,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
723
725
|
def _cast_mlflow_version_to_model_version(
|
724
726
|
self,
|
725
727
|
mlflow_model_version: MLflowModelVersion,
|
726
|
-
) ->
|
728
|
+
) -> RegistryModelVersion:
|
727
729
|
"""Cast an MLflow model version to a model version.
|
728
730
|
|
729
731
|
Args:
|
@@ -748,7 +750,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
748
750
|
)
|
749
751
|
except ImportError:
|
750
752
|
model_library = None
|
751
|
-
return
|
753
|
+
return RegistryModelVersion(
|
752
754
|
registered_model=RegisteredModel(name=mlflow_model_version.name),
|
753
755
|
model_format=MLFLOW_MODEL_FORMAT,
|
754
756
|
model_library=model_library,
|
zenml/metadata/lazy_load.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
from typing import TYPE_CHECKING, Optional
|
17
17
|
|
18
18
|
if TYPE_CHECKING:
|
19
|
-
from zenml.model.
|
19
|
+
from zenml.model.model import Model
|
20
20
|
from zenml.models import RunMetadataResponse
|
21
21
|
|
22
22
|
|
@@ -30,18 +30,18 @@ class RunMetadataLazyGetter:
|
|
30
30
|
|
31
31
|
def __init__(
|
32
32
|
self,
|
33
|
-
|
33
|
+
_lazy_load_model: "Model",
|
34
34
|
_lazy_load_artifact_name: Optional[str],
|
35
35
|
_lazy_load_artifact_version: Optional[str],
|
36
36
|
):
|
37
37
|
"""Initialize a RunMetadataLazyGetter.
|
38
38
|
|
39
39
|
Args:
|
40
|
-
|
40
|
+
_lazy_load_model: The model version.
|
41
41
|
_lazy_load_artifact_name: The artifact name.
|
42
42
|
_lazy_load_artifact_version: The artifact version.
|
43
43
|
"""
|
44
|
-
self.
|
44
|
+
self._lazy_load_model = _lazy_load_model
|
45
45
|
self._lazy_load_artifact_name = _lazy_load_artifact_name
|
46
46
|
self._lazy_load_artifact_version = _lazy_load_artifact_version
|
47
47
|
|
@@ -57,7 +57,7 @@ class RunMetadataLazyGetter:
|
|
57
57
|
from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse
|
58
58
|
|
59
59
|
return LazyRunMetadataResponse(
|
60
|
-
|
60
|
+
_lazy_load_model=self._lazy_load_model,
|
61
61
|
_lazy_load_artifact_name=self._lazy_load_artifact_name,
|
62
62
|
_lazy_load_artifact_version=self._lazy_load_artifact_version,
|
63
63
|
_lazy_load_metadata_name=key,
|
zenml/model/lazy_load.py
CHANGED
@@ -17,7 +17,7 @@ from typing import Optional
|
|
17
17
|
|
18
18
|
from pydantic import BaseModel
|
19
19
|
|
20
|
-
from zenml.model.
|
20
|
+
from zenml.model.model import Model
|
21
21
|
|
22
22
|
|
23
23
|
class ModelVersionDataLazyLoader(BaseModel):
|
@@ -28,7 +28,7 @@ class ModelVersionDataLazyLoader(BaseModel):
|
|
28
28
|
model version during runtime time of the step.
|
29
29
|
"""
|
30
30
|
|
31
|
-
|
31
|
+
model: Model
|
32
32
|
artifact_name: Optional[str] = None
|
33
33
|
artifact_version: Optional[str] = None
|
34
34
|
metadata_name: Optional[str] = None
|
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
12
|
# or implied. See the License for the specific language governing
|
13
13
|
# permissions and limitations under the License.
|
14
|
-
"""
|
14
|
+
"""Model user facing interface to pass into pipeline or step."""
|
15
15
|
|
16
16
|
from typing import (
|
17
17
|
TYPE_CHECKING,
|
@@ -42,8 +42,8 @@ if TYPE_CHECKING:
|
|
42
42
|
logger = get_logger(__name__)
|
43
43
|
|
44
44
|
|
45
|
-
class
|
46
|
-
"""
|
45
|
+
class Model(BaseModel):
|
46
|
+
"""Model class to pass into pipeline or step to set it into a model context.
|
47
47
|
|
48
48
|
name: The name of the model.
|
49
49
|
license: The license under which the model is created.
|
@@ -54,8 +54,8 @@ class ModelVersion(BaseModel):
|
|
54
54
|
trade_offs: The tradeoffs of the model.
|
55
55
|
ethics: The ethical implications of the model.
|
56
56
|
tags: Tags associated with the model.
|
57
|
-
version: The
|
58
|
-
to a specific version/stage. If skipped new
|
57
|
+
version: The version name, version number or stage is optional and points model context
|
58
|
+
to a specific version/stage. If skipped new version will be created.
|
59
59
|
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
|
60
60
|
if available in active stack.
|
61
61
|
"""
|
@@ -97,7 +97,7 @@ class ModelVersion(BaseModel):
|
|
97
97
|
self._get_or_create_model_version()
|
98
98
|
except RuntimeError:
|
99
99
|
logger.info(
|
100
|
-
f"
|
100
|
+
f"Version `{self.version}` of `{self.name}` model doesn't exist "
|
101
101
|
"and cannot be fetched from the Model Control Plane."
|
102
102
|
)
|
103
103
|
return self._id
|
@@ -128,7 +128,7 @@ class ModelVersion(BaseModel):
|
|
128
128
|
self._get_or_create_model_version()
|
129
129
|
except RuntimeError:
|
130
130
|
logger.info(
|
131
|
-
f"
|
131
|
+
f"Version `{self.version}` of `{self.name}` model doesn't exist "
|
132
132
|
"and cannot be fetched from the Model Control Plane."
|
133
133
|
)
|
134
134
|
return self._number
|
@@ -149,7 +149,7 @@ class ModelVersion(BaseModel):
|
|
149
149
|
return ModelStages(stage)
|
150
150
|
except RuntimeError:
|
151
151
|
logger.info(
|
152
|
-
f"
|
152
|
+
f"Version `{self.version}` of `{self.name}` model doesn't exist "
|
153
153
|
"and cannot be fetched from the Model Control Plane."
|
154
154
|
)
|
155
155
|
return None
|
@@ -283,7 +283,7 @@ class ModelVersion(BaseModel):
|
|
283
283
|
def set_stage(
|
284
284
|
self, stage: Union[str, ModelStages], force: bool = False
|
285
285
|
) -> None:
|
286
|
-
"""Sets this Model
|
286
|
+
"""Sets this Model to a desired stage.
|
287
287
|
|
288
288
|
Args:
|
289
289
|
stage: the target stage for model version.
|
@@ -431,7 +431,7 @@ class ModelVersion(BaseModel):
|
|
431
431
|
return LazyArtifactVersionResponse(
|
432
432
|
_lazy_load_name=name,
|
433
433
|
_lazy_load_version=version,
|
434
|
-
|
434
|
+
_lazy_load_model=Model(
|
435
435
|
name=self.name, version=self.version or self.number
|
436
436
|
),
|
437
437
|
)
|
@@ -441,7 +441,7 @@ class ModelVersion(BaseModel):
|
|
441
441
|
return None
|
442
442
|
|
443
443
|
def __eq__(self, other: object) -> bool:
|
444
|
-
"""Check two
|
444
|
+
"""Check two Models for equality.
|
445
445
|
|
446
446
|
Args:
|
447
447
|
other: object to compare with
|
@@ -449,7 +449,7 @@ class ModelVersion(BaseModel):
|
|
449
449
|
Returns:
|
450
450
|
True, if equal, False otherwise.
|
451
451
|
"""
|
452
|
-
if not isinstance(other,
|
452
|
+
if not isinstance(other, Model):
|
453
453
|
return NotImplemented
|
454
454
|
if self.name != other.name:
|
455
455
|
return False
|
@@ -509,7 +509,7 @@ class ModelVersion(BaseModel):
|
|
509
509
|
model_name_or_id=self.name
|
510
510
|
)
|
511
511
|
|
512
|
-
difference = {}
|
512
|
+
difference: Dict[str, Any] = {}
|
513
513
|
for key in (
|
514
514
|
"license",
|
515
515
|
"audience",
|
@@ -519,12 +519,20 @@ class ModelVersion(BaseModel):
|
|
519
519
|
"ethics",
|
520
520
|
"save_models_to_registry",
|
521
521
|
):
|
522
|
-
if getattr(self, key
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
522
|
+
if self_attr := getattr(self, key, None):
|
523
|
+
if self_attr != getattr(model, key):
|
524
|
+
difference[key] = {
|
525
|
+
"config": getattr(self, key),
|
526
|
+
"db": getattr(model, key),
|
527
|
+
}
|
528
|
+
if self.tags:
|
529
|
+
configured_tags = set(self.tags)
|
530
|
+
db_tags = {t.name for t in model.tags}
|
531
|
+
if db_tags != configured_tags:
|
532
|
+
difference["tags added"] = list(configured_tags - db_tags)
|
533
|
+
difference["tags removed"] = list(
|
534
|
+
db_tags - configured_tags
|
535
|
+
)
|
528
536
|
if difference:
|
529
537
|
logger.warning(
|
530
538
|
"Provided model configuration does not match "
|
@@ -581,16 +589,17 @@ class ModelVersion(BaseModel):
|
|
581
589
|
self._id = mv.id
|
582
590
|
|
583
591
|
difference: Dict[str, Any] = {}
|
584
|
-
if mv.description != self.description:
|
592
|
+
if self.description and mv.description != self.description:
|
585
593
|
difference["description"] = {
|
586
594
|
"config": self.description,
|
587
595
|
"db": mv.description,
|
588
596
|
}
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
597
|
+
if self.tags:
|
598
|
+
configured_tags = set(self.tags)
|
599
|
+
db_tags = {t.name for t in mv.tags}
|
600
|
+
if db_tags != configured_tags:
|
601
|
+
difference["tags added"] = list(configured_tags - db_tags)
|
602
|
+
difference["tags removed"] = list(db_tags - configured_tags)
|
594
603
|
if difference:
|
595
604
|
logger.warning(
|
596
605
|
"Provided model version configuration does not match existing model "
|
@@ -654,7 +663,7 @@ class ModelVersion(BaseModel):
|
|
654
663
|
# model version for current model was already
|
655
664
|
# created in the current run, not to create
|
656
665
|
# new model versions
|
657
|
-
pipeline_mv = context.pipeline_run.config.
|
666
|
+
pipeline_mv = context.pipeline_run.config.model
|
658
667
|
if (
|
659
668
|
pipeline_mv
|
660
669
|
and pipeline_mv.was_created_in_this_run
|
@@ -664,7 +673,7 @@ class ModelVersion(BaseModel):
|
|
664
673
|
self.version = pipeline_mv.version
|
665
674
|
else:
|
666
675
|
for step in context.pipeline_run.steps.values():
|
667
|
-
step_mv = step.config.
|
676
|
+
step_mv = step.config.model
|
668
677
|
if (
|
669
678
|
step_mv
|
670
679
|
and step_mv.was_created_in_this_run
|
@@ -712,21 +721,21 @@ class ModelVersion(BaseModel):
|
|
712
721
|
self._number = model_version.number
|
713
722
|
return model_version
|
714
723
|
|
715
|
-
def _merge(self,
|
716
|
-
self.license = self.license or
|
717
|
-
self.description = self.description or
|
718
|
-
self.audience = self.audience or
|
719
|
-
self.use_cases = self.use_cases or
|
720
|
-
self.limitations = self.limitations or
|
721
|
-
self.trade_offs = self.trade_offs or
|
722
|
-
self.ethics = self.ethics or
|
723
|
-
if
|
724
|
+
def _merge(self, model: "Model") -> None:
|
725
|
+
self.license = self.license or model.license
|
726
|
+
self.description = self.description or model.description
|
727
|
+
self.audience = self.audience or model.audience
|
728
|
+
self.use_cases = self.use_cases or model.use_cases
|
729
|
+
self.limitations = self.limitations or model.limitations
|
730
|
+
self.trade_offs = self.trade_offs or model.trade_offs
|
731
|
+
self.ethics = self.ethics or model.ethics
|
732
|
+
if model.tags is not None:
|
724
733
|
self.tags = list(
|
725
|
-
{t for t in self.tags or []}.union(set(
|
734
|
+
{t for t in self.tags or []}.union(set(model.tags))
|
726
735
|
)
|
727
736
|
|
728
737
|
def __hash__(self) -> int:
|
729
|
-
"""Get hash of the `
|
738
|
+
"""Get hash of the `Model`.
|
730
739
|
|
731
740
|
Returns:
|
732
741
|
Hash function results
|
zenml/model/utils.py
CHANGED
@@ -22,7 +22,7 @@ from zenml.enums import ModelStages
|
|
22
22
|
from zenml.exceptions import StepContextError
|
23
23
|
from zenml.logger import get_logger
|
24
24
|
from zenml.metadata.metadata_types import MetadataType
|
25
|
-
from zenml.model.
|
25
|
+
from zenml.model.model import Model
|
26
26
|
from zenml.models import ModelVersionArtifactRequest
|
27
27
|
from zenml.new.steps.step_context import get_step_context
|
28
28
|
|
@@ -48,9 +48,9 @@ def link_step_artifacts_to_model(
|
|
48
48
|
"step."
|
49
49
|
)
|
50
50
|
try:
|
51
|
-
|
51
|
+
model = step_context.model
|
52
52
|
except StepContextError:
|
53
|
-
|
53
|
+
model = None
|
54
54
|
logger.debug("No model context found, unable to auto-link artifacts.")
|
55
55
|
|
56
56
|
for artifact_name, artifact_version_id in artifact_version_ids.items():
|
@@ -59,47 +59,47 @@ def link_step_artifacts_to_model(
|
|
59
59
|
).artifact_config
|
60
60
|
|
61
61
|
# Implicit linking
|
62
|
-
if artifact_config is None and
|
62
|
+
if artifact_config is None and model is not None:
|
63
63
|
artifact_config = ArtifactConfig(name=artifact_name)
|
64
64
|
logger.info(
|
65
65
|
f"Implicitly linking artifact `{artifact_name}` to model "
|
66
|
-
f"`{
|
66
|
+
f"`{model.name}` version `{model.version}`."
|
67
67
|
)
|
68
68
|
|
69
69
|
if artifact_config:
|
70
|
-
|
70
|
+
link_artifact_config_to_model(
|
71
71
|
artifact_config=artifact_config,
|
72
72
|
artifact_version_id=artifact_version_id,
|
73
|
-
|
73
|
+
model=model,
|
74
74
|
)
|
75
75
|
|
76
76
|
|
77
|
-
def
|
77
|
+
def link_artifact_config_to_model(
|
78
78
|
artifact_config: ArtifactConfig,
|
79
79
|
artifact_version_id: UUID,
|
80
|
-
|
80
|
+
model: Optional["Model"] = None,
|
81
81
|
) -> None:
|
82
82
|
"""Link an artifact config to its model version.
|
83
83
|
|
84
84
|
Args:
|
85
85
|
artifact_config: The artifact config to link.
|
86
86
|
artifact_version_id: The ID of the artifact to link.
|
87
|
-
|
87
|
+
model: The model version from the step or pipeline context.
|
88
88
|
"""
|
89
89
|
client = Client()
|
90
90
|
|
91
91
|
# If the artifact config specifies a model itself then always use that
|
92
92
|
if artifact_config.model_name is not None:
|
93
|
-
from zenml.model.
|
93
|
+
from zenml.model.model import Model
|
94
94
|
|
95
|
-
|
95
|
+
model = Model(
|
96
96
|
name=artifact_config.model_name,
|
97
97
|
version=artifact_config.model_version,
|
98
98
|
)
|
99
99
|
|
100
|
-
if
|
101
|
-
|
102
|
-
model_version_response =
|
100
|
+
if model:
|
101
|
+
model._get_or_create_model_version()
|
102
|
+
model_version_response = model._get_model_version()
|
103
103
|
request = ModelVersionArtifactRequest(
|
104
104
|
user=client.active_user.id,
|
105
105
|
workspace=client.active_workspace.id,
|
@@ -125,10 +125,10 @@ def log_model_version_metadata(
|
|
125
125
|
metadata: The metadata to log.
|
126
126
|
model_name: The name of the model to log metadata for. Can
|
127
127
|
be omitted when being called inside a step with configured
|
128
|
-
`
|
128
|
+
`model` in decorator.
|
129
129
|
model_version: The version of the model to log metadata for. Can
|
130
130
|
be omitted when being called inside a step with configured
|
131
|
-
`
|
131
|
+
`model` in decorator.
|
132
132
|
"""
|
133
133
|
logger.warning(
|
134
134
|
"`log_model_version_metadata` is deprecated. Please use "
|
@@ -152,38 +152,38 @@ def log_model_metadata(
|
|
152
152
|
metadata: The metadata to log.
|
153
153
|
model_name: The name of the model to log metadata for. Can
|
154
154
|
be omitted when being called inside a step with configured
|
155
|
-
`
|
155
|
+
`model` in decorator.
|
156
156
|
model_version: The version of the model to log metadata for. Can
|
157
157
|
be omitted when being called inside a step with configured
|
158
|
-
`
|
158
|
+
`model` in decorator.
|
159
159
|
|
160
160
|
Raises:
|
161
161
|
ValueError: If no model name/version is provided and the function is not
|
162
|
-
called inside a step with configured `
|
162
|
+
called inside a step with configured `model` in decorator.
|
163
163
|
"""
|
164
164
|
mv = None
|
165
165
|
try:
|
166
166
|
step_context = get_step_context()
|
167
|
-
mv = step_context.
|
167
|
+
mv = step_context.model
|
168
168
|
except RuntimeError:
|
169
169
|
step_context = None
|
170
170
|
|
171
171
|
if not step_context and not (model_name and model_version):
|
172
172
|
raise ValueError(
|
173
173
|
"Model name and version must be provided unless the function is "
|
174
|
-
"called inside a step with configured `
|
174
|
+
"called inside a step with configured `model` in decorator."
|
175
175
|
)
|
176
176
|
if mv is None:
|
177
|
-
from zenml import
|
177
|
+
from zenml import Model
|
178
178
|
|
179
|
-
mv =
|
179
|
+
mv = Model(name=model_name, version=model_version)
|
180
180
|
|
181
181
|
mv.log_metadata(metadata)
|
182
182
|
|
183
183
|
|
184
184
|
def link_artifact_to_model(
|
185
185
|
artifact_version_id: UUID,
|
186
|
-
|
186
|
+
model: Optional["Model"] = None,
|
187
187
|
is_model_artifact: bool = False,
|
188
188
|
is_deployment_artifact: bool = False,
|
189
189
|
) -> None:
|
@@ -191,34 +191,34 @@ def link_artifact_to_model(
|
|
191
191
|
|
192
192
|
Args:
|
193
193
|
artifact_version_id: The ID of the artifact version.
|
194
|
-
|
194
|
+
model: The model to link to.
|
195
195
|
is_model_artifact: Whether the artifact is a model artifact.
|
196
196
|
is_deployment_artifact: Whether the artifact is a deployment artifact.
|
197
197
|
|
198
198
|
Raises:
|
199
199
|
RuntimeError: If called outside of a step.
|
200
200
|
"""
|
201
|
-
if not
|
201
|
+
if not model:
|
202
202
|
is_issue = False
|
203
203
|
try:
|
204
204
|
step_context = get_step_context()
|
205
|
-
|
205
|
+
model = step_context.model
|
206
206
|
except StepContextError:
|
207
207
|
is_issue = True
|
208
208
|
|
209
|
-
if
|
209
|
+
if model is None or is_issue:
|
210
210
|
raise RuntimeError(
|
211
|
-
"`link_artifact_to_model` called without `
|
211
|
+
"`link_artifact_to_model` called without `model` parameter "
|
212
212
|
"and configured model context cannot be identified. Consider "
|
213
|
-
"passing the `
|
213
|
+
"passing the `model` explicitly or configuring it in "
|
214
214
|
"@step or @pipeline decorator."
|
215
215
|
)
|
216
216
|
|
217
|
-
|
217
|
+
link_artifact_config_to_model(
|
218
218
|
artifact_config=ArtifactConfig(
|
219
219
|
is_model_artifact=is_model_artifact,
|
220
220
|
is_deployment_artifact=is_deployment_artifact,
|
221
221
|
),
|
222
222
|
artifact_version_id=artifact_version_id,
|
223
|
-
|
223
|
+
model=model,
|
224
224
|
)
|
@@ -123,10 +123,10 @@ class ModelRegistryModelMetadata(BaseModel):
|
|
123
123
|
extra = "allow"
|
124
124
|
|
125
125
|
|
126
|
-
class
|
126
|
+
class RegistryModelVersion(BaseModel):
|
127
127
|
"""Base class for all ZenML model versions.
|
128
128
|
|
129
|
-
The `
|
129
|
+
The `RegistryModelVersion` class represents a version or snapshot of a registered
|
130
130
|
model, including information such as the associated `ModelBundle`, version
|
131
131
|
number, creation time, pipeline run information, and metadata. It serves as
|
132
132
|
a blueprint for creating concrete model version implementations in a registry,
|
@@ -288,7 +288,7 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
288
288
|
description: Optional[str] = None,
|
289
289
|
metadata: Optional[ModelRegistryModelMetadata] = None,
|
290
290
|
**kwargs: Any,
|
291
|
-
) ->
|
291
|
+
) -> RegistryModelVersion:
|
292
292
|
"""Registers a model version in the model registry.
|
293
293
|
|
294
294
|
Args:
|
@@ -333,7 +333,7 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
333
333
|
metadata: Optional[ModelRegistryModelMetadata] = None,
|
334
334
|
remove_metadata: Optional[List[str]] = None,
|
335
335
|
stage: Optional[ModelVersionStage] = None,
|
336
|
-
) ->
|
336
|
+
) -> RegistryModelVersion:
|
337
337
|
"""Updates a model version in the model registry.
|
338
338
|
|
339
339
|
Args:
|
@@ -364,7 +364,7 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
364
364
|
created_before: Optional[datetime] = None,
|
365
365
|
order_by_date: Optional[str] = None,
|
366
366
|
**kwargs: Any,
|
367
|
-
) -> Optional[List[
|
367
|
+
) -> Optional[List[RegistryModelVersion]]:
|
368
368
|
"""Lists all model versions for a registered model.
|
369
369
|
|
370
370
|
Args:
|
@@ -387,7 +387,7 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
387
387
|
self,
|
388
388
|
name: str,
|
389
389
|
stage: Optional[ModelVersionStage] = None,
|
390
|
-
) -> Optional[
|
390
|
+
) -> Optional[RegistryModelVersion]:
|
391
391
|
"""Gets the latest model version for a registered model.
|
392
392
|
|
393
393
|
This method is used to get the latest model version for a registered
|
@@ -410,7 +410,9 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
410
410
|
return None
|
411
411
|
|
412
412
|
@abstractmethod
|
413
|
-
def get_model_version(
|
413
|
+
def get_model_version(
|
414
|
+
self, name: str, version: str
|
415
|
+
) -> RegistryModelVersion:
|
414
416
|
"""Gets a model version for a registered model.
|
415
417
|
|
416
418
|
Args:
|
@@ -450,7 +452,7 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
450
452
|
@abstractmethod
|
451
453
|
def get_model_uri_artifact_store(
|
452
454
|
self,
|
453
|
-
model_version:
|
455
|
+
model_version: RegistryModelVersion,
|
454
456
|
) -> str:
|
455
457
|
"""Gets the URI artifact store for a model version.
|
456
458
|
|
zenml/models/__init__.py
CHANGED
@@ -34,6 +34,7 @@ from zenml.models.v2.base.scoped import (
|
|
34
34
|
WorkspaceScopedResponse,
|
35
35
|
WorkspaceScopedResponseBody,
|
36
36
|
WorkspaceScopedResponseMetadata,
|
37
|
+
WorkspaceScopedTaggableFilter
|
37
38
|
)
|
38
39
|
from zenml.models.v2.base.filter import (
|
39
40
|
BaseFilter,
|
@@ -478,6 +479,7 @@ __all__ = [
|
|
478
479
|
"WorkspaceScopedResponse",
|
479
480
|
"WorkspaceScopedResponseBody",
|
480
481
|
"WorkspaceScopedResponseMetadata",
|
482
|
+
"WorkspaceScopedTaggableFilter",
|
481
483
|
"BaseFilter",
|
482
484
|
"StrFilter",
|
483
485
|
"BoolFilter",
|
zenml/models/v2/base/filter.py
CHANGED
@@ -264,6 +264,9 @@ class BaseFilter(BaseModel):
|
|
264
264
|
# List of fields that are not even mentioned as options in the CLI.
|
265
265
|
CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = []
|
266
266
|
|
267
|
+
# List of fields that are wrapped with `fastapi.Query(default)` in API.
|
268
|
+
API_MULTI_INPUT_PARAMS: ClassVar[List[str]] = []
|
269
|
+
|
267
270
|
sort_by: str = Field(
|
268
271
|
default="created", description="Which column to sort by."
|
269
272
|
)
|