zenml-nightly 0.54.1.dev20240119__py3-none-any.whl → 0.54.1.dev20240120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/__init__.py +10 -4
- zenml/artifacts/artifact_config.py +12 -12
- zenml/artifacts/external_artifact.py +3 -3
- zenml/artifacts/external_artifact_config.py +8 -8
- zenml/artifacts/utils.py +4 -4
- zenml/cli/__init__.py +1 -1
- zenml/cli/base.py +3 -3
- zenml/cli/utils.py +3 -3
- zenml/config/compiler.py +1 -1
- zenml/config/pipeline_configurations.py +2 -2
- zenml/config/pipeline_run_configuration.py +2 -2
- zenml/config/step_configurations.py +2 -2
- zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +11 -9
- zenml/metadata/lazy_load.py +5 -5
- zenml/model/lazy_load.py +2 -2
- zenml/model/{model_version.py → model.py} +35 -28
- zenml/model/utils.py +33 -33
- zenml/model_registries/base_model_registry.py +10 -8
- zenml/models/v2/core/artifact_version.py +3 -3
- zenml/models/v2/core/model.py +3 -3
- zenml/models/v2/core/model_version.py +7 -7
- zenml/models/v2/core/run_metadata.py +2 -2
- zenml/new/pipelines/model_utils.py +20 -20
- zenml/new/pipelines/pipeline.py +47 -54
- zenml/new/pipelines/pipeline_context.py +1 -1
- zenml/new/pipelines/pipeline_decorator.py +4 -4
- zenml/new/steps/step_context.py +15 -15
- zenml/new/steps/step_decorator.py +5 -5
- zenml/orchestrators/input_utils.py +5 -7
- zenml/orchestrators/step_launcher.py +12 -19
- zenml/orchestrators/step_runner.py +8 -10
- zenml/pipelines/base_pipeline.py +1 -1
- zenml/pipelines/pipeline_decorator.py +6 -6
- zenml/steps/base_step.py +15 -15
- zenml/steps/step_decorator.py +6 -6
- zenml/steps/utils.py +68 -0
- zenml/zen_stores/migrations/versions/4d688d8f7aff_rename_model_version_to_model.py +94 -0
- {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/METADATA +1 -1
- {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/RECORD +42 -41
- {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.54.1.dev20240119.dist-info → zenml_nightly-0.54.1.dev20240120.dist-info}/entry_points.txt +0 -0
zenml/__init__.py
CHANGED
@@ -42,15 +42,19 @@ from zenml.artifacts.utils import (
|
|
42
42
|
save_artifact,
|
43
43
|
load_artifact,
|
44
44
|
)
|
45
|
-
from zenml.model.utils import
|
45
|
+
from zenml.model.utils import (
|
46
|
+
log_model_metadata,
|
47
|
+
link_artifact_to_model,
|
48
|
+
log_model_version_metadata,
|
49
|
+
)
|
46
50
|
from zenml.artifacts.artifact_config import ArtifactConfig
|
47
51
|
from zenml.artifacts.external_artifact import ExternalArtifact
|
48
|
-
from zenml.model.
|
49
|
-
from zenml.model.utils import log_model_version_metadata
|
52
|
+
from zenml.model.model import Model
|
50
53
|
from zenml.new.pipelines.pipeline_context import get_pipeline_context
|
51
54
|
from zenml.new.pipelines.pipeline_decorator import pipeline
|
52
55
|
from zenml.new.steps.step_decorator import step
|
53
56
|
from zenml.new.steps.step_context import get_step_context
|
57
|
+
from zenml.steps.utils import log_step_metadata
|
54
58
|
|
55
59
|
__all__ = [
|
56
60
|
"ArtifactConfig",
|
@@ -60,8 +64,10 @@ __all__ = [
|
|
60
64
|
"load_artifact",
|
61
65
|
"log_artifact_metadata",
|
62
66
|
"log_model_metadata",
|
67
|
+
"log_model_version_metadata",
|
68
|
+
"log_step_metadata",
|
69
|
+
"Model",
|
63
70
|
"link_artifact_to_model",
|
64
|
-
"ModelVersion",
|
65
71
|
"pipeline",
|
66
72
|
"save_artifact",
|
67
73
|
"show",
|
@@ -22,7 +22,7 @@ from zenml.logger import get_logger
|
|
22
22
|
from zenml.new.steps.step_context import get_step_context
|
23
23
|
|
24
24
|
if TYPE_CHECKING:
|
25
|
-
from zenml.model.
|
25
|
+
from zenml.model.model import Model
|
26
26
|
|
27
27
|
|
28
28
|
logger = get_logger(__name__)
|
@@ -52,7 +52,7 @@ class ArtifactConfig(BaseModel):
|
|
52
52
|
version: The version of the artifact.
|
53
53
|
tags: The tags of the artifact.
|
54
54
|
model_name: The name of the model to link artifact to.
|
55
|
-
model_version: The identifier of the model
|
55
|
+
model_version: The identifier of a version of the model to link the artifact
|
56
56
|
to. It can be an exact version ("my_version"), exact version number
|
57
57
|
(42), stage (ModelStages.PRODUCTION or "production"), or
|
58
58
|
(ModelStages.LATEST or None) for the latest version (default).
|
@@ -88,26 +88,26 @@ class ArtifactConfig(BaseModel):
|
|
88
88
|
smart_union = True
|
89
89
|
|
90
90
|
@property
|
91
|
-
def
|
92
|
-
"""The model
|
91
|
+
def _model(self) -> Optional["Model"]:
|
92
|
+
"""The model linked to this artifact.
|
93
93
|
|
94
94
|
Returns:
|
95
|
-
The model
|
95
|
+
The model or None if the model version cannot be determined.
|
96
96
|
"""
|
97
97
|
try:
|
98
|
-
|
98
|
+
model_ = get_step_context().model
|
99
99
|
except (StepContextError, RuntimeError):
|
100
|
-
|
100
|
+
model_ = None
|
101
101
|
# Check if another model name was specified
|
102
102
|
if (self.model_name is not None) and (
|
103
|
-
|
103
|
+
model_ is None or model_.name != self.model_name
|
104
104
|
):
|
105
|
-
# Create a new
|
106
|
-
from zenml.model.
|
105
|
+
# Create a new Model instance with the provided model name and version
|
106
|
+
from zenml.model.model import Model
|
107
107
|
|
108
|
-
on_the_fly_config =
|
108
|
+
on_the_fly_config = Model(
|
109
109
|
name=self.model_name, version=self.model_version
|
110
110
|
)
|
111
111
|
return on_the_fly_config
|
112
112
|
|
113
|
-
return
|
113
|
+
return model_
|
@@ -56,8 +56,8 @@ class ExternalArtifact(ExternalArtifactConfiguration):
|
|
56
56
|
`version`, `pipeline_run_name`, or `pipeline_name` are set, the
|
57
57
|
latest version of the artifact will be used.
|
58
58
|
version: Version of the artifact to search. Only used when `name` is
|
59
|
-
provided. Cannot be used together with `
|
60
|
-
|
59
|
+
provided. Cannot be used together with `model`.
|
60
|
+
model: The model to search in. Only used when `name`
|
61
61
|
is provided. Cannot be used together with `version`.
|
62
62
|
materializer: The materializer to use for saving the artifact value
|
63
63
|
to the artifact store. Only used when `value` is provided.
|
@@ -149,5 +149,5 @@ class ExternalArtifact(ExternalArtifactConfiguration):
|
|
149
149
|
id=self.id,
|
150
150
|
name=self.name,
|
151
151
|
version=self.version,
|
152
|
-
|
152
|
+
model=self.model,
|
153
153
|
)
|
@@ -18,7 +18,7 @@ from uuid import UUID
|
|
18
18
|
from pydantic import BaseModel, root_validator
|
19
19
|
|
20
20
|
from zenml.logger import get_logger
|
21
|
-
from zenml.model.
|
21
|
+
from zenml.model.model import Model
|
22
22
|
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
|
23
23
|
|
24
24
|
logger = get_logger(__name__)
|
@@ -33,13 +33,13 @@ class ExternalArtifactConfiguration(BaseModel):
|
|
33
33
|
id: Optional[UUID] = None
|
34
34
|
name: Optional[str] = None
|
35
35
|
version: Optional[str] = None
|
36
|
-
|
36
|
+
model: Optional[Model] = None
|
37
37
|
|
38
38
|
@root_validator
|
39
39
|
def _validate_all_eac(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
40
|
-
if values.get("version", None) and values.get("
|
40
|
+
if values.get("version", None) and values.get("model", None):
|
41
41
|
raise ValueError(
|
42
|
-
"Cannot provide both `version` and `
|
42
|
+
"Cannot provide both `version` and `model` when "
|
43
43
|
"creating an external artifact."
|
44
44
|
)
|
45
45
|
return values
|
@@ -67,13 +67,13 @@ class ExternalArtifactConfiguration(BaseModel):
|
|
67
67
|
response = client.get_artifact_version(
|
68
68
|
self.name, version=self.version
|
69
69
|
)
|
70
|
-
elif self.
|
71
|
-
response_ = self.
|
70
|
+
elif self.model:
|
71
|
+
response_ = self.model.get_artifact(self.name)
|
72
72
|
if not isinstance(response_, ArtifactVersionResponse):
|
73
73
|
raise RuntimeError(
|
74
74
|
f"Failed to pull artifact `{self.name}` from the Model "
|
75
|
-
f"
|
76
|
-
f"`{self.
|
75
|
+
f"(name=`{self.model.name}`, version="
|
76
|
+
f"`{self.model.version}`). Please validate the "
|
77
77
|
"input and try again."
|
78
78
|
)
|
79
79
|
response = response_
|
zenml/artifacts/utils.py
CHANGED
@@ -233,14 +233,14 @@ def save_artifact(
|
|
233
233
|
saved_artifact_versions={name: response.id}
|
234
234
|
),
|
235
235
|
)
|
236
|
-
error_message = "model
|
237
|
-
|
238
|
-
if
|
236
|
+
error_message = "model"
|
237
|
+
model = step_context.model
|
238
|
+
if model:
|
239
239
|
from zenml.model.utils import link_artifact_to_model
|
240
240
|
|
241
241
|
link_artifact_to_model(
|
242
242
|
artifact_version_id=response.id,
|
243
|
-
|
243
|
+
model=model,
|
244
244
|
is_model_artifact=is_model_artifact,
|
245
245
|
is_deployment_artifact=is_deployment_artifact,
|
246
246
|
)
|
zenml/cli/__init__.py
CHANGED
@@ -796,7 +796,7 @@ Administering your Models
|
|
796
796
|
----------------------------
|
797
797
|
|
798
798
|
ZenML provides several CLI commands to help you administer your models and
|
799
|
-
|
799
|
+
their versions as part of the Model Control Plane.
|
800
800
|
|
801
801
|
To register a new model, you can use the following CLI command:
|
802
802
|
```bash
|
zenml/cli/base.py
CHANGED
@@ -73,15 +73,15 @@ class ZenMLProjectTemplateLocation(BaseModel):
|
|
73
73
|
ZENML_PROJECT_TEMPLATES = dict(
|
74
74
|
e2e_batch=ZenMLProjectTemplateLocation(
|
75
75
|
github_url="zenml-io/template-e2e-batch",
|
76
|
-
github_tag="2024.01.
|
76
|
+
github_tag="2024.01.18", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
|
77
77
|
),
|
78
78
|
starter=ZenMLProjectTemplateLocation(
|
79
79
|
github_url="zenml-io/template-starter",
|
80
|
-
github_tag="
|
80
|
+
github_tag="2024.01.12", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
|
81
81
|
),
|
82
82
|
nlp=ZenMLProjectTemplateLocation(
|
83
83
|
github_url="zenml-io/template-nlp",
|
84
|
-
github_tag="
|
84
|
+
github_tag="2024.01.12", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
|
85
85
|
),
|
86
86
|
)
|
87
87
|
|
zenml/cli/utils.py
CHANGED
@@ -58,8 +58,8 @@ from zenml.constants import (
|
|
58
58
|
from zenml.enums import GenericFilterOps, StackComponentType
|
59
59
|
from zenml.logger import get_logger
|
60
60
|
from zenml.model_registries.base_model_registry import (
|
61
|
-
ModelVersion,
|
62
61
|
RegisteredModel,
|
62
|
+
RegistryModelVersion,
|
63
63
|
)
|
64
64
|
from zenml.models import (
|
65
65
|
BaseFilter,
|
@@ -1182,7 +1182,7 @@ def pretty_print_registered_model_table(
|
|
1182
1182
|
|
1183
1183
|
|
1184
1184
|
def pretty_print_model_version_table(
|
1185
|
-
model_versions: List["
|
1185
|
+
model_versions: List["RegistryModelVersion"],
|
1186
1186
|
) -> None:
|
1187
1187
|
"""Given a list of model_versions, print all associated key-value pairs.
|
1188
1188
|
|
@@ -1206,7 +1206,7 @@ def pretty_print_model_version_table(
|
|
1206
1206
|
|
1207
1207
|
|
1208
1208
|
def pretty_print_model_version_details(
|
1209
|
-
model_version: "
|
1209
|
+
model_version: "RegistryModelVersion",
|
1210
1210
|
) -> None:
|
1211
1211
|
"""Given a model_version, print all associated key-value pairs.
|
1212
1212
|
|
zenml/config/compiler.py
CHANGED
@@ -19,7 +19,7 @@ from pydantic import validator
|
|
19
19
|
from zenml.config.constants import DOCKER_SETTINGS_KEY
|
20
20
|
from zenml.config.source import Source, convert_source_validator
|
21
21
|
from zenml.config.strict_base_model import StrictBaseModel
|
22
|
-
from zenml.model.
|
22
|
+
from zenml.model.model import Model
|
23
23
|
|
24
24
|
if TYPE_CHECKING:
|
25
25
|
from zenml.config import DockerSettings
|
@@ -40,7 +40,7 @@ class PipelineConfigurationUpdate(StrictBaseModel):
|
|
40
40
|
extra: Dict[str, Any] = {}
|
41
41
|
failure_hook_source: Optional[Source] = None
|
42
42
|
success_hook_source: Optional[Source] = None
|
43
|
-
|
43
|
+
model: Optional[Model] = None
|
44
44
|
parameters: Optional[Dict[str, Any]] = None
|
45
45
|
|
46
46
|
_convert_source = convert_source_validator(
|
@@ -19,7 +19,7 @@ from zenml.config.base_settings import BaseSettings
|
|
19
19
|
from zenml.config.schedule import Schedule
|
20
20
|
from zenml.config.step_configurations import StepConfigurationUpdate
|
21
21
|
from zenml.config.strict_base_model import StrictBaseModel
|
22
|
-
from zenml.model.
|
22
|
+
from zenml.model.model import Model
|
23
23
|
from zenml.models import PipelineBuildBase
|
24
24
|
from zenml.utils import pydantic_utils
|
25
25
|
|
@@ -39,5 +39,5 @@ class PipelineRunConfiguration(
|
|
39
39
|
steps: Dict[str, StepConfigurationUpdate] = {}
|
40
40
|
settings: Dict[str, BaseSettings] = {}
|
41
41
|
extra: Dict[str, Any] = {}
|
42
|
-
|
42
|
+
model: Optional[Model] = None
|
43
43
|
parameters: Optional[Dict[str, Any]] = None
|
@@ -35,7 +35,7 @@ from zenml.config.source import Source, convert_source_validator
|
|
35
35
|
from zenml.config.strict_base_model import StrictBaseModel
|
36
36
|
from zenml.logger import get_logger
|
37
37
|
from zenml.model.lazy_load import ModelVersionDataLazyLoader
|
38
|
-
from zenml.model.
|
38
|
+
from zenml.model.model import Model
|
39
39
|
from zenml.utils import deprecation_utils
|
40
40
|
|
41
41
|
if TYPE_CHECKING:
|
@@ -135,7 +135,7 @@ class StepConfigurationUpdate(StrictBaseModel):
|
|
135
135
|
extra: Dict[str, Any] = {}
|
136
136
|
failure_hook_source: Optional[Source] = None
|
137
137
|
success_hook_source: Optional[Source] = None
|
138
|
-
|
138
|
+
model: Optional[Model] = None
|
139
139
|
|
140
140
|
outputs: Mapping[str, PartialArtifactConfiguration] = {}
|
141
141
|
|
@@ -19,7 +19,9 @@ from typing import Any, Dict, List, Optional, Tuple, cast
|
|
19
19
|
|
20
20
|
import mlflow
|
21
21
|
from mlflow import MlflowClient
|
22
|
-
from mlflow.entities.model_registry import
|
22
|
+
from mlflow.entities.model_registry import (
|
23
|
+
ModelVersion as MLflowModelVersion,
|
24
|
+
)
|
23
25
|
from mlflow.exceptions import MlflowException
|
24
26
|
from mlflow.pyfunc import load_model
|
25
27
|
|
@@ -36,9 +38,9 @@ from zenml.logger import get_logger
|
|
36
38
|
from zenml.model_registries.base_model_registry import (
|
37
39
|
BaseModelRegistry,
|
38
40
|
ModelRegistryModelMetadata,
|
39
|
-
ModelVersion,
|
40
41
|
ModelVersionStage,
|
41
42
|
RegisteredModel,
|
43
|
+
RegistryModelVersion,
|
42
44
|
)
|
43
45
|
from zenml.stack.stack import Stack
|
44
46
|
from zenml.stack.stack_validator import StackValidator
|
@@ -348,7 +350,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
348
350
|
description: Optional[str] = None,
|
349
351
|
metadata: Optional[ModelRegistryModelMetadata] = None,
|
350
352
|
**kwargs: Any,
|
351
|
-
) ->
|
353
|
+
) -> RegistryModelVersion:
|
352
354
|
"""Register a model version to the MLflow model registry.
|
353
355
|
|
354
356
|
Args:
|
@@ -442,7 +444,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
442
444
|
metadata: Optional[ModelRegistryModelMetadata] = None,
|
443
445
|
remove_metadata: Optional[List[str]] = None,
|
444
446
|
stage: Optional[ModelVersionStage] = None,
|
445
|
-
) ->
|
447
|
+
) -> RegistryModelVersion:
|
446
448
|
"""Update a model version in the MLflow model registry.
|
447
449
|
|
448
450
|
Args:
|
@@ -521,7 +523,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
521
523
|
self,
|
522
524
|
name: str,
|
523
525
|
version: str,
|
524
|
-
) ->
|
526
|
+
) -> RegistryModelVersion:
|
525
527
|
"""Get a model version from the MLflow model registry.
|
526
528
|
|
527
529
|
Args:
|
@@ -561,7 +563,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
561
563
|
created_before: Optional[datetime] = None,
|
562
564
|
order_by_date: Optional[str] = None,
|
563
565
|
**kwargs: Any,
|
564
|
-
) -> List[
|
566
|
+
) -> List[RegistryModelVersion]:
|
565
567
|
"""List model versions from the MLflow model registry.
|
566
568
|
|
567
569
|
Args:
|
@@ -704,7 +706,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
704
706
|
|
705
707
|
def get_model_uri_artifact_store(
|
706
708
|
self,
|
707
|
-
model_version:
|
709
|
+
model_version: RegistryModelVersion,
|
708
710
|
) -> str:
|
709
711
|
"""Get the model URI artifact store.
|
710
712
|
|
@@ -723,7 +725,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
723
725
|
def _cast_mlflow_version_to_model_version(
|
724
726
|
self,
|
725
727
|
mlflow_model_version: MLflowModelVersion,
|
726
|
-
) ->
|
728
|
+
) -> RegistryModelVersion:
|
727
729
|
"""Cast an MLflow model version to a model version.
|
728
730
|
|
729
731
|
Args:
|
@@ -748,7 +750,7 @@ class MLFlowModelRegistry(BaseModelRegistry):
|
|
748
750
|
)
|
749
751
|
except ImportError:
|
750
752
|
model_library = None
|
751
|
-
return
|
753
|
+
return RegistryModelVersion(
|
752
754
|
registered_model=RegisteredModel(name=mlflow_model_version.name),
|
753
755
|
model_format=MLFLOW_MODEL_FORMAT,
|
754
756
|
model_library=model_library,
|
zenml/metadata/lazy_load.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
from typing import TYPE_CHECKING, Optional
|
17
17
|
|
18
18
|
if TYPE_CHECKING:
|
19
|
-
from zenml.model.
|
19
|
+
from zenml.model.model import Model
|
20
20
|
from zenml.models import RunMetadataResponse
|
21
21
|
|
22
22
|
|
@@ -30,18 +30,18 @@ class RunMetadataLazyGetter:
|
|
30
30
|
|
31
31
|
def __init__(
|
32
32
|
self,
|
33
|
-
|
33
|
+
_lazy_load_model: "Model",
|
34
34
|
_lazy_load_artifact_name: Optional[str],
|
35
35
|
_lazy_load_artifact_version: Optional[str],
|
36
36
|
):
|
37
37
|
"""Initialize a RunMetadataLazyGetter.
|
38
38
|
|
39
39
|
Args:
|
40
|
-
|
40
|
+
_lazy_load_model: The model version.
|
41
41
|
_lazy_load_artifact_name: The artifact name.
|
42
42
|
_lazy_load_artifact_version: The artifact version.
|
43
43
|
"""
|
44
|
-
self.
|
44
|
+
self._lazy_load_model = _lazy_load_model
|
45
45
|
self._lazy_load_artifact_name = _lazy_load_artifact_name
|
46
46
|
self._lazy_load_artifact_version = _lazy_load_artifact_version
|
47
47
|
|
@@ -57,7 +57,7 @@ class RunMetadataLazyGetter:
|
|
57
57
|
from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse
|
58
58
|
|
59
59
|
return LazyRunMetadataResponse(
|
60
|
-
|
60
|
+
_lazy_load_model=self._lazy_load_model,
|
61
61
|
_lazy_load_artifact_name=self._lazy_load_artifact_name,
|
62
62
|
_lazy_load_artifact_version=self._lazy_load_artifact_version,
|
63
63
|
_lazy_load_metadata_name=key,
|
zenml/model/lazy_load.py
CHANGED
@@ -17,7 +17,7 @@ from typing import Optional
|
|
17
17
|
|
18
18
|
from pydantic import BaseModel
|
19
19
|
|
20
|
-
from zenml.model.
|
20
|
+
from zenml.model.model import Model
|
21
21
|
|
22
22
|
|
23
23
|
class ModelVersionDataLazyLoader(BaseModel):
|
@@ -28,7 +28,7 @@ class ModelVersionDataLazyLoader(BaseModel):
|
|
28
28
|
model version during runtime time of the step.
|
29
29
|
"""
|
30
30
|
|
31
|
-
|
31
|
+
model: Model
|
32
32
|
artifact_name: Optional[str] = None
|
33
33
|
artifact_version: Optional[str] = None
|
34
34
|
metadata_name: Optional[str] = None
|
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
12
|
# or implied. See the License for the specific language governing
|
13
13
|
# permissions and limitations under the License.
|
14
|
-
"""
|
14
|
+
"""Model user facing interface to pass into pipeline or step."""
|
15
15
|
|
16
16
|
from typing import (
|
17
17
|
TYPE_CHECKING,
|
@@ -42,8 +42,8 @@ if TYPE_CHECKING:
|
|
42
42
|
logger = get_logger(__name__)
|
43
43
|
|
44
44
|
|
45
|
-
class
|
46
|
-
"""
|
45
|
+
class Model(BaseModel):
|
46
|
+
"""Model class to pass into pipeline or step to set it into a model context.
|
47
47
|
|
48
48
|
name: The name of the model.
|
49
49
|
license: The license under which the model is created.
|
@@ -54,8 +54,8 @@ class ModelVersion(BaseModel):
|
|
54
54
|
trade_offs: The tradeoffs of the model.
|
55
55
|
ethics: The ethical implications of the model.
|
56
56
|
tags: Tags associated with the model.
|
57
|
-
version: The
|
58
|
-
to a specific version/stage. If skipped new
|
57
|
+
version: The version name, version number or stage is optional and points model context
|
58
|
+
to a specific version/stage. If skipped new version will be created.
|
59
59
|
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
|
60
60
|
if available in active stack.
|
61
61
|
"""
|
@@ -97,7 +97,7 @@ class ModelVersion(BaseModel):
|
|
97
97
|
self._get_or_create_model_version()
|
98
98
|
except RuntimeError:
|
99
99
|
logger.info(
|
100
|
-
f"
|
100
|
+
f"Version `{self.version}` of `{self.name}` model doesn't exist "
|
101
101
|
"and cannot be fetched from the Model Control Plane."
|
102
102
|
)
|
103
103
|
return self._id
|
@@ -128,7 +128,7 @@ class ModelVersion(BaseModel):
|
|
128
128
|
self._get_or_create_model_version()
|
129
129
|
except RuntimeError:
|
130
130
|
logger.info(
|
131
|
-
f"
|
131
|
+
f"Version `{self.version}` of `{self.name}` model doesn't exist "
|
132
132
|
"and cannot be fetched from the Model Control Plane."
|
133
133
|
)
|
134
134
|
return self._number
|
@@ -149,7 +149,7 @@ class ModelVersion(BaseModel):
|
|
149
149
|
return ModelStages(stage)
|
150
150
|
except RuntimeError:
|
151
151
|
logger.info(
|
152
|
-
f"
|
152
|
+
f"Version `{self.version}` of `{self.name}` model doesn't exist "
|
153
153
|
"and cannot be fetched from the Model Control Plane."
|
154
154
|
)
|
155
155
|
return None
|
@@ -283,7 +283,7 @@ class ModelVersion(BaseModel):
|
|
283
283
|
def set_stage(
|
284
284
|
self, stage: Union[str, ModelStages], force: bool = False
|
285
285
|
) -> None:
|
286
|
-
"""Sets this Model
|
286
|
+
"""Sets this Model to a desired stage.
|
287
287
|
|
288
288
|
Args:
|
289
289
|
stage: the target stage for model version.
|
@@ -431,7 +431,7 @@ class ModelVersion(BaseModel):
|
|
431
431
|
return LazyArtifactVersionResponse(
|
432
432
|
_lazy_load_name=name,
|
433
433
|
_lazy_load_version=version,
|
434
|
-
|
434
|
+
_lazy_load_model=Model(
|
435
435
|
name=self.name, version=self.version or self.number
|
436
436
|
),
|
437
437
|
)
|
@@ -441,7 +441,7 @@ class ModelVersion(BaseModel):
|
|
441
441
|
return None
|
442
442
|
|
443
443
|
def __eq__(self, other: object) -> bool:
|
444
|
-
"""Check two
|
444
|
+
"""Check two Models for equality.
|
445
445
|
|
446
446
|
Args:
|
447
447
|
other: object to compare with
|
@@ -449,7 +449,7 @@ class ModelVersion(BaseModel):
|
|
449
449
|
Returns:
|
450
450
|
True, if equal, False otherwise.
|
451
451
|
"""
|
452
|
-
if not isinstance(other,
|
452
|
+
if not isinstance(other, Model):
|
453
453
|
return NotImplemented
|
454
454
|
if self.name != other.name:
|
455
455
|
return False
|
@@ -509,7 +509,7 @@ class ModelVersion(BaseModel):
|
|
509
509
|
model_name_or_id=self.name
|
510
510
|
)
|
511
511
|
|
512
|
-
difference = {}
|
512
|
+
difference: Dict[str, Any] = {}
|
513
513
|
for key in (
|
514
514
|
"license",
|
515
515
|
"audience",
|
@@ -525,7 +525,14 @@ class ModelVersion(BaseModel):
|
|
525
525
|
"config": getattr(self, key),
|
526
526
|
"db": getattr(model, key),
|
527
527
|
}
|
528
|
-
|
528
|
+
if self.tags:
|
529
|
+
configured_tags = set(self.tags)
|
530
|
+
db_tags = {t.name for t in model.tags}
|
531
|
+
if db_tags != configured_tags:
|
532
|
+
difference["tags added"] = list(configured_tags - db_tags)
|
533
|
+
difference["tags removed"] = list(
|
534
|
+
db_tags - configured_tags
|
535
|
+
)
|
529
536
|
if difference:
|
530
537
|
logger.warning(
|
531
538
|
"Provided model configuration does not match "
|
@@ -588,7 +595,7 @@ class ModelVersion(BaseModel):
|
|
588
595
|
"db": mv.description,
|
589
596
|
}
|
590
597
|
if self.tags:
|
591
|
-
configured_tags = set(self.tags
|
598
|
+
configured_tags = set(self.tags)
|
592
599
|
db_tags = {t.name for t in mv.tags}
|
593
600
|
if db_tags != configured_tags:
|
594
601
|
difference["tags added"] = list(configured_tags - db_tags)
|
@@ -656,7 +663,7 @@ class ModelVersion(BaseModel):
|
|
656
663
|
# model version for current model was already
|
657
664
|
# created in the current run, not to create
|
658
665
|
# new model versions
|
659
|
-
pipeline_mv = context.pipeline_run.config.
|
666
|
+
pipeline_mv = context.pipeline_run.config.model
|
660
667
|
if (
|
661
668
|
pipeline_mv
|
662
669
|
and pipeline_mv.was_created_in_this_run
|
@@ -666,7 +673,7 @@ class ModelVersion(BaseModel):
|
|
666
673
|
self.version = pipeline_mv.version
|
667
674
|
else:
|
668
675
|
for step in context.pipeline_run.steps.values():
|
669
|
-
step_mv = step.config.
|
676
|
+
step_mv = step.config.model
|
670
677
|
if (
|
671
678
|
step_mv
|
672
679
|
and step_mv.was_created_in_this_run
|
@@ -714,21 +721,21 @@ class ModelVersion(BaseModel):
|
|
714
721
|
self._number = model_version.number
|
715
722
|
return model_version
|
716
723
|
|
717
|
-
def _merge(self,
|
718
|
-
self.license = self.license or
|
719
|
-
self.description = self.description or
|
720
|
-
self.audience = self.audience or
|
721
|
-
self.use_cases = self.use_cases or
|
722
|
-
self.limitations = self.limitations or
|
723
|
-
self.trade_offs = self.trade_offs or
|
724
|
-
self.ethics = self.ethics or
|
725
|
-
if
|
724
|
+
def _merge(self, model: "Model") -> None:
|
725
|
+
self.license = self.license or model.license
|
726
|
+
self.description = self.description or model.description
|
727
|
+
self.audience = self.audience or model.audience
|
728
|
+
self.use_cases = self.use_cases or model.use_cases
|
729
|
+
self.limitations = self.limitations or model.limitations
|
730
|
+
self.trade_offs = self.trade_offs or model.trade_offs
|
731
|
+
self.ethics = self.ethics or model.ethics
|
732
|
+
if model.tags is not None:
|
726
733
|
self.tags = list(
|
727
|
-
{t for t in self.tags or []}.union(set(
|
734
|
+
{t for t in self.tags or []}.union(set(model.tags))
|
728
735
|
)
|
729
736
|
|
730
737
|
def __hash__(self) -> int:
|
731
|
-
"""Get hash of the `
|
738
|
+
"""Get hash of the `Model`.
|
732
739
|
|
733
740
|
Returns:
|
734
741
|
Hash function results
|