zenml-nightly 0.54.1.dev20240119__py3-none-any.whl → 0.54.1.dev20240120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/__init__.py +10 -4
- zenml/artifacts/artifact_config.py +12 -12
- zenml/artifacts/external_artifact.py +3 -3
- zenml/artifacts/external_artifact_config.py +8 -8
- zenml/artifacts/utils.py +4 -4
- zenml/cli/__init__.py +1 -1
- zenml/cli/base.py +3 -3
- zenml/cli/utils.py +3 -3
- zenml/config/compiler.py +1 -1
- zenml/config/pipeline_configurations.py +2 -2
- zenml/config/pipeline_run_configuration.py +2 -2
- zenml/config/step_configurations.py +2 -2
- zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +11 -9
- zenml/metadata/lazy_load.py +5 -5
- zenml/model/lazy_load.py +2 -2
- zenml/model/{model_version.py → model.py} +35 -28
- zenml/model/utils.py +33 -33
- zenml/model_registries/base_model_registry.py +10 -8
- zenml/models/v2/core/artifact_version.py +3 -3
- zenml/models/v2/core/model.py +3 -3
- zenml/models/v2/core/model_version.py +7 -7
- zenml/models/v2/core/run_metadata.py +2 -2
- zenml/new/pipelines/model_utils.py +20 -20
- zenml/new/pipelines/pipeline.py +47 -54
- zenml/new/pipelines/pipeline_context.py +1 -1
- zenml/new/pipelines/pipeline_decorator.py +4 -4
- zenml/new/steps/step_context.py +15 -15
- zenml/new/steps/step_decorator.py +5 -5
- zenml/orchestrators/input_utils.py +5 -7
- zenml/orchestrators/step_launcher.py +12 -19
- zenml/orchestrators/step_runner.py +8 -10
- zenml/pipelines/base_pipeline.py +1 -1
- zenml/pipelines/pipeline_decorator.py +6 -6
- zenml/steps/base_step.py +15 -15
- zenml/steps/step_decorator.py +6 -6
- zenml/steps/utils.py +68 -0
- zenml/zen_stores/migrations/versions/4d688d8f7aff_rename_model_version_to_model.py +94 -0
- {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/METADATA +1 -1
- {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/RECORD +42 -41
- {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/entry_points.txt +0 -0
zenml/model/utils.py
CHANGED
@@ -22,7 +22,7 @@ from zenml.enums import ModelStages
|
|
22
22
|
from zenml.exceptions import StepContextError
|
23
23
|
from zenml.logger import get_logger
|
24
24
|
from zenml.metadata.metadata_types import MetadataType
|
25
|
-
from zenml.model.
|
25
|
+
from zenml.model.model import Model
|
26
26
|
from zenml.models import ModelVersionArtifactRequest
|
27
27
|
from zenml.new.steps.step_context import get_step_context
|
28
28
|
|
@@ -48,9 +48,9 @@ def link_step_artifacts_to_model(
|
|
48
48
|
"step."
|
49
49
|
)
|
50
50
|
try:
|
51
|
-
|
51
|
+
model = step_context.model
|
52
52
|
except StepContextError:
|
53
|
-
|
53
|
+
model = None
|
54
54
|
logger.debug("No model context found, unable to auto-link artifacts.")
|
55
55
|
|
56
56
|
for artifact_name, artifact_version_id in artifact_version_ids.items():
|
@@ -59,47 +59,47 @@ def link_step_artifacts_to_model(
|
|
59
59
|
).artifact_config
|
60
60
|
|
61
61
|
# Implicit linking
|
62
|
-
if artifact_config is None and
|
62
|
+
if artifact_config is None and model is not None:
|
63
63
|
artifact_config = ArtifactConfig(name=artifact_name)
|
64
64
|
logger.info(
|
65
65
|
f"Implicitly linking artifact `{artifact_name}` to model "
|
66
|
-
f"`{
|
66
|
+
f"`{model.name}` version `{model.version}`."
|
67
67
|
)
|
68
68
|
|
69
69
|
if artifact_config:
|
70
|
-
|
70
|
+
link_artifact_config_to_model(
|
71
71
|
artifact_config=artifact_config,
|
72
72
|
artifact_version_id=artifact_version_id,
|
73
|
-
|
73
|
+
model=model,
|
74
74
|
)
|
75
75
|
|
76
76
|
|
77
|
-
def
|
77
|
+
def link_artifact_config_to_model(
|
78
78
|
artifact_config: ArtifactConfig,
|
79
79
|
artifact_version_id: UUID,
|
80
|
-
|
80
|
+
model: Optional["Model"] = None,
|
81
81
|
) -> None:
|
82
82
|
"""Link an artifact config to its model version.
|
83
83
|
|
84
84
|
Args:
|
85
85
|
artifact_config: The artifact config to link.
|
86
86
|
artifact_version_id: The ID of the artifact to link.
|
87
|
-
|
87
|
+
model: The model version from the step or pipeline context.
|
88
88
|
"""
|
89
89
|
client = Client()
|
90
90
|
|
91
91
|
# If the artifact config specifies a model itself then always use that
|
92
92
|
if artifact_config.model_name is not None:
|
93
|
-
from zenml.model.
|
93
|
+
from zenml.model.model import Model
|
94
94
|
|
95
|
-
|
95
|
+
model = Model(
|
96
96
|
name=artifact_config.model_name,
|
97
97
|
version=artifact_config.model_version,
|
98
98
|
)
|
99
99
|
|
100
|
-
if
|
101
|
-
|
102
|
-
model_version_response =
|
100
|
+
if model:
|
101
|
+
model._get_or_create_model_version()
|
102
|
+
model_version_response = model._get_model_version()
|
103
103
|
request = ModelVersionArtifactRequest(
|
104
104
|
user=client.active_user.id,
|
105
105
|
workspace=client.active_workspace.id,
|
@@ -125,10 +125,10 @@ def log_model_version_metadata(
|
|
125
125
|
metadata: The metadata to log.
|
126
126
|
model_name: The name of the model to log metadata for. Can
|
127
127
|
be omitted when being called inside a step with configured
|
128
|
-
`
|
128
|
+
`model` in decorator.
|
129
129
|
model_version: The version of the model to log metadata for. Can
|
130
130
|
be omitted when being called inside a step with configured
|
131
|
-
`
|
131
|
+
`model` in decorator.
|
132
132
|
"""
|
133
133
|
logger.warning(
|
134
134
|
"`log_model_version_metadata` is deprecated. Please use "
|
@@ -152,38 +152,38 @@ def log_model_metadata(
|
|
152
152
|
metadata: The metadata to log.
|
153
153
|
model_name: The name of the model to log metadata for. Can
|
154
154
|
be omitted when being called inside a step with configured
|
155
|
-
`
|
155
|
+
`model` in decorator.
|
156
156
|
model_version: The version of the model to log metadata for. Can
|
157
157
|
be omitted when being called inside a step with configured
|
158
|
-
`
|
158
|
+
`model` in decorator.
|
159
159
|
|
160
160
|
Raises:
|
161
161
|
ValueError: If no model name/version is provided and the function is not
|
162
|
-
called inside a step with configured `
|
162
|
+
called inside a step with configured `model` in decorator.
|
163
163
|
"""
|
164
164
|
mv = None
|
165
165
|
try:
|
166
166
|
step_context = get_step_context()
|
167
|
-
mv = step_context.
|
167
|
+
mv = step_context.model
|
168
168
|
except RuntimeError:
|
169
169
|
step_context = None
|
170
170
|
|
171
171
|
if not step_context and not (model_name and model_version):
|
172
172
|
raise ValueError(
|
173
173
|
"Model name and version must be provided unless the function is "
|
174
|
-
"called inside a step with configured `
|
174
|
+
"called inside a step with configured `model` in decorator."
|
175
175
|
)
|
176
176
|
if mv is None:
|
177
|
-
from zenml import
|
177
|
+
from zenml import Model
|
178
178
|
|
179
|
-
mv =
|
179
|
+
mv = Model(name=model_name, version=model_version)
|
180
180
|
|
181
181
|
mv.log_metadata(metadata)
|
182
182
|
|
183
183
|
|
184
184
|
def link_artifact_to_model(
|
185
185
|
artifact_version_id: UUID,
|
186
|
-
|
186
|
+
model: Optional["Model"] = None,
|
187
187
|
is_model_artifact: bool = False,
|
188
188
|
is_deployment_artifact: bool = False,
|
189
189
|
) -> None:
|
@@ -191,34 +191,34 @@ def link_artifact_to_model(
|
|
191
191
|
|
192
192
|
Args:
|
193
193
|
artifact_version_id: The ID of the artifact version.
|
194
|
-
|
194
|
+
model: The model to link to.
|
195
195
|
is_model_artifact: Whether the artifact is a model artifact.
|
196
196
|
is_deployment_artifact: Whether the artifact is a deployment artifact.
|
197
197
|
|
198
198
|
Raises:
|
199
199
|
RuntimeError: If called outside of a step.
|
200
200
|
"""
|
201
|
-
if not
|
201
|
+
if not model:
|
202
202
|
is_issue = False
|
203
203
|
try:
|
204
204
|
step_context = get_step_context()
|
205
|
-
|
205
|
+
model = step_context.model
|
206
206
|
except StepContextError:
|
207
207
|
is_issue = True
|
208
208
|
|
209
|
-
if
|
209
|
+
if model is None or is_issue:
|
210
210
|
raise RuntimeError(
|
211
|
-
"`link_artifact_to_model` called without `
|
211
|
+
"`link_artifact_to_model` called without `model` parameter "
|
212
212
|
"and configured model context cannot be identified. Consider "
|
213
|
-
"passing the `
|
213
|
+
"passing the `model` explicitly or configuring it in "
|
214
214
|
"@step or @pipeline decorator."
|
215
215
|
)
|
216
216
|
|
217
|
-
|
217
|
+
link_artifact_config_to_model(
|
218
218
|
artifact_config=ArtifactConfig(
|
219
219
|
is_model_artifact=is_model_artifact,
|
220
220
|
is_deployment_artifact=is_deployment_artifact,
|
221
221
|
),
|
222
222
|
artifact_version_id=artifact_version_id,
|
223
|
-
|
223
|
+
model=model,
|
224
224
|
)
|
@@ -123,10 +123,10 @@ class ModelRegistryModelMetadata(BaseModel):
|
|
123
123
|
extra = "allow"
|
124
124
|
|
125
125
|
|
126
|
-
class
|
126
|
+
class RegistryModelVersion(BaseModel):
|
127
127
|
"""Base class for all ZenML model versions.
|
128
128
|
|
129
|
-
The `
|
129
|
+
The `RegistryModelVersion` class represents a version or snapshot of a registered
|
130
130
|
model, including information such as the associated `ModelBundle`, version
|
131
131
|
number, creation time, pipeline run information, and metadata. It serves as
|
132
132
|
a blueprint for creating concrete model version implementations in a registry,
|
@@ -288,7 +288,7 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
288
288
|
description: Optional[str] = None,
|
289
289
|
metadata: Optional[ModelRegistryModelMetadata] = None,
|
290
290
|
**kwargs: Any,
|
291
|
-
) ->
|
291
|
+
) -> RegistryModelVersion:
|
292
292
|
"""Registers a model version in the model registry.
|
293
293
|
|
294
294
|
Args:
|
@@ -333,7 +333,7 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
333
333
|
metadata: Optional[ModelRegistryModelMetadata] = None,
|
334
334
|
remove_metadata: Optional[List[str]] = None,
|
335
335
|
stage: Optional[ModelVersionStage] = None,
|
336
|
-
) ->
|
336
|
+
) -> RegistryModelVersion:
|
337
337
|
"""Updates a model version in the model registry.
|
338
338
|
|
339
339
|
Args:
|
@@ -364,7 +364,7 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
364
364
|
created_before: Optional[datetime] = None,
|
365
365
|
order_by_date: Optional[str] = None,
|
366
366
|
**kwargs: Any,
|
367
|
-
) -> Optional[List[
|
367
|
+
) -> Optional[List[RegistryModelVersion]]:
|
368
368
|
"""Lists all model versions for a registered model.
|
369
369
|
|
370
370
|
Args:
|
@@ -387,7 +387,7 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
387
387
|
self,
|
388
388
|
name: str,
|
389
389
|
stage: Optional[ModelVersionStage] = None,
|
390
|
-
) -> Optional[
|
390
|
+
) -> Optional[RegistryModelVersion]:
|
391
391
|
"""Gets the latest model version for a registered model.
|
392
392
|
|
393
393
|
This method is used to get the latest model version for a registered
|
@@ -410,7 +410,9 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
410
410
|
return None
|
411
411
|
|
412
412
|
@abstractmethod
|
413
|
-
def get_model_version(
|
413
|
+
def get_model_version(
|
414
|
+
self, name: str, version: str
|
415
|
+
) -> RegistryModelVersion:
|
414
416
|
"""Gets a model version for a registered model.
|
415
417
|
|
416
418
|
Args:
|
@@ -450,7 +452,7 @@ class BaseModelRegistry(StackComponent, ABC):
|
|
450
452
|
@abstractmethod
|
451
453
|
def get_model_uri_artifact_store(
|
452
454
|
self,
|
453
|
-
model_version:
|
455
|
+
model_version: RegistryModelVersion,
|
454
456
|
) -> str:
|
455
457
|
"""Gets the URI artifact store for a model version.
|
456
458
|
|
@@ -44,7 +44,7 @@ from zenml.models.v2.core.tag import TagResponse
|
|
44
44
|
if TYPE_CHECKING:
|
45
45
|
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
|
46
46
|
|
47
|
-
from zenml.model.
|
47
|
+
from zenml.model.model import Model
|
48
48
|
from zenml.models.v2.core.artifact_visualization import (
|
49
49
|
ArtifactVisualizationRequest,
|
50
50
|
ArtifactVisualizationResponse,
|
@@ -477,7 +477,7 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse):
|
|
477
477
|
id: Optional[UUID] = None # type: ignore[assignment]
|
478
478
|
_lazy_load_name: Optional[str] = None
|
479
479
|
_lazy_load_version: Optional[str] = None
|
480
|
-
|
480
|
+
_lazy_load_model: "Model"
|
481
481
|
|
482
482
|
def get_body(self) -> None: # type: ignore[override]
|
483
483
|
"""Protects from misuse of the lazy loader.
|
@@ -507,7 +507,7 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse):
|
|
507
507
|
from zenml.metadata.lazy_load import RunMetadataLazyGetter
|
508
508
|
|
509
509
|
return RunMetadataLazyGetter( # type: ignore[return-value]
|
510
|
-
self.
|
510
|
+
self._lazy_load_model,
|
511
511
|
self._lazy_load_name,
|
512
512
|
self._lazy_load_version,
|
513
513
|
)
|
zenml/models/v2/core/model.py
CHANGED
@@ -31,7 +31,7 @@ from zenml.models.v2.base.scoped import (
|
|
31
31
|
from zenml.utils.pagination_utils import depaginate
|
32
32
|
|
33
33
|
if TYPE_CHECKING:
|
34
|
-
from zenml.model.
|
34
|
+
from zenml.model.model import Model
|
35
35
|
from zenml.models.v2.core.tag import TagResponse
|
36
36
|
|
37
37
|
|
@@ -310,7 +310,7 @@ class ModelResponse(
|
|
310
310
|
|
311
311
|
# Helper functions
|
312
312
|
@property
|
313
|
-
def versions(self) -> List["
|
313
|
+
def versions(self) -> List["Model"]:
|
314
314
|
"""List all versions of the model.
|
315
315
|
|
316
316
|
Returns:
|
@@ -323,7 +323,7 @@ class ModelResponse(
|
|
323
323
|
partial(client.list_model_versions, model_name_or_id=self.id)
|
324
324
|
)
|
325
325
|
return [
|
326
|
-
mv.
|
326
|
+
mv.to_model_class(suppress_class_validation_warnings=True)
|
327
327
|
for mv in model_versions
|
328
328
|
]
|
329
329
|
|
@@ -32,7 +32,7 @@ from zenml.models.v2.base.scoped import (
|
|
32
32
|
from zenml.models.v2.core.tag import TagResponse
|
33
33
|
|
34
34
|
if TYPE_CHECKING:
|
35
|
-
from zenml.model.
|
35
|
+
from zenml.model.model import Model
|
36
36
|
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
|
37
37
|
from zenml.models.v2.core.model import ModelResponse
|
38
38
|
from zenml.models.v2.core.pipeline_run import PipelineRunResponse
|
@@ -311,12 +311,12 @@ class ModelVersionResponse(
|
|
311
311
|
return Client().zen_store.get_model_version(self.id)
|
312
312
|
|
313
313
|
# Helper functions
|
314
|
-
def
|
314
|
+
def to_model_class(
|
315
315
|
self,
|
316
316
|
was_created_in_this_run: bool = False,
|
317
317
|
suppress_class_validation_warnings: bool = False,
|
318
|
-
) -> "
|
319
|
-
"""Convert response model to
|
318
|
+
) -> "Model":
|
319
|
+
"""Convert response model to Model object.
|
320
320
|
|
321
321
|
Args:
|
322
322
|
was_created_in_this_run: Whether model version was created during
|
@@ -325,11 +325,11 @@ class ModelVersionResponse(
|
|
325
325
|
repeated warnings.
|
326
326
|
|
327
327
|
Returns:
|
328
|
-
|
328
|
+
Model object
|
329
329
|
"""
|
330
|
-
from zenml.model.
|
330
|
+
from zenml.model.model import Model
|
331
331
|
|
332
|
-
mv =
|
332
|
+
mv = Model(
|
333
333
|
name=self.model.name,
|
334
334
|
license=self.model.license,
|
335
335
|
description=self.description,
|
@@ -30,7 +30,7 @@ from zenml.models.v2.base.scoped import (
|
|
30
30
|
)
|
31
31
|
|
32
32
|
if TYPE_CHECKING:
|
33
|
-
from zenml.model.
|
33
|
+
from zenml.model.model import Model
|
34
34
|
|
35
35
|
# ------------------ Request Model ------------------
|
36
36
|
|
@@ -203,7 +203,7 @@ class LazyRunMetadataResponse(RunMetadataResponse):
|
|
203
203
|
_lazy_load_artifact_name: Optional[str] = None
|
204
204
|
_lazy_load_artifact_version: Optional[str] = None
|
205
205
|
_lazy_load_metadata_name: Optional[str] = None
|
206
|
-
|
206
|
+
_lazy_load_model: "Model"
|
207
207
|
|
208
208
|
def get_body(self) -> None: # type: ignore[override]
|
209
209
|
"""Protects from misuse of the lazy loader.
|
@@ -17,14 +17,14 @@ from typing import List, Optional
|
|
17
17
|
|
18
18
|
from pydantic import BaseModel, PrivateAttr
|
19
19
|
|
20
|
-
from zenml.model.
|
20
|
+
from zenml.model.model import Model
|
21
21
|
|
22
22
|
|
23
|
-
class
|
24
|
-
"""Request to create a new model
|
23
|
+
class NewModelRequest(BaseModel):
|
24
|
+
"""Request to create a new version of a model."""
|
25
25
|
|
26
26
|
class Requester(BaseModel):
|
27
|
-
"""Requester of a new model
|
27
|
+
"""Requester of a new version of a model."""
|
28
28
|
|
29
29
|
source: str
|
30
30
|
name: str
|
@@ -38,35 +38,35 @@ class NewModelVersionRequest(BaseModel):
|
|
38
38
|
return f"{self.source}::{self.name}"
|
39
39
|
|
40
40
|
requesters: List[Requester] = []
|
41
|
-
|
41
|
+
_model: Optional[Model] = PrivateAttr(default=None)
|
42
42
|
|
43
43
|
@property
|
44
|
-
def
|
45
|
-
"""Model
|
44
|
+
def model(self) -> Model:
|
45
|
+
"""Model getter.
|
46
46
|
|
47
47
|
Returns:
|
48
|
-
The model
|
48
|
+
The model.
|
49
49
|
|
50
50
|
Raises:
|
51
|
-
RuntimeError: If the model
|
51
|
+
RuntimeError: If the model is not set.
|
52
52
|
"""
|
53
|
-
if self.
|
54
|
-
raise RuntimeError("Model
|
55
|
-
return self.
|
53
|
+
if self._model is None:
|
54
|
+
raise RuntimeError("Model is not set.")
|
55
|
+
return self._model
|
56
56
|
|
57
57
|
def update_request(
|
58
58
|
self,
|
59
|
-
|
60
|
-
requester: "
|
59
|
+
model: Model,
|
60
|
+
requester: "NewModelRequest.Requester",
|
61
61
|
) -> None:
|
62
|
-
"""Update from `
|
62
|
+
"""Update from `Model` in place.
|
63
63
|
|
64
64
|
Args:
|
65
|
-
|
66
|
-
requester: Requester of a new model
|
65
|
+
model: `Model` to use.
|
66
|
+
requester: Requester of a new version of a model.
|
67
67
|
"""
|
68
68
|
self.requesters.append(requester)
|
69
|
-
if self.
|
70
|
-
self.
|
69
|
+
if self._model is None:
|
70
|
+
self._model = model
|
71
71
|
|
72
|
-
self.
|
72
|
+
self._model._merge(model)
|
zenml/new/pipelines/pipeline.py
CHANGED
@@ -70,7 +70,7 @@ from zenml.models import (
|
|
70
70
|
ScheduleRequest,
|
71
71
|
)
|
72
72
|
from zenml.new.pipelines import build_utils
|
73
|
-
from zenml.new.pipelines.model_utils import
|
73
|
+
from zenml.new.pipelines.model_utils import NewModelRequest
|
74
74
|
from zenml.orchestrators.utils import get_run_name
|
75
75
|
from zenml.stack import Stack
|
76
76
|
from zenml.steps import BaseStep
|
@@ -93,7 +93,7 @@ if TYPE_CHECKING:
|
|
93
93
|
from zenml.config.base_settings import SettingsOrDict
|
94
94
|
from zenml.config.source import Source
|
95
95
|
from zenml.model.lazy_load import ModelVersionDataLazyLoader
|
96
|
-
from zenml.model.
|
96
|
+
from zenml.model.model import Model
|
97
97
|
|
98
98
|
StepConfigurationUpdateOrDict = Union[
|
99
99
|
Dict[str, Any], StepConfigurationUpdate
|
@@ -126,7 +126,7 @@ class Pipeline:
|
|
126
126
|
extra: Optional[Dict[str, Any]] = None,
|
127
127
|
on_failure: Optional["HookSpecification"] = None,
|
128
128
|
on_success: Optional["HookSpecification"] = None,
|
129
|
-
|
129
|
+
model: Optional["Model"] = None,
|
130
130
|
) -> None:
|
131
131
|
"""Initializes a pipeline.
|
132
132
|
|
@@ -147,7 +147,7 @@ class Pipeline:
|
|
147
147
|
on_success: Callback function in event of success of the step. Can
|
148
148
|
be a function with no arguments, or a source path to such a
|
149
149
|
function (e.g. `module.my_function`).
|
150
|
-
|
150
|
+
model: configuration of the model in the Model Control Plane.
|
151
151
|
"""
|
152
152
|
self._invocations: Dict[str, StepInvocation] = {}
|
153
153
|
self._run_args: Dict[str, Any] = {}
|
@@ -166,7 +166,7 @@ class Pipeline:
|
|
166
166
|
extra=extra,
|
167
167
|
on_failure=on_failure,
|
168
168
|
on_success=on_success,
|
169
|
-
|
169
|
+
model=model,
|
170
170
|
)
|
171
171
|
self.entrypoint = entrypoint
|
172
172
|
self._parameters: Dict[str, Any] = {}
|
@@ -305,7 +305,7 @@ class Pipeline:
|
|
305
305
|
extra: Optional[Dict[str, Any]] = None,
|
306
306
|
on_failure: Optional["HookSpecification"] = None,
|
307
307
|
on_success: Optional["HookSpecification"] = None,
|
308
|
-
|
308
|
+
model: Optional["Model"] = None,
|
309
309
|
parameters: Optional[Dict[str, Any]] = None,
|
310
310
|
merge: bool = True,
|
311
311
|
) -> T:
|
@@ -341,7 +341,7 @@ class Pipeline:
|
|
341
341
|
configurations. If `False` the given configurations will
|
342
342
|
overwrite all existing ones. See the general description of this
|
343
343
|
method for an example.
|
344
|
-
|
344
|
+
model: configuration of the model version in the Model Control Plane.
|
345
345
|
parameters: input parameters for the pipeline.
|
346
346
|
|
347
347
|
Returns:
|
@@ -367,7 +367,7 @@ class Pipeline:
|
|
367
367
|
"extra": extra,
|
368
368
|
"failure_hook_source": failure_hook_source,
|
369
369
|
"success_hook_source": success_hook_source,
|
370
|
-
"
|
370
|
+
"model": model,
|
371
371
|
"parameters": parameters,
|
372
372
|
}
|
373
373
|
)
|
@@ -872,34 +872,32 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
872
872
|
def _update_new_requesters(
|
873
873
|
self,
|
874
874
|
requester_name: str,
|
875
|
-
|
875
|
+
model: "Model",
|
876
876
|
new_versions_requested: Dict[
|
877
|
-
Tuple[str, Optional[str]],
|
877
|
+
Tuple[str, Optional[str]], NewModelRequest
|
878
878
|
],
|
879
|
-
|
879
|
+
other_models: Set["Model"],
|
880
880
|
) -> None:
|
881
881
|
key = (
|
882
|
-
|
883
|
-
str(
|
882
|
+
model.name,
|
883
|
+
str(model.version) if model.version else None,
|
884
884
|
)
|
885
|
-
if
|
885
|
+
if model.version is None:
|
886
886
|
version_existed = False
|
887
887
|
else:
|
888
888
|
try:
|
889
|
-
|
889
|
+
model._get_model_version()
|
890
890
|
version_existed = key not in new_versions_requested
|
891
891
|
except KeyError:
|
892
892
|
version_existed = False
|
893
893
|
if not version_existed:
|
894
|
-
|
894
|
+
model.was_created_in_this_run = True
|
895
895
|
new_versions_requested[key].update_request(
|
896
|
-
|
897
|
-
|
898
|
-
source="step", name=requester_name
|
899
|
-
),
|
896
|
+
model,
|
897
|
+
NewModelRequest.Requester(source="step", name=requester_name),
|
900
898
|
)
|
901
899
|
else:
|
902
|
-
|
900
|
+
other_models.add(model)
|
903
901
|
|
904
902
|
def prepare_model_versions(
|
905
903
|
self, deployment: "PipelineDeploymentBase"
|
@@ -910,35 +908,32 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
910
908
|
deployment: The pipeline deployment configuration.
|
911
909
|
"""
|
912
910
|
new_versions_requested: Dict[
|
913
|
-
Tuple[str, Optional[str]],
|
914
|
-
] = defaultdict(
|
915
|
-
|
911
|
+
Tuple[str, Optional[str]], NewModelRequest
|
912
|
+
] = defaultdict(NewModelRequest)
|
913
|
+
other_models: Set["Model"] = set()
|
916
914
|
all_steps_have_own_config = True
|
917
915
|
for step in deployment.step_configurations.values():
|
918
|
-
|
916
|
+
step_model = step.config.model
|
919
917
|
all_steps_have_own_config = (
|
920
|
-
all_steps_have_own_config
|
921
|
-
and step.config.model_version is not None
|
918
|
+
all_steps_have_own_config and step.config.model is not None
|
922
919
|
)
|
923
|
-
if
|
920
|
+
if step_model:
|
924
921
|
self._update_new_requesters(
|
925
|
-
|
922
|
+
model=step_model,
|
926
923
|
requester_name=step.config.name,
|
927
924
|
new_versions_requested=new_versions_requested,
|
928
|
-
|
925
|
+
other_models=other_models,
|
929
926
|
)
|
930
927
|
if not all_steps_have_own_config:
|
931
|
-
|
932
|
-
|
933
|
-
)
|
934
|
-
if pipeline_model_version:
|
928
|
+
pipeline_model = deployment.pipeline_configuration.model
|
929
|
+
if pipeline_model:
|
935
930
|
self._update_new_requesters(
|
936
|
-
|
931
|
+
model=pipeline_model,
|
937
932
|
requester_name=self.name,
|
938
933
|
new_versions_requested=new_versions_requested,
|
939
|
-
|
934
|
+
other_models=other_models,
|
940
935
|
)
|
941
|
-
elif deployment.pipeline_configuration.
|
936
|
+
elif deployment.pipeline_configuration.model is not None:
|
942
937
|
logger.warning(
|
943
938
|
f"ModelConfig of pipeline `{self.name}` is overridden in all "
|
944
939
|
f"steps. "
|
@@ -946,13 +941,13 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
946
941
|
|
947
942
|
self._validate_new_version_requests(new_versions_requested)
|
948
943
|
|
949
|
-
for
|
950
|
-
|
944
|
+
for other_model in other_models:
|
945
|
+
other_model._validate_config_in_runtime()
|
951
946
|
|
952
947
|
def _validate_new_version_requests(
|
953
948
|
self,
|
954
949
|
new_versions_requested: Dict[
|
955
|
-
Tuple[str, Optional[str]],
|
950
|
+
Tuple[str, Optional[str]], NewModelRequest
|
956
951
|
],
|
957
952
|
) -> None:
|
958
953
|
"""Validate the model version that are used in the pipeline run.
|
@@ -967,13 +962,13 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
967
962
|
logger.warning(
|
968
963
|
f"New version of model version `{model_name}::{model_version or 'NEW'}` "
|
969
964
|
f"requested in multiple decorators:\n{data.requesters}\n We recommend "
|
970
|
-
"that `
|
965
|
+
"that `Model` requesting new version is configured only in one "
|
971
966
|
"place of the pipeline."
|
972
967
|
)
|
973
|
-
data.
|
968
|
+
data.model._validate_config_in_runtime()
|
974
969
|
self.__new_unnamed_model_versions_in_current_run__[
|
975
|
-
data.
|
976
|
-
] = data.
|
970
|
+
data.model.name
|
971
|
+
] = data.model.number
|
977
972
|
|
978
973
|
def get_runs(self, **kwargs: Any) -> List["PipelineRunResponse"]:
|
979
974
|
"""(Deprecated) Get runs of this pipeline.
|
@@ -1400,18 +1395,16 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
1400
1395
|
{k: v for k, v in _from_config_file.items() if k in matcher}
|
1401
1396
|
)
|
1402
1397
|
|
1403
|
-
if "
|
1404
|
-
if "
|
1405
|
-
_from_config_file[
|
1406
|
-
"
|
1407
|
-
]
|
1398
|
+
if "model" in _from_config_file:
|
1399
|
+
if "model" in self._from_config_file:
|
1400
|
+
_from_config_file["model"] = self._from_config_file[
|
1401
|
+
"model"
|
1402
|
+
]
|
1408
1403
|
else:
|
1409
|
-
from zenml.model.
|
1404
|
+
from zenml.model.model import Model
|
1410
1405
|
|
1411
|
-
_from_config_file[
|
1412
|
-
"
|
1413
|
-
] = ModelVersion.parse_obj(
|
1414
|
-
_from_config_file["model_version"]
|
1406
|
+
_from_config_file["model"] = Model.parse_obj(
|
1407
|
+
_from_config_file["model"]
|
1415
1408
|
)
|
1416
1409
|
self._from_config_file = _from_config_file
|
1417
1410
|
|
@@ -107,4 +107,4 @@ class PipelineContext:
|
|
107
107
|
self.enable_step_logs = pipeline_configuration.enable_step_logs
|
108
108
|
self.settings = pipeline_configuration.settings
|
109
109
|
self.extra = pipeline_configuration.extra
|
110
|
-
self.
|
110
|
+
self.model = pipeline_configuration.model
|