zenml-nightly 0.54.1.dev20240118__py3-none-any.whl → 0.54.1.dev20240120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. zenml/__init__.py +10 -4
  2. zenml/artifacts/artifact_config.py +12 -12
  3. zenml/artifacts/external_artifact.py +3 -3
  4. zenml/artifacts/external_artifact_config.py +8 -8
  5. zenml/artifacts/utils.py +4 -4
  6. zenml/cli/__init__.py +4 -4
  7. zenml/cli/artifact.py +38 -18
  8. zenml/cli/base.py +3 -3
  9. zenml/cli/model.py +24 -16
  10. zenml/cli/server.py +9 -0
  11. zenml/cli/utils.py +3 -3
  12. zenml/client.py +13 -2
  13. zenml/config/compiler.py +1 -1
  14. zenml/config/pipeline_configurations.py +2 -2
  15. zenml/config/pipeline_run_configuration.py +2 -2
  16. zenml/config/step_configurations.py +2 -2
  17. zenml/integrations/__init__.py +2 -4
  18. zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +11 -9
  19. zenml/metadata/lazy_load.py +5 -5
  20. zenml/model/lazy_load.py +2 -2
  21. zenml/model/{model_version.py → model.py} +47 -38
  22. zenml/model/utils.py +33 -33
  23. zenml/model_registries/base_model_registry.py +10 -8
  24. zenml/models/__init__.py +2 -0
  25. zenml/models/v2/base/filter.py +3 -0
  26. zenml/models/v2/base/scoped.py +59 -0
  27. zenml/models/v2/core/artifact.py +2 -2
  28. zenml/models/v2/core/artifact_version.py +6 -6
  29. zenml/models/v2/core/model.py +6 -6
  30. zenml/models/v2/core/model_version.py +9 -9
  31. zenml/models/v2/core/run_metadata.py +2 -2
  32. zenml/new/pipelines/model_utils.py +20 -20
  33. zenml/new/pipelines/pipeline.py +47 -54
  34. zenml/new/pipelines/pipeline_context.py +1 -1
  35. zenml/new/pipelines/pipeline_decorator.py +4 -4
  36. zenml/new/steps/step_context.py +15 -15
  37. zenml/new/steps/step_decorator.py +5 -5
  38. zenml/orchestrators/input_utils.py +5 -7
  39. zenml/orchestrators/step_launcher.py +12 -19
  40. zenml/orchestrators/step_runner.py +8 -10
  41. zenml/pipelines/base_pipeline.py +1 -1
  42. zenml/pipelines/pipeline_decorator.py +6 -6
  43. zenml/steps/base_step.py +15 -15
  44. zenml/steps/step_decorator.py +6 -6
  45. zenml/steps/utils.py +68 -0
  46. zenml/zen_server/deploy/helm/templates/server-db-job.yaml +1 -1
  47. zenml/zen_server/deploy/helm/templates/server-secret.yaml +1 -1
  48. zenml/zen_server/deploy/helm/templates/serviceaccount.yaml +1 -1
  49. zenml/zen_server/utils.py +19 -1
  50. zenml/zen_stores/migrations/versions/4d688d8f7aff_rename_model_version_to_model.py +94 -0
  51. zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +16 -4
  52. zenml/zen_stores/rest_zen_store.py +2 -2
  53. zenml/zen_stores/sql_zen_store.py +4 -1
  54. {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/METADATA +1 -1
  55. {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/RECORD +58 -57
  56. {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/LICENSE +0 -0
  57. {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/WHEEL +0 -0
  58. {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/entry_points.txt +0 -0
@@ -19,7 +19,9 @@ from typing import Any, Dict, List, Optional, Tuple, cast
19
19
 
20
20
  import mlflow
21
21
  from mlflow import MlflowClient
22
- from mlflow.entities.model_registry import ModelVersion as MLflowModelVersion
22
+ from mlflow.entities.model_registry import (
23
+ ModelVersion as MLflowModelVersion,
24
+ )
23
25
  from mlflow.exceptions import MlflowException
24
26
  from mlflow.pyfunc import load_model
25
27
 
@@ -36,9 +38,9 @@ from zenml.logger import get_logger
36
38
  from zenml.model_registries.base_model_registry import (
37
39
  BaseModelRegistry,
38
40
  ModelRegistryModelMetadata,
39
- ModelVersion,
40
41
  ModelVersionStage,
41
42
  RegisteredModel,
43
+ RegistryModelVersion,
42
44
  )
43
45
  from zenml.stack.stack import Stack
44
46
  from zenml.stack.stack_validator import StackValidator
@@ -348,7 +350,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
348
350
  description: Optional[str] = None,
349
351
  metadata: Optional[ModelRegistryModelMetadata] = None,
350
352
  **kwargs: Any,
351
- ) -> ModelVersion:
353
+ ) -> RegistryModelVersion:
352
354
  """Register a model version to the MLflow model registry.
353
355
 
354
356
  Args:
@@ -442,7 +444,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
442
444
  metadata: Optional[ModelRegistryModelMetadata] = None,
443
445
  remove_metadata: Optional[List[str]] = None,
444
446
  stage: Optional[ModelVersionStage] = None,
445
- ) -> ModelVersion:
447
+ ) -> RegistryModelVersion:
446
448
  """Update a model version in the MLflow model registry.
447
449
 
448
450
  Args:
@@ -521,7 +523,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
521
523
  self,
522
524
  name: str,
523
525
  version: str,
524
- ) -> ModelVersion:
526
+ ) -> RegistryModelVersion:
525
527
  """Get a model version from the MLflow model registry.
526
528
 
527
529
  Args:
@@ -561,7 +563,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
561
563
  created_before: Optional[datetime] = None,
562
564
  order_by_date: Optional[str] = None,
563
565
  **kwargs: Any,
564
- ) -> List[ModelVersion]:
566
+ ) -> List[RegistryModelVersion]:
565
567
  """List model versions from the MLflow model registry.
566
568
 
567
569
  Args:
@@ -704,7 +706,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
704
706
 
705
707
  def get_model_uri_artifact_store(
706
708
  self,
707
- model_version: ModelVersion,
709
+ model_version: RegistryModelVersion,
708
710
  ) -> str:
709
711
  """Get the model URI artifact store.
710
712
 
@@ -723,7 +725,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
723
725
  def _cast_mlflow_version_to_model_version(
724
726
  self,
725
727
  mlflow_model_version: MLflowModelVersion,
726
- ) -> ModelVersion:
728
+ ) -> RegistryModelVersion:
727
729
  """Cast an MLflow model version to a model version.
728
730
 
729
731
  Args:
@@ -748,7 +750,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
748
750
  )
749
751
  except ImportError:
750
752
  model_library = None
751
- return ModelVersion(
753
+ return RegistryModelVersion(
752
754
  registered_model=RegisteredModel(name=mlflow_model_version.name),
753
755
  model_format=MLFLOW_MODEL_FORMAT,
754
756
  model_library=model_library,
@@ -16,7 +16,7 @@
16
16
  from typing import TYPE_CHECKING, Optional
17
17
 
18
18
  if TYPE_CHECKING:
19
- from zenml.model.model_version import ModelVersion
19
+ from zenml.model.model import Model
20
20
  from zenml.models import RunMetadataResponse
21
21
 
22
22
 
@@ -30,18 +30,18 @@ class RunMetadataLazyGetter:
30
30
 
31
31
  def __init__(
32
32
  self,
33
- _lazy_load_model_version: "ModelVersion",
33
+ _lazy_load_model: "Model",
34
34
  _lazy_load_artifact_name: Optional[str],
35
35
  _lazy_load_artifact_version: Optional[str],
36
36
  ):
37
37
  """Initialize a RunMetadataLazyGetter.
38
38
 
39
39
  Args:
40
- _lazy_load_model_version: The model version.
40
+ _lazy_load_model: The model version.
41
41
  _lazy_load_artifact_name: The artifact name.
42
42
  _lazy_load_artifact_version: The artifact version.
43
43
  """
44
- self._lazy_load_model_version = _lazy_load_model_version
44
+ self._lazy_load_model = _lazy_load_model
45
45
  self._lazy_load_artifact_name = _lazy_load_artifact_name
46
46
  self._lazy_load_artifact_version = _lazy_load_artifact_version
47
47
 
@@ -57,7 +57,7 @@ class RunMetadataLazyGetter:
57
57
  from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse
58
58
 
59
59
  return LazyRunMetadataResponse(
60
- _lazy_load_model_version=self._lazy_load_model_version,
60
+ _lazy_load_model=self._lazy_load_model,
61
61
  _lazy_load_artifact_name=self._lazy_load_artifact_name,
62
62
  _lazy_load_artifact_version=self._lazy_load_artifact_version,
63
63
  _lazy_load_metadata_name=key,
zenml/model/lazy_load.py CHANGED
@@ -17,7 +17,7 @@ from typing import Optional
17
17
 
18
18
  from pydantic import BaseModel
19
19
 
20
- from zenml.model.model_version import ModelVersion
20
+ from zenml.model.model import Model
21
21
 
22
22
 
23
23
  class ModelVersionDataLazyLoader(BaseModel):
@@ -28,7 +28,7 @@ class ModelVersionDataLazyLoader(BaseModel):
28
28
  model version during runtime time of the step.
29
29
  """
30
30
 
31
- model_version: ModelVersion
31
+ model: Model
32
32
  artifact_name: Optional[str] = None
33
33
  artifact_version: Optional[str] = None
34
34
  metadata_name: Optional[str] = None
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
12
  # or implied. See the License for the specific language governing
13
13
  # permissions and limitations under the License.
14
- """ModelVersion user facing interface to pass into pipeline or step."""
14
+ """Model user facing interface to pass into pipeline or step."""
15
15
 
16
16
  from typing import (
17
17
  TYPE_CHECKING,
@@ -42,8 +42,8 @@ if TYPE_CHECKING:
42
42
  logger = get_logger(__name__)
43
43
 
44
44
 
45
- class ModelVersion(BaseModel):
46
- """ModelVersion class to pass into pipeline or step to set it into a model context.
45
+ class Model(BaseModel):
46
+ """Model class to pass into pipeline or step to set it into a model context.
47
47
 
48
48
  name: The name of the model.
49
49
  license: The license under which the model is created.
@@ -54,8 +54,8 @@ class ModelVersion(BaseModel):
54
54
  trade_offs: The tradeoffs of the model.
55
55
  ethics: The ethical implications of the model.
56
56
  tags: Tags associated with the model.
57
- version: The model version name, number or stage is optional and points model context
58
- to a specific version/stage. If skipped new model version will be created.
57
+ version: The version name, version number or stage is optional and points model context
58
+ to a specific version/stage. If skipped new version will be created.
59
59
  save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
60
60
  if available in active stack.
61
61
  """
@@ -97,7 +97,7 @@ class ModelVersion(BaseModel):
97
97
  self._get_or_create_model_version()
98
98
  except RuntimeError:
99
99
  logger.info(
100
- f"Model version `{self.version}` doesn't exist "
100
+ f"Version `{self.version}` of `{self.name}` model doesn't exist "
101
101
  "and cannot be fetched from the Model Control Plane."
102
102
  )
103
103
  return self._id
@@ -128,7 +128,7 @@ class ModelVersion(BaseModel):
128
128
  self._get_or_create_model_version()
129
129
  except RuntimeError:
130
130
  logger.info(
131
- f"Model version `{self.version}` doesn't exist "
131
+ f"Version `{self.version}` of `{self.name}` model doesn't exist "
132
132
  "and cannot be fetched from the Model Control Plane."
133
133
  )
134
134
  return self._number
@@ -149,7 +149,7 @@ class ModelVersion(BaseModel):
149
149
  return ModelStages(stage)
150
150
  except RuntimeError:
151
151
  logger.info(
152
- f"Model version `{self.version}` doesn't exist "
152
+ f"Version `{self.version}` of `{self.name}` model doesn't exist "
153
153
  "and cannot be fetched from the Model Control Plane."
154
154
  )
155
155
  return None
@@ -283,7 +283,7 @@ class ModelVersion(BaseModel):
283
283
  def set_stage(
284
284
  self, stage: Union[str, ModelStages], force: bool = False
285
285
  ) -> None:
286
- """Sets this Model Version to a desired stage.
286
+ """Sets this Model to a desired stage.
287
287
 
288
288
  Args:
289
289
  stage: the target stage for model version.
@@ -431,7 +431,7 @@ class ModelVersion(BaseModel):
431
431
  return LazyArtifactVersionResponse(
432
432
  _lazy_load_name=name,
433
433
  _lazy_load_version=version,
434
- _lazy_load_model_version=ModelVersion(
434
+ _lazy_load_model=Model(
435
435
  name=self.name, version=self.version or self.number
436
436
  ),
437
437
  )
@@ -441,7 +441,7 @@ class ModelVersion(BaseModel):
441
441
  return None
442
442
 
443
443
  def __eq__(self, other: object) -> bool:
444
- """Check two ModelVersions for equality.
444
+ """Check two Models for equality.
445
445
 
446
446
  Args:
447
447
  other: object to compare with
@@ -449,7 +449,7 @@ class ModelVersion(BaseModel):
449
449
  Returns:
450
450
  True, if equal, False otherwise.
451
451
  """
452
- if not isinstance(other, ModelVersion):
452
+ if not isinstance(other, Model):
453
453
  return NotImplemented
454
454
  if self.name != other.name:
455
455
  return False
@@ -509,7 +509,7 @@ class ModelVersion(BaseModel):
509
509
  model_name_or_id=self.name
510
510
  )
511
511
 
512
- difference = {}
512
+ difference: Dict[str, Any] = {}
513
513
  for key in (
514
514
  "license",
515
515
  "audience",
@@ -519,12 +519,20 @@ class ModelVersion(BaseModel):
519
519
  "ethics",
520
520
  "save_models_to_registry",
521
521
  ):
522
- if getattr(self, key) != getattr(model, key):
523
- difference[key] = {
524
- "config": getattr(self, key),
525
- "db": getattr(model, key),
526
- }
527
-
522
+ if self_attr := getattr(self, key, None):
523
+ if self_attr != getattr(model, key):
524
+ difference[key] = {
525
+ "config": getattr(self, key),
526
+ "db": getattr(model, key),
527
+ }
528
+ if self.tags:
529
+ configured_tags = set(self.tags)
530
+ db_tags = {t.name for t in model.tags}
531
+ if db_tags != configured_tags:
532
+ difference["tags added"] = list(configured_tags - db_tags)
533
+ difference["tags removed"] = list(
534
+ db_tags - configured_tags
535
+ )
528
536
  if difference:
529
537
  logger.warning(
530
538
  "Provided model configuration does not match "
@@ -581,16 +589,17 @@ class ModelVersion(BaseModel):
581
589
  self._id = mv.id
582
590
 
583
591
  difference: Dict[str, Any] = {}
584
- if mv.description != self.description:
592
+ if self.description and mv.description != self.description:
585
593
  difference["description"] = {
586
594
  "config": self.description,
587
595
  "db": mv.description,
588
596
  }
589
- configured_tags = set(self.tags or [])
590
- db_tags = {t.name for t in mv.tags}
591
- if db_tags != configured_tags:
592
- difference["tags added"] = list(configured_tags - db_tags)
593
- difference["tags removed"] = list(db_tags - configured_tags)
597
+ if self.tags:
598
+ configured_tags = set(self.tags)
599
+ db_tags = {t.name for t in mv.tags}
600
+ if db_tags != configured_tags:
601
+ difference["tags added"] = list(configured_tags - db_tags)
602
+ difference["tags removed"] = list(db_tags - configured_tags)
594
603
  if difference:
595
604
  logger.warning(
596
605
  "Provided model version configuration does not match existing model "
@@ -654,7 +663,7 @@ class ModelVersion(BaseModel):
654
663
  # model version for current model was already
655
664
  # created in the current run, not to create
656
665
  # new model versions
657
- pipeline_mv = context.pipeline_run.config.model_version
666
+ pipeline_mv = context.pipeline_run.config.model
658
667
  if (
659
668
  pipeline_mv
660
669
  and pipeline_mv.was_created_in_this_run
@@ -664,7 +673,7 @@ class ModelVersion(BaseModel):
664
673
  self.version = pipeline_mv.version
665
674
  else:
666
675
  for step in context.pipeline_run.steps.values():
667
- step_mv = step.config.model_version
676
+ step_mv = step.config.model
668
677
  if (
669
678
  step_mv
670
679
  and step_mv.was_created_in_this_run
@@ -712,21 +721,21 @@ class ModelVersion(BaseModel):
712
721
  self._number = model_version.number
713
722
  return model_version
714
723
 
715
- def _merge(self, model_version: "ModelVersion") -> None:
716
- self.license = self.license or model_version.license
717
- self.description = self.description or model_version.description
718
- self.audience = self.audience or model_version.audience
719
- self.use_cases = self.use_cases or model_version.use_cases
720
- self.limitations = self.limitations or model_version.limitations
721
- self.trade_offs = self.trade_offs or model_version.trade_offs
722
- self.ethics = self.ethics or model_version.ethics
723
- if model_version.tags is not None:
724
+ def _merge(self, model: "Model") -> None:
725
+ self.license = self.license or model.license
726
+ self.description = self.description or model.description
727
+ self.audience = self.audience or model.audience
728
+ self.use_cases = self.use_cases or model.use_cases
729
+ self.limitations = self.limitations or model.limitations
730
+ self.trade_offs = self.trade_offs or model.trade_offs
731
+ self.ethics = self.ethics or model.ethics
732
+ if model.tags is not None:
724
733
  self.tags = list(
725
- {t for t in self.tags or []}.union(set(model_version.tags))
734
+ {t for t in self.tags or []}.union(set(model.tags))
726
735
  )
727
736
 
728
737
  def __hash__(self) -> int:
729
- """Get hash of the `ModelVersion`.
738
+ """Get hash of the `Model`.
730
739
 
731
740
  Returns:
732
741
  Hash function results
zenml/model/utils.py CHANGED
@@ -22,7 +22,7 @@ from zenml.enums import ModelStages
22
22
  from zenml.exceptions import StepContextError
23
23
  from zenml.logger import get_logger
24
24
  from zenml.metadata.metadata_types import MetadataType
25
- from zenml.model.model_version import ModelVersion
25
+ from zenml.model.model import Model
26
26
  from zenml.models import ModelVersionArtifactRequest
27
27
  from zenml.new.steps.step_context import get_step_context
28
28
 
@@ -48,9 +48,9 @@ def link_step_artifacts_to_model(
48
48
  "step."
49
49
  )
50
50
  try:
51
- model_version = step_context.model_version
51
+ model = step_context.model
52
52
  except StepContextError:
53
- model_version = None
53
+ model = None
54
54
  logger.debug("No model context found, unable to auto-link artifacts.")
55
55
 
56
56
  for artifact_name, artifact_version_id in artifact_version_ids.items():
@@ -59,47 +59,47 @@ def link_step_artifacts_to_model(
59
59
  ).artifact_config
60
60
 
61
61
  # Implicit linking
62
- if artifact_config is None and model_version is not None:
62
+ if artifact_config is None and model is not None:
63
63
  artifact_config = ArtifactConfig(name=artifact_name)
64
64
  logger.info(
65
65
  f"Implicitly linking artifact `{artifact_name}` to model "
66
- f"`{model_version.name}` version `{model_version.version}`."
66
+ f"`{model.name}` version `{model.version}`."
67
67
  )
68
68
 
69
69
  if artifact_config:
70
- link_artifact_config_to_model_version(
70
+ link_artifact_config_to_model(
71
71
  artifact_config=artifact_config,
72
72
  artifact_version_id=artifact_version_id,
73
- model_version=model_version,
73
+ model=model,
74
74
  )
75
75
 
76
76
 
77
- def link_artifact_config_to_model_version(
77
+ def link_artifact_config_to_model(
78
78
  artifact_config: ArtifactConfig,
79
79
  artifact_version_id: UUID,
80
- model_version: Optional["ModelVersion"] = None,
80
+ model: Optional["Model"] = None,
81
81
  ) -> None:
82
82
  """Link an artifact config to its model version.
83
83
 
84
84
  Args:
85
85
  artifact_config: The artifact config to link.
86
86
  artifact_version_id: The ID of the artifact to link.
87
- model_version: The model version from the step or pipeline context.
87
+ model: The model version from the step or pipeline context.
88
88
  """
89
89
  client = Client()
90
90
 
91
91
  # If the artifact config specifies a model itself then always use that
92
92
  if artifact_config.model_name is not None:
93
- from zenml.model.model_version import ModelVersion
93
+ from zenml.model.model import Model
94
94
 
95
- model_version = ModelVersion(
95
+ model = Model(
96
96
  name=artifact_config.model_name,
97
97
  version=artifact_config.model_version,
98
98
  )
99
99
 
100
- if model_version:
101
- model_version._get_or_create_model_version()
102
- model_version_response = model_version._get_model_version()
100
+ if model:
101
+ model._get_or_create_model_version()
102
+ model_version_response = model._get_model_version()
103
103
  request = ModelVersionArtifactRequest(
104
104
  user=client.active_user.id,
105
105
  workspace=client.active_workspace.id,
@@ -125,10 +125,10 @@ def log_model_version_metadata(
125
125
  metadata: The metadata to log.
126
126
  model_name: The name of the model to log metadata for. Can
127
127
  be omitted when being called inside a step with configured
128
- `model_version` in decorator.
128
+ `model` in decorator.
129
129
  model_version: The version of the model to log metadata for. Can
130
130
  be omitted when being called inside a step with configured
131
- `model_version` in decorator.
131
+ `model` in decorator.
132
132
  """
133
133
  logger.warning(
134
134
  "`log_model_version_metadata` is deprecated. Please use "
@@ -152,38 +152,38 @@ def log_model_metadata(
152
152
  metadata: The metadata to log.
153
153
  model_name: The name of the model to log metadata for. Can
154
154
  be omitted when being called inside a step with configured
155
- `model_version` in decorator.
155
+ `model` in decorator.
156
156
  model_version: The version of the model to log metadata for. Can
157
157
  be omitted when being called inside a step with configured
158
- `model_version` in decorator.
158
+ `model` in decorator.
159
159
 
160
160
  Raises:
161
161
  ValueError: If no model name/version is provided and the function is not
162
- called inside a step with configured `model_version` in decorator.
162
+ called inside a step with configured `model` in decorator.
163
163
  """
164
164
  mv = None
165
165
  try:
166
166
  step_context = get_step_context()
167
- mv = step_context.model_version
167
+ mv = step_context.model
168
168
  except RuntimeError:
169
169
  step_context = None
170
170
 
171
171
  if not step_context and not (model_name and model_version):
172
172
  raise ValueError(
173
173
  "Model name and version must be provided unless the function is "
174
- "called inside a step with configured `model_version` in decorator."
174
+ "called inside a step with configured `model` in decorator."
175
175
  )
176
176
  if mv is None:
177
- from zenml import ModelVersion
177
+ from zenml import Model
178
178
 
179
- mv = ModelVersion(name=model_name, version=model_version)
179
+ mv = Model(name=model_name, version=model_version)
180
180
 
181
181
  mv.log_metadata(metadata)
182
182
 
183
183
 
184
184
  def link_artifact_to_model(
185
185
  artifact_version_id: UUID,
186
- model_version: Optional["ModelVersion"] = None,
186
+ model: Optional["Model"] = None,
187
187
  is_model_artifact: bool = False,
188
188
  is_deployment_artifact: bool = False,
189
189
  ) -> None:
@@ -191,34 +191,34 @@ def link_artifact_to_model(
191
191
 
192
192
  Args:
193
193
  artifact_version_id: The ID of the artifact version.
194
- model_version: The model version to link to.
194
+ model: The model to link to.
195
195
  is_model_artifact: Whether the artifact is a model artifact.
196
196
  is_deployment_artifact: Whether the artifact is a deployment artifact.
197
197
 
198
198
  Raises:
199
199
  RuntimeError: If called outside of a step.
200
200
  """
201
- if not model_version:
201
+ if not model:
202
202
  is_issue = False
203
203
  try:
204
204
  step_context = get_step_context()
205
- model_version = step_context.model_version
205
+ model = step_context.model
206
206
  except StepContextError:
207
207
  is_issue = True
208
208
 
209
- if model_version is None or is_issue:
209
+ if model is None or is_issue:
210
210
  raise RuntimeError(
211
- "`link_artifact_to_model` called without `model_version` parameter "
211
+ "`link_artifact_to_model` called without `model` parameter "
212
212
  "and configured model context cannot be identified. Consider "
213
- "passing the `model_version` explicitly or configuring it in "
213
+ "passing the `model` explicitly or configuring it in "
214
214
  "@step or @pipeline decorator."
215
215
  )
216
216
 
217
- link_artifact_config_to_model_version(
217
+ link_artifact_config_to_model(
218
218
  artifact_config=ArtifactConfig(
219
219
  is_model_artifact=is_model_artifact,
220
220
  is_deployment_artifact=is_deployment_artifact,
221
221
  ),
222
222
  artifact_version_id=artifact_version_id,
223
- model_version=model_version,
223
+ model=model,
224
224
  )
@@ -123,10 +123,10 @@ class ModelRegistryModelMetadata(BaseModel):
123
123
  extra = "allow"
124
124
 
125
125
 
126
- class ModelVersion(BaseModel):
126
+ class RegistryModelVersion(BaseModel):
127
127
  """Base class for all ZenML model versions.
128
128
 
129
- The `ModelVersion` class represents a version or snapshot of a registered
129
+ The `RegistryModelVersion` class represents a version or snapshot of a registered
130
130
  model, including information such as the associated `ModelBundle`, version
131
131
  number, creation time, pipeline run information, and metadata. It serves as
132
132
  a blueprint for creating concrete model version implementations in a registry,
@@ -288,7 +288,7 @@ class BaseModelRegistry(StackComponent, ABC):
288
288
  description: Optional[str] = None,
289
289
  metadata: Optional[ModelRegistryModelMetadata] = None,
290
290
  **kwargs: Any,
291
- ) -> ModelVersion:
291
+ ) -> RegistryModelVersion:
292
292
  """Registers a model version in the model registry.
293
293
 
294
294
  Args:
@@ -333,7 +333,7 @@ class BaseModelRegistry(StackComponent, ABC):
333
333
  metadata: Optional[ModelRegistryModelMetadata] = None,
334
334
  remove_metadata: Optional[List[str]] = None,
335
335
  stage: Optional[ModelVersionStage] = None,
336
- ) -> ModelVersion:
336
+ ) -> RegistryModelVersion:
337
337
  """Updates a model version in the model registry.
338
338
 
339
339
  Args:
@@ -364,7 +364,7 @@ class BaseModelRegistry(StackComponent, ABC):
364
364
  created_before: Optional[datetime] = None,
365
365
  order_by_date: Optional[str] = None,
366
366
  **kwargs: Any,
367
- ) -> Optional[List[ModelVersion]]:
367
+ ) -> Optional[List[RegistryModelVersion]]:
368
368
  """Lists all model versions for a registered model.
369
369
 
370
370
  Args:
@@ -387,7 +387,7 @@ class BaseModelRegistry(StackComponent, ABC):
387
387
  self,
388
388
  name: str,
389
389
  stage: Optional[ModelVersionStage] = None,
390
- ) -> Optional[ModelVersion]:
390
+ ) -> Optional[RegistryModelVersion]:
391
391
  """Gets the latest model version for a registered model.
392
392
 
393
393
  This method is used to get the latest model version for a registered
@@ -410,7 +410,9 @@ class BaseModelRegistry(StackComponent, ABC):
410
410
  return None
411
411
 
412
412
  @abstractmethod
413
- def get_model_version(self, name: str, version: str) -> ModelVersion:
413
+ def get_model_version(
414
+ self, name: str, version: str
415
+ ) -> RegistryModelVersion:
414
416
  """Gets a model version for a registered model.
415
417
 
416
418
  Args:
@@ -450,7 +452,7 @@ class BaseModelRegistry(StackComponent, ABC):
450
452
  @abstractmethod
451
453
  def get_model_uri_artifact_store(
452
454
  self,
453
- model_version: ModelVersion,
455
+ model_version: RegistryModelVersion,
454
456
  ) -> str:
455
457
  """Gets the URI artifact store for a model version.
456
458
 
zenml/models/__init__.py CHANGED
@@ -34,6 +34,7 @@ from zenml.models.v2.base.scoped import (
34
34
  WorkspaceScopedResponse,
35
35
  WorkspaceScopedResponseBody,
36
36
  WorkspaceScopedResponseMetadata,
37
+ WorkspaceScopedTaggableFilter
37
38
  )
38
39
  from zenml.models.v2.base.filter import (
39
40
  BaseFilter,
@@ -478,6 +479,7 @@ __all__ = [
478
479
  "WorkspaceScopedResponse",
479
480
  "WorkspaceScopedResponseBody",
480
481
  "WorkspaceScopedResponseMetadata",
482
+ "WorkspaceScopedTaggableFilter",
481
483
  "BaseFilter",
482
484
  "StrFilter",
483
485
  "BoolFilter",
@@ -264,6 +264,9 @@ class BaseFilter(BaseModel):
264
264
  # List of fields that are not even mentioned as options in the CLI.
265
265
  CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = []
266
266
 
267
+ # List of fields that are wrapped with `fastapi.Query(default)` in API.
268
+ API_MULTI_INPUT_PARAMS: ClassVar[List[str]] = []
269
+
267
270
  sort_by: str = Field(
268
271
  default="created", description="Which column to sort by."
269
272
  )