zenml-nightly 0.54.1.dev20240118__py3-none-any.whl → 0.54.1.dev20240120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/__init__.py +10 -4
- zenml/artifacts/artifact_config.py +12 -12
- zenml/artifacts/external_artifact.py +3 -3
- zenml/artifacts/external_artifact_config.py +8 -8
- zenml/artifacts/utils.py +4 -4
- zenml/cli/__init__.py +4 -4
- zenml/cli/artifact.py +38 -18
- zenml/cli/base.py +3 -3
- zenml/cli/model.py +24 -16
- zenml/cli/server.py +9 -0
- zenml/cli/utils.py +3 -3
- zenml/client.py +13 -2
- zenml/config/compiler.py +1 -1
- zenml/config/pipeline_configurations.py +2 -2
- zenml/config/pipeline_run_configuration.py +2 -2
- zenml/config/step_configurations.py +2 -2
- zenml/integrations/__init__.py +2 -4
- zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +11 -9
- zenml/metadata/lazy_load.py +5 -5
- zenml/model/lazy_load.py +2 -2
- zenml/model/{model_version.py → model.py} +47 -38
- zenml/model/utils.py +33 -33
- zenml/model_registries/base_model_registry.py +10 -8
- zenml/models/__init__.py +2 -0
- zenml/models/v2/base/filter.py +3 -0
- zenml/models/v2/base/scoped.py +59 -0
- zenml/models/v2/core/artifact.py +2 -2
- zenml/models/v2/core/artifact_version.py +6 -6
- zenml/models/v2/core/model.py +6 -6
- zenml/models/v2/core/model_version.py +9 -9
- zenml/models/v2/core/run_metadata.py +2 -2
- zenml/new/pipelines/model_utils.py +20 -20
- zenml/new/pipelines/pipeline.py +47 -54
- zenml/new/pipelines/pipeline_context.py +1 -1
- zenml/new/pipelines/pipeline_decorator.py +4 -4
- zenml/new/steps/step_context.py +15 -15
- zenml/new/steps/step_decorator.py +5 -5
- zenml/orchestrators/input_utils.py +5 -7
- zenml/orchestrators/step_launcher.py +12 -19
- zenml/orchestrators/step_runner.py +8 -10
- zenml/pipelines/base_pipeline.py +1 -1
- zenml/pipelines/pipeline_decorator.py +6 -6
- zenml/steps/base_step.py +15 -15
- zenml/steps/step_decorator.py +6 -6
- zenml/steps/utils.py +68 -0
- zenml/zen_server/deploy/helm/templates/server-db-job.yaml +1 -1
- zenml/zen_server/deploy/helm/templates/server-secret.yaml +1 -1
- zenml/zen_server/deploy/helm/templates/serviceaccount.yaml +1 -1
- zenml/zen_server/utils.py +19 -1
- zenml/zen_stores/migrations/versions/4d688d8f7aff_rename_model_version_to_model.py +94 -0
- zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +16 -4
- zenml/zen_stores/rest_zen_store.py +2 -2
- zenml/zen_stores/sql_zen_store.py +4 -1
- {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/METADATA +1 -1
- {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/RECORD +58 -57
- {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.54.1.dev20240118.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/entry_points.txt +0 -0
zenml/models/v2/base/scoped.py
CHANGED
@@ -23,10 +23,12 @@ from typing import (
|
|
23
23
|
Optional,
|
24
24
|
Type,
|
25
25
|
TypeVar,
|
26
|
+
Union,
|
26
27
|
)
|
27
28
|
from uuid import UUID
|
28
29
|
|
29
30
|
from pydantic import Field
|
31
|
+
from sqlmodel import col
|
30
32
|
|
31
33
|
from zenml.models.v2.base.base import (
|
32
34
|
BaseRequest,
|
@@ -37,6 +39,8 @@ from zenml.models.v2.base.base import (
|
|
37
39
|
from zenml.models.v2.base.filter import AnyQuery, BaseFilter
|
38
40
|
|
39
41
|
if TYPE_CHECKING:
|
42
|
+
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
|
43
|
+
|
40
44
|
from zenml.models.v2.core.user import UserResponse
|
41
45
|
from zenml.models.v2.core.workspace import WorkspaceResponse
|
42
46
|
from zenml.zen_stores.schemas import BaseSchema
|
@@ -274,3 +278,58 @@ class WorkspaceScopedFilter(BaseFilter):
|
|
274
278
|
query = query.where(scope_filter)
|
275
279
|
|
276
280
|
return query
|
281
|
+
|
282
|
+
|
283
|
+
class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter):
|
284
|
+
"""Model to enable advanced scoping with workspace and tagging."""
|
285
|
+
|
286
|
+
tag: Optional[str] = Field(
|
287
|
+
description="Tag to apply to the filter query.", default=None
|
288
|
+
)
|
289
|
+
|
290
|
+
FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
291
|
+
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
|
292
|
+
"tag",
|
293
|
+
]
|
294
|
+
|
295
|
+
def apply_filter(
|
296
|
+
self,
|
297
|
+
query: AnyQuery,
|
298
|
+
table: Type["AnySchema"],
|
299
|
+
) -> AnyQuery:
|
300
|
+
"""Applies the filter to a query.
|
301
|
+
|
302
|
+
Args:
|
303
|
+
query: The query to which to apply the filter.
|
304
|
+
table: The query table.
|
305
|
+
|
306
|
+
Returns:
|
307
|
+
The query with filter applied.
|
308
|
+
"""
|
309
|
+
from zenml.zen_stores.schemas import TagResourceSchema
|
310
|
+
|
311
|
+
query = super().apply_filter(query=query, table=table)
|
312
|
+
if self.tag:
|
313
|
+
query = (
|
314
|
+
query.join(getattr(table, "tags"))
|
315
|
+
.join(TagResourceSchema.tag)
|
316
|
+
.distinct()
|
317
|
+
)
|
318
|
+
|
319
|
+
return query
|
320
|
+
|
321
|
+
def get_custom_filters(
|
322
|
+
self,
|
323
|
+
) -> List[Union["BinaryExpression[Any]", "BooleanClauseList[Any]"]]:
|
324
|
+
"""Get custom tag filters.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
A list of custom filters.
|
328
|
+
"""
|
329
|
+
from zenml.zen_stores.schemas import TagSchema
|
330
|
+
|
331
|
+
custom_filters = super().get_custom_filters()
|
332
|
+
if self.tag:
|
333
|
+
custom_filters.append(col(TagSchema.name) == self.tag) # type: ignore[arg-type]
|
334
|
+
|
335
|
+
return custom_filters
|
zenml/models/v2/core/artifact.py
CHANGED
@@ -24,7 +24,7 @@ from zenml.models.v2.base.base import (
|
|
24
24
|
BaseResponseBody,
|
25
25
|
BaseResponseMetadata,
|
26
26
|
)
|
27
|
-
from zenml.models.v2.base.
|
27
|
+
from zenml.models.v2.base.scoped import WorkspaceScopedTaggableFilter
|
28
28
|
from zenml.models.v2.core.tag import TagResponse
|
29
29
|
|
30
30
|
if TYPE_CHECKING:
|
@@ -133,7 +133,7 @@ class ArtifactResponse(
|
|
133
133
|
# ------------------ Filter Model ------------------
|
134
134
|
|
135
135
|
|
136
|
-
class ArtifactFilter(
|
136
|
+
class ArtifactFilter(WorkspaceScopedTaggableFilter):
|
137
137
|
"""Model to enable advanced filtering of artifacts."""
|
138
138
|
|
139
139
|
name: Optional[str] = None
|
@@ -32,11 +32,11 @@ from zenml.enums import ArtifactType, GenericFilterOps
|
|
32
32
|
from zenml.logger import get_logger
|
33
33
|
from zenml.models.v2.base.filter import StrFilter
|
34
34
|
from zenml.models.v2.base.scoped import (
|
35
|
-
WorkspaceScopedFilter,
|
36
35
|
WorkspaceScopedRequest,
|
37
36
|
WorkspaceScopedResponse,
|
38
37
|
WorkspaceScopedResponseBody,
|
39
38
|
WorkspaceScopedResponseMetadata,
|
39
|
+
WorkspaceScopedTaggableFilter,
|
40
40
|
)
|
41
41
|
from zenml.models.v2.core.artifact import ArtifactResponse
|
42
42
|
from zenml.models.v2.core.tag import TagResponse
|
@@ -44,7 +44,7 @@ from zenml.models.v2.core.tag import TagResponse
|
|
44
44
|
if TYPE_CHECKING:
|
45
45
|
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
|
46
46
|
|
47
|
-
from zenml.model.
|
47
|
+
from zenml.model.model import Model
|
48
48
|
from zenml.models.v2.core.artifact_visualization import (
|
49
49
|
ArtifactVisualizationRequest,
|
50
50
|
ArtifactVisualizationResponse,
|
@@ -347,14 +347,14 @@ class ArtifactVersionResponse(
|
|
347
347
|
# ------------------ Filter Model ------------------
|
348
348
|
|
349
349
|
|
350
|
-
class ArtifactVersionFilter(
|
350
|
+
class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
|
351
351
|
"""Model to enable advanced filtering of artifact versions."""
|
352
352
|
|
353
353
|
# `name` and `only_unused` refer to properties related to other entities
|
354
354
|
# rather than a field in the db, hence they needs to be handled
|
355
355
|
# explicitly
|
356
356
|
FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
357
|
-
*
|
357
|
+
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
|
358
358
|
"name",
|
359
359
|
"only_unused",
|
360
360
|
"has_custom_name",
|
@@ -477,7 +477,7 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse):
|
|
477
477
|
id: Optional[UUID] = None # type: ignore[assignment]
|
478
478
|
_lazy_load_name: Optional[str] = None
|
479
479
|
_lazy_load_version: Optional[str] = None
|
480
|
-
|
480
|
+
_lazy_load_model: "Model"
|
481
481
|
|
482
482
|
def get_body(self) -> None: # type: ignore[override]
|
483
483
|
"""Protects from misuse of the lazy loader.
|
@@ -507,7 +507,7 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse):
|
|
507
507
|
from zenml.metadata.lazy_load import RunMetadataLazyGetter
|
508
508
|
|
509
509
|
return RunMetadataLazyGetter( # type: ignore[return-value]
|
510
|
-
self.
|
510
|
+
self._lazy_load_model,
|
511
511
|
self._lazy_load_name,
|
512
512
|
self._lazy_load_version,
|
513
513
|
)
|
zenml/models/v2/core/model.py
CHANGED
@@ -22,16 +22,16 @@ from pydantic import BaseModel, Field
|
|
22
22
|
|
23
23
|
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
|
24
24
|
from zenml.models.v2.base.scoped import (
|
25
|
-
WorkspaceScopedFilter,
|
26
25
|
WorkspaceScopedRequest,
|
27
26
|
WorkspaceScopedResponse,
|
28
27
|
WorkspaceScopedResponseBody,
|
29
28
|
WorkspaceScopedResponseMetadata,
|
29
|
+
WorkspaceScopedTaggableFilter,
|
30
30
|
)
|
31
31
|
from zenml.utils.pagination_utils import depaginate
|
32
32
|
|
33
33
|
if TYPE_CHECKING:
|
34
|
-
from zenml.model.
|
34
|
+
from zenml.model.model import Model
|
35
35
|
from zenml.models.v2.core.tag import TagResponse
|
36
36
|
|
37
37
|
|
@@ -310,7 +310,7 @@ class ModelResponse(
|
|
310
310
|
|
311
311
|
# Helper functions
|
312
312
|
@property
|
313
|
-
def versions(self) -> List["
|
313
|
+
def versions(self) -> List["Model"]:
|
314
314
|
"""List all versions of the model.
|
315
315
|
|
316
316
|
Returns:
|
@@ -323,7 +323,7 @@ class ModelResponse(
|
|
323
323
|
partial(client.list_model_versions, model_name_or_id=self.id)
|
324
324
|
)
|
325
325
|
return [
|
326
|
-
mv.
|
326
|
+
mv.to_model_class(suppress_class_validation_warnings=True)
|
327
327
|
for mv in model_versions
|
328
328
|
]
|
329
329
|
|
@@ -331,7 +331,7 @@ class ModelResponse(
|
|
331
331
|
# ------------------ Filter Model ------------------
|
332
332
|
|
333
333
|
|
334
|
-
class ModelFilter(
|
334
|
+
class ModelFilter(WorkspaceScopedTaggableFilter):
|
335
335
|
"""Model to enable advanced filtering of all Workspaces."""
|
336
336
|
|
337
337
|
name: Optional[str] = Field(
|
@@ -346,7 +346,7 @@ class ModelFilter(WorkspaceScopedFilter):
|
|
346
346
|
)
|
347
347
|
|
348
348
|
CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
349
|
-
*
|
349
|
+
*WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS,
|
350
350
|
"workspace_id",
|
351
351
|
"user_id",
|
352
352
|
]
|
@@ -23,16 +23,16 @@ from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
|
|
23
23
|
from zenml.enums import ModelStages
|
24
24
|
from zenml.models.v2.base.filter import AnyQuery
|
25
25
|
from zenml.models.v2.base.scoped import (
|
26
|
-
WorkspaceScopedFilter,
|
27
26
|
WorkspaceScopedRequest,
|
28
27
|
WorkspaceScopedResponse,
|
29
28
|
WorkspaceScopedResponseBody,
|
30
29
|
WorkspaceScopedResponseMetadata,
|
30
|
+
WorkspaceScopedTaggableFilter,
|
31
31
|
)
|
32
32
|
from zenml.models.v2.core.tag import TagResponse
|
33
33
|
|
34
34
|
if TYPE_CHECKING:
|
35
|
-
from zenml.model.
|
35
|
+
from zenml.model.model import Model
|
36
36
|
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
|
37
37
|
from zenml.models.v2.core.model import ModelResponse
|
38
38
|
from zenml.models.v2.core.pipeline_run import PipelineRunResponse
|
@@ -311,12 +311,12 @@ class ModelVersionResponse(
|
|
311
311
|
return Client().zen_store.get_model_version(self.id)
|
312
312
|
|
313
313
|
# Helper functions
|
314
|
-
def
|
314
|
+
def to_model_class(
|
315
315
|
self,
|
316
316
|
was_created_in_this_run: bool = False,
|
317
317
|
suppress_class_validation_warnings: bool = False,
|
318
|
-
) -> "
|
319
|
-
"""Convert response model to
|
318
|
+
) -> "Model":
|
319
|
+
"""Convert response model to Model object.
|
320
320
|
|
321
321
|
Args:
|
322
322
|
was_created_in_this_run: Whether model version was created during
|
@@ -325,11 +325,11 @@ class ModelVersionResponse(
|
|
325
325
|
repeated warnings.
|
326
326
|
|
327
327
|
Returns:
|
328
|
-
|
328
|
+
Model object
|
329
329
|
"""
|
330
|
-
from zenml.model.
|
330
|
+
from zenml.model.model import Model
|
331
331
|
|
332
|
-
mv =
|
332
|
+
mv = Model(
|
333
333
|
name=self.model.name,
|
334
334
|
license=self.model.license,
|
335
335
|
description=self.description,
|
@@ -578,7 +578,7 @@ class ModelVersionResponse(
|
|
578
578
|
# ------------------ Filter Model ------------------
|
579
579
|
|
580
580
|
|
581
|
-
class ModelVersionFilter(
|
581
|
+
class ModelVersionFilter(WorkspaceScopedTaggableFilter):
|
582
582
|
"""Filter model for model versions."""
|
583
583
|
|
584
584
|
name: Optional[str] = Field(
|
@@ -30,7 +30,7 @@ from zenml.models.v2.base.scoped import (
|
|
30
30
|
)
|
31
31
|
|
32
32
|
if TYPE_CHECKING:
|
33
|
-
from zenml.model.
|
33
|
+
from zenml.model.model import Model
|
34
34
|
|
35
35
|
# ------------------ Request Model ------------------
|
36
36
|
|
@@ -203,7 +203,7 @@ class LazyRunMetadataResponse(RunMetadataResponse):
|
|
203
203
|
_lazy_load_artifact_name: Optional[str] = None
|
204
204
|
_lazy_load_artifact_version: Optional[str] = None
|
205
205
|
_lazy_load_metadata_name: Optional[str] = None
|
206
|
-
|
206
|
+
_lazy_load_model: "Model"
|
207
207
|
|
208
208
|
def get_body(self) -> None: # type: ignore[override]
|
209
209
|
"""Protects from misuse of the lazy loader.
|
@@ -17,14 +17,14 @@ from typing import List, Optional
|
|
17
17
|
|
18
18
|
from pydantic import BaseModel, PrivateAttr
|
19
19
|
|
20
|
-
from zenml.model.
|
20
|
+
from zenml.model.model import Model
|
21
21
|
|
22
22
|
|
23
|
-
class
|
24
|
-
"""Request to create a new model
|
23
|
+
class NewModelRequest(BaseModel):
|
24
|
+
"""Request to create a new version of a model."""
|
25
25
|
|
26
26
|
class Requester(BaseModel):
|
27
|
-
"""Requester of a new model
|
27
|
+
"""Requester of a new version of a model."""
|
28
28
|
|
29
29
|
source: str
|
30
30
|
name: str
|
@@ -38,35 +38,35 @@ class NewModelVersionRequest(BaseModel):
|
|
38
38
|
return f"{self.source}::{self.name}"
|
39
39
|
|
40
40
|
requesters: List[Requester] = []
|
41
|
-
|
41
|
+
_model: Optional[Model] = PrivateAttr(default=None)
|
42
42
|
|
43
43
|
@property
|
44
|
-
def
|
45
|
-
"""Model
|
44
|
+
def model(self) -> Model:
|
45
|
+
"""Model getter.
|
46
46
|
|
47
47
|
Returns:
|
48
|
-
The model
|
48
|
+
The model.
|
49
49
|
|
50
50
|
Raises:
|
51
|
-
RuntimeError: If the model
|
51
|
+
RuntimeError: If the model is not set.
|
52
52
|
"""
|
53
|
-
if self.
|
54
|
-
raise RuntimeError("Model
|
55
|
-
return self.
|
53
|
+
if self._model is None:
|
54
|
+
raise RuntimeError("Model is not set.")
|
55
|
+
return self._model
|
56
56
|
|
57
57
|
def update_request(
|
58
58
|
self,
|
59
|
-
|
60
|
-
requester: "
|
59
|
+
model: Model,
|
60
|
+
requester: "NewModelRequest.Requester",
|
61
61
|
) -> None:
|
62
|
-
"""Update from `
|
62
|
+
"""Update from `Model` in place.
|
63
63
|
|
64
64
|
Args:
|
65
|
-
|
66
|
-
requester: Requester of a new model
|
65
|
+
model: `Model` to use.
|
66
|
+
requester: Requester of a new version of a model.
|
67
67
|
"""
|
68
68
|
self.requesters.append(requester)
|
69
|
-
if self.
|
70
|
-
self.
|
69
|
+
if self._model is None:
|
70
|
+
self._model = model
|
71
71
|
|
72
|
-
self.
|
72
|
+
self._model._merge(model)
|
zenml/new/pipelines/pipeline.py
CHANGED
@@ -70,7 +70,7 @@ from zenml.models import (
|
|
70
70
|
ScheduleRequest,
|
71
71
|
)
|
72
72
|
from zenml.new.pipelines import build_utils
|
73
|
-
from zenml.new.pipelines.model_utils import
|
73
|
+
from zenml.new.pipelines.model_utils import NewModelRequest
|
74
74
|
from zenml.orchestrators.utils import get_run_name
|
75
75
|
from zenml.stack import Stack
|
76
76
|
from zenml.steps import BaseStep
|
@@ -93,7 +93,7 @@ if TYPE_CHECKING:
|
|
93
93
|
from zenml.config.base_settings import SettingsOrDict
|
94
94
|
from zenml.config.source import Source
|
95
95
|
from zenml.model.lazy_load import ModelVersionDataLazyLoader
|
96
|
-
from zenml.model.
|
96
|
+
from zenml.model.model import Model
|
97
97
|
|
98
98
|
StepConfigurationUpdateOrDict = Union[
|
99
99
|
Dict[str, Any], StepConfigurationUpdate
|
@@ -126,7 +126,7 @@ class Pipeline:
|
|
126
126
|
extra: Optional[Dict[str, Any]] = None,
|
127
127
|
on_failure: Optional["HookSpecification"] = None,
|
128
128
|
on_success: Optional["HookSpecification"] = None,
|
129
|
-
|
129
|
+
model: Optional["Model"] = None,
|
130
130
|
) -> None:
|
131
131
|
"""Initializes a pipeline.
|
132
132
|
|
@@ -147,7 +147,7 @@ class Pipeline:
|
|
147
147
|
on_success: Callback function in event of success of the step. Can
|
148
148
|
be a function with no arguments, or a source path to such a
|
149
149
|
function (e.g. `module.my_function`).
|
150
|
-
|
150
|
+
model: configuration of the model in the Model Control Plane.
|
151
151
|
"""
|
152
152
|
self._invocations: Dict[str, StepInvocation] = {}
|
153
153
|
self._run_args: Dict[str, Any] = {}
|
@@ -166,7 +166,7 @@ class Pipeline:
|
|
166
166
|
extra=extra,
|
167
167
|
on_failure=on_failure,
|
168
168
|
on_success=on_success,
|
169
|
-
|
169
|
+
model=model,
|
170
170
|
)
|
171
171
|
self.entrypoint = entrypoint
|
172
172
|
self._parameters: Dict[str, Any] = {}
|
@@ -305,7 +305,7 @@ class Pipeline:
|
|
305
305
|
extra: Optional[Dict[str, Any]] = None,
|
306
306
|
on_failure: Optional["HookSpecification"] = None,
|
307
307
|
on_success: Optional["HookSpecification"] = None,
|
308
|
-
|
308
|
+
model: Optional["Model"] = None,
|
309
309
|
parameters: Optional[Dict[str, Any]] = None,
|
310
310
|
merge: bool = True,
|
311
311
|
) -> T:
|
@@ -341,7 +341,7 @@ class Pipeline:
|
|
341
341
|
configurations. If `False` the given configurations will
|
342
342
|
overwrite all existing ones. See the general description of this
|
343
343
|
method for an example.
|
344
|
-
|
344
|
+
model: configuration of the model version in the Model Control Plane.
|
345
345
|
parameters: input parameters for the pipeline.
|
346
346
|
|
347
347
|
Returns:
|
@@ -367,7 +367,7 @@ class Pipeline:
|
|
367
367
|
"extra": extra,
|
368
368
|
"failure_hook_source": failure_hook_source,
|
369
369
|
"success_hook_source": success_hook_source,
|
370
|
-
"
|
370
|
+
"model": model,
|
371
371
|
"parameters": parameters,
|
372
372
|
}
|
373
373
|
)
|
@@ -872,34 +872,32 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
872
872
|
def _update_new_requesters(
|
873
873
|
self,
|
874
874
|
requester_name: str,
|
875
|
-
|
875
|
+
model: "Model",
|
876
876
|
new_versions_requested: Dict[
|
877
|
-
Tuple[str, Optional[str]],
|
877
|
+
Tuple[str, Optional[str]], NewModelRequest
|
878
878
|
],
|
879
|
-
|
879
|
+
other_models: Set["Model"],
|
880
880
|
) -> None:
|
881
881
|
key = (
|
882
|
-
|
883
|
-
str(
|
882
|
+
model.name,
|
883
|
+
str(model.version) if model.version else None,
|
884
884
|
)
|
885
|
-
if
|
885
|
+
if model.version is None:
|
886
886
|
version_existed = False
|
887
887
|
else:
|
888
888
|
try:
|
889
|
-
|
889
|
+
model._get_model_version()
|
890
890
|
version_existed = key not in new_versions_requested
|
891
891
|
except KeyError:
|
892
892
|
version_existed = False
|
893
893
|
if not version_existed:
|
894
|
-
|
894
|
+
model.was_created_in_this_run = True
|
895
895
|
new_versions_requested[key].update_request(
|
896
|
-
|
897
|
-
|
898
|
-
source="step", name=requester_name
|
899
|
-
),
|
896
|
+
model,
|
897
|
+
NewModelRequest.Requester(source="step", name=requester_name),
|
900
898
|
)
|
901
899
|
else:
|
902
|
-
|
900
|
+
other_models.add(model)
|
903
901
|
|
904
902
|
def prepare_model_versions(
|
905
903
|
self, deployment: "PipelineDeploymentBase"
|
@@ -910,35 +908,32 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
910
908
|
deployment: The pipeline deployment configuration.
|
911
909
|
"""
|
912
910
|
new_versions_requested: Dict[
|
913
|
-
Tuple[str, Optional[str]],
|
914
|
-
] = defaultdict(
|
915
|
-
|
911
|
+
Tuple[str, Optional[str]], NewModelRequest
|
912
|
+
] = defaultdict(NewModelRequest)
|
913
|
+
other_models: Set["Model"] = set()
|
916
914
|
all_steps_have_own_config = True
|
917
915
|
for step in deployment.step_configurations.values():
|
918
|
-
|
916
|
+
step_model = step.config.model
|
919
917
|
all_steps_have_own_config = (
|
920
|
-
all_steps_have_own_config
|
921
|
-
and step.config.model_version is not None
|
918
|
+
all_steps_have_own_config and step.config.model is not None
|
922
919
|
)
|
923
|
-
if
|
920
|
+
if step_model:
|
924
921
|
self._update_new_requesters(
|
925
|
-
|
922
|
+
model=step_model,
|
926
923
|
requester_name=step.config.name,
|
927
924
|
new_versions_requested=new_versions_requested,
|
928
|
-
|
925
|
+
other_models=other_models,
|
929
926
|
)
|
930
927
|
if not all_steps_have_own_config:
|
931
|
-
|
932
|
-
|
933
|
-
)
|
934
|
-
if pipeline_model_version:
|
928
|
+
pipeline_model = deployment.pipeline_configuration.model
|
929
|
+
if pipeline_model:
|
935
930
|
self._update_new_requesters(
|
936
|
-
|
931
|
+
model=pipeline_model,
|
937
932
|
requester_name=self.name,
|
938
933
|
new_versions_requested=new_versions_requested,
|
939
|
-
|
934
|
+
other_models=other_models,
|
940
935
|
)
|
941
|
-
elif deployment.pipeline_configuration.
|
936
|
+
elif deployment.pipeline_configuration.model is not None:
|
942
937
|
logger.warning(
|
943
938
|
f"ModelConfig of pipeline `{self.name}` is overridden in all "
|
944
939
|
f"steps. "
|
@@ -946,13 +941,13 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
946
941
|
|
947
942
|
self._validate_new_version_requests(new_versions_requested)
|
948
943
|
|
949
|
-
for
|
950
|
-
|
944
|
+
for other_model in other_models:
|
945
|
+
other_model._validate_config_in_runtime()
|
951
946
|
|
952
947
|
def _validate_new_version_requests(
|
953
948
|
self,
|
954
949
|
new_versions_requested: Dict[
|
955
|
-
Tuple[str, Optional[str]],
|
950
|
+
Tuple[str, Optional[str]], NewModelRequest
|
956
951
|
],
|
957
952
|
) -> None:
|
958
953
|
"""Validate the model version that are used in the pipeline run.
|
@@ -967,13 +962,13 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
967
962
|
logger.warning(
|
968
963
|
f"New version of model version `{model_name}::{model_version or 'NEW'}` "
|
969
964
|
f"requested in multiple decorators:\n{data.requesters}\n We recommend "
|
970
|
-
"that `
|
965
|
+
"that `Model` requesting new version is configured only in one "
|
971
966
|
"place of the pipeline."
|
972
967
|
)
|
973
|
-
data.
|
968
|
+
data.model._validate_config_in_runtime()
|
974
969
|
self.__new_unnamed_model_versions_in_current_run__[
|
975
|
-
data.
|
976
|
-
] = data.
|
970
|
+
data.model.name
|
971
|
+
] = data.model.number
|
977
972
|
|
978
973
|
def get_runs(self, **kwargs: Any) -> List["PipelineRunResponse"]:
|
979
974
|
"""(Deprecated) Get runs of this pipeline.
|
@@ -1400,18 +1395,16 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
1400
1395
|
{k: v for k, v in _from_config_file.items() if k in matcher}
|
1401
1396
|
)
|
1402
1397
|
|
1403
|
-
if "
|
1404
|
-
if "
|
1405
|
-
_from_config_file[
|
1406
|
-
"
|
1407
|
-
]
|
1398
|
+
if "model" in _from_config_file:
|
1399
|
+
if "model" in self._from_config_file:
|
1400
|
+
_from_config_file["model"] = self._from_config_file[
|
1401
|
+
"model"
|
1402
|
+
]
|
1408
1403
|
else:
|
1409
|
-
from zenml.model.
|
1404
|
+
from zenml.model.model import Model
|
1410
1405
|
|
1411
|
-
_from_config_file[
|
1412
|
-
"
|
1413
|
-
] = ModelVersion.parse_obj(
|
1414
|
-
_from_config_file["model_version"]
|
1406
|
+
_from_config_file["model"] = Model.parse_obj(
|
1407
|
+
_from_config_file["model"]
|
1415
1408
|
)
|
1416
1409
|
self._from_config_file = _from_config_file
|
1417
1410
|
|
@@ -107,4 +107,4 @@ class PipelineContext:
|
|
107
107
|
self.enable_step_logs = pipeline_configuration.enable_step_logs
|
108
108
|
self.settings = pipeline_configuration.settings
|
109
109
|
self.extra = pipeline_configuration.extra
|
110
|
-
self.
|
110
|
+
self.model = pipeline_configuration.model
|
@@ -26,7 +26,7 @@ from typing import (
|
|
26
26
|
|
27
27
|
if TYPE_CHECKING:
|
28
28
|
from zenml.config.base_settings import SettingsOrDict
|
29
|
-
from zenml.model.
|
29
|
+
from zenml.model.model import Model
|
30
30
|
from zenml.new.pipelines.pipeline import Pipeline
|
31
31
|
|
32
32
|
HookSpecification = Union[str, FunctionType]
|
@@ -62,7 +62,7 @@ def pipeline(
|
|
62
62
|
extra: Optional[Dict[str, Any]] = None,
|
63
63
|
on_failure: Optional["HookSpecification"] = None,
|
64
64
|
on_success: Optional["HookSpecification"] = None,
|
65
|
-
|
65
|
+
model: Optional["Model"] = None,
|
66
66
|
) -> Union["Pipeline", Callable[["F"], "Pipeline"]]:
|
67
67
|
"""Decorator to create a pipeline.
|
68
68
|
|
@@ -81,7 +81,7 @@ def pipeline(
|
|
81
81
|
on_success: Callback function in event of success of the step. Can be a
|
82
82
|
function with no arguments, or a source path to such a function
|
83
83
|
(e.g. `module.my_function`).
|
84
|
-
|
84
|
+
model: configuration of the model in the Model Control Plane.
|
85
85
|
|
86
86
|
Returns:
|
87
87
|
A pipeline instance.
|
@@ -99,7 +99,7 @@ def pipeline(
|
|
99
99
|
extra=extra,
|
100
100
|
on_failure=on_failure,
|
101
101
|
on_success=on_success,
|
102
|
-
|
102
|
+
model=model,
|
103
103
|
entrypoint=func,
|
104
104
|
)
|
105
105
|
|