zenml-nightly 0.54.1.dev20240119__py3-none-any.whl → 0.54.1.dev20240120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. zenml/__init__.py +10 -4
  2. zenml/artifacts/artifact_config.py +12 -12
  3. zenml/artifacts/external_artifact.py +3 -3
  4. zenml/artifacts/external_artifact_config.py +8 -8
  5. zenml/artifacts/utils.py +4 -4
  6. zenml/cli/__init__.py +1 -1
  7. zenml/cli/base.py +3 -3
  8. zenml/cli/utils.py +3 -3
  9. zenml/config/compiler.py +1 -1
  10. zenml/config/pipeline_configurations.py +2 -2
  11. zenml/config/pipeline_run_configuration.py +2 -2
  12. zenml/config/step_configurations.py +2 -2
  13. zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +11 -9
  14. zenml/metadata/lazy_load.py +5 -5
  15. zenml/model/lazy_load.py +2 -2
  16. zenml/model/{model_version.py → model.py} +35 -28
  17. zenml/model/utils.py +33 -33
  18. zenml/model_registries/base_model_registry.py +10 -8
  19. zenml/models/v2/core/artifact_version.py +3 -3
  20. zenml/models/v2/core/model.py +3 -3
  21. zenml/models/v2/core/model_version.py +7 -7
  22. zenml/models/v2/core/run_metadata.py +2 -2
  23. zenml/new/pipelines/model_utils.py +20 -20
  24. zenml/new/pipelines/pipeline.py +47 -54
  25. zenml/new/pipelines/pipeline_context.py +1 -1
  26. zenml/new/pipelines/pipeline_decorator.py +4 -4
  27. zenml/new/steps/step_context.py +15 -15
  28. zenml/new/steps/step_decorator.py +5 -5
  29. zenml/orchestrators/input_utils.py +5 -7
  30. zenml/orchestrators/step_launcher.py +12 -19
  31. zenml/orchestrators/step_runner.py +8 -10
  32. zenml/pipelines/base_pipeline.py +1 -1
  33. zenml/pipelines/pipeline_decorator.py +6 -6
  34. zenml/steps/base_step.py +15 -15
  35. zenml/steps/step_decorator.py +6 -6
  36. zenml/steps/utils.py +68 -0
  37. zenml/zen_stores/migrations/versions/4d688d8f7aff_rename_model_version_to_model.py +94 -0
  38. {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/METADATA +1 -1
  39. {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/RECORD +42 -41
  40. {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/LICENSE +0 -0
  41. {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/WHEEL +0 -0
  42. {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/entry_points.txt +0 -0
zenml/__init__.py CHANGED
@@ -42,15 +42,19 @@ from zenml.artifacts.utils import (
42
42
  save_artifact,
43
43
  load_artifact,
44
44
  )
45
- from zenml.model.utils import log_model_metadata, link_artifact_to_model
45
+ from zenml.model.utils import (
46
+ log_model_metadata,
47
+ link_artifact_to_model,
48
+ log_model_version_metadata,
49
+ )
46
50
  from zenml.artifacts.artifact_config import ArtifactConfig
47
51
  from zenml.artifacts.external_artifact import ExternalArtifact
48
- from zenml.model.model_version import ModelVersion
49
- from zenml.model.utils import log_model_version_metadata
52
+ from zenml.model.model import Model
50
53
  from zenml.new.pipelines.pipeline_context import get_pipeline_context
51
54
  from zenml.new.pipelines.pipeline_decorator import pipeline
52
55
  from zenml.new.steps.step_decorator import step
53
56
  from zenml.new.steps.step_context import get_step_context
57
+ from zenml.steps.utils import log_step_metadata
54
58
 
55
59
  __all__ = [
56
60
  "ArtifactConfig",
@@ -60,8 +64,10 @@ __all__ = [
60
64
  "load_artifact",
61
65
  "log_artifact_metadata",
62
66
  "log_model_metadata",
67
+ "log_model_version_metadata",
68
+ "log_step_metadata",
69
+ "Model",
63
70
  "link_artifact_to_model",
64
- "ModelVersion",
65
71
  "pipeline",
66
72
  "save_artifact",
67
73
  "show",
@@ -22,7 +22,7 @@ from zenml.logger import get_logger
22
22
  from zenml.new.steps.step_context import get_step_context
23
23
 
24
24
  if TYPE_CHECKING:
25
- from zenml.model.model_version import ModelVersion
25
+ from zenml.model.model import Model
26
26
 
27
27
 
28
28
  logger = get_logger(__name__)
@@ -52,7 +52,7 @@ class ArtifactConfig(BaseModel):
52
52
  version: The version of the artifact.
53
53
  tags: The tags of the artifact.
54
54
  model_name: The name of the model to link artifact to.
55
- model_version: The identifier of the model version to link the artifact
55
+ model_version: The identifier of a version of the model to link the artifact
56
56
  to. It can be an exact version ("my_version"), exact version number
57
57
  (42), stage (ModelStages.PRODUCTION or "production"), or
58
58
  (ModelStages.LATEST or None) for the latest version (default).
@@ -88,26 +88,26 @@ class ArtifactConfig(BaseModel):
88
88
  smart_union = True
89
89
 
90
90
  @property
91
- def _model_version(self) -> Optional["ModelVersion"]:
92
- """The model version linked to this artifact.
91
+ def _model(self) -> Optional["Model"]:
92
+ """The model linked to this artifact.
93
93
 
94
94
  Returns:
95
- The model version or None if the model version cannot be determined.
95
+ The model or None if the model version cannot be determined.
96
96
  """
97
97
  try:
98
- model_version = get_step_context().model_version
98
+ model_ = get_step_context().model
99
99
  except (StepContextError, RuntimeError):
100
- model_version = None
100
+ model_ = None
101
101
  # Check if another model name was specified
102
102
  if (self.model_name is not None) and (
103
- model_version is None or model_version.name != self.model_name
103
+ model_ is None or model_.name != self.model_name
104
104
  ):
105
- # Create a new ModelConfig instance with the provided model name and version
106
- from zenml.model.model_version import ModelVersion
105
+ # Create a new Model instance with the provided model name and version
106
+ from zenml.model.model import Model
107
107
 
108
- on_the_fly_config = ModelVersion(
108
+ on_the_fly_config = Model(
109
109
  name=self.model_name, version=self.model_version
110
110
  )
111
111
  return on_the_fly_config
112
112
 
113
- return model_version
113
+ return model_
@@ -56,8 +56,8 @@ class ExternalArtifact(ExternalArtifactConfiguration):
56
56
  `version`, `pipeline_run_name`, or `pipeline_name` are set, the
57
57
  latest version of the artifact will be used.
58
58
  version: Version of the artifact to search. Only used when `name` is
59
- provided. Cannot be used together with `model_version`.
60
- model_version: The model version to search in. Only used when `name`
59
+ provided. Cannot be used together with `model`.
60
+ model: The model to search in. Only used when `name`
61
61
  is provided. Cannot be used together with `version`.
62
62
  materializer: The materializer to use for saving the artifact value
63
63
  to the artifact store. Only used when `value` is provided.
@@ -149,5 +149,5 @@ class ExternalArtifact(ExternalArtifactConfiguration):
149
149
  id=self.id,
150
150
  name=self.name,
151
151
  version=self.version,
152
- model_version=self.model_version,
152
+ model=self.model,
153
153
  )
@@ -18,7 +18,7 @@ from uuid import UUID
18
18
  from pydantic import BaseModel, root_validator
19
19
 
20
20
  from zenml.logger import get_logger
21
- from zenml.model.model_version import ModelVersion
21
+ from zenml.model.model import Model
22
22
  from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
23
23
 
24
24
  logger = get_logger(__name__)
@@ -33,13 +33,13 @@ class ExternalArtifactConfiguration(BaseModel):
33
33
  id: Optional[UUID] = None
34
34
  name: Optional[str] = None
35
35
  version: Optional[str] = None
36
- model_version: Optional[ModelVersion] = None
36
+ model: Optional[Model] = None
37
37
 
38
38
  @root_validator
39
39
  def _validate_all_eac(cls, values: Dict[str, Any]) -> Dict[str, Any]:
40
- if values.get("version", None) and values.get("model_version", None):
40
+ if values.get("version", None) and values.get("model", None):
41
41
  raise ValueError(
42
- "Cannot provide both `version` and `model_version` when "
42
+ "Cannot provide both `version` and `model` when "
43
43
  "creating an external artifact."
44
44
  )
45
45
  return values
@@ -67,13 +67,13 @@ class ExternalArtifactConfiguration(BaseModel):
67
67
  response = client.get_artifact_version(
68
68
  self.name, version=self.version
69
69
  )
70
- elif self.model_version:
71
- response_ = self.model_version.get_artifact(self.name)
70
+ elif self.model:
71
+ response_ = self.model.get_artifact(self.name)
72
72
  if not isinstance(response_, ArtifactVersionResponse):
73
73
  raise RuntimeError(
74
74
  f"Failed to pull artifact `{self.name}` from the Model "
75
- f"Version (name=`{self.model_version.name}`, version="
76
- f"`{self.model_version.version}`). Please validate the "
75
+ f"(name=`{self.model.name}`, version="
76
+ f"`{self.model.version}`). Please validate the "
77
77
  "input and try again."
78
78
  )
79
79
  response = response_
zenml/artifacts/utils.py CHANGED
@@ -233,14 +233,14 @@ def save_artifact(
233
233
  saved_artifact_versions={name: response.id}
234
234
  ),
235
235
  )
236
- error_message = "model version"
237
- model_version = step_context.model_version
238
- if model_version:
236
+ error_message = "model"
237
+ model = step_context.model
238
+ if model:
239
239
  from zenml.model.utils import link_artifact_to_model
240
240
 
241
241
  link_artifact_to_model(
242
242
  artifact_version_id=response.id,
243
- model_version=model_version,
243
+ model=model,
244
244
  is_model_artifact=is_model_artifact,
245
245
  is_deployment_artifact=is_deployment_artifact,
246
246
  )
zenml/cli/__init__.py CHANGED
@@ -796,7 +796,7 @@ Administering your Models
796
796
  ----------------------------
797
797
 
798
798
  ZenML provides several CLI commands to help you administer your models and
799
- model versions as part of the Model Control Plane.
799
+ their versions as part of the Model Control Plane.
800
800
 
801
801
  To register a new model, you can use the following CLI command:
802
802
  ```bash
zenml/cli/base.py CHANGED
@@ -73,15 +73,15 @@ class ZenMLProjectTemplateLocation(BaseModel):
73
73
  ZENML_PROJECT_TEMPLATES = dict(
74
74
  e2e_batch=ZenMLProjectTemplateLocation(
75
75
  github_url="zenml-io/template-e2e-batch",
76
- github_tag="2024.01.17", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
76
+ github_tag="2024.01.18", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
77
77
  ),
78
78
  starter=ZenMLProjectTemplateLocation(
79
79
  github_url="zenml-io/template-starter",
80
- github_tag="2023.12.18", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
80
+ github_tag="2024.01.12", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
81
81
  ),
82
82
  nlp=ZenMLProjectTemplateLocation(
83
83
  github_url="zenml-io/template-nlp",
84
- github_tag="0.45.0", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
84
+ github_tag="2024.01.12", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
85
85
  ),
86
86
  )
87
87
 
zenml/cli/utils.py CHANGED
@@ -58,8 +58,8 @@ from zenml.constants import (
58
58
  from zenml.enums import GenericFilterOps, StackComponentType
59
59
  from zenml.logger import get_logger
60
60
  from zenml.model_registries.base_model_registry import (
61
- ModelVersion,
62
61
  RegisteredModel,
62
+ RegistryModelVersion,
63
63
  )
64
64
  from zenml.models import (
65
65
  BaseFilter,
@@ -1182,7 +1182,7 @@ def pretty_print_registered_model_table(
1182
1182
 
1183
1183
 
1184
1184
  def pretty_print_model_version_table(
1185
- model_versions: List["ModelVersion"],
1185
+ model_versions: List["RegistryModelVersion"],
1186
1186
  ) -> None:
1187
1187
  """Given a list of model_versions, print all associated key-value pairs.
1188
1188
 
@@ -1206,7 +1206,7 @@ def pretty_print_model_version_table(
1206
1206
 
1207
1207
 
1208
1208
  def pretty_print_model_version_details(
1209
- model_version: "ModelVersion",
1209
+ model_version: "RegistryModelVersion",
1210
1210
  ) -> None:
1211
1211
  """Given a model_version, print all associated key-value pairs.
1212
1212
 
zenml/config/compiler.py CHANGED
@@ -204,7 +204,7 @@ class Compiler:
204
204
  enable_step_logs=config.enable_step_logs,
205
205
  settings=config.settings,
206
206
  extra=config.extra,
207
- model_version=config.model_version,
207
+ model=config.model,
208
208
  parameters=config.parameters,
209
209
  )
210
210
 
@@ -19,7 +19,7 @@ from pydantic import validator
19
19
  from zenml.config.constants import DOCKER_SETTINGS_KEY
20
20
  from zenml.config.source import Source, convert_source_validator
21
21
  from zenml.config.strict_base_model import StrictBaseModel
22
- from zenml.model.model_version import ModelVersion
22
+ from zenml.model.model import Model
23
23
 
24
24
  if TYPE_CHECKING:
25
25
  from zenml.config import DockerSettings
@@ -40,7 +40,7 @@ class PipelineConfigurationUpdate(StrictBaseModel):
40
40
  extra: Dict[str, Any] = {}
41
41
  failure_hook_source: Optional[Source] = None
42
42
  success_hook_source: Optional[Source] = None
43
- model_version: Optional[ModelVersion] = None
43
+ model: Optional[Model] = None
44
44
  parameters: Optional[Dict[str, Any]] = None
45
45
 
46
46
  _convert_source = convert_source_validator(
@@ -19,7 +19,7 @@ from zenml.config.base_settings import BaseSettings
19
19
  from zenml.config.schedule import Schedule
20
20
  from zenml.config.step_configurations import StepConfigurationUpdate
21
21
  from zenml.config.strict_base_model import StrictBaseModel
22
- from zenml.model.model_version import ModelVersion
22
+ from zenml.model.model import Model
23
23
  from zenml.models import PipelineBuildBase
24
24
  from zenml.utils import pydantic_utils
25
25
 
@@ -39,5 +39,5 @@ class PipelineRunConfiguration(
39
39
  steps: Dict[str, StepConfigurationUpdate] = {}
40
40
  settings: Dict[str, BaseSettings] = {}
41
41
  extra: Dict[str, Any] = {}
42
- model_version: Optional[ModelVersion] = None
42
+ model: Optional[Model] = None
43
43
  parameters: Optional[Dict[str, Any]] = None
@@ -35,7 +35,7 @@ from zenml.config.source import Source, convert_source_validator
35
35
  from zenml.config.strict_base_model import StrictBaseModel
36
36
  from zenml.logger import get_logger
37
37
  from zenml.model.lazy_load import ModelVersionDataLazyLoader
38
- from zenml.model.model_version import ModelVersion
38
+ from zenml.model.model import Model
39
39
  from zenml.utils import deprecation_utils
40
40
 
41
41
  if TYPE_CHECKING:
@@ -135,7 +135,7 @@ class StepConfigurationUpdate(StrictBaseModel):
135
135
  extra: Dict[str, Any] = {}
136
136
  failure_hook_source: Optional[Source] = None
137
137
  success_hook_source: Optional[Source] = None
138
- model_version: Optional[ModelVersion] = None
138
+ model: Optional[Model] = None
139
139
 
140
140
  outputs: Mapping[str, PartialArtifactConfiguration] = {}
141
141
 
@@ -19,7 +19,9 @@ from typing import Any, Dict, List, Optional, Tuple, cast
19
19
 
20
20
  import mlflow
21
21
  from mlflow import MlflowClient
22
- from mlflow.entities.model_registry import ModelVersion as MLflowModelVersion
22
+ from mlflow.entities.model_registry import (
23
+ ModelVersion as MLflowModelVersion,
24
+ )
23
25
  from mlflow.exceptions import MlflowException
24
26
  from mlflow.pyfunc import load_model
25
27
 
@@ -36,9 +38,9 @@ from zenml.logger import get_logger
36
38
  from zenml.model_registries.base_model_registry import (
37
39
  BaseModelRegistry,
38
40
  ModelRegistryModelMetadata,
39
- ModelVersion,
40
41
  ModelVersionStage,
41
42
  RegisteredModel,
43
+ RegistryModelVersion,
42
44
  )
43
45
  from zenml.stack.stack import Stack
44
46
  from zenml.stack.stack_validator import StackValidator
@@ -348,7 +350,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
348
350
  description: Optional[str] = None,
349
351
  metadata: Optional[ModelRegistryModelMetadata] = None,
350
352
  **kwargs: Any,
351
- ) -> ModelVersion:
353
+ ) -> RegistryModelVersion:
352
354
  """Register a model version to the MLflow model registry.
353
355
 
354
356
  Args:
@@ -442,7 +444,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
442
444
  metadata: Optional[ModelRegistryModelMetadata] = None,
443
445
  remove_metadata: Optional[List[str]] = None,
444
446
  stage: Optional[ModelVersionStage] = None,
445
- ) -> ModelVersion:
447
+ ) -> RegistryModelVersion:
446
448
  """Update a model version in the MLflow model registry.
447
449
 
448
450
  Args:
@@ -521,7 +523,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
521
523
  self,
522
524
  name: str,
523
525
  version: str,
524
- ) -> ModelVersion:
526
+ ) -> RegistryModelVersion:
525
527
  """Get a model version from the MLflow model registry.
526
528
 
527
529
  Args:
@@ -561,7 +563,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
561
563
  created_before: Optional[datetime] = None,
562
564
  order_by_date: Optional[str] = None,
563
565
  **kwargs: Any,
564
- ) -> List[ModelVersion]:
566
+ ) -> List[RegistryModelVersion]:
565
567
  """List model versions from the MLflow model registry.
566
568
 
567
569
  Args:
@@ -704,7 +706,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
704
706
 
705
707
  def get_model_uri_artifact_store(
706
708
  self,
707
- model_version: ModelVersion,
709
+ model_version: RegistryModelVersion,
708
710
  ) -> str:
709
711
  """Get the model URI artifact store.
710
712
 
@@ -723,7 +725,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
723
725
  def _cast_mlflow_version_to_model_version(
724
726
  self,
725
727
  mlflow_model_version: MLflowModelVersion,
726
- ) -> ModelVersion:
728
+ ) -> RegistryModelVersion:
727
729
  """Cast an MLflow model version to a model version.
728
730
 
729
731
  Args:
@@ -748,7 +750,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
748
750
  )
749
751
  except ImportError:
750
752
  model_library = None
751
- return ModelVersion(
753
+ return RegistryModelVersion(
752
754
  registered_model=RegisteredModel(name=mlflow_model_version.name),
753
755
  model_format=MLFLOW_MODEL_FORMAT,
754
756
  model_library=model_library,
@@ -16,7 +16,7 @@
16
16
  from typing import TYPE_CHECKING, Optional
17
17
 
18
18
  if TYPE_CHECKING:
19
- from zenml.model.model_version import ModelVersion
19
+ from zenml.model.model import Model
20
20
  from zenml.models import RunMetadataResponse
21
21
 
22
22
 
@@ -30,18 +30,18 @@ class RunMetadataLazyGetter:
30
30
 
31
31
  def __init__(
32
32
  self,
33
- _lazy_load_model_version: "ModelVersion",
33
+ _lazy_load_model: "Model",
34
34
  _lazy_load_artifact_name: Optional[str],
35
35
  _lazy_load_artifact_version: Optional[str],
36
36
  ):
37
37
  """Initialize a RunMetadataLazyGetter.
38
38
 
39
39
  Args:
40
- _lazy_load_model_version: The model version.
40
+ _lazy_load_model: The model version.
41
41
  _lazy_load_artifact_name: The artifact name.
42
42
  _lazy_load_artifact_version: The artifact version.
43
43
  """
44
- self._lazy_load_model_version = _lazy_load_model_version
44
+ self._lazy_load_model = _lazy_load_model
45
45
  self._lazy_load_artifact_name = _lazy_load_artifact_name
46
46
  self._lazy_load_artifact_version = _lazy_load_artifact_version
47
47
 
@@ -57,7 +57,7 @@ class RunMetadataLazyGetter:
57
57
  from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse
58
58
 
59
59
  return LazyRunMetadataResponse(
60
- _lazy_load_model_version=self._lazy_load_model_version,
60
+ _lazy_load_model=self._lazy_load_model,
61
61
  _lazy_load_artifact_name=self._lazy_load_artifact_name,
62
62
  _lazy_load_artifact_version=self._lazy_load_artifact_version,
63
63
  _lazy_load_metadata_name=key,
zenml/model/lazy_load.py CHANGED
@@ -17,7 +17,7 @@ from typing import Optional
17
17
 
18
18
  from pydantic import BaseModel
19
19
 
20
- from zenml.model.model_version import ModelVersion
20
+ from zenml.model.model import Model
21
21
 
22
22
 
23
23
  class ModelVersionDataLazyLoader(BaseModel):
@@ -28,7 +28,7 @@ class ModelVersionDataLazyLoader(BaseModel):
28
28
  model version during runtime time of the step.
29
29
  """
30
30
 
31
- model_version: ModelVersion
31
+ model: Model
32
32
  artifact_name: Optional[str] = None
33
33
  artifact_version: Optional[str] = None
34
34
  metadata_name: Optional[str] = None
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
12
  # or implied. See the License for the specific language governing
13
13
  # permissions and limitations under the License.
14
- """ModelVersion user facing interface to pass into pipeline or step."""
14
+ """Model user facing interface to pass into pipeline or step."""
15
15
 
16
16
  from typing import (
17
17
  TYPE_CHECKING,
@@ -42,8 +42,8 @@ if TYPE_CHECKING:
42
42
  logger = get_logger(__name__)
43
43
 
44
44
 
45
- class ModelVersion(BaseModel):
46
- """ModelVersion class to pass into pipeline or step to set it into a model context.
45
+ class Model(BaseModel):
46
+ """Model class to pass into pipeline or step to set it into a model context.
47
47
 
48
48
  name: The name of the model.
49
49
  license: The license under which the model is created.
@@ -54,8 +54,8 @@ class ModelVersion(BaseModel):
54
54
  trade_offs: The tradeoffs of the model.
55
55
  ethics: The ethical implications of the model.
56
56
  tags: Tags associated with the model.
57
- version: The model version name, number or stage is optional and points model context
58
- to a specific version/stage. If skipped new model version will be created.
57
+ version: The version name, version number or stage is optional and points model context
58
+ to a specific version/stage. If skipped new version will be created.
59
59
  save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
60
60
  if available in active stack.
61
61
  """
@@ -97,7 +97,7 @@ class ModelVersion(BaseModel):
97
97
  self._get_or_create_model_version()
98
98
  except RuntimeError:
99
99
  logger.info(
100
- f"Model version `{self.version}` doesn't exist "
100
+ f"Version `{self.version}` of `{self.name}` model doesn't exist "
101
101
  "and cannot be fetched from the Model Control Plane."
102
102
  )
103
103
  return self._id
@@ -128,7 +128,7 @@ class ModelVersion(BaseModel):
128
128
  self._get_or_create_model_version()
129
129
  except RuntimeError:
130
130
  logger.info(
131
- f"Model version `{self.version}` doesn't exist "
131
+ f"Version `{self.version}` of `{self.name}` model doesn't exist "
132
132
  "and cannot be fetched from the Model Control Plane."
133
133
  )
134
134
  return self._number
@@ -149,7 +149,7 @@ class ModelVersion(BaseModel):
149
149
  return ModelStages(stage)
150
150
  except RuntimeError:
151
151
  logger.info(
152
- f"Model version `{self.version}` doesn't exist "
152
+ f"Version `{self.version}` of `{self.name}` model doesn't exist "
153
153
  "and cannot be fetched from the Model Control Plane."
154
154
  )
155
155
  return None
@@ -283,7 +283,7 @@ class ModelVersion(BaseModel):
283
283
  def set_stage(
284
284
  self, stage: Union[str, ModelStages], force: bool = False
285
285
  ) -> None:
286
- """Sets this Model Version to a desired stage.
286
+ """Sets this Model to a desired stage.
287
287
 
288
288
  Args:
289
289
  stage: the target stage for model version.
@@ -431,7 +431,7 @@ class ModelVersion(BaseModel):
431
431
  return LazyArtifactVersionResponse(
432
432
  _lazy_load_name=name,
433
433
  _lazy_load_version=version,
434
- _lazy_load_model_version=ModelVersion(
434
+ _lazy_load_model=Model(
435
435
  name=self.name, version=self.version or self.number
436
436
  ),
437
437
  )
@@ -441,7 +441,7 @@ class ModelVersion(BaseModel):
441
441
  return None
442
442
 
443
443
  def __eq__(self, other: object) -> bool:
444
- """Check two ModelVersions for equality.
444
+ """Check two Models for equality.
445
445
 
446
446
  Args:
447
447
  other: object to compare with
@@ -449,7 +449,7 @@ class ModelVersion(BaseModel):
449
449
  Returns:
450
450
  True, if equal, False otherwise.
451
451
  """
452
- if not isinstance(other, ModelVersion):
452
+ if not isinstance(other, Model):
453
453
  return NotImplemented
454
454
  if self.name != other.name:
455
455
  return False
@@ -509,7 +509,7 @@ class ModelVersion(BaseModel):
509
509
  model_name_or_id=self.name
510
510
  )
511
511
 
512
- difference = {}
512
+ difference: Dict[str, Any] = {}
513
513
  for key in (
514
514
  "license",
515
515
  "audience",
@@ -525,7 +525,14 @@ class ModelVersion(BaseModel):
525
525
  "config": getattr(self, key),
526
526
  "db": getattr(model, key),
527
527
  }
528
-
528
+ if self.tags:
529
+ configured_tags = set(self.tags)
530
+ db_tags = {t.name for t in model.tags}
531
+ if db_tags != configured_tags:
532
+ difference["tags added"] = list(configured_tags - db_tags)
533
+ difference["tags removed"] = list(
534
+ db_tags - configured_tags
535
+ )
529
536
  if difference:
530
537
  logger.warning(
531
538
  "Provided model configuration does not match "
@@ -588,7 +595,7 @@ class ModelVersion(BaseModel):
588
595
  "db": mv.description,
589
596
  }
590
597
  if self.tags:
591
- configured_tags = set(self.tags or [])
598
+ configured_tags = set(self.tags)
592
599
  db_tags = {t.name for t in mv.tags}
593
600
  if db_tags != configured_tags:
594
601
  difference["tags added"] = list(configured_tags - db_tags)
@@ -656,7 +663,7 @@ class ModelVersion(BaseModel):
656
663
  # model version for current model was already
657
664
  # created in the current run, not to create
658
665
  # new model versions
659
- pipeline_mv = context.pipeline_run.config.model_version
666
+ pipeline_mv = context.pipeline_run.config.model
660
667
  if (
661
668
  pipeline_mv
662
669
  and pipeline_mv.was_created_in_this_run
@@ -666,7 +673,7 @@ class ModelVersion(BaseModel):
666
673
  self.version = pipeline_mv.version
667
674
  else:
668
675
  for step in context.pipeline_run.steps.values():
669
- step_mv = step.config.model_version
676
+ step_mv = step.config.model
670
677
  if (
671
678
  step_mv
672
679
  and step_mv.was_created_in_this_run
@@ -714,21 +721,21 @@ class ModelVersion(BaseModel):
714
721
  self._number = model_version.number
715
722
  return model_version
716
723
 
717
- def _merge(self, model_version: "ModelVersion") -> None:
718
- self.license = self.license or model_version.license
719
- self.description = self.description or model_version.description
720
- self.audience = self.audience or model_version.audience
721
- self.use_cases = self.use_cases or model_version.use_cases
722
- self.limitations = self.limitations or model_version.limitations
723
- self.trade_offs = self.trade_offs or model_version.trade_offs
724
- self.ethics = self.ethics or model_version.ethics
725
- if model_version.tags is not None:
724
+ def _merge(self, model: "Model") -> None:
725
+ self.license = self.license or model.license
726
+ self.description = self.description or model.description
727
+ self.audience = self.audience or model.audience
728
+ self.use_cases = self.use_cases or model.use_cases
729
+ self.limitations = self.limitations or model.limitations
730
+ self.trade_offs = self.trade_offs or model.trade_offs
731
+ self.ethics = self.ethics or model.ethics
732
+ if model.tags is not None:
726
733
  self.tags = list(
727
- {t for t in self.tags or []}.union(set(model_version.tags))
734
+ {t for t in self.tags or []}.union(set(model.tags))
728
735
  )
729
736
 
730
737
  def __hash__(self) -> int:
731
- """Get hash of the `ModelVersion`.
738
+ """Get hash of the `Model`.
732
739
 
733
740
  Returns:
734
741
  Hash function results