zenml-nightly 0.54.1.dev20240118__py3-none-any.whl → 0.54.1.dev20240120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. zenml/__init__.py +10 -4
  2. zenml/artifacts/artifact_config.py +12 -12
  3. zenml/artifacts/external_artifact.py +3 -3
  4. zenml/artifacts/external_artifact_config.py +8 -8
  5. zenml/artifacts/utils.py +4 -4
  6. zenml/cli/__init__.py +4 -4
  7. zenml/cli/artifact.py +38 -18
  8. zenml/cli/base.py +3 -3
  9. zenml/cli/model.py +24 -16
  10. zenml/cli/server.py +9 -0
  11. zenml/cli/utils.py +3 -3
  12. zenml/client.py +13 -2
  13. zenml/config/compiler.py +1 -1
  14. zenml/config/pipeline_configurations.py +2 -2
  15. zenml/config/pipeline_run_configuration.py +2 -2
  16. zenml/config/step_configurations.py +2 -2
  17. zenml/integrations/__init__.py +2 -4
  18. zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +11 -9
  19. zenml/metadata/lazy_load.py +5 -5
  20. zenml/model/lazy_load.py +2 -2
  21. zenml/model/{model_version.py → model.py} +47 -38
  22. zenml/model/utils.py +33 -33
  23. zenml/model_registries/base_model_registry.py +10 -8
  24. zenml/models/__init__.py +2 -0
  25. zenml/models/v2/base/filter.py +3 -0
  26. zenml/models/v2/base/scoped.py +59 -0
  27. zenml/models/v2/core/artifact.py +2 -2
  28. zenml/models/v2/core/artifact_version.py +6 -6
  29. zenml/models/v2/core/model.py +6 -6
  30. zenml/models/v2/core/model_version.py +9 -9
  31. zenml/models/v2/core/run_metadata.py +2 -2
  32. zenml/new/pipelines/model_utils.py +20 -20
  33. zenml/new/pipelines/pipeline.py +47 -54
  34. zenml/new/pipelines/pipeline_context.py +1 -1
  35. zenml/new/pipelines/pipeline_decorator.py +4 -4
  36. zenml/new/steps/step_context.py +15 -15
  37. zenml/new/steps/step_decorator.py +5 -5
  38. zenml/orchestrators/input_utils.py +5 -7
  39. zenml/orchestrators/step_launcher.py +12 -19
  40. zenml/orchestrators/step_runner.py +8 -10
  41. zenml/pipelines/base_pipeline.py +1 -1
  42. zenml/pipelines/pipeline_decorator.py +6 -6
  43. zenml/steps/base_step.py +15 -15
  44. zenml/steps/step_decorator.py +6 -6
  45. zenml/steps/utils.py +68 -0
  46. zenml/zen_server/deploy/helm/templates/server-db-job.yaml +1 -1
  47. zenml/zen_server/deploy/helm/templates/server-secret.yaml +1 -1
  48. zenml/zen_server/deploy/helm/templates/serviceaccount.yaml +1 -1
  49. zenml/zen_server/utils.py +19 -1
  50. zenml/zen_stores/migrations/versions/4d688d8f7aff_rename_model_version_to_model.py +94 -0
  51. zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +16 -4
  52. zenml/zen_stores/rest_zen_store.py +2 -2
  53. zenml/zen_stores/sql_zen_store.py +4 -1
  54. {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/METADATA +1 -1
  55. {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/RECORD +58 -57
  56. {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/LICENSE +0 -0
  57. {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/WHEEL +0 -0
  58. {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/entry_points.txt +0 -0
@@ -23,10 +23,12 @@ from typing import (
23
23
  Optional,
24
24
  Type,
25
25
  TypeVar,
26
+ Union,
26
27
  )
27
28
  from uuid import UUID
28
29
 
29
30
  from pydantic import Field
31
+ from sqlmodel import col
30
32
 
31
33
  from zenml.models.v2.base.base import (
32
34
  BaseRequest,
@@ -37,6 +39,8 @@ from zenml.models.v2.base.base import (
37
39
  from zenml.models.v2.base.filter import AnyQuery, BaseFilter
38
40
 
39
41
  if TYPE_CHECKING:
42
+ from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
43
+
40
44
  from zenml.models.v2.core.user import UserResponse
41
45
  from zenml.models.v2.core.workspace import WorkspaceResponse
42
46
  from zenml.zen_stores.schemas import BaseSchema
@@ -274,3 +278,58 @@ class WorkspaceScopedFilter(BaseFilter):
274
278
  query = query.where(scope_filter)
275
279
 
276
280
  return query
281
+
282
+
283
+ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter):
284
+ """Model to enable advanced scoping with workspace and tagging."""
285
+
286
+ tag: Optional[str] = Field(
287
+ description="Tag to apply to the filter query.", default=None
288
+ )
289
+
290
+ FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
291
+ *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
292
+ "tag",
293
+ ]
294
+
295
+ def apply_filter(
296
+ self,
297
+ query: AnyQuery,
298
+ table: Type["AnySchema"],
299
+ ) -> AnyQuery:
300
+ """Applies the filter to a query.
301
+
302
+ Args:
303
+ query: The query to which to apply the filter.
304
+ table: The query table.
305
+
306
+ Returns:
307
+ The query with filter applied.
308
+ """
309
+ from zenml.zen_stores.schemas import TagResourceSchema
310
+
311
+ query = super().apply_filter(query=query, table=table)
312
+ if self.tag:
313
+ query = (
314
+ query.join(getattr(table, "tags"))
315
+ .join(TagResourceSchema.tag)
316
+ .distinct()
317
+ )
318
+
319
+ return query
320
+
321
+ def get_custom_filters(
322
+ self,
323
+ ) -> List[Union["BinaryExpression[Any]", "BooleanClauseList[Any]"]]:
324
+ """Get custom tag filters.
325
+
326
+ Returns:
327
+ A list of custom filters.
328
+ """
329
+ from zenml.zen_stores.schemas import TagSchema
330
+
331
+ custom_filters = super().get_custom_filters()
332
+ if self.tag:
333
+ custom_filters.append(col(TagSchema.name) == self.tag) # type: ignore[arg-type]
334
+
335
+ return custom_filters
@@ -24,7 +24,7 @@ from zenml.models.v2.base.base import (
24
24
  BaseResponseBody,
25
25
  BaseResponseMetadata,
26
26
  )
27
- from zenml.models.v2.base.filter import BaseFilter
27
+ from zenml.models.v2.base.scoped import WorkspaceScopedTaggableFilter
28
28
  from zenml.models.v2.core.tag import TagResponse
29
29
 
30
30
  if TYPE_CHECKING:
@@ -133,7 +133,7 @@ class ArtifactResponse(
133
133
  # ------------------ Filter Model ------------------
134
134
 
135
135
 
136
- class ArtifactFilter(BaseFilter):
136
+ class ArtifactFilter(WorkspaceScopedTaggableFilter):
137
137
  """Model to enable advanced filtering of artifacts."""
138
138
 
139
139
  name: Optional[str] = None
@@ -32,11 +32,11 @@ from zenml.enums import ArtifactType, GenericFilterOps
32
32
  from zenml.logger import get_logger
33
33
  from zenml.models.v2.base.filter import StrFilter
34
34
  from zenml.models.v2.base.scoped import (
35
- WorkspaceScopedFilter,
36
35
  WorkspaceScopedRequest,
37
36
  WorkspaceScopedResponse,
38
37
  WorkspaceScopedResponseBody,
39
38
  WorkspaceScopedResponseMetadata,
39
+ WorkspaceScopedTaggableFilter,
40
40
  )
41
41
  from zenml.models.v2.core.artifact import ArtifactResponse
42
42
  from zenml.models.v2.core.tag import TagResponse
@@ -44,7 +44,7 @@ from zenml.models.v2.core.tag import TagResponse
44
44
  if TYPE_CHECKING:
45
45
  from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
46
46
 
47
- from zenml.model.model_version import ModelVersion
47
+ from zenml.model.model import Model
48
48
  from zenml.models.v2.core.artifact_visualization import (
49
49
  ArtifactVisualizationRequest,
50
50
  ArtifactVisualizationResponse,
@@ -347,14 +347,14 @@ class ArtifactVersionResponse(
347
347
  # ------------------ Filter Model ------------------
348
348
 
349
349
 
350
- class ArtifactVersionFilter(WorkspaceScopedFilter):
350
+ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
351
351
  """Model to enable advanced filtering of artifact versions."""
352
352
 
353
353
  # `name` and `only_unused` refer to properties related to other entities
354
354
  # rather than a field in the db, hence they needs to be handled
355
355
  # explicitly
356
356
  FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
357
- *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
357
+ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
358
358
  "name",
359
359
  "only_unused",
360
360
  "has_custom_name",
@@ -477,7 +477,7 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse):
477
477
  id: Optional[UUID] = None # type: ignore[assignment]
478
478
  _lazy_load_name: Optional[str] = None
479
479
  _lazy_load_version: Optional[str] = None
480
- _lazy_load_model_version: "ModelVersion"
480
+ _lazy_load_model: "Model"
481
481
 
482
482
  def get_body(self) -> None: # type: ignore[override]
483
483
  """Protects from misuse of the lazy loader.
@@ -507,7 +507,7 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse):
507
507
  from zenml.metadata.lazy_load import RunMetadataLazyGetter
508
508
 
509
509
  return RunMetadataLazyGetter( # type: ignore[return-value]
510
- self._lazy_load_model_version,
510
+ self._lazy_load_model,
511
511
  self._lazy_load_name,
512
512
  self._lazy_load_version,
513
513
  )
@@ -22,16 +22,16 @@ from pydantic import BaseModel, Field
22
22
 
23
23
  from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
24
24
  from zenml.models.v2.base.scoped import (
25
- WorkspaceScopedFilter,
26
25
  WorkspaceScopedRequest,
27
26
  WorkspaceScopedResponse,
28
27
  WorkspaceScopedResponseBody,
29
28
  WorkspaceScopedResponseMetadata,
29
+ WorkspaceScopedTaggableFilter,
30
30
  )
31
31
  from zenml.utils.pagination_utils import depaginate
32
32
 
33
33
  if TYPE_CHECKING:
34
- from zenml.model.model_version import ModelVersion
34
+ from zenml.model.model import Model
35
35
  from zenml.models.v2.core.tag import TagResponse
36
36
 
37
37
 
@@ -310,7 +310,7 @@ class ModelResponse(
310
310
 
311
311
  # Helper functions
312
312
  @property
313
- def versions(self) -> List["ModelVersion"]:
313
+ def versions(self) -> List["Model"]:
314
314
  """List all versions of the model.
315
315
 
316
316
  Returns:
@@ -323,7 +323,7 @@ class ModelResponse(
323
323
  partial(client.list_model_versions, model_name_or_id=self.id)
324
324
  )
325
325
  return [
326
- mv.to_model_version(suppress_class_validation_warnings=True)
326
+ mv.to_model_class(suppress_class_validation_warnings=True)
327
327
  for mv in model_versions
328
328
  ]
329
329
 
@@ -331,7 +331,7 @@ class ModelResponse(
331
331
  # ------------------ Filter Model ------------------
332
332
 
333
333
 
334
- class ModelFilter(WorkspaceScopedFilter):
334
+ class ModelFilter(WorkspaceScopedTaggableFilter):
335
335
  """Model to enable advanced filtering of all Workspaces."""
336
336
 
337
337
  name: Optional[str] = Field(
@@ -346,7 +346,7 @@ class ModelFilter(WorkspaceScopedFilter):
346
346
  )
347
347
 
348
348
  CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [
349
- *WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
349
+ *WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS,
350
350
  "workspace_id",
351
351
  "user_id",
352
352
  ]
@@ -23,16 +23,16 @@ from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
23
23
  from zenml.enums import ModelStages
24
24
  from zenml.models.v2.base.filter import AnyQuery
25
25
  from zenml.models.v2.base.scoped import (
26
- WorkspaceScopedFilter,
27
26
  WorkspaceScopedRequest,
28
27
  WorkspaceScopedResponse,
29
28
  WorkspaceScopedResponseBody,
30
29
  WorkspaceScopedResponseMetadata,
30
+ WorkspaceScopedTaggableFilter,
31
31
  )
32
32
  from zenml.models.v2.core.tag import TagResponse
33
33
 
34
34
  if TYPE_CHECKING:
35
- from zenml.model.model_version import ModelVersion
35
+ from zenml.model.model import Model
36
36
  from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
37
37
  from zenml.models.v2.core.model import ModelResponse
38
38
  from zenml.models.v2.core.pipeline_run import PipelineRunResponse
@@ -311,12 +311,12 @@ class ModelVersionResponse(
311
311
  return Client().zen_store.get_model_version(self.id)
312
312
 
313
313
  # Helper functions
314
- def to_model_version(
314
+ def to_model_class(
315
315
  self,
316
316
  was_created_in_this_run: bool = False,
317
317
  suppress_class_validation_warnings: bool = False,
318
- ) -> "ModelVersion":
319
- """Convert response model to ModelVersion object.
318
+ ) -> "Model":
319
+ """Convert response model to Model object.
320
320
 
321
321
  Args:
322
322
  was_created_in_this_run: Whether model version was created during
@@ -325,11 +325,11 @@ class ModelVersionResponse(
325
325
  repeated warnings.
326
326
 
327
327
  Returns:
328
- ModelVersion object
328
+ Model object
329
329
  """
330
- from zenml.model.model_version import ModelVersion
330
+ from zenml.model.model import Model
331
331
 
332
- mv = ModelVersion(
332
+ mv = Model(
333
333
  name=self.model.name,
334
334
  license=self.model.license,
335
335
  description=self.description,
@@ -578,7 +578,7 @@ class ModelVersionResponse(
578
578
  # ------------------ Filter Model ------------------
579
579
 
580
580
 
581
- class ModelVersionFilter(WorkspaceScopedFilter):
581
+ class ModelVersionFilter(WorkspaceScopedTaggableFilter):
582
582
  """Filter model for model versions."""
583
583
 
584
584
  name: Optional[str] = Field(
@@ -30,7 +30,7 @@ from zenml.models.v2.base.scoped import (
30
30
  )
31
31
 
32
32
  if TYPE_CHECKING:
33
- from zenml.model.model_version import ModelVersion
33
+ from zenml.model.model import Model
34
34
 
35
35
  # ------------------ Request Model ------------------
36
36
 
@@ -203,7 +203,7 @@ class LazyRunMetadataResponse(RunMetadataResponse):
203
203
  _lazy_load_artifact_name: Optional[str] = None
204
204
  _lazy_load_artifact_version: Optional[str] = None
205
205
  _lazy_load_metadata_name: Optional[str] = None
206
- _lazy_load_model_version: "ModelVersion"
206
+ _lazy_load_model: "Model"
207
207
 
208
208
  def get_body(self) -> None: # type: ignore[override]
209
209
  """Protects from misuse of the lazy loader.
@@ -17,14 +17,14 @@ from typing import List, Optional
17
17
 
18
18
  from pydantic import BaseModel, PrivateAttr
19
19
 
20
- from zenml.model.model_version import ModelVersion
20
+ from zenml.model.model import Model
21
21
 
22
22
 
23
- class NewModelVersionRequest(BaseModel):
24
- """Request to create a new model version."""
23
+ class NewModelRequest(BaseModel):
24
+ """Request to create a new version of a model."""
25
25
 
26
26
  class Requester(BaseModel):
27
- """Requester of a new model version."""
27
+ """Requester of a new version of a model."""
28
28
 
29
29
  source: str
30
30
  name: str
@@ -38,35 +38,35 @@ class NewModelVersionRequest(BaseModel):
38
38
  return f"{self.source}::{self.name}"
39
39
 
40
40
  requesters: List[Requester] = []
41
- _model_version: Optional[ModelVersion] = PrivateAttr(default=None)
41
+ _model: Optional[Model] = PrivateAttr(default=None)
42
42
 
43
43
  @property
44
- def model_version(self) -> ModelVersion:
45
- """Model version getter.
44
+ def model(self) -> Model:
45
+ """Model getter.
46
46
 
47
47
  Returns:
48
- The model version.
48
+ The model.
49
49
 
50
50
  Raises:
51
- RuntimeError: If the model version is not set.
51
+ RuntimeError: If the model is not set.
52
52
  """
53
- if self._model_version is None:
54
- raise RuntimeError("Model version is not set.")
55
- return self._model_version
53
+ if self._model is None:
54
+ raise RuntimeError("Model is not set.")
55
+ return self._model
56
56
 
57
57
  def update_request(
58
58
  self,
59
- model_version: ModelVersion,
60
- requester: "NewModelVersionRequest.Requester",
59
+ model: Model,
60
+ requester: "NewModelRequest.Requester",
61
61
  ) -> None:
62
- """Update from `ModelVersion` in place.
62
+ """Update from `Model` in place.
63
63
 
64
64
  Args:
65
- model_version: `ModelVersion` to use.
66
- requester: Requester of a new model version.
65
+ model: `Model` to use.
66
+ requester: Requester of a new version of a model.
67
67
  """
68
68
  self.requesters.append(requester)
69
- if self._model_version is None:
70
- self._model_version = model_version
69
+ if self._model is None:
70
+ self._model = model
71
71
 
72
- self._model_version._merge(model_version)
72
+ self._model._merge(model)
@@ -70,7 +70,7 @@ from zenml.models import (
70
70
  ScheduleRequest,
71
71
  )
72
72
  from zenml.new.pipelines import build_utils
73
- from zenml.new.pipelines.model_utils import NewModelVersionRequest
73
+ from zenml.new.pipelines.model_utils import NewModelRequest
74
74
  from zenml.orchestrators.utils import get_run_name
75
75
  from zenml.stack import Stack
76
76
  from zenml.steps import BaseStep
@@ -93,7 +93,7 @@ if TYPE_CHECKING:
93
93
  from zenml.config.base_settings import SettingsOrDict
94
94
  from zenml.config.source import Source
95
95
  from zenml.model.lazy_load import ModelVersionDataLazyLoader
96
- from zenml.model.model_version import ModelVersion
96
+ from zenml.model.model import Model
97
97
 
98
98
  StepConfigurationUpdateOrDict = Union[
99
99
  Dict[str, Any], StepConfigurationUpdate
@@ -126,7 +126,7 @@ class Pipeline:
126
126
  extra: Optional[Dict[str, Any]] = None,
127
127
  on_failure: Optional["HookSpecification"] = None,
128
128
  on_success: Optional["HookSpecification"] = None,
129
- model_version: Optional["ModelVersion"] = None,
129
+ model: Optional["Model"] = None,
130
130
  ) -> None:
131
131
  """Initializes a pipeline.
132
132
 
@@ -147,7 +147,7 @@ class Pipeline:
147
147
  on_success: Callback function in event of success of the step. Can
148
148
  be a function with no arguments, or a source path to such a
149
149
  function (e.g. `module.my_function`).
150
- model_version: configuration of the model version in the Model Control Plane.
150
+ model: configuration of the model in the Model Control Plane.
151
151
  """
152
152
  self._invocations: Dict[str, StepInvocation] = {}
153
153
  self._run_args: Dict[str, Any] = {}
@@ -166,7 +166,7 @@ class Pipeline:
166
166
  extra=extra,
167
167
  on_failure=on_failure,
168
168
  on_success=on_success,
169
- model_version=model_version,
169
+ model=model,
170
170
  )
171
171
  self.entrypoint = entrypoint
172
172
  self._parameters: Dict[str, Any] = {}
@@ -305,7 +305,7 @@ class Pipeline:
305
305
  extra: Optional[Dict[str, Any]] = None,
306
306
  on_failure: Optional["HookSpecification"] = None,
307
307
  on_success: Optional["HookSpecification"] = None,
308
- model_version: Optional["ModelVersion"] = None,
308
+ model: Optional["Model"] = None,
309
309
  parameters: Optional[Dict[str, Any]] = None,
310
310
  merge: bool = True,
311
311
  ) -> T:
@@ -341,7 +341,7 @@ class Pipeline:
341
341
  configurations. If `False` the given configurations will
342
342
  overwrite all existing ones. See the general description of this
343
343
  method for an example.
344
- model_version: configuration of the model version in the Model Control Plane.
344
+ model: configuration of the model version in the Model Control Plane.
345
345
  parameters: input parameters for the pipeline.
346
346
 
347
347
  Returns:
@@ -367,7 +367,7 @@ class Pipeline:
367
367
  "extra": extra,
368
368
  "failure_hook_source": failure_hook_source,
369
369
  "success_hook_source": success_hook_source,
370
- "model_version": model_version,
370
+ "model": model,
371
371
  "parameters": parameters,
372
372
  }
373
373
  )
@@ -872,34 +872,32 @@ To avoid this consider setting pipeline parameters only in one place (config or
872
872
  def _update_new_requesters(
873
873
  self,
874
874
  requester_name: str,
875
- model_version: "ModelVersion",
875
+ model: "Model",
876
876
  new_versions_requested: Dict[
877
- Tuple[str, Optional[str]], NewModelVersionRequest
877
+ Tuple[str, Optional[str]], NewModelRequest
878
878
  ],
879
- other_model_versions: Set["ModelVersion"],
879
+ other_models: Set["Model"],
880
880
  ) -> None:
881
881
  key = (
882
- model_version.name,
883
- str(model_version.version) if model_version.version else None,
882
+ model.name,
883
+ str(model.version) if model.version else None,
884
884
  )
885
- if model_version.version is None:
885
+ if model.version is None:
886
886
  version_existed = False
887
887
  else:
888
888
  try:
889
- model_version._get_model_version()
889
+ model._get_model_version()
890
890
  version_existed = key not in new_versions_requested
891
891
  except KeyError:
892
892
  version_existed = False
893
893
  if not version_existed:
894
- model_version.was_created_in_this_run = True
894
+ model.was_created_in_this_run = True
895
895
  new_versions_requested[key].update_request(
896
- model_version,
897
- NewModelVersionRequest.Requester(
898
- source="step", name=requester_name
899
- ),
896
+ model,
897
+ NewModelRequest.Requester(source="step", name=requester_name),
900
898
  )
901
899
  else:
902
- other_model_versions.add(model_version)
900
+ other_models.add(model)
903
901
 
904
902
  def prepare_model_versions(
905
903
  self, deployment: "PipelineDeploymentBase"
@@ -910,35 +908,32 @@ To avoid this consider setting pipeline parameters only in one place (config or
910
908
  deployment: The pipeline deployment configuration.
911
909
  """
912
910
  new_versions_requested: Dict[
913
- Tuple[str, Optional[str]], NewModelVersionRequest
914
- ] = defaultdict(NewModelVersionRequest)
915
- other_model_versions: Set["ModelVersion"] = set()
911
+ Tuple[str, Optional[str]], NewModelRequest
912
+ ] = defaultdict(NewModelRequest)
913
+ other_models: Set["Model"] = set()
916
914
  all_steps_have_own_config = True
917
915
  for step in deployment.step_configurations.values():
918
- step_model_version = step.config.model_version
916
+ step_model = step.config.model
919
917
  all_steps_have_own_config = (
920
- all_steps_have_own_config
921
- and step.config.model_version is not None
918
+ all_steps_have_own_config and step.config.model is not None
922
919
  )
923
- if step_model_version:
920
+ if step_model:
924
921
  self._update_new_requesters(
925
- model_version=step_model_version,
922
+ model=step_model,
926
923
  requester_name=step.config.name,
927
924
  new_versions_requested=new_versions_requested,
928
- other_model_versions=other_model_versions,
925
+ other_models=other_models,
929
926
  )
930
927
  if not all_steps_have_own_config:
931
- pipeline_model_version = (
932
- deployment.pipeline_configuration.model_version
933
- )
934
- if pipeline_model_version:
928
+ pipeline_model = deployment.pipeline_configuration.model
929
+ if pipeline_model:
935
930
  self._update_new_requesters(
936
- model_version=pipeline_model_version,
931
+ model=pipeline_model,
937
932
  requester_name=self.name,
938
933
  new_versions_requested=new_versions_requested,
939
- other_model_versions=other_model_versions,
934
+ other_models=other_models,
940
935
  )
941
- elif deployment.pipeline_configuration.model_version is not None:
936
+ elif deployment.pipeline_configuration.model is not None:
942
937
  logger.warning(
943
938
  f"ModelConfig of pipeline `{self.name}` is overridden in all "
944
939
  f"steps. "
@@ -946,13 +941,13 @@ To avoid this consider setting pipeline parameters only in one place (config or
946
941
 
947
942
  self._validate_new_version_requests(new_versions_requested)
948
943
 
949
- for other_model_version in other_model_versions:
950
- other_model_version._validate_config_in_runtime()
944
+ for other_model in other_models:
945
+ other_model._validate_config_in_runtime()
951
946
 
952
947
  def _validate_new_version_requests(
953
948
  self,
954
949
  new_versions_requested: Dict[
955
- Tuple[str, Optional[str]], NewModelVersionRequest
950
+ Tuple[str, Optional[str]], NewModelRequest
956
951
  ],
957
952
  ) -> None:
958
953
  """Validate the model version that are used in the pipeline run.
@@ -967,13 +962,13 @@ To avoid this consider setting pipeline parameters only in one place (config or
967
962
  logger.warning(
968
963
  f"New version of model version `{model_name}::{model_version or 'NEW'}` "
969
964
  f"requested in multiple decorators:\n{data.requesters}\n We recommend "
970
- "that `ModelVersion` requesting new version is configured only in one "
965
+ "that `Model` requesting new version is configured only in one "
971
966
  "place of the pipeline."
972
967
  )
973
- data.model_version._validate_config_in_runtime()
968
+ data.model._validate_config_in_runtime()
974
969
  self.__new_unnamed_model_versions_in_current_run__[
975
- data.model_version.name
976
- ] = data.model_version.number
970
+ data.model.name
971
+ ] = data.model.number
977
972
 
978
973
  def get_runs(self, **kwargs: Any) -> List["PipelineRunResponse"]:
979
974
  """(Deprecated) Get runs of this pipeline.
@@ -1400,18 +1395,16 @@ To avoid this consider setting pipeline parameters only in one place (config or
1400
1395
  {k: v for k, v in _from_config_file.items() if k in matcher}
1401
1396
  )
1402
1397
 
1403
- if "model_version" in _from_config_file:
1404
- if "model_version" in self._from_config_file:
1405
- _from_config_file[
1406
- "model_version"
1407
- ] = self._from_config_file["model_version"]
1398
+ if "model" in _from_config_file:
1399
+ if "model" in self._from_config_file:
1400
+ _from_config_file["model"] = self._from_config_file[
1401
+ "model"
1402
+ ]
1408
1403
  else:
1409
- from zenml.model.model_version import ModelVersion
1404
+ from zenml.model.model import Model
1410
1405
 
1411
- _from_config_file[
1412
- "model_version"
1413
- ] = ModelVersion.parse_obj(
1414
- _from_config_file["model_version"]
1406
+ _from_config_file["model"] = Model.parse_obj(
1407
+ _from_config_file["model"]
1415
1408
  )
1416
1409
  self._from_config_file = _from_config_file
1417
1410
 
@@ -107,4 +107,4 @@ class PipelineContext:
107
107
  self.enable_step_logs = pipeline_configuration.enable_step_logs
108
108
  self.settings = pipeline_configuration.settings
109
109
  self.extra = pipeline_configuration.extra
110
- self.model_version = pipeline_configuration.model_version
110
+ self.model = pipeline_configuration.model
@@ -26,7 +26,7 @@ from typing import (
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  from zenml.config.base_settings import SettingsOrDict
29
- from zenml.model.model_version import ModelVersion
29
+ from zenml.model.model import Model
30
30
  from zenml.new.pipelines.pipeline import Pipeline
31
31
 
32
32
  HookSpecification = Union[str, FunctionType]
@@ -62,7 +62,7 @@ def pipeline(
62
62
  extra: Optional[Dict[str, Any]] = None,
63
63
  on_failure: Optional["HookSpecification"] = None,
64
64
  on_success: Optional["HookSpecification"] = None,
65
- model_version: Optional["ModelVersion"] = None,
65
+ model: Optional["Model"] = None,
66
66
  ) -> Union["Pipeline", Callable[["F"], "Pipeline"]]:
67
67
  """Decorator to create a pipeline.
68
68
 
@@ -81,7 +81,7 @@ def pipeline(
81
81
  on_success: Callback function in event of success of the step. Can be a
82
82
  function with no arguments, or a source path to such a function
83
83
  (e.g. `module.my_function`).
84
- model_version: configuration of the model version in the Model Control Plane.
84
+ model: configuration of the model in the Model Control Plane.
85
85
 
86
86
  Returns:
87
87
  A pipeline instance.
@@ -99,7 +99,7 @@ def pipeline(
99
99
  extra=extra,
100
100
  on_failure=on_failure,
101
101
  on_success=on_success,
102
- model_version=model_version,
102
+ model=model,
103
103
  entrypoint=func,
104
104
  )
105
105