zenml-nightly 0.70.0.dev20241126__py3-none-any.whl → 0.70.0.dev20241128__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/artifact_stores/base_artifact_store.py +2 -2
- zenml/artifacts/utils.py +1 -1
- zenml/cli/__init__.py +3 -0
- zenml/cli/login.py +26 -0
- zenml/integrations/__init__.py +1 -0
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +14 -6
- zenml/integrations/constants.py +1 -0
- zenml/integrations/kubernetes/orchestrators/kube_utils.py +46 -2
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +13 -2
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +3 -1
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +3 -2
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +3 -1
- zenml/integrations/modal/__init__.py +46 -0
- zenml/integrations/modal/flavors/__init__.py +26 -0
- zenml/integrations/modal/flavors/modal_step_operator_flavor.py +125 -0
- zenml/integrations/modal/step_operators/__init__.py +22 -0
- zenml/integrations/modal/step_operators/modal_step_operator.py +242 -0
- zenml/io/filesystem.py +2 -2
- zenml/io/local_filesystem.py +3 -3
- zenml/model/model.py +0 -82
- zenml/orchestrators/step_run_utils.py +8 -3
- zenml/orchestrators/step_runner.py +1 -1
- zenml/orchestrators/utils.py +24 -2
- zenml/steps/entrypoint_function_utils.py +3 -1
- zenml/zen_server/cloud_utils.py +3 -1
- zenml/zen_server/rbac/endpoint_utils.py +6 -4
- zenml/zen_server/rbac/models.py +3 -2
- zenml/zen_server/rbac/utils.py +4 -7
- zenml/zen_server/routers/users_endpoints.py +35 -37
- zenml/zen_server/routers/workspaces_endpoints.py +25 -36
- zenml/zen_stores/sql_zen_store.py +13 -0
- {zenml_nightly-0.70.0.dev20241126.dist-info → zenml_nightly-0.70.0.dev20241128.dist-info}/METADATA +1 -1
- {zenml_nightly-0.70.0.dev20241126.dist-info → zenml_nightly-0.70.0.dev20241128.dist-info}/RECORD +37 -33
- zenml/utils/cloud_utils.py +0 -40
- {zenml_nightly-0.70.0.dev20241126.dist-info → zenml_nightly-0.70.0.dev20241128.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.70.0.dev20241126.dist-info → zenml_nightly-0.70.0.dev20241128.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.70.0.dev20241126.dist-info → zenml_nightly-0.70.0.dev20241128.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,242 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Modal step operator implementation."""
|
15
|
+
|
16
|
+
import asyncio
|
17
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, cast
|
18
|
+
|
19
|
+
import modal
|
20
|
+
from modal_proto import api_pb2
|
21
|
+
|
22
|
+
from zenml.client import Client
|
23
|
+
from zenml.config.build_configuration import BuildConfiguration
|
24
|
+
from zenml.config.resource_settings import ByteUnit, ResourceSettings
|
25
|
+
from zenml.enums import StackComponentType
|
26
|
+
from zenml.integrations.modal.flavors import (
|
27
|
+
ModalStepOperatorConfig,
|
28
|
+
ModalStepOperatorSettings,
|
29
|
+
)
|
30
|
+
from zenml.logger import get_logger
|
31
|
+
from zenml.stack import Stack, StackValidator
|
32
|
+
from zenml.step_operators import BaseStepOperator
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from zenml.config.base_settings import BaseSettings
|
36
|
+
from zenml.config.step_run_info import StepRunInfo
|
37
|
+
from zenml.models import PipelineDeploymentBase
|
38
|
+
|
39
|
+
logger = get_logger(__name__)
|
40
|
+
|
41
|
+
MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY = "modal_step_operator"
|
42
|
+
|
43
|
+
|
44
|
+
def get_gpu_values(
|
45
|
+
settings: ModalStepOperatorSettings, resource_settings: ResourceSettings
|
46
|
+
) -> Optional[str]:
|
47
|
+
"""Get the GPU values for the Modal step operator.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
settings: The Modal step operator settings.
|
51
|
+
resource_settings: The resource settings.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
The GPU string if a count is specified, otherwise the GPU type.
|
55
|
+
"""
|
56
|
+
if not settings.gpu:
|
57
|
+
return None
|
58
|
+
gpu_count = resource_settings.gpu_count
|
59
|
+
return f"{settings.gpu}:{gpu_count}" if gpu_count else settings.gpu
|
60
|
+
|
61
|
+
|
62
|
+
class ModalStepOperator(BaseStepOperator):
|
63
|
+
"""Step operator to run a step on Modal.
|
64
|
+
|
65
|
+
This class defines code that can set up a Modal environment and run
|
66
|
+
functions in it.
|
67
|
+
"""
|
68
|
+
|
69
|
+
@property
|
70
|
+
def config(self) -> ModalStepOperatorConfig:
|
71
|
+
"""Get the Modal step operator configuration.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
The Modal step operator configuration.
|
75
|
+
"""
|
76
|
+
return cast(ModalStepOperatorConfig, self._config)
|
77
|
+
|
78
|
+
@property
|
79
|
+
def settings_class(self) -> Optional[Type["BaseSettings"]]:
|
80
|
+
"""Get the settings class for the Modal step operator.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
The Modal step operator settings class.
|
84
|
+
"""
|
85
|
+
return ModalStepOperatorSettings
|
86
|
+
|
87
|
+
@property
|
88
|
+
def validator(self) -> Optional[StackValidator]:
|
89
|
+
"""Get the stack validator for the Modal step operator.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
The stack validator.
|
93
|
+
"""
|
94
|
+
|
95
|
+
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
|
96
|
+
if stack.artifact_store.config.is_local:
|
97
|
+
return False, (
|
98
|
+
"The Modal step operator runs code remotely and "
|
99
|
+
"needs to write files into the artifact store, but the "
|
100
|
+
f"artifact store `{stack.artifact_store.name}` of the "
|
101
|
+
"active stack is local. Please ensure that your stack "
|
102
|
+
"contains a remote artifact store when using the Modal "
|
103
|
+
"step operator."
|
104
|
+
)
|
105
|
+
|
106
|
+
container_registry = stack.container_registry
|
107
|
+
assert container_registry is not None
|
108
|
+
|
109
|
+
if container_registry.config.is_local:
|
110
|
+
return False, (
|
111
|
+
"The Modal step operator runs code remotely and "
|
112
|
+
"needs to push/pull Docker images, but the "
|
113
|
+
f"container registry `{container_registry.name}` of the "
|
114
|
+
"active stack is local. Please ensure that your stack "
|
115
|
+
"contains a remote container registry when using the "
|
116
|
+
"Modal step operator."
|
117
|
+
)
|
118
|
+
|
119
|
+
return True, ""
|
120
|
+
|
121
|
+
return StackValidator(
|
122
|
+
required_components={
|
123
|
+
StackComponentType.CONTAINER_REGISTRY,
|
124
|
+
StackComponentType.IMAGE_BUILDER,
|
125
|
+
},
|
126
|
+
custom_validation_function=_validate_remote_components,
|
127
|
+
)
|
128
|
+
|
129
|
+
def get_docker_builds(
|
130
|
+
self, deployment: "PipelineDeploymentBase"
|
131
|
+
) -> List["BuildConfiguration"]:
|
132
|
+
"""Get the Docker build configurations for the Modal step operator.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
deployment: The pipeline deployment.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
A list of Docker build configurations.
|
139
|
+
"""
|
140
|
+
builds = []
|
141
|
+
for step_name, step in deployment.step_configurations.items():
|
142
|
+
if step.config.step_operator == self.name:
|
143
|
+
build = BuildConfiguration(
|
144
|
+
key=MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY,
|
145
|
+
settings=step.config.docker_settings,
|
146
|
+
step_name=step_name,
|
147
|
+
)
|
148
|
+
builds.append(build)
|
149
|
+
|
150
|
+
return builds
|
151
|
+
|
152
|
+
def launch(
|
153
|
+
self,
|
154
|
+
info: "StepRunInfo",
|
155
|
+
entrypoint_command: List[str],
|
156
|
+
environment: Dict[str, str],
|
157
|
+
) -> None:
|
158
|
+
"""Launch a step run on Modal.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
info: The step run information.
|
162
|
+
entrypoint_command: The entrypoint command for the step.
|
163
|
+
environment: The environment variables for the step.
|
164
|
+
|
165
|
+
Raises:
|
166
|
+
RuntimeError: If no Docker credentials are found for the container registry.
|
167
|
+
ValueError: If no container registry is found in the stack.
|
168
|
+
"""
|
169
|
+
settings = cast(ModalStepOperatorSettings, self.get_settings(info))
|
170
|
+
image_name = info.get_image(key=MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY)
|
171
|
+
zc = Client()
|
172
|
+
stack = zc.active_stack
|
173
|
+
|
174
|
+
if not stack.container_registry:
|
175
|
+
raise ValueError(
|
176
|
+
"No Container registry found in the stack. "
|
177
|
+
"Please add a container registry and ensure "
|
178
|
+
"it is correctly configured."
|
179
|
+
)
|
180
|
+
|
181
|
+
if docker_creds := stack.container_registry.credentials:
|
182
|
+
docker_username, docker_password = docker_creds
|
183
|
+
else:
|
184
|
+
raise RuntimeError(
|
185
|
+
"No Docker credentials found for the container registry."
|
186
|
+
)
|
187
|
+
|
188
|
+
my_secret = modal.secret._Secret.from_dict(
|
189
|
+
{
|
190
|
+
"REGISTRY_USERNAME": docker_username,
|
191
|
+
"REGISTRY_PASSWORD": docker_password,
|
192
|
+
}
|
193
|
+
)
|
194
|
+
|
195
|
+
spec = modal.image.DockerfileSpec(
|
196
|
+
commands=[f"FROM {image_name}"], context_files={}
|
197
|
+
)
|
198
|
+
|
199
|
+
zenml_image = modal.Image._from_args(
|
200
|
+
dockerfile_function=lambda *_, **__: spec,
|
201
|
+
force_build=False,
|
202
|
+
image_registry_config=modal.image._ImageRegistryConfig(
|
203
|
+
api_pb2.REGISTRY_AUTH_TYPE_STATIC_CREDS, my_secret
|
204
|
+
),
|
205
|
+
).env(environment)
|
206
|
+
|
207
|
+
resource_settings = info.config.resource_settings
|
208
|
+
gpu_values = get_gpu_values(settings, resource_settings)
|
209
|
+
|
210
|
+
app = modal.App(
|
211
|
+
f"zenml-{info.run_name}-{info.step_run_id}-{info.pipeline_step_name}"
|
212
|
+
)
|
213
|
+
|
214
|
+
async def run_sandbox() -> asyncio.Future[None]:
|
215
|
+
loop = asyncio.get_event_loop()
|
216
|
+
future = loop.create_future()
|
217
|
+
with modal.enable_output():
|
218
|
+
async with app.run():
|
219
|
+
memory_mb = resource_settings.get_memory(ByteUnit.MB)
|
220
|
+
memory_int = (
|
221
|
+
int(memory_mb) if memory_mb is not None else None
|
222
|
+
)
|
223
|
+
sb = await modal.Sandbox.create.aio(
|
224
|
+
"bash",
|
225
|
+
"-c",
|
226
|
+
" ".join(entrypoint_command),
|
227
|
+
image=zenml_image,
|
228
|
+
gpu=gpu_values,
|
229
|
+
cpu=resource_settings.cpu_count,
|
230
|
+
memory=memory_int,
|
231
|
+
cloud=settings.cloud,
|
232
|
+
region=settings.region,
|
233
|
+
app=app,
|
234
|
+
timeout=86400, # 24h, the max Modal allows
|
235
|
+
)
|
236
|
+
|
237
|
+
await sb.wait.aio()
|
238
|
+
|
239
|
+
future.set_result(None)
|
240
|
+
return future
|
241
|
+
|
242
|
+
asyncio.run(run_sandbox())
|
zenml/io/filesystem.py
CHANGED
@@ -54,11 +54,11 @@ class BaseFilesystem(ABC):
|
|
54
54
|
|
55
55
|
@staticmethod
|
56
56
|
@abstractmethod
|
57
|
-
def open(
|
57
|
+
def open(path: PathType, mode: str = "r") -> Any:
|
58
58
|
"""Opens a file.
|
59
59
|
|
60
60
|
Args:
|
61
|
-
|
61
|
+
path: The path to the file.
|
62
62
|
mode: The mode to open the file in.
|
63
63
|
|
64
64
|
Returns:
|
zenml/io/local_filesystem.py
CHANGED
@@ -55,18 +55,18 @@ class LocalFilesystem(BaseFilesystem):
|
|
55
55
|
SUPPORTED_SCHEMES: ClassVar[Set[str]] = {""}
|
56
56
|
|
57
57
|
@staticmethod
|
58
|
-
def open(
|
58
|
+
def open(path: PathType, mode: str = "r") -> Any:
|
59
59
|
"""Open a file at the given path.
|
60
60
|
|
61
61
|
Args:
|
62
|
-
|
62
|
+
path: The path to the file.
|
63
63
|
mode: The mode to open the file.
|
64
64
|
|
65
65
|
Returns:
|
66
66
|
Any: The file object.
|
67
67
|
"""
|
68
68
|
encoding = "utf-8" if "b" not in mode else None
|
69
|
-
return open(
|
69
|
+
return open(path, mode=mode, encoding=encoding)
|
70
70
|
|
71
71
|
@staticmethod
|
72
72
|
def copyfile(
|
zenml/model/model.py
CHANGED
@@ -13,14 +13,12 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Model user facing interface to pass into pipeline or step."""
|
15
15
|
|
16
|
-
import datetime
|
17
16
|
from typing import (
|
18
17
|
TYPE_CHECKING,
|
19
18
|
Any,
|
20
19
|
Dict,
|
21
20
|
List,
|
22
21
|
Optional,
|
23
|
-
Tuple,
|
24
22
|
Union,
|
25
23
|
)
|
26
24
|
from uuid import UUID
|
@@ -41,7 +39,6 @@ if TYPE_CHECKING:
|
|
41
39
|
ModelResponse,
|
42
40
|
ModelVersionResponse,
|
43
41
|
PipelineRunResponse,
|
44
|
-
StepRunResponse,
|
45
42
|
)
|
46
43
|
|
47
44
|
logger = get_logger(__name__)
|
@@ -743,85 +740,6 @@ class Model(BaseModel):
|
|
743
740
|
)
|
744
741
|
)
|
745
742
|
|
746
|
-
def _prepare_model_version_before_step_launch(
|
747
|
-
self,
|
748
|
-
pipeline_run: "PipelineRunResponse",
|
749
|
-
step_run: Optional["StepRunResponse"],
|
750
|
-
return_logs: bool,
|
751
|
-
) -> Tuple[str, "PipelineRunResponse", Optional["StepRunResponse"]]:
|
752
|
-
"""Prepares model version inside pipeline run.
|
753
|
-
|
754
|
-
Args:
|
755
|
-
pipeline_run: pipeline run
|
756
|
-
step_run: step run (passed only if model version is defined in a step explicitly)
|
757
|
-
return_logs: whether to return logs or not
|
758
|
-
|
759
|
-
Returns:
|
760
|
-
Logs related to the Dashboard URL to show later.
|
761
|
-
"""
|
762
|
-
from zenml.client import Client
|
763
|
-
from zenml.models import PipelineRunUpdate, StepRunUpdate
|
764
|
-
|
765
|
-
logs = ""
|
766
|
-
|
767
|
-
# copy Model instance to prevent corrupting configs of the
|
768
|
-
# subsequent runs, if they share the same config object
|
769
|
-
self_copy = self.model_copy()
|
770
|
-
|
771
|
-
# in case request is within the step and no self-configuration is provided
|
772
|
-
# try reuse what's in the pipeline run first
|
773
|
-
if step_run is None and pipeline_run.model_version is not None:
|
774
|
-
self_copy.version = pipeline_run.model_version.name
|
775
|
-
self_copy.model_version_id = pipeline_run.model_version.id
|
776
|
-
# otherwise try to fill the templated name, if needed
|
777
|
-
elif isinstance(self_copy.version, str):
|
778
|
-
if pipeline_run.start_time:
|
779
|
-
start_time = pipeline_run.start_time
|
780
|
-
else:
|
781
|
-
start_time = datetime.datetime.now(datetime.timezone.utc)
|
782
|
-
self_copy.version = format_name_template(
|
783
|
-
self_copy.version,
|
784
|
-
date=start_time.strftime("%Y_%m_%d"),
|
785
|
-
time=start_time.strftime("%H_%M_%S_%f"),
|
786
|
-
)
|
787
|
-
|
788
|
-
# if exact model not yet defined - try to get/create and update it
|
789
|
-
# back to the run accordingly
|
790
|
-
if self_copy.model_version_id is None:
|
791
|
-
model_version_response = self_copy._get_or_create_model_version()
|
792
|
-
|
793
|
-
client = Client()
|
794
|
-
# update the configured model version id in runs accordingly
|
795
|
-
if step_run:
|
796
|
-
step_run = client.zen_store.update_run_step(
|
797
|
-
step_run_id=step_run.id,
|
798
|
-
step_run_update=StepRunUpdate(
|
799
|
-
model_version_id=model_version_response.id
|
800
|
-
),
|
801
|
-
)
|
802
|
-
else:
|
803
|
-
pipeline_run = client.zen_store.update_run(
|
804
|
-
run_id=pipeline_run.id,
|
805
|
-
run_update=PipelineRunUpdate(
|
806
|
-
model_version_id=model_version_response.id
|
807
|
-
),
|
808
|
-
)
|
809
|
-
|
810
|
-
if return_logs:
|
811
|
-
from zenml.utils.cloud_utils import try_get_model_version_url
|
812
|
-
|
813
|
-
if logs_to_show := try_get_model_version_url(
|
814
|
-
model_version_response
|
815
|
-
):
|
816
|
-
logs = logs_to_show
|
817
|
-
else:
|
818
|
-
logs = (
|
819
|
-
"Models can be viewed in the dashboard using ZenML Pro. Sign up "
|
820
|
-
"for a free trial at https://www.zenml.io/pro/"
|
821
|
-
)
|
822
|
-
self.model_version_id = self_copy.model_version_id
|
823
|
-
return logs, pipeline_run, step_run
|
824
|
-
|
825
743
|
@property
|
826
744
|
def _lazy_version(self) -> Optional[str]:
|
827
745
|
"""Get version name for lazy loader.
|
@@ -518,10 +518,15 @@ def log_model_version_dashboard_url(
|
|
518
518
|
Args:
|
519
519
|
model_version: The model version for which to log the dashboard URL.
|
520
520
|
"""
|
521
|
-
from zenml.utils.
|
521
|
+
from zenml.utils.dashboard_utils import get_model_version_url
|
522
522
|
|
523
|
-
if
|
524
|
-
logger.info(
|
523
|
+
if model_version_url := get_model_version_url(model_version.id):
|
524
|
+
logger.info(
|
525
|
+
"Dashboard URL for Model Version `%s (%s)`:\n%s",
|
526
|
+
model_version.model.name,
|
527
|
+
model_version.name,
|
528
|
+
model_version_url,
|
529
|
+
)
|
525
530
|
else:
|
526
531
|
logger.info(
|
527
532
|
"Models can be viewed in the dashboard using ZenML Pro. Sign up "
|
@@ -400,7 +400,7 @@ class StepRunner:
|
|
400
400
|
**artifact.get_hydrated_version().model_dump()
|
401
401
|
)
|
402
402
|
|
403
|
-
if data_type
|
403
|
+
if data_type in (None, Any) or is_union(get_origin(data_type)):
|
404
404
|
# Entrypoint function does not define a specific type for the input,
|
405
405
|
# we use the datatype of the stored artifact
|
406
406
|
data_type = source_utils.load(artifact.data_type)
|
zenml/orchestrators/utils.py
CHANGED
@@ -40,7 +40,9 @@ if TYPE_CHECKING:
|
|
40
40
|
from zenml.artifact_stores.base_artifact_store import BaseArtifactStore
|
41
41
|
|
42
42
|
|
43
|
-
def get_orchestrator_run_name(
|
43
|
+
def get_orchestrator_run_name(
|
44
|
+
pipeline_name: str, max_length: Optional[int] = None
|
45
|
+
) -> str:
|
44
46
|
"""Gets an orchestrator run name.
|
45
47
|
|
46
48
|
This run name is not the same as the ZenML run name but can instead be
|
@@ -48,11 +50,31 @@ def get_orchestrator_run_name(pipeline_name: str) -> str:
|
|
48
50
|
|
49
51
|
Args:
|
50
52
|
pipeline_name: Name of the pipeline that will run.
|
53
|
+
max_length: Maximum length of the generated name.
|
54
|
+
|
55
|
+
Raises:
|
56
|
+
ValueError: If the max length is below 8 characters.
|
51
57
|
|
52
58
|
Returns:
|
53
59
|
The orchestrator run name.
|
54
60
|
"""
|
55
|
-
|
61
|
+
suffix_length = 32
|
62
|
+
pipeline_name = f"{pipeline_name}_"
|
63
|
+
|
64
|
+
if max_length:
|
65
|
+
if max_length < 8:
|
66
|
+
raise ValueError(
|
67
|
+
"Maximum length for orchestrator run name must be 8 or above."
|
68
|
+
)
|
69
|
+
|
70
|
+
# Make sure we always have a certain suffix to guarantee no overlap
|
71
|
+
# with other runs
|
72
|
+
suffix_length = min(32, max(8, max_length - len(pipeline_name)))
|
73
|
+
pipeline_name = pipeline_name[: (max_length - suffix_length)]
|
74
|
+
|
75
|
+
suffix = "".join(random.choices("0123456789abcdef", k=suffix_length))
|
76
|
+
|
77
|
+
return f"{pipeline_name}{suffix}"
|
56
78
|
|
57
79
|
|
58
80
|
def is_setting_enabled(
|
@@ -195,7 +195,9 @@ class EntrypointFunctionDefinition(NamedTuple):
|
|
195
195
|
parameter: The function parameter for which the value was provided.
|
196
196
|
value: The input value.
|
197
197
|
"""
|
198
|
-
|
198
|
+
# We allow passing None for optional annotations that would otherwise
|
199
|
+
# not be allowed as a parameter
|
200
|
+
config_dict = ConfigDict(arbitrary_types_allowed=value is None)
|
199
201
|
|
200
202
|
# Create a pydantic model with just a single required field with the
|
201
203
|
# type annotation of the parameter to verify the input type including
|
zenml/zen_server/cloud_utils.py
CHANGED
@@ -170,7 +170,9 @@ class ZenMLCloudConnection:
|
|
170
170
|
token = self._fetch_auth_token()
|
171
171
|
self._session.headers.update({"Authorization": "Bearer " + token})
|
172
172
|
|
173
|
-
retries = Retry(
|
173
|
+
retries = Retry(
|
174
|
+
total=5, backoff_factor=0.1, status_forcelist=[502, 504]
|
175
|
+
)
|
174
176
|
self._session.mount(
|
175
177
|
"https://",
|
176
178
|
HTTPAdapter(
|
@@ -189,7 +189,7 @@ def verify_permissions_and_list_entities(
|
|
189
189
|
def verify_permissions_and_update_entity(
|
190
190
|
id: UUIDOrStr,
|
191
191
|
update_model: AnyUpdate,
|
192
|
-
get_method: Callable[[UUIDOrStr], AnyResponse],
|
192
|
+
get_method: Callable[[UUIDOrStr, bool], AnyResponse],
|
193
193
|
update_method: Callable[[UUIDOrStr, AnyUpdate], AnyResponse],
|
194
194
|
) -> AnyResponse:
|
195
195
|
"""Verify permissions and update an entity.
|
@@ -203,7 +203,8 @@ def verify_permissions_and_update_entity(
|
|
203
203
|
Returns:
|
204
204
|
A model of the updated entity.
|
205
205
|
"""
|
206
|
-
|
206
|
+
# We don't need the hydrated version here
|
207
|
+
model = get_method(id, False)
|
207
208
|
verify_permission_for_model(model, action=Action.UPDATE)
|
208
209
|
updated_model = update_method(model.id, update_model)
|
209
210
|
return dehydrate_response_model(updated_model)
|
@@ -211,7 +212,7 @@ def verify_permissions_and_update_entity(
|
|
211
212
|
|
212
213
|
def verify_permissions_and_delete_entity(
|
213
214
|
id: UUIDOrStr,
|
214
|
-
get_method: Callable[[UUIDOrStr], AnyResponse],
|
215
|
+
get_method: Callable[[UUIDOrStr, bool], AnyResponse],
|
215
216
|
delete_method: Callable[[UUIDOrStr], None],
|
216
217
|
) -> AnyResponse:
|
217
218
|
"""Verify permissions and delete an entity.
|
@@ -224,7 +225,8 @@ def verify_permissions_and_delete_entity(
|
|
224
225
|
Returns:
|
225
226
|
The deleted entity.
|
226
227
|
"""
|
227
|
-
|
228
|
+
# We don't need the hydrated version here
|
229
|
+
model = get_method(id, False)
|
228
230
|
verify_permission_for_model(model, action=Action.DELETE)
|
229
231
|
delete_method(model.id)
|
230
232
|
|
zenml/zen_server/rbac/models.py
CHANGED
@@ -59,7 +59,6 @@ class ResourceType(StrEnum):
|
|
59
59
|
PIPELINE_DEPLOYMENT = "pipeline_deployment"
|
60
60
|
PIPELINE_BUILD = "pipeline_build"
|
61
61
|
RUN_TEMPLATE = "run_template"
|
62
|
-
USER = "user"
|
63
62
|
SERVICE = "service"
|
64
63
|
RUN_METADATA = "run_metadata"
|
65
64
|
SECRET = "secret"
|
@@ -70,7 +69,9 @@ class ResourceType(StrEnum):
|
|
70
69
|
TAG = "tag"
|
71
70
|
TRIGGER = "trigger"
|
72
71
|
TRIGGER_EXECUTION = "trigger_execution"
|
73
|
-
|
72
|
+
# Deactivated for now
|
73
|
+
# USER = "user"
|
74
|
+
# WORKSPACE = "workspace"
|
74
75
|
|
75
76
|
|
76
77
|
class Resource(BaseModel):
|
zenml/zen_server/rbac/utils.py
CHANGED
@@ -413,8 +413,6 @@ def get_resource_type_for_model(
|
|
413
413
|
TagResponse,
|
414
414
|
TriggerExecutionResponse,
|
415
415
|
TriggerResponse,
|
416
|
-
UserResponse,
|
417
|
-
WorkspaceResponse,
|
418
416
|
)
|
419
417
|
|
420
418
|
mapping: Dict[
|
@@ -434,8 +432,8 @@ def get_resource_type_for_model(
|
|
434
432
|
ModelVersionResponse: ResourceType.MODEL_VERSION,
|
435
433
|
ArtifactResponse: ResourceType.ARTIFACT,
|
436
434
|
ArtifactVersionResponse: ResourceType.ARTIFACT_VERSION,
|
437
|
-
WorkspaceResponse: ResourceType.WORKSPACE,
|
438
|
-
UserResponse: ResourceType.USER,
|
435
|
+
# WorkspaceResponse: ResourceType.WORKSPACE,
|
436
|
+
# UserResponse: ResourceType.USER,
|
439
437
|
PipelineDeploymentResponse: ResourceType.PIPELINE_DEPLOYMENT,
|
440
438
|
PipelineBuildResponse: ResourceType.PIPELINE_BUILD,
|
441
439
|
PipelineRunResponse: ResourceType.PIPELINE_RUN,
|
@@ -570,7 +568,6 @@ def get_schema_for_resource_type(
|
|
570
568
|
TriggerExecutionSchema,
|
571
569
|
TriggerSchema,
|
572
570
|
UserSchema,
|
573
|
-
WorkspaceSchema,
|
574
571
|
)
|
575
572
|
|
576
573
|
mapping: Dict[ResourceType, Type["BaseSchema"]] = {
|
@@ -588,13 +585,13 @@ def get_schema_for_resource_type(
|
|
588
585
|
ResourceType.SERVICE: ServiceSchema,
|
589
586
|
ResourceType.TAG: TagSchema,
|
590
587
|
ResourceType.SERVICE_ACCOUNT: UserSchema,
|
591
|
-
ResourceType.WORKSPACE: WorkspaceSchema,
|
588
|
+
# ResourceType.WORKSPACE: WorkspaceSchema,
|
592
589
|
ResourceType.PIPELINE_RUN: PipelineRunSchema,
|
593
590
|
ResourceType.PIPELINE_DEPLOYMENT: PipelineDeploymentSchema,
|
594
591
|
ResourceType.PIPELINE_BUILD: PipelineBuildSchema,
|
595
592
|
ResourceType.RUN_TEMPLATE: RunTemplateSchema,
|
596
593
|
ResourceType.RUN_METADATA: RunMetadataSchema,
|
597
|
-
ResourceType.USER: UserSchema,
|
594
|
+
# ResourceType.USER: UserSchema,
|
598
595
|
ResourceType.ACTION: ActionSchema,
|
599
596
|
ResourceType.EVENT_SOURCE: EventSourceSchema,
|
600
597
|
ResourceType.TRIGGER: TriggerSchema,
|