zenml-nightly 0.57.1.dev20240522__py3-none-any.whl → 0.57.1.dev20240527__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/base.py +15 -16
- zenml/client.py +11 -1
- zenml/config/__init__.py +2 -0
- zenml/config/pipeline_configurations.py +2 -0
- zenml/config/pipeline_run_configuration.py +2 -0
- zenml/config/retry_config.py +27 -0
- zenml/config/server_config.py +13 -9
- zenml/config/step_configurations.py +2 -0
- zenml/constants.py +1 -0
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +2 -0
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +14 -0
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +4 -0
- zenml/integrations/slack/alerters/slack_alerter.py +0 -2
- zenml/model/model.py +77 -45
- zenml/models/v2/core/model_version.py +1 -1
- zenml/models/v2/core/pipeline_run.py +12 -0
- zenml/models/v2/core/step_run.py +12 -0
- zenml/models/v2/misc/server_models.py +9 -3
- zenml/new/pipelines/run_utils.py +8 -2
- zenml/new/steps/step_decorator.py +5 -0
- zenml/orchestrators/step_launcher.py +71 -53
- zenml/orchestrators/step_runner.py +26 -132
- zenml/orchestrators/utils.py +158 -1
- zenml/steps/base_step.py +7 -0
- zenml/utils/dashboard_utils.py +4 -8
- zenml/zen_server/deploy/helm/templates/_environment.tpl +5 -5
- zenml/zen_server/deploy/helm/values.yaml +13 -9
- zenml/zen_server/pipeline_deployment/utils.py +6 -2
- zenml/zen_server/routers/auth_endpoints.py +4 -4
- zenml/zen_server/zen_server_api.py +1 -1
- zenml/zen_stores/base_zen_store.py +2 -2
- zenml/zen_stores/schemas/pipeline_run_schemas.py +12 -0
- zenml/zen_stores/schemas/step_run_schemas.py +14 -0
- zenml/zen_stores/sql_zen_store.py +4 -2
- {zenml_nightly-0.57.1.dev20240522.dist-info → zenml_nightly-0.57.1.dev20240527.dist-info}/METADATA +3 -3
- {zenml_nightly-0.57.1.dev20240522.dist-info → zenml_nightly-0.57.1.dev20240527.dist-info}/RECORD +40 -39
- {zenml_nightly-0.57.1.dev20240522.dist-info → zenml_nightly-0.57.1.dev20240527.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.57.1.dev20240522.dist-info → zenml_nightly-0.57.1.dev20240527.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.57.1.dev20240522.dist-info → zenml_nightly-0.57.1.dev20240527.dist-info}/entry_points.txt +0 -0
zenml/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.57.1.
|
1
|
+
0.57.1.dev20240527
|
zenml/cli/base.py
CHANGED
@@ -91,7 +91,7 @@ ZENML_PROJECT_TEMPLATES = dict(
|
|
91
91
|
),
|
92
92
|
llm_finetuning=ZenMLProjectTemplateLocation(
|
93
93
|
github_url="zenml-io/template-llm-finetuning",
|
94
|
-
github_tag="2024.
|
94
|
+
github_tag="2024.05.23", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
|
95
95
|
),
|
96
96
|
)
|
97
97
|
|
@@ -432,6 +432,13 @@ def go() -> None:
|
|
432
432
|
|
433
433
|
zenml_tutorial_path = os.path.join(os.getcwd(), "zenml_tutorial")
|
434
434
|
|
435
|
+
if not is_jupyter_installed():
|
436
|
+
cli_utils.error(
|
437
|
+
"Jupyter Notebook or JupyterLab is not installed. "
|
438
|
+
"Please install the 'notebook' package with `pip` "
|
439
|
+
"first so you can run the tutorial notebooks."
|
440
|
+
)
|
441
|
+
|
435
442
|
with track_handler(event=AnalyticsEvent.RUN_ZENML_GO, metadata=metadata):
|
436
443
|
console.print(zenml_cli_privacy_message, width=80)
|
437
444
|
|
@@ -459,6 +466,7 @@ def go() -> None:
|
|
459
466
|
TUTORIAL_REPO,
|
460
467
|
tmp_cloned_dir,
|
461
468
|
branch=f"release/{zenml_version}",
|
469
|
+
depth=1, # to prevent timeouts when downloading
|
462
470
|
)
|
463
471
|
example_dir = os.path.join(
|
464
472
|
tmp_cloned_dir, "examples/quickstart"
|
@@ -483,23 +491,14 @@ def go() -> None:
|
|
483
491
|
)
|
484
492
|
input("Press ENTER to continue...")
|
485
493
|
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
["jupyter", "notebook"], cwd=zenml_tutorial_path
|
490
|
-
)
|
491
|
-
except subprocess.CalledProcessError as e:
|
492
|
-
cli_utils.error(
|
493
|
-
"An error occurred while launching Jupyter Notebook. "
|
494
|
-
"Please make sure Jupyter is properly installed and try again."
|
495
|
-
)
|
496
|
-
raise e
|
497
|
-
else:
|
494
|
+
try:
|
495
|
+
subprocess.check_call(["jupyter", "notebook"], cwd=zenml_tutorial_path)
|
496
|
+
except subprocess.CalledProcessError as e:
|
498
497
|
cli_utils.error(
|
499
|
-
"
|
500
|
-
"Please
|
501
|
-
"to run the tutorial notebooks."
|
498
|
+
"An error occurred while launching Jupyter Notebook. "
|
499
|
+
"Please make sure Jupyter is properly installed and try again."
|
502
500
|
)
|
501
|
+
raise e
|
503
502
|
|
504
503
|
|
505
504
|
def _prompt_email(event_source: AnalyticsEventSource) -> bool:
|
zenml/client.py
CHANGED
@@ -5679,7 +5679,7 @@ class Client(metaclass=ClientMetaClass):
|
|
5679
5679
|
|
5680
5680
|
def get_model_version(
|
5681
5681
|
self,
|
5682
|
-
model_name_or_id: Union[str, UUID],
|
5682
|
+
model_name_or_id: Optional[Union[str, UUID]] = None,
|
5683
5683
|
model_version_name_or_number_or_id: Optional[
|
5684
5684
|
Union[str, int, ModelStages, UUID]
|
5685
5685
|
] = None,
|
@@ -5702,7 +5702,17 @@ class Client(metaclass=ClientMetaClass):
|
|
5702
5702
|
Raises:
|
5703
5703
|
RuntimeError: In case method inputs don't adhere to restrictions.
|
5704
5704
|
KeyError: In case no model version with the identifiers exists.
|
5705
|
+
ValueError: In case retrieval is attempted using non UUID model version
|
5706
|
+
identifier and no model identifier provided.
|
5705
5707
|
"""
|
5708
|
+
if (
|
5709
|
+
not is_valid_uuid(model_version_name_or_number_or_id)
|
5710
|
+
and model_name_or_id is None
|
5711
|
+
):
|
5712
|
+
raise ValueError(
|
5713
|
+
"No model identifier provided and model version identifier "
|
5714
|
+
f"`{model_version_name_or_number_or_id}` is not a valid UUID."
|
5715
|
+
)
|
5706
5716
|
if cll := client_lazy_loader(
|
5707
5717
|
"get_model_version",
|
5708
5718
|
model_name_or_id=model_name_or_id,
|
zenml/config/__init__.py
CHANGED
@@ -25,8 +25,10 @@ order to persist the configuration across sessions.
|
|
25
25
|
"""
|
26
26
|
from zenml.config.docker_settings import DockerSettings
|
27
27
|
from zenml.config.resource_settings import ResourceSettings
|
28
|
+
from zenml.config.retry_config import StepRetryConfig
|
28
29
|
|
29
30
|
__all__ = [
|
30
31
|
"DockerSettings",
|
31
32
|
"ResourceSettings",
|
33
|
+
"StepRetryConfig",
|
32
34
|
]
|
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
18
18
|
from pydantic import validator
|
19
19
|
|
20
20
|
from zenml.config.constants import DOCKER_SETTINGS_KEY
|
21
|
+
from zenml.config.retry_config import StepRetryConfig
|
21
22
|
from zenml.config.source import Source, convert_source_validator
|
22
23
|
from zenml.config.strict_base_model import StrictBaseModel
|
23
24
|
from zenml.model.model import Model
|
@@ -43,6 +44,7 @@ class PipelineConfigurationUpdate(StrictBaseModel):
|
|
43
44
|
success_hook_source: Optional[Source] = None
|
44
45
|
model: Optional[Model] = None
|
45
46
|
parameters: Optional[Dict[str, Any]] = None
|
47
|
+
retry: Optional[StepRetryConfig] = None
|
46
48
|
|
47
49
|
_convert_source = convert_source_validator(
|
48
50
|
"failure_hook_source", "success_hook_source"
|
@@ -17,6 +17,7 @@ from typing import Any, Dict, Optional, Union
|
|
17
17
|
from uuid import UUID
|
18
18
|
|
19
19
|
from zenml.config.base_settings import BaseSettings
|
20
|
+
from zenml.config.retry_config import StepRetryConfig
|
20
21
|
from zenml.config.schedule import Schedule
|
21
22
|
from zenml.config.step_configurations import StepConfigurationUpdate
|
22
23
|
from zenml.config.strict_base_model import StrictBaseModel
|
@@ -42,3 +43,4 @@ class PipelineRunConfiguration(
|
|
42
43
|
extra: Dict[str, Any] = {}
|
43
44
|
model: Optional[Model] = None
|
44
45
|
parameters: Optional[Dict[str, Any]] = None
|
46
|
+
retry: Optional[StepRetryConfig] = None
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Retry configuration for a step."""
|
15
|
+
|
16
|
+
from zenml.config.strict_base_model import StrictBaseModel
|
17
|
+
|
18
|
+
|
19
|
+
class StepRetryConfig(StrictBaseModel):
|
20
|
+
"""Retry configuration for a step.
|
21
|
+
|
22
|
+
Delay is an integer (specified in seconds).
|
23
|
+
"""
|
24
|
+
|
25
|
+
max_retries: int = 1
|
26
|
+
delay: int = 0 # in seconds
|
27
|
+
backoff: int = 0
|
zenml/config/server_config.py
CHANGED
@@ -65,8 +65,16 @@ class ServerConfiguration(BaseModel):
|
|
65
65
|
|
66
66
|
Attributes:
|
67
67
|
deployment_type: The type of ZenML server deployment that is running.
|
68
|
-
|
69
|
-
|
68
|
+
server_url: The URL where the ZenML server API is reachable. Must be
|
69
|
+
configured for features that involve triggering workloads from the
|
70
|
+
ZenML dashboard (e.g., running pipelines). If not specified, the
|
71
|
+
clients will use the same URL used to connect them to the ZenML
|
72
|
+
server.
|
73
|
+
dashboard_url: The URL where the ZenML dashboard is reachable.
|
74
|
+
If not specified, the `server_url` value is used. This should be
|
75
|
+
configured if the dashboard is served from a different URL than the
|
76
|
+
ZenML server.
|
77
|
+
root_url_path: The root URL path for the ZenML API and dashboard.
|
70
78
|
auth_scheme: The authentication scheme used by the ZenML server.
|
71
79
|
jwt_token_algorithm: The algorithm used to sign and verify JWT tokens.
|
72
80
|
jwt_token_issuer: The issuer of the JWT tokens. If not specified, the
|
@@ -93,10 +101,6 @@ class ServerConfiguration(BaseModel):
|
|
93
101
|
2.0 device authorization request expires.
|
94
102
|
device_auth_polling_interval: The polling interval in seconds used to
|
95
103
|
poll the OAuth 2.0 device authorization endpoint.
|
96
|
-
dashboard_url: The URL where the ZenML dashboard is hosted. Used to
|
97
|
-
construct the OAuth 2.0 device authorization endpoint. If not set,
|
98
|
-
a partial URL is returned to the client which is used to construct
|
99
|
-
the full URL based on the server's root URL path.
|
100
104
|
device_expiration_minutes: The time in minutes that an OAuth 2.0 device is
|
101
105
|
allowed to be used to authenticate with the ZenML server. If not
|
102
106
|
set or if `jwt_token_expire_minutes` is not set, the devices are
|
@@ -204,7 +208,7 @@ class ServerConfiguration(BaseModel):
|
|
204
208
|
of the reserved values `enabled`, `yes`, `true`, `on`, the
|
205
209
|
`Permissions-Policy` header will be set to the default value
|
206
210
|
(`accelerometer=(), camera=(), geolocation=(), gyroscope=(),
|
207
|
-
|
211
|
+
magnetometer=(), microphone=(), payment=(), usb=()`). If set to
|
208
212
|
one of the reserved values `disabled`, `no`, `none`, `false`, `off`
|
209
213
|
or to an empty string, the `Permissions-Policy` header will not be
|
210
214
|
included in responses.
|
@@ -225,7 +229,8 @@ class ServerConfiguration(BaseModel):
|
|
225
229
|
"""
|
226
230
|
|
227
231
|
deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER
|
228
|
-
|
232
|
+
server_url: Optional[str] = None
|
233
|
+
dashboard_url: Optional[str] = None
|
229
234
|
root_url_path: str = ""
|
230
235
|
metadata: Dict[str, Any] = {}
|
231
236
|
auth_scheme: AuthScheme = AuthScheme.OAUTH2_PASSWORD_BEARER
|
@@ -245,7 +250,6 @@ class ServerConfiguration(BaseModel):
|
|
245
250
|
device_auth_polling_interval: int = (
|
246
251
|
DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING
|
247
252
|
)
|
248
|
-
dashboard_url: Optional[str] = None
|
249
253
|
device_expiration_minutes: Optional[int] = None
|
250
254
|
trusted_device_expiration_minutes: Optional[int] = None
|
251
255
|
|
@@ -32,6 +32,7 @@ from zenml.artifacts.external_artifact_config import (
|
|
32
32
|
from zenml.client_lazy_loader import ClientLazyLoader
|
33
33
|
from zenml.config.base_settings import BaseSettings, SettingsOrDict
|
34
34
|
from zenml.config.constants import DOCKER_SETTINGS_KEY, RESOURCE_SETTINGS_KEY
|
35
|
+
from zenml.config.retry_config import StepRetryConfig
|
35
36
|
from zenml.config.source import Source, convert_source_validator
|
36
37
|
from zenml.config.strict_base_model import StrictBaseModel
|
37
38
|
from zenml.logger import get_logger
|
@@ -137,6 +138,7 @@ class StepConfigurationUpdate(StrictBaseModel):
|
|
137
138
|
failure_hook_source: Optional[Source] = None
|
138
139
|
success_hook_source: Optional[Source] = None
|
139
140
|
model: Optional[Model] = None
|
141
|
+
retry: Optional[StepRetryConfig] = None
|
140
142
|
|
141
143
|
outputs: Mapping[str, PartialArtifactConfiguration] = {}
|
142
144
|
|
zenml/constants.py
CHANGED
@@ -148,6 +148,7 @@ ENV_ZENML_DISABLE_STEP_LOGS_STORAGE = "ZENML_DISABLE_STEP_LOGS_STORAGE"
|
|
148
148
|
ENV_ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES = (
|
149
149
|
"ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES"
|
150
150
|
)
|
151
|
+
ENV_ZENML_IGNORE_FAILURE_HOOK = "ZENML_IGNORE_FAILURE_HOOK"
|
151
152
|
|
152
153
|
# ZenML Server environment variables
|
153
154
|
ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_"
|
@@ -43,6 +43,7 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
43
43
|
permissions will be created.
|
44
44
|
step_pod_service_account_name: Name of the service account to use for the
|
45
45
|
step pods. If not provided, the default service account will be used.
|
46
|
+
privileged: If the container should be run in privileged mode.
|
46
47
|
pod_settings: Pod settings to apply.
|
47
48
|
"""
|
48
49
|
|
@@ -50,6 +51,7 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
50
51
|
timeout: int = 0
|
51
52
|
service_account_name: Optional[str] = None
|
52
53
|
step_pod_service_account_name: Optional[str] = None
|
54
|
+
privileged: bool = False
|
53
55
|
pod_settings: Optional[KubernetesPodSettings] = None
|
54
56
|
|
55
57
|
|
@@ -391,6 +391,20 @@ class KubernetesOrchestrator(ContainerizedOrchestrator):
|
|
391
391
|
# Authorize pod to run Kubernetes commands inside the cluster.
|
392
392
|
service_account_name = self._get_service_account_name(settings)
|
393
393
|
|
394
|
+
if settings.pod_settings:
|
395
|
+
# Remove all settings that specify on which pod to run for the
|
396
|
+
# orchestrator pod. These settings should only be used
|
397
|
+
# for the pods executing the actual steps.
|
398
|
+
pod_settings = settings.pod_settings.copy(
|
399
|
+
update={
|
400
|
+
"resources": {},
|
401
|
+
"node_selectors": {},
|
402
|
+
"affinity": {},
|
403
|
+
"tolerations": [],
|
404
|
+
}
|
405
|
+
)
|
406
|
+
settings = settings.copy(update={"pod_settings": pod_settings})
|
407
|
+
|
394
408
|
# Schedule as CRON job if CRON schedule is given.
|
395
409
|
if deployment.schedule:
|
396
410
|
if not deployment.schedule.cron_expression:
|
@@ -128,6 +128,9 @@ def build_pod_manifest(
|
|
128
128
|
env = env.copy() if env else {}
|
129
129
|
env.setdefault(ENV_ZENML_ENABLE_REPO_INIT_WARNINGS, "False")
|
130
130
|
|
131
|
+
security_context = k8s_client.V1SecurityContext(
|
132
|
+
privileged=settings.privileged
|
133
|
+
)
|
131
134
|
container_spec = k8s_client.V1Container(
|
132
135
|
name="main",
|
133
136
|
image=image_name,
|
@@ -137,6 +140,7 @@ def build_pod_manifest(
|
|
137
140
|
k8s_client.V1EnvVar(name=name, value=value)
|
138
141
|
for name, value in env.items()
|
139
142
|
],
|
143
|
+
security_context=security_context,
|
140
144
|
)
|
141
145
|
|
142
146
|
pod_spec = k8s_client.V1PodSpec(
|
@@ -276,8 +276,6 @@ class SlackAlerter(BaseAlerter):
|
|
276
276
|
|
277
277
|
approved = False # will be modified by handle()
|
278
278
|
|
279
|
-
# breakpoint()
|
280
|
-
|
281
279
|
@RTMClient.run_on(event="hello") # type: ignore
|
282
280
|
def post_initial_message(**payload: Any) -> None:
|
283
281
|
"""Post an initial message in a channel and start listening.
|
zenml/model/model.py
CHANGED
@@ -60,6 +60,8 @@ class Model(BaseModel):
|
|
60
60
|
to a specific version/stage. If skipped new version will be created.
|
61
61
|
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
|
62
62
|
if available in active stack.
|
63
|
+
model_version_id: The ID of a specific Model Version, if given - it will override
|
64
|
+
`name` and `version` settings. Used mostly internally.
|
63
65
|
"""
|
64
66
|
|
65
67
|
name: str
|
@@ -73,12 +75,12 @@ class Model(BaseModel):
|
|
73
75
|
tags: Optional[List[str]] = None
|
74
76
|
version: Optional[Union[ModelStages, int, str]] = None
|
75
77
|
save_models_to_registry: bool = True
|
78
|
+
model_version_id: Optional[UUID] = None
|
76
79
|
|
77
80
|
suppress_class_validation_warnings: bool = False
|
78
81
|
was_created_in_this_run: bool = False
|
79
82
|
|
80
83
|
_model_id: UUID = PrivateAttr(None)
|
81
|
-
_id: UUID = PrivateAttr(None)
|
82
84
|
_number: int = PrivateAttr(None)
|
83
85
|
|
84
86
|
#########################
|
@@ -93,16 +95,21 @@ class Model(BaseModel):
|
|
93
95
|
doesn't exist and can only be read given current
|
94
96
|
config (you used stage name or number as
|
95
97
|
a version name).
|
98
|
+
|
99
|
+
Raises:
|
100
|
+
RuntimeError: if model version doesn't exist and
|
101
|
+
cannot be fetched from the Model Control Plane.
|
96
102
|
"""
|
97
|
-
if self.
|
103
|
+
if self.model_version_id is None:
|
98
104
|
try:
|
99
|
-
self._get_or_create_model_version()
|
100
|
-
|
101
|
-
|
105
|
+
mv = self._get_or_create_model_version()
|
106
|
+
self.model_version_id = mv.id
|
107
|
+
except RuntimeError as e:
|
108
|
+
raise RuntimeError(
|
102
109
|
f"Version `{self.version}` of `{self.name}` model doesn't exist "
|
103
110
|
"and cannot be fetched from the Model Control Plane."
|
104
|
-
)
|
105
|
-
return self.
|
111
|
+
) from e
|
112
|
+
return self.model_version_id
|
106
113
|
|
107
114
|
@property
|
108
115
|
def model_id(self) -> UUID:
|
@@ -523,61 +530,82 @@ class Model(BaseModel):
|
|
523
530
|
from zenml.models import ModelRequest
|
524
531
|
|
525
532
|
zenml_client = Client()
|
526
|
-
|
527
|
-
|
528
|
-
|
533
|
+
if self.model_version_id:
|
534
|
+
mv = zenml_client.get_model_version(
|
535
|
+
model_version_name_or_number_or_id=self.model_version_id,
|
529
536
|
)
|
530
|
-
|
531
|
-
|
532
|
-
name=self.name,
|
533
|
-
license=self.license,
|
534
|
-
description=self.description,
|
535
|
-
audience=self.audience,
|
536
|
-
use_cases=self.use_cases,
|
537
|
-
limitations=self.limitations,
|
538
|
-
trade_offs=self.trade_offs,
|
539
|
-
ethics=self.ethics,
|
540
|
-
tags=self.tags,
|
541
|
-
user=zenml_client.active_user.id,
|
542
|
-
workspace=zenml_client.active_workspace.id,
|
543
|
-
save_models_to_registry=self.save_models_to_registry,
|
544
|
-
)
|
545
|
-
model_request = ModelRequest.parse_obj(model_request)
|
537
|
+
model = mv.model
|
538
|
+
else:
|
546
539
|
try:
|
547
|
-
model = zenml_client.zen_store.create_model(
|
548
|
-
model=model_request
|
549
|
-
)
|
550
|
-
logger.info(f"New model `{self.name}` was created implicitly.")
|
551
|
-
except EntityExistsError:
|
552
540
|
model = zenml_client.zen_store.get_model(
|
553
541
|
model_name_or_id=self.name
|
554
542
|
)
|
543
|
+
except KeyError:
|
544
|
+
model_request = ModelRequest(
|
545
|
+
name=self.name,
|
546
|
+
license=self.license,
|
547
|
+
description=self.description,
|
548
|
+
audience=self.audience,
|
549
|
+
use_cases=self.use_cases,
|
550
|
+
limitations=self.limitations,
|
551
|
+
trade_offs=self.trade_offs,
|
552
|
+
ethics=self.ethics,
|
553
|
+
tags=self.tags,
|
554
|
+
user=zenml_client.active_user.id,
|
555
|
+
workspace=zenml_client.active_workspace.id,
|
556
|
+
save_models_to_registry=self.save_models_to_registry,
|
557
|
+
)
|
558
|
+
model_request = ModelRequest.parse_obj(model_request)
|
559
|
+
try:
|
560
|
+
model = zenml_client.zen_store.create_model(
|
561
|
+
model=model_request
|
562
|
+
)
|
563
|
+
logger.info(
|
564
|
+
f"New model `{self.name}` was created implicitly."
|
565
|
+
)
|
566
|
+
except EntityExistsError:
|
567
|
+
model = zenml_client.zen_store.get_model(
|
568
|
+
model_name_or_id=self.name
|
569
|
+
)
|
555
570
|
|
556
571
|
self._model_id = model.id
|
557
572
|
return model
|
558
573
|
|
559
|
-
def _get_model_version(
|
574
|
+
def _get_model_version(
|
575
|
+
self, hydrate: bool = True
|
576
|
+
) -> "ModelVersionResponse":
|
560
577
|
"""This method gets a model version from Model Control Plane.
|
561
578
|
|
579
|
+
Args:
|
580
|
+
hydrate: Flag deciding whether to hydrate the output model(s)
|
581
|
+
by including metadata fields in the response.
|
582
|
+
|
562
583
|
Returns:
|
563
584
|
The model version based on configuration.
|
564
585
|
"""
|
565
586
|
from zenml.client import Client
|
566
587
|
|
567
588
|
zenml_client = Client()
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
589
|
+
if self.model_version_id:
|
590
|
+
mv = zenml_client.get_model_version(
|
591
|
+
model_version_name_or_number_or_id=self.model_version_id,
|
592
|
+
hydrate=hydrate,
|
593
|
+
)
|
594
|
+
else:
|
595
|
+
mv = zenml_client.get_model_version(
|
596
|
+
model_name_or_id=self.name,
|
597
|
+
model_version_name_or_number_or_id=self.version,
|
598
|
+
hydrate=hydrate,
|
599
|
+
)
|
600
|
+
self.model_version_id = mv.id
|
574
601
|
|
575
602
|
difference: Dict[str, Any] = {}
|
576
|
-
if
|
577
|
-
|
578
|
-
"
|
579
|
-
|
580
|
-
|
603
|
+
if mv.metadata:
|
604
|
+
if self.description and mv.description != self.description:
|
605
|
+
difference["description"] = {
|
606
|
+
"config": self.description,
|
607
|
+
"db": mv.description,
|
608
|
+
}
|
581
609
|
if self.tags:
|
582
610
|
configured_tags = set(self.tags)
|
583
611
|
db_tags = {t.name for t in mv.tags}
|
@@ -656,6 +684,7 @@ class Model(BaseModel):
|
|
656
684
|
and pipeline_mv.version is not None
|
657
685
|
):
|
658
686
|
self.version = pipeline_mv.version
|
687
|
+
self.model_version_id = pipeline_mv.model_version_id
|
659
688
|
else:
|
660
689
|
for step in context.pipeline_run.steps.values():
|
661
690
|
step_mv = step.config.model
|
@@ -666,8 +695,11 @@ class Model(BaseModel):
|
|
666
695
|
and step_mv.version is not None
|
667
696
|
):
|
668
697
|
self.version = step_mv.version
|
698
|
+
self.model_version_id = (
|
699
|
+
step_mv.model_version_id
|
700
|
+
)
|
669
701
|
break
|
670
|
-
if self.version:
|
702
|
+
if self.version or self.model_version_id:
|
671
703
|
model_version = self._get_model_version()
|
672
704
|
else:
|
673
705
|
raise KeyError
|
@@ -727,7 +759,7 @@ class Model(BaseModel):
|
|
727
759
|
|
728
760
|
logger.info(f"New model version `{self.version}` was created.")
|
729
761
|
|
730
|
-
self.
|
762
|
+
self.model_version_id = model_version.id
|
731
763
|
self._model_id = model_version.model.id
|
732
764
|
self._number = model_version.number
|
733
765
|
return model_version
|
@@ -39,6 +39,7 @@ from zenml.models.v2.base.scoped import (
|
|
39
39
|
WorkspaceScopedResponseMetadata,
|
40
40
|
WorkspaceScopedResponseResources,
|
41
41
|
)
|
42
|
+
from zenml.models.v2.core.model_version import ModelVersionResponse
|
42
43
|
|
43
44
|
if TYPE_CHECKING:
|
44
45
|
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
|
@@ -197,6 +198,8 @@ class PipelineRunResponseMetadata(WorkspaceScopedResponseMetadata):
|
|
197
198
|
class PipelineRunResponseResources(WorkspaceScopedResponseResources):
|
198
199
|
"""Class for all resource models associated with the pipeline run entity."""
|
199
200
|
|
201
|
+
model_version: Optional[ModelVersionResponse]
|
202
|
+
|
200
203
|
|
201
204
|
class PipelineRunResponse(
|
202
205
|
WorkspaceScopedResponse[
|
@@ -394,6 +397,15 @@ class PipelineRunResponse(
|
|
394
397
|
"""
|
395
398
|
return self.get_metadata().orchestrator_run_id
|
396
399
|
|
400
|
+
@property
|
401
|
+
def model_version(self) -> Optional[ModelVersionResponse]:
|
402
|
+
"""The `model_version` property.
|
403
|
+
|
404
|
+
Returns:
|
405
|
+
the value of the property.
|
406
|
+
"""
|
407
|
+
return self.get_resources().model_version
|
408
|
+
|
397
409
|
|
398
410
|
# ------------------ Filter Model ------------------
|
399
411
|
|
zenml/models/v2/core/step_run.py
CHANGED
@@ -30,6 +30,7 @@ from zenml.models.v2.base.scoped import (
|
|
30
30
|
WorkspaceScopedResponseMetadata,
|
31
31
|
WorkspaceScopedResponseResources,
|
32
32
|
)
|
33
|
+
from zenml.models.v2.core.model_version import ModelVersionResponse
|
33
34
|
|
34
35
|
if TYPE_CHECKING:
|
35
36
|
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
|
@@ -219,6 +220,8 @@ class StepRunResponseMetadata(WorkspaceScopedResponseMetadata):
|
|
219
220
|
class StepRunResponseResources(WorkspaceScopedResponseResources):
|
220
221
|
"""Class for all resource models associated with the step run entity."""
|
221
222
|
|
223
|
+
model_version: Optional[ModelVersionResponse]
|
224
|
+
|
222
225
|
|
223
226
|
class StepRunResponse(
|
224
227
|
WorkspaceScopedResponse[
|
@@ -435,6 +438,15 @@ class StepRunResponse(
|
|
435
438
|
"""
|
436
439
|
return self.get_metadata().run_metadata
|
437
440
|
|
441
|
+
@property
|
442
|
+
def model_version(self) -> Optional[ModelVersionResponse]:
|
443
|
+
"""The `model_version` property.
|
444
|
+
|
445
|
+
Returns:
|
446
|
+
the value of the property.
|
447
|
+
"""
|
448
|
+
return self.get_resources().model_version
|
449
|
+
|
438
450
|
|
439
451
|
# ------------------ Filter Model ------------------
|
440
452
|
|
@@ -78,11 +78,17 @@ class ServerModel(BaseModel):
|
|
78
78
|
auth_scheme: AuthScheme = Field(
|
79
79
|
title="The authentication scheme that the server is using.",
|
80
80
|
)
|
81
|
-
|
81
|
+
server_url: str = Field(
|
82
82
|
"",
|
83
|
-
title="The
|
83
|
+
title="The URL where the ZenML server API is reachable. If not "
|
84
|
+
"specified, the clients will use the same URL used to connect them to "
|
85
|
+
"the ZenML server.",
|
86
|
+
)
|
87
|
+
dashboard_url: str = Field(
|
88
|
+
"",
|
89
|
+
title="The URL where the ZenML dashboard is reachable. If "
|
90
|
+
"not specified, the `server_url` value will be used instead.",
|
84
91
|
)
|
85
|
-
|
86
92
|
analytics_enabled: bool = Field(
|
87
93
|
default=True, # We set a default for migrations from < 0.57.0
|
88
94
|
title="Enable server-side analytics.",
|
zenml/new/pipelines/run_utils.py
CHANGED
@@ -17,7 +17,7 @@ from uuid import UUID
|
|
17
17
|
from zenml import constants
|
18
18
|
from zenml.client import Client
|
19
19
|
from zenml.config.step_configurations import StepConfigurationUpdate
|
20
|
-
from zenml.enums import ExecutionStatus
|
20
|
+
from zenml.enums import ExecutionStatus, ModelStages
|
21
21
|
from zenml.logger import get_logger
|
22
22
|
from zenml.models import (
|
23
23
|
PipelineDeploymentBase,
|
@@ -166,7 +166,13 @@ def _update_new_requesters(
|
|
166
166
|
try:
|
167
167
|
model._get_model_version()
|
168
168
|
version_existed = key not in new_versions_requested
|
169
|
-
except KeyError:
|
169
|
+
except KeyError as e:
|
170
|
+
if model.version in ModelStages.values():
|
171
|
+
raise KeyError(
|
172
|
+
f"Unable to get model `{model.name}` using stage "
|
173
|
+
f"`{model.version}`, please check that the model "
|
174
|
+
"version in given stage exists before running a pipeline."
|
175
|
+
) from e
|
170
176
|
version_existed = False
|
171
177
|
if not version_existed:
|
172
178
|
model.was_created_in_this_run = True
|