zenml-nightly 0.54.1.dev20240119__py3-none-any.whl → 0.54.1.dev20240120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. zenml/__init__.py +10 -4
  2. zenml/artifacts/artifact_config.py +12 -12
  3. zenml/artifacts/external_artifact.py +3 -3
  4. zenml/artifacts/external_artifact_config.py +8 -8
  5. zenml/artifacts/utils.py +4 -4
  6. zenml/cli/__init__.py +1 -1
  7. zenml/cli/base.py +3 -3
  8. zenml/cli/utils.py +3 -3
  9. zenml/config/compiler.py +1 -1
  10. zenml/config/pipeline_configurations.py +2 -2
  11. zenml/config/pipeline_run_configuration.py +2 -2
  12. zenml/config/step_configurations.py +2 -2
  13. zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +11 -9
  14. zenml/metadata/lazy_load.py +5 -5
  15. zenml/model/lazy_load.py +2 -2
  16. zenml/model/{model_version.py → model.py} +35 -28
  17. zenml/model/utils.py +33 -33
  18. zenml/model_registries/base_model_registry.py +10 -8
  19. zenml/models/v2/core/artifact_version.py +3 -3
  20. zenml/models/v2/core/model.py +3 -3
  21. zenml/models/v2/core/model_version.py +7 -7
  22. zenml/models/v2/core/run_metadata.py +2 -2
  23. zenml/new/pipelines/model_utils.py +20 -20
  24. zenml/new/pipelines/pipeline.py +47 -54
  25. zenml/new/pipelines/pipeline_context.py +1 -1
  26. zenml/new/pipelines/pipeline_decorator.py +4 -4
  27. zenml/new/steps/step_context.py +15 -15
  28. zenml/new/steps/step_decorator.py +5 -5
  29. zenml/orchestrators/input_utils.py +5 -7
  30. zenml/orchestrators/step_launcher.py +12 -19
  31. zenml/orchestrators/step_runner.py +8 -10
  32. zenml/pipelines/base_pipeline.py +1 -1
  33. zenml/pipelines/pipeline_decorator.py +6 -6
  34. zenml/steps/base_step.py +15 -15
  35. zenml/steps/step_decorator.py +6 -6
  36. zenml/steps/utils.py +68 -0
  37. zenml/zen_stores/migrations/versions/4d688d8f7aff_rename_model_version_to_model.py +94 -0
  38. {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/METADATA +1 -1
  39. {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/RECORD +42 -41
  40. {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/LICENSE +0 -0
  41. {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/WHEEL +0 -0
  42. {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/entry_points.txt +0 -0
zenml/model/utils.py CHANGED
@@ -22,7 +22,7 @@ from zenml.enums import ModelStages
22
22
  from zenml.exceptions import StepContextError
23
23
  from zenml.logger import get_logger
24
24
  from zenml.metadata.metadata_types import MetadataType
25
- from zenml.model.model_version import ModelVersion
25
+ from zenml.model.model import Model
26
26
  from zenml.models import ModelVersionArtifactRequest
27
27
  from zenml.new.steps.step_context import get_step_context
28
28
 
@@ -48,9 +48,9 @@ def link_step_artifacts_to_model(
48
48
  "step."
49
49
  )
50
50
  try:
51
- model_version = step_context.model_version
51
+ model = step_context.model
52
52
  except StepContextError:
53
- model_version = None
53
+ model = None
54
54
  logger.debug("No model context found, unable to auto-link artifacts.")
55
55
 
56
56
  for artifact_name, artifact_version_id in artifact_version_ids.items():
@@ -59,47 +59,47 @@ def link_step_artifacts_to_model(
59
59
  ).artifact_config
60
60
 
61
61
  # Implicit linking
62
- if artifact_config is None and model_version is not None:
62
+ if artifact_config is None and model is not None:
63
63
  artifact_config = ArtifactConfig(name=artifact_name)
64
64
  logger.info(
65
65
  f"Implicitly linking artifact `{artifact_name}` to model "
66
- f"`{model_version.name}` version `{model_version.version}`."
66
+ f"`{model.name}` version `{model.version}`."
67
67
  )
68
68
 
69
69
  if artifact_config:
70
- link_artifact_config_to_model_version(
70
+ link_artifact_config_to_model(
71
71
  artifact_config=artifact_config,
72
72
  artifact_version_id=artifact_version_id,
73
- model_version=model_version,
73
+ model=model,
74
74
  )
75
75
 
76
76
 
77
- def link_artifact_config_to_model_version(
77
+ def link_artifact_config_to_model(
78
78
  artifact_config: ArtifactConfig,
79
79
  artifact_version_id: UUID,
80
- model_version: Optional["ModelVersion"] = None,
80
+ model: Optional["Model"] = None,
81
81
  ) -> None:
82
82
  """Link an artifact config to its model version.
83
83
 
84
84
  Args:
85
85
  artifact_config: The artifact config to link.
86
86
  artifact_version_id: The ID of the artifact to link.
87
- model_version: The model version from the step or pipeline context.
87
+ model: The model version from the step or pipeline context.
88
88
  """
89
89
  client = Client()
90
90
 
91
91
  # If the artifact config specifies a model itself then always use that
92
92
  if artifact_config.model_name is not None:
93
- from zenml.model.model_version import ModelVersion
93
+ from zenml.model.model import Model
94
94
 
95
- model_version = ModelVersion(
95
+ model = Model(
96
96
  name=artifact_config.model_name,
97
97
  version=artifact_config.model_version,
98
98
  )
99
99
 
100
- if model_version:
101
- model_version._get_or_create_model_version()
102
- model_version_response = model_version._get_model_version()
100
+ if model:
101
+ model._get_or_create_model_version()
102
+ model_version_response = model._get_model_version()
103
103
  request = ModelVersionArtifactRequest(
104
104
  user=client.active_user.id,
105
105
  workspace=client.active_workspace.id,
@@ -125,10 +125,10 @@ def log_model_version_metadata(
125
125
  metadata: The metadata to log.
126
126
  model_name: The name of the model to log metadata for. Can
127
127
  be omitted when being called inside a step with configured
128
- `model_version` in decorator.
128
+ `model` in decorator.
129
129
  model_version: The version of the model to log metadata for. Can
130
130
  be omitted when being called inside a step with configured
131
- `model_version` in decorator.
131
+ `model` in decorator.
132
132
  """
133
133
  logger.warning(
134
134
  "`log_model_version_metadata` is deprecated. Please use "
@@ -152,38 +152,38 @@ def log_model_metadata(
152
152
  metadata: The metadata to log.
153
153
  model_name: The name of the model to log metadata for. Can
154
154
  be omitted when being called inside a step with configured
155
- `model_version` in decorator.
155
+ `model` in decorator.
156
156
  model_version: The version of the model to log metadata for. Can
157
157
  be omitted when being called inside a step with configured
158
- `model_version` in decorator.
158
+ `model` in decorator.
159
159
 
160
160
  Raises:
161
161
  ValueError: If no model name/version is provided and the function is not
162
- called inside a step with configured `model_version` in decorator.
162
+ called inside a step with configured `model` in decorator.
163
163
  """
164
164
  mv = None
165
165
  try:
166
166
  step_context = get_step_context()
167
- mv = step_context.model_version
167
+ mv = step_context.model
168
168
  except RuntimeError:
169
169
  step_context = None
170
170
 
171
171
  if not step_context and not (model_name and model_version):
172
172
  raise ValueError(
173
173
  "Model name and version must be provided unless the function is "
174
- "called inside a step with configured `model_version` in decorator."
174
+ "called inside a step with configured `model` in decorator."
175
175
  )
176
176
  if mv is None:
177
- from zenml import ModelVersion
177
+ from zenml import Model
178
178
 
179
- mv = ModelVersion(name=model_name, version=model_version)
179
+ mv = Model(name=model_name, version=model_version)
180
180
 
181
181
  mv.log_metadata(metadata)
182
182
 
183
183
 
184
184
  def link_artifact_to_model(
185
185
  artifact_version_id: UUID,
186
- model_version: Optional["ModelVersion"] = None,
186
+ model: Optional["Model"] = None,
187
187
  is_model_artifact: bool = False,
188
188
  is_deployment_artifact: bool = False,
189
189
  ) -> None:
@@ -191,34 +191,34 @@ def link_artifact_to_model(
191
191
 
192
192
  Args:
193
193
  artifact_version_id: The ID of the artifact version.
194
- model_version: The model version to link to.
194
+ model: The model to link to.
195
195
  is_model_artifact: Whether the artifact is a model artifact.
196
196
  is_deployment_artifact: Whether the artifact is a deployment artifact.
197
197
 
198
198
  Raises:
199
199
  RuntimeError: If called outside of a step.
200
200
  """
201
- if not model_version:
201
+ if not model:
202
202
  is_issue = False
203
203
  try:
204
204
  step_context = get_step_context()
205
- model_version = step_context.model_version
205
+ model = step_context.model
206
206
  except StepContextError:
207
207
  is_issue = True
208
208
 
209
- if model_version is None or is_issue:
209
+ if model is None or is_issue:
210
210
  raise RuntimeError(
211
- "`link_artifact_to_model` called without `model_version` parameter "
211
+ "`link_artifact_to_model` called without `model` parameter "
212
212
  "and configured model context cannot be identified. Consider "
213
- "passing the `model_version` explicitly or configuring it in "
213
+ "passing the `model` explicitly or configuring it in "
214
214
  "@step or @pipeline decorator."
215
215
  )
216
216
 
217
- link_artifact_config_to_model_version(
217
+ link_artifact_config_to_model(
218
218
  artifact_config=ArtifactConfig(
219
219
  is_model_artifact=is_model_artifact,
220
220
  is_deployment_artifact=is_deployment_artifact,
221
221
  ),
222
222
  artifact_version_id=artifact_version_id,
223
- model_version=model_version,
223
+ model=model,
224
224
  )
@@ -123,10 +123,10 @@ class ModelRegistryModelMetadata(BaseModel):
123
123
  extra = "allow"
124
124
 
125
125
 
126
- class ModelVersion(BaseModel):
126
+ class RegistryModelVersion(BaseModel):
127
127
  """Base class for all ZenML model versions.
128
128
 
129
- The `ModelVersion` class represents a version or snapshot of a registered
129
+ The `RegistryModelVersion` class represents a version or snapshot of a registered
130
130
  model, including information such as the associated `ModelBundle`, version
131
131
  number, creation time, pipeline run information, and metadata. It serves as
132
132
  a blueprint for creating concrete model version implementations in a registry,
@@ -288,7 +288,7 @@ class BaseModelRegistry(StackComponent, ABC):
288
288
  description: Optional[str] = None,
289
289
  metadata: Optional[ModelRegistryModelMetadata] = None,
290
290
  **kwargs: Any,
291
- ) -> ModelVersion:
291
+ ) -> RegistryModelVersion:
292
292
  """Registers a model version in the model registry.
293
293
 
294
294
  Args:
@@ -333,7 +333,7 @@ class BaseModelRegistry(StackComponent, ABC):
333
333
  metadata: Optional[ModelRegistryModelMetadata] = None,
334
334
  remove_metadata: Optional[List[str]] = None,
335
335
  stage: Optional[ModelVersionStage] = None,
336
- ) -> ModelVersion:
336
+ ) -> RegistryModelVersion:
337
337
  """Updates a model version in the model registry.
338
338
 
339
339
  Args:
@@ -364,7 +364,7 @@ class BaseModelRegistry(StackComponent, ABC):
364
364
  created_before: Optional[datetime] = None,
365
365
  order_by_date: Optional[str] = None,
366
366
  **kwargs: Any,
367
- ) -> Optional[List[ModelVersion]]:
367
+ ) -> Optional[List[RegistryModelVersion]]:
368
368
  """Lists all model versions for a registered model.
369
369
 
370
370
  Args:
@@ -387,7 +387,7 @@ class BaseModelRegistry(StackComponent, ABC):
387
387
  self,
388
388
  name: str,
389
389
  stage: Optional[ModelVersionStage] = None,
390
- ) -> Optional[ModelVersion]:
390
+ ) -> Optional[RegistryModelVersion]:
391
391
  """Gets the latest model version for a registered model.
392
392
 
393
393
  This method is used to get the latest model version for a registered
@@ -410,7 +410,9 @@ class BaseModelRegistry(StackComponent, ABC):
410
410
  return None
411
411
 
412
412
  @abstractmethod
413
- def get_model_version(self, name: str, version: str) -> ModelVersion:
413
+ def get_model_version(
414
+ self, name: str, version: str
415
+ ) -> RegistryModelVersion:
414
416
  """Gets a model version for a registered model.
415
417
 
416
418
  Args:
@@ -450,7 +452,7 @@ class BaseModelRegistry(StackComponent, ABC):
450
452
  @abstractmethod
451
453
  def get_model_uri_artifact_store(
452
454
  self,
453
- model_version: ModelVersion,
455
+ model_version: RegistryModelVersion,
454
456
  ) -> str:
455
457
  """Gets the URI artifact store for a model version.
456
458
 
@@ -44,7 +44,7 @@ from zenml.models.v2.core.tag import TagResponse
44
44
  if TYPE_CHECKING:
45
45
  from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
46
46
 
47
- from zenml.model.model_version import ModelVersion
47
+ from zenml.model.model import Model
48
48
  from zenml.models.v2.core.artifact_visualization import (
49
49
  ArtifactVisualizationRequest,
50
50
  ArtifactVisualizationResponse,
@@ -477,7 +477,7 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse):
477
477
  id: Optional[UUID] = None # type: ignore[assignment]
478
478
  _lazy_load_name: Optional[str] = None
479
479
  _lazy_load_version: Optional[str] = None
480
- _lazy_load_model_version: "ModelVersion"
480
+ _lazy_load_model: "Model"
481
481
 
482
482
  def get_body(self) -> None: # type: ignore[override]
483
483
  """Protects from misuse of the lazy loader.
@@ -507,7 +507,7 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse):
507
507
  from zenml.metadata.lazy_load import RunMetadataLazyGetter
508
508
 
509
509
  return RunMetadataLazyGetter( # type: ignore[return-value]
510
- self._lazy_load_model_version,
510
+ self._lazy_load_model,
511
511
  self._lazy_load_name,
512
512
  self._lazy_load_version,
513
513
  )
@@ -31,7 +31,7 @@ from zenml.models.v2.base.scoped import (
31
31
  from zenml.utils.pagination_utils import depaginate
32
32
 
33
33
  if TYPE_CHECKING:
34
- from zenml.model.model_version import ModelVersion
34
+ from zenml.model.model import Model
35
35
  from zenml.models.v2.core.tag import TagResponse
36
36
 
37
37
 
@@ -310,7 +310,7 @@ class ModelResponse(
310
310
 
311
311
  # Helper functions
312
312
  @property
313
- def versions(self) -> List["ModelVersion"]:
313
+ def versions(self) -> List["Model"]:
314
314
  """List all versions of the model.
315
315
 
316
316
  Returns:
@@ -323,7 +323,7 @@ class ModelResponse(
323
323
  partial(client.list_model_versions, model_name_or_id=self.id)
324
324
  )
325
325
  return [
326
- mv.to_model_version(suppress_class_validation_warnings=True)
326
+ mv.to_model_class(suppress_class_validation_warnings=True)
327
327
  for mv in model_versions
328
328
  ]
329
329
 
@@ -32,7 +32,7 @@ from zenml.models.v2.base.scoped import (
32
32
  from zenml.models.v2.core.tag import TagResponse
33
33
 
34
34
  if TYPE_CHECKING:
35
- from zenml.model.model_version import ModelVersion
35
+ from zenml.model.model import Model
36
36
  from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
37
37
  from zenml.models.v2.core.model import ModelResponse
38
38
  from zenml.models.v2.core.pipeline_run import PipelineRunResponse
@@ -311,12 +311,12 @@ class ModelVersionResponse(
311
311
  return Client().zen_store.get_model_version(self.id)
312
312
 
313
313
  # Helper functions
314
- def to_model_version(
314
+ def to_model_class(
315
315
  self,
316
316
  was_created_in_this_run: bool = False,
317
317
  suppress_class_validation_warnings: bool = False,
318
- ) -> "ModelVersion":
319
- """Convert response model to ModelVersion object.
318
+ ) -> "Model":
319
+ """Convert response model to Model object.
320
320
 
321
321
  Args:
322
322
  was_created_in_this_run: Whether model version was created during
@@ -325,11 +325,11 @@ class ModelVersionResponse(
325
325
  repeated warnings.
326
326
 
327
327
  Returns:
328
- ModelVersion object
328
+ Model object
329
329
  """
330
- from zenml.model.model_version import ModelVersion
330
+ from zenml.model.model import Model
331
331
 
332
- mv = ModelVersion(
332
+ mv = Model(
333
333
  name=self.model.name,
334
334
  license=self.model.license,
335
335
  description=self.description,
@@ -30,7 +30,7 @@ from zenml.models.v2.base.scoped import (
30
30
  )
31
31
 
32
32
  if TYPE_CHECKING:
33
- from zenml.model.model_version import ModelVersion
33
+ from zenml.model.model import Model
34
34
 
35
35
  # ------------------ Request Model ------------------
36
36
 
@@ -203,7 +203,7 @@ class LazyRunMetadataResponse(RunMetadataResponse):
203
203
  _lazy_load_artifact_name: Optional[str] = None
204
204
  _lazy_load_artifact_version: Optional[str] = None
205
205
  _lazy_load_metadata_name: Optional[str] = None
206
- _lazy_load_model_version: "ModelVersion"
206
+ _lazy_load_model: "Model"
207
207
 
208
208
  def get_body(self) -> None: # type: ignore[override]
209
209
  """Protects from misuse of the lazy loader.
@@ -17,14 +17,14 @@ from typing import List, Optional
17
17
 
18
18
  from pydantic import BaseModel, PrivateAttr
19
19
 
20
- from zenml.model.model_version import ModelVersion
20
+ from zenml.model.model import Model
21
21
 
22
22
 
23
- class NewModelVersionRequest(BaseModel):
24
- """Request to create a new model version."""
23
+ class NewModelRequest(BaseModel):
24
+ """Request to create a new version of a model."""
25
25
 
26
26
  class Requester(BaseModel):
27
- """Requester of a new model version."""
27
+ """Requester of a new version of a model."""
28
28
 
29
29
  source: str
30
30
  name: str
@@ -38,35 +38,35 @@ class NewModelVersionRequest(BaseModel):
38
38
  return f"{self.source}::{self.name}"
39
39
 
40
40
  requesters: List[Requester] = []
41
- _model_version: Optional[ModelVersion] = PrivateAttr(default=None)
41
+ _model: Optional[Model] = PrivateAttr(default=None)
42
42
 
43
43
  @property
44
- def model_version(self) -> ModelVersion:
45
- """Model version getter.
44
+ def model(self) -> Model:
45
+ """Model getter.
46
46
 
47
47
  Returns:
48
- The model version.
48
+ The model.
49
49
 
50
50
  Raises:
51
- RuntimeError: If the model version is not set.
51
+ RuntimeError: If the model is not set.
52
52
  """
53
- if self._model_version is None:
54
- raise RuntimeError("Model version is not set.")
55
- return self._model_version
53
+ if self._model is None:
54
+ raise RuntimeError("Model is not set.")
55
+ return self._model
56
56
 
57
57
  def update_request(
58
58
  self,
59
- model_version: ModelVersion,
60
- requester: "NewModelVersionRequest.Requester",
59
+ model: Model,
60
+ requester: "NewModelRequest.Requester",
61
61
  ) -> None:
62
- """Update from `ModelVersion` in place.
62
+ """Update from `Model` in place.
63
63
 
64
64
  Args:
65
- model_version: `ModelVersion` to use.
66
- requester: Requester of a new model version.
65
+ model: `Model` to use.
66
+ requester: Requester of a new version of a model.
67
67
  """
68
68
  self.requesters.append(requester)
69
- if self._model_version is None:
70
- self._model_version = model_version
69
+ if self._model is None:
70
+ self._model = model
71
71
 
72
- self._model_version._merge(model_version)
72
+ self._model._merge(model)
@@ -70,7 +70,7 @@ from zenml.models import (
70
70
  ScheduleRequest,
71
71
  )
72
72
  from zenml.new.pipelines import build_utils
73
- from zenml.new.pipelines.model_utils import NewModelVersionRequest
73
+ from zenml.new.pipelines.model_utils import NewModelRequest
74
74
  from zenml.orchestrators.utils import get_run_name
75
75
  from zenml.stack import Stack
76
76
  from zenml.steps import BaseStep
@@ -93,7 +93,7 @@ if TYPE_CHECKING:
93
93
  from zenml.config.base_settings import SettingsOrDict
94
94
  from zenml.config.source import Source
95
95
  from zenml.model.lazy_load import ModelVersionDataLazyLoader
96
- from zenml.model.model_version import ModelVersion
96
+ from zenml.model.model import Model
97
97
 
98
98
  StepConfigurationUpdateOrDict = Union[
99
99
  Dict[str, Any], StepConfigurationUpdate
@@ -126,7 +126,7 @@ class Pipeline:
126
126
  extra: Optional[Dict[str, Any]] = None,
127
127
  on_failure: Optional["HookSpecification"] = None,
128
128
  on_success: Optional["HookSpecification"] = None,
129
- model_version: Optional["ModelVersion"] = None,
129
+ model: Optional["Model"] = None,
130
130
  ) -> None:
131
131
  """Initializes a pipeline.
132
132
 
@@ -147,7 +147,7 @@ class Pipeline:
147
147
  on_success: Callback function in event of success of the step. Can
148
148
  be a function with no arguments, or a source path to such a
149
149
  function (e.g. `module.my_function`).
150
- model_version: configuration of the model version in the Model Control Plane.
150
+ model: configuration of the model in the Model Control Plane.
151
151
  """
152
152
  self._invocations: Dict[str, StepInvocation] = {}
153
153
  self._run_args: Dict[str, Any] = {}
@@ -166,7 +166,7 @@ class Pipeline:
166
166
  extra=extra,
167
167
  on_failure=on_failure,
168
168
  on_success=on_success,
169
- model_version=model_version,
169
+ model=model,
170
170
  )
171
171
  self.entrypoint = entrypoint
172
172
  self._parameters: Dict[str, Any] = {}
@@ -305,7 +305,7 @@ class Pipeline:
305
305
  extra: Optional[Dict[str, Any]] = None,
306
306
  on_failure: Optional["HookSpecification"] = None,
307
307
  on_success: Optional["HookSpecification"] = None,
308
- model_version: Optional["ModelVersion"] = None,
308
+ model: Optional["Model"] = None,
309
309
  parameters: Optional[Dict[str, Any]] = None,
310
310
  merge: bool = True,
311
311
  ) -> T:
@@ -341,7 +341,7 @@ class Pipeline:
341
341
  configurations. If `False` the given configurations will
342
342
  overwrite all existing ones. See the general description of this
343
343
  method for an example.
344
- model_version: configuration of the model version in the Model Control Plane.
344
+ model: configuration of the model version in the Model Control Plane.
345
345
  parameters: input parameters for the pipeline.
346
346
 
347
347
  Returns:
@@ -367,7 +367,7 @@ class Pipeline:
367
367
  "extra": extra,
368
368
  "failure_hook_source": failure_hook_source,
369
369
  "success_hook_source": success_hook_source,
370
- "model_version": model_version,
370
+ "model": model,
371
371
  "parameters": parameters,
372
372
  }
373
373
  )
@@ -872,34 +872,32 @@ To avoid this consider setting pipeline parameters only in one place (config or
872
872
  def _update_new_requesters(
873
873
  self,
874
874
  requester_name: str,
875
- model_version: "ModelVersion",
875
+ model: "Model",
876
876
  new_versions_requested: Dict[
877
- Tuple[str, Optional[str]], NewModelVersionRequest
877
+ Tuple[str, Optional[str]], NewModelRequest
878
878
  ],
879
- other_model_versions: Set["ModelVersion"],
879
+ other_models: Set["Model"],
880
880
  ) -> None:
881
881
  key = (
882
- model_version.name,
883
- str(model_version.version) if model_version.version else None,
882
+ model.name,
883
+ str(model.version) if model.version else None,
884
884
  )
885
- if model_version.version is None:
885
+ if model.version is None:
886
886
  version_existed = False
887
887
  else:
888
888
  try:
889
- model_version._get_model_version()
889
+ model._get_model_version()
890
890
  version_existed = key not in new_versions_requested
891
891
  except KeyError:
892
892
  version_existed = False
893
893
  if not version_existed:
894
- model_version.was_created_in_this_run = True
894
+ model.was_created_in_this_run = True
895
895
  new_versions_requested[key].update_request(
896
- model_version,
897
- NewModelVersionRequest.Requester(
898
- source="step", name=requester_name
899
- ),
896
+ model,
897
+ NewModelRequest.Requester(source="step", name=requester_name),
900
898
  )
901
899
  else:
902
- other_model_versions.add(model_version)
900
+ other_models.add(model)
903
901
 
904
902
  def prepare_model_versions(
905
903
  self, deployment: "PipelineDeploymentBase"
@@ -910,35 +908,32 @@ To avoid this consider setting pipeline parameters only in one place (config or
910
908
  deployment: The pipeline deployment configuration.
911
909
  """
912
910
  new_versions_requested: Dict[
913
- Tuple[str, Optional[str]], NewModelVersionRequest
914
- ] = defaultdict(NewModelVersionRequest)
915
- other_model_versions: Set["ModelVersion"] = set()
911
+ Tuple[str, Optional[str]], NewModelRequest
912
+ ] = defaultdict(NewModelRequest)
913
+ other_models: Set["Model"] = set()
916
914
  all_steps_have_own_config = True
917
915
  for step in deployment.step_configurations.values():
918
- step_model_version = step.config.model_version
916
+ step_model = step.config.model
919
917
  all_steps_have_own_config = (
920
- all_steps_have_own_config
921
- and step.config.model_version is not None
918
+ all_steps_have_own_config and step.config.model is not None
922
919
  )
923
- if step_model_version:
920
+ if step_model:
924
921
  self._update_new_requesters(
925
- model_version=step_model_version,
922
+ model=step_model,
926
923
  requester_name=step.config.name,
927
924
  new_versions_requested=new_versions_requested,
928
- other_model_versions=other_model_versions,
925
+ other_models=other_models,
929
926
  )
930
927
  if not all_steps_have_own_config:
931
- pipeline_model_version = (
932
- deployment.pipeline_configuration.model_version
933
- )
934
- if pipeline_model_version:
928
+ pipeline_model = deployment.pipeline_configuration.model
929
+ if pipeline_model:
935
930
  self._update_new_requesters(
936
- model_version=pipeline_model_version,
931
+ model=pipeline_model,
937
932
  requester_name=self.name,
938
933
  new_versions_requested=new_versions_requested,
939
- other_model_versions=other_model_versions,
934
+ other_models=other_models,
940
935
  )
941
- elif deployment.pipeline_configuration.model_version is not None:
936
+ elif deployment.pipeline_configuration.model is not None:
942
937
  logger.warning(
943
938
  f"ModelConfig of pipeline `{self.name}` is overridden in all "
944
939
  f"steps. "
@@ -946,13 +941,13 @@ To avoid this consider setting pipeline parameters only in one place (config or
946
941
 
947
942
  self._validate_new_version_requests(new_versions_requested)
948
943
 
949
- for other_model_version in other_model_versions:
950
- other_model_version._validate_config_in_runtime()
944
+ for other_model in other_models:
945
+ other_model._validate_config_in_runtime()
951
946
 
952
947
  def _validate_new_version_requests(
953
948
  self,
954
949
  new_versions_requested: Dict[
955
- Tuple[str, Optional[str]], NewModelVersionRequest
950
+ Tuple[str, Optional[str]], NewModelRequest
956
951
  ],
957
952
  ) -> None:
958
953
  """Validate the model version that are used in the pipeline run.
@@ -967,13 +962,13 @@ To avoid this consider setting pipeline parameters only in one place (config or
967
962
  logger.warning(
968
963
  f"New version of model version `{model_name}::{model_version or 'NEW'}` "
969
964
  f"requested in multiple decorators:\n{data.requesters}\n We recommend "
970
- "that `ModelVersion` requesting new version is configured only in one "
965
+ "that `Model` requesting new version is configured only in one "
971
966
  "place of the pipeline."
972
967
  )
973
- data.model_version._validate_config_in_runtime()
968
+ data.model._validate_config_in_runtime()
974
969
  self.__new_unnamed_model_versions_in_current_run__[
975
- data.model_version.name
976
- ] = data.model_version.number
970
+ data.model.name
971
+ ] = data.model.number
977
972
 
978
973
  def get_runs(self, **kwargs: Any) -> List["PipelineRunResponse"]:
979
974
  """(Deprecated) Get runs of this pipeline.
@@ -1400,18 +1395,16 @@ To avoid this consider setting pipeline parameters only in one place (config or
1400
1395
  {k: v for k, v in _from_config_file.items() if k in matcher}
1401
1396
  )
1402
1397
 
1403
- if "model_version" in _from_config_file:
1404
- if "model_version" in self._from_config_file:
1405
- _from_config_file[
1406
- "model_version"
1407
- ] = self._from_config_file["model_version"]
1398
+ if "model" in _from_config_file:
1399
+ if "model" in self._from_config_file:
1400
+ _from_config_file["model"] = self._from_config_file[
1401
+ "model"
1402
+ ]
1408
1403
  else:
1409
- from zenml.model.model_version import ModelVersion
1404
+ from zenml.model.model import Model
1410
1405
 
1411
- _from_config_file[
1412
- "model_version"
1413
- ] = ModelVersion.parse_obj(
1414
- _from_config_file["model_version"]
1406
+ _from_config_file["model"] = Model.parse_obj(
1407
+ _from_config_file["model"]
1415
1408
  )
1416
1409
  self._from_config_file = _from_config_file
1417
1410
 
@@ -107,4 +107,4 @@ class PipelineContext:
107
107
  self.enable_step_logs = pipeline_configuration.enable_step_logs
108
108
  self.settings = pipeline_configuration.settings
109
109
  self.extra = pipeline_configuration.extra
110
- self.model_version = pipeline_configuration.model_version
110
+ self.model = pipeline_configuration.model