zenml-nightly 0.70.0.dev20241119__py3-none-any.whl → 0.70.0.dev20241121__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/artifacts/artifact_config.py +32 -4
- zenml/artifacts/utils.py +12 -24
- zenml/cli/base.py +1 -1
- zenml/client.py +4 -19
- zenml/model/utils.py +0 -24
- zenml/models/v2/core/artifact_version.py +25 -2
- zenml/models/v2/core/model_version.py +0 -4
- zenml/models/v2/core/model_version_artifact.py +19 -76
- zenml/models/v2/core/model_version_pipeline_run.py +6 -39
- zenml/orchestrators/step_launcher.py +0 -1
- zenml/orchestrators/step_run_utils.py +4 -17
- zenml/orchestrators/step_runner.py +3 -1
- zenml/service_connectors/service_connector_registry.py +68 -57
- zenml/zen_server/routers/model_versions_endpoints.py +59 -0
- zenml/zen_server/routers/workspaces_endpoints.py +0 -130
- zenml/zen_stores/base_zen_store.py +2 -1
- zenml/zen_stores/migrations/versions/ec6307720f92_simplify_model_version_links.py +118 -0
- zenml/zen_stores/rest_zen_store.py +4 -4
- zenml/zen_stores/schemas/model_schemas.py +10 -94
- zenml/zen_stores/schemas/user_schemas.py +0 -8
- zenml/zen_stores/schemas/workspace_schemas.py +0 -14
- {zenml_nightly-0.70.0.dev20241119.dist-info → zenml_nightly-0.70.0.dev20241121.dist-info}/METADATA +1 -1
- {zenml_nightly-0.70.0.dev20241119.dist-info → zenml_nightly-0.70.0.dev20241121.dist-info}/RECORD +27 -26
- {zenml_nightly-0.70.0.dev20241119.dist-info → zenml_nightly-0.70.0.dev20241121.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.70.0.dev20241119.dist-info → zenml_nightly-0.70.0.dev20241121.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.70.0.dev20241119.dist-info → zenml_nightly-0.70.0.dev20241121.dist-info}/entry_points.txt +0 -0
@@ -23,23 +23,19 @@ from zenml.enums import GenericFilterOps
|
|
23
23
|
from zenml.models.v2.base.base import (
|
24
24
|
BaseDatedResponseBody,
|
25
25
|
BaseIdentifiedResponse,
|
26
|
+
BaseRequest,
|
26
27
|
BaseResponseMetadata,
|
27
28
|
BaseResponseResources,
|
28
29
|
)
|
29
|
-
from zenml.models.v2.base.filter import StrFilter
|
30
|
-
from zenml.models.v2.base.scoped import (
|
31
|
-
WorkspaceScopedFilter,
|
32
|
-
WorkspaceScopedRequest,
|
33
|
-
)
|
30
|
+
from zenml.models.v2.base.filter import BaseFilter, StrFilter
|
34
31
|
from zenml.models.v2.core.pipeline_run import PipelineRunResponse
|
35
32
|
|
36
33
|
# ------------------ Request Model ------------------
|
37
34
|
|
38
35
|
|
39
|
-
class ModelVersionPipelineRunRequest(
|
36
|
+
class ModelVersionPipelineRunRequest(BaseRequest):
|
40
37
|
"""Request model for links between model versions and pipeline runs."""
|
41
38
|
|
42
|
-
model: UUID
|
43
39
|
model_version: UUID
|
44
40
|
pipeline_run: UUID
|
45
41
|
|
@@ -62,7 +58,6 @@ class ModelVersionPipelineRunRequest(WorkspaceScopedRequest):
|
|
62
58
|
class ModelVersionPipelineRunResponseBody(BaseDatedResponseBody):
|
63
59
|
"""Response body for links between model versions and pipeline runs."""
|
64
60
|
|
65
|
-
model: UUID
|
66
61
|
model_version: UUID
|
67
62
|
pipeline_run: PipelineRunResponse
|
68
63
|
|
@@ -88,16 +83,6 @@ class ModelVersionPipelineRunResponse(
|
|
88
83
|
):
|
89
84
|
"""Response model for links between model versions and pipeline runs."""
|
90
85
|
|
91
|
-
# Body and metadata properties
|
92
|
-
@property
|
93
|
-
def model(self) -> UUID:
|
94
|
-
"""The `model` property.
|
95
|
-
|
96
|
-
Returns:
|
97
|
-
the value of the property.
|
98
|
-
"""
|
99
|
-
return self.get_body().model
|
100
|
-
|
101
86
|
@property
|
102
87
|
def model_version(self) -> UUID:
|
103
88
|
"""The `model_version` property.
|
@@ -120,39 +105,21 @@ class ModelVersionPipelineRunResponse(
|
|
120
105
|
# ------------------ Filter Model ------------------
|
121
106
|
|
122
107
|
|
123
|
-
class ModelVersionPipelineRunFilter(
|
108
|
+
class ModelVersionPipelineRunFilter(BaseFilter):
|
124
109
|
"""Model version pipeline run links filter model."""
|
125
110
|
|
126
111
|
FILTER_EXCLUDE_FIELDS = [
|
127
|
-
*
|
112
|
+
*BaseFilter.FILTER_EXCLUDE_FIELDS,
|
128
113
|
"pipeline_run_name",
|
129
114
|
"user",
|
130
115
|
]
|
131
116
|
CLI_EXCLUDE_FIELDS = [
|
132
|
-
*
|
133
|
-
"model_id",
|
117
|
+
*BaseFilter.CLI_EXCLUDE_FIELDS,
|
134
118
|
"model_version_id",
|
135
|
-
"user_id",
|
136
|
-
"workspace_id",
|
137
119
|
"updated",
|
138
120
|
"id",
|
139
121
|
]
|
140
122
|
|
141
|
-
workspace_id: Optional[Union[UUID, str]] = Field(
|
142
|
-
default=None,
|
143
|
-
description="The workspace of the Model Version",
|
144
|
-
union_mode="left_to_right",
|
145
|
-
)
|
146
|
-
user_id: Optional[Union[UUID, str]] = Field(
|
147
|
-
default=None,
|
148
|
-
description="The user of the Model Version",
|
149
|
-
union_mode="left_to_right",
|
150
|
-
)
|
151
|
-
model_id: Optional[Union[UUID, str]] = Field(
|
152
|
-
default=None,
|
153
|
-
description="Filter by model ID",
|
154
|
-
union_mode="left_to_right",
|
155
|
-
)
|
156
123
|
model_version_id: Optional[Union[UUID, str]] = Field(
|
157
124
|
default=None,
|
158
125
|
description="Filter by model version ID",
|
@@ -14,12 +14,12 @@
|
|
14
14
|
"""Utilities for creating step runs."""
|
15
15
|
|
16
16
|
from datetime import datetime
|
17
|
-
from typing import TYPE_CHECKING, Dict, List,
|
17
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
18
18
|
|
19
19
|
from zenml.client import Client
|
20
|
-
from zenml.config.step_configurations import
|
20
|
+
from zenml.config.step_configurations import Step
|
21
21
|
from zenml.constants import CODE_HASH_PARAMETER_NAME, TEXT_FIELD_MAX_LENGTH
|
22
|
-
from zenml.enums import
|
22
|
+
from zenml.enums import ExecutionStatus
|
23
23
|
from zenml.logger import get_logger
|
24
24
|
from zenml.model.utils import link_artifact_version_to_model_version
|
25
25
|
from zenml.models import (
|
@@ -344,7 +344,6 @@ def create_cached_step_runs(
|
|
344
344
|
if model_version := step_model_version or pipeline_model_version:
|
345
345
|
link_output_artifacts_to_model_version(
|
346
346
|
artifacts=step_run.outputs,
|
347
|
-
output_configurations=step_run.config.outputs,
|
348
347
|
model_version=model_version,
|
349
348
|
)
|
350
349
|
|
@@ -542,10 +541,7 @@ def link_pipeline_run_to_model_version(
|
|
542
541
|
client = Client()
|
543
542
|
client.zen_store.create_model_version_pipeline_run_link(
|
544
543
|
ModelVersionPipelineRunRequest(
|
545
|
-
user=client.active_user.id,
|
546
|
-
workspace=client.active_workspace.id,
|
547
544
|
pipeline_run=pipeline_run.id,
|
548
|
-
model=model_version.model.id,
|
549
545
|
model_version=model_version.id,
|
550
546
|
)
|
551
547
|
)
|
@@ -553,26 +549,17 @@ def link_pipeline_run_to_model_version(
|
|
553
549
|
|
554
550
|
def link_output_artifacts_to_model_version(
|
555
551
|
artifacts: Dict[str, List[ArtifactVersionResponse]],
|
556
|
-
output_configurations: Mapping[str, ArtifactConfiguration],
|
557
552
|
model_version: ModelVersionResponse,
|
558
553
|
) -> None:
|
559
554
|
"""Link the outputs of a step run to a model version.
|
560
555
|
|
561
556
|
Args:
|
562
557
|
artifacts: The step output artifacts.
|
563
|
-
output_configurations: The output configurations for the step.
|
564
558
|
model_version: The model version to link.
|
565
559
|
"""
|
566
|
-
for
|
560
|
+
for output_artifacts in artifacts.values():
|
567
561
|
for output_artifact in output_artifacts:
|
568
|
-
artifact_config = None
|
569
|
-
if output_artifact.save_type == ArtifactSaveType.STEP_OUTPUT and (
|
570
|
-
output_config := output_configurations.get(output_name, None)
|
571
|
-
):
|
572
|
-
artifact_config = output_config.artifact_config
|
573
|
-
|
574
562
|
link_artifact_version_to_model_version(
|
575
563
|
artifact_version=output_artifact,
|
576
564
|
model_version=model_version,
|
577
|
-
artifact_config=artifact_config,
|
578
565
|
)
|
@@ -247,7 +247,6 @@ class StepRunner:
|
|
247
247
|
artifacts={
|
248
248
|
k: [v] for k, v in output_artifacts.items()
|
249
249
|
},
|
250
|
-
output_configurations=step_run.config.outputs,
|
251
250
|
model_version=model_version,
|
252
251
|
)
|
253
252
|
finally:
|
@@ -577,9 +576,11 @@ class StepRunner:
|
|
577
576
|
uri = output_artifact_uris[output_name]
|
578
577
|
artifact_config = output_annotations[output_name].artifact_config
|
579
578
|
|
579
|
+
artifact_type = None
|
580
580
|
if artifact_config is not None:
|
581
581
|
has_custom_name = bool(artifact_config.name)
|
582
582
|
version = artifact_config.version
|
583
|
+
artifact_type = artifact_config.artifact_type
|
583
584
|
else:
|
584
585
|
has_custom_name, version = False, None
|
585
586
|
|
@@ -605,6 +606,7 @@ class StepRunner:
|
|
605
606
|
data=return_value,
|
606
607
|
materializer_class=materializer_class,
|
607
608
|
uri=uri,
|
609
|
+
artifact_type=artifact_type,
|
608
610
|
store_metadata=artifact_metadata_enabled,
|
609
611
|
store_visualizations=artifact_visualization_enabled,
|
610
612
|
has_custom_name=has_custom_name,
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Implementation of a service connector registry."""
|
15
15
|
|
16
|
+
import threading
|
16
17
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
17
18
|
|
18
19
|
from zenml.logger import get_logger
|
@@ -34,6 +35,7 @@ class ServiceConnectorRegistry:
|
|
34
35
|
"""Initialize the service connector registry."""
|
35
36
|
self.service_connector_types: Dict[str, ServiceConnectorTypeModel] = {}
|
36
37
|
self.initialized = False
|
38
|
+
self.lock = threading.RLock()
|
37
39
|
|
38
40
|
def register_service_connector_type(
|
39
41
|
self,
|
@@ -46,23 +48,25 @@ class ServiceConnectorRegistry:
|
|
46
48
|
service_connector_type: Service connector type.
|
47
49
|
overwrite: Whether to overwrite an existing service connector type.
|
48
50
|
"""
|
49
|
-
|
50
|
-
|
51
|
-
not in self.service_connector_types
|
52
|
-
or overwrite
|
53
|
-
):
|
54
|
-
self.service_connector_types[
|
51
|
+
with self.lock:
|
52
|
+
if (
|
55
53
|
service_connector_type.connector_type
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
54
|
+
not in self.service_connector_types
|
55
|
+
or overwrite
|
56
|
+
):
|
57
|
+
self.service_connector_types[
|
58
|
+
service_connector_type.connector_type
|
59
|
+
] = service_connector_type
|
60
|
+
logger.debug(
|
61
|
+
"Registered service connector type "
|
62
|
+
f"{service_connector_type.connector_type}."
|
63
|
+
)
|
64
|
+
else:
|
65
|
+
logger.debug(
|
66
|
+
f"Found existing service connector for type "
|
67
|
+
f"{service_connector_type.connector_type}: Skipping "
|
68
|
+
"registration."
|
69
|
+
)
|
66
70
|
|
67
71
|
def get_service_connector_type(
|
68
72
|
self,
|
@@ -201,54 +205,61 @@ class ServiceConnectorRegistry:
|
|
201
205
|
def register_builtin_service_connectors(self) -> None:
|
202
206
|
"""Registers the default built-in service connectors."""
|
203
207
|
# Only register built-in service connectors once
|
204
|
-
|
205
|
-
|
208
|
+
with self.lock:
|
209
|
+
if self.initialized:
|
210
|
+
return
|
206
211
|
|
207
|
-
|
212
|
+
self.initialized = True
|
208
213
|
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
214
|
+
try:
|
215
|
+
from zenml.integrations.aws.service_connectors.aws_service_connector import ( # noqa
|
216
|
+
AWSServiceConnector,
|
217
|
+
)
|
218
|
+
except ImportError as e:
|
219
|
+
logger.warning(f"Could not import AWS service connector: {e}.")
|
215
220
|
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
221
|
+
try:
|
222
|
+
from zenml.integrations.gcp.service_connectors.gcp_service_connector import ( # noqa
|
223
|
+
GCPServiceConnector,
|
224
|
+
)
|
225
|
+
except ImportError as e:
|
226
|
+
logger.warning(f"Could not import GCP service connector: {e}.")
|
222
227
|
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
228
|
+
try:
|
229
|
+
from zenml.integrations.azure.service_connectors.azure_service_connector import ( # noqa
|
230
|
+
AzureServiceConnector,
|
231
|
+
)
|
232
|
+
except ImportError as e:
|
233
|
+
logger.warning(
|
234
|
+
f"Could not import Azure service connector: {e}."
|
235
|
+
)
|
229
236
|
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
237
|
+
try:
|
238
|
+
from zenml.integrations.kubernetes.service_connectors.kubernetes_service_connector import ( # noqa
|
239
|
+
KubernetesServiceConnector,
|
240
|
+
)
|
241
|
+
except ImportError as e:
|
242
|
+
logger.warning(
|
243
|
+
f"Could not import Kubernetes service connector: {e}."
|
244
|
+
)
|
238
245
|
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
246
|
+
try:
|
247
|
+
from zenml.service_connectors.docker_service_connector import ( # noqa
|
248
|
+
DockerServiceConnector,
|
249
|
+
)
|
250
|
+
except ImportError as e:
|
251
|
+
logger.warning(
|
252
|
+
f"Could not import Docker service connector: {e}."
|
253
|
+
)
|
245
254
|
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
255
|
+
try:
|
256
|
+
from zenml.integrations.hyperai.service_connectors.hyperai_service_connector import ( # noqa
|
257
|
+
HyperAIServiceConnector,
|
258
|
+
)
|
259
|
+
except ImportError as e:
|
260
|
+
logger.warning(
|
261
|
+
f"Could not import HyperAI service connector: {e}."
|
262
|
+
)
|
252
263
|
|
253
264
|
|
254
265
|
service_connector_registry = ServiceConnectorRegistry()
|
@@ -29,9 +29,11 @@ from zenml.constants import (
|
|
29
29
|
)
|
30
30
|
from zenml.models import (
|
31
31
|
ModelVersionArtifactFilter,
|
32
|
+
ModelVersionArtifactRequest,
|
32
33
|
ModelVersionArtifactResponse,
|
33
34
|
ModelVersionFilter,
|
34
35
|
ModelVersionPipelineRunFilter,
|
36
|
+
ModelVersionPipelineRunRequest,
|
35
37
|
ModelVersionPipelineRunResponse,
|
36
38
|
ModelVersionResponse,
|
37
39
|
ModelVersionUpdate,
|
@@ -198,6 +200,34 @@ model_version_artifacts_router = APIRouter(
|
|
198
200
|
)
|
199
201
|
|
200
202
|
|
203
|
+
@model_version_artifacts_router.post(
|
204
|
+
"",
|
205
|
+
responses={401: error_response, 409: error_response, 422: error_response},
|
206
|
+
)
|
207
|
+
@handle_exceptions
|
208
|
+
def create_model_version_artifact_link(
|
209
|
+
model_version_artifact_link: ModelVersionArtifactRequest,
|
210
|
+
_: AuthContext = Security(authorize),
|
211
|
+
) -> ModelVersionArtifactResponse:
|
212
|
+
"""Create a new model version to artifact link.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
model_version_artifact_link: The model version to artifact link to create.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
The created model version to artifact link.
|
219
|
+
"""
|
220
|
+
model_version = zen_store().get_model_version(
|
221
|
+
model_version_artifact_link.model_version
|
222
|
+
)
|
223
|
+
verify_permission_for_model(model_version, action=Action.UPDATE)
|
224
|
+
|
225
|
+
mv = zen_store().create_model_version_artifact_link(
|
226
|
+
model_version_artifact_link
|
227
|
+
)
|
228
|
+
return mv
|
229
|
+
|
230
|
+
|
201
231
|
@model_version_artifacts_router.get(
|
202
232
|
"",
|
203
233
|
response_model=Page[ModelVersionArtifactResponse],
|
@@ -291,6 +321,35 @@ model_version_pipeline_runs_router = APIRouter(
|
|
291
321
|
)
|
292
322
|
|
293
323
|
|
324
|
+
@model_version_pipeline_runs_router.post(
|
325
|
+
"",
|
326
|
+
responses={401: error_response, 409: error_response, 422: error_response},
|
327
|
+
)
|
328
|
+
@handle_exceptions
|
329
|
+
def create_model_version_pipeline_run_link(
|
330
|
+
model_version_pipeline_run_link: ModelVersionPipelineRunRequest,
|
331
|
+
_: AuthContext = Security(authorize),
|
332
|
+
) -> ModelVersionPipelineRunResponse:
|
333
|
+
"""Create a new model version to pipeline run link.
|
334
|
+
|
335
|
+
Args:
|
336
|
+
model_version_pipeline_run_link: The model version to pipeline run link to create.
|
337
|
+
|
338
|
+
Returns:
|
339
|
+
- If Model Version to Pipeline Run Link already exists - returns the existing link.
|
340
|
+
- Otherwise, returns the newly created model version to pipeline run link.
|
341
|
+
"""
|
342
|
+
model_version = zen_store().get_model_version(
|
343
|
+
model_version_pipeline_run_link.model_version, hydrate=False
|
344
|
+
)
|
345
|
+
verify_permission_for_model(model_version, action=Action.UPDATE)
|
346
|
+
|
347
|
+
mv = zen_store().create_model_version_pipeline_run_link(
|
348
|
+
model_version_pipeline_run_link
|
349
|
+
)
|
350
|
+
return mv
|
351
|
+
|
352
|
+
|
294
353
|
@model_version_pipeline_runs_router.get(
|
295
354
|
"",
|
296
355
|
response_model=Page[ModelVersionPipelineRunResponse],
|
@@ -20,7 +20,6 @@ from fastapi import APIRouter, Depends, Security
|
|
20
20
|
|
21
21
|
from zenml.constants import (
|
22
22
|
API,
|
23
|
-
ARTIFACTS,
|
24
23
|
CODE_REPOSITORIES,
|
25
24
|
GET_OR_CREATE,
|
26
25
|
MODEL_VERSIONS,
|
@@ -54,10 +53,6 @@ from zenml.models import (
|
|
54
53
|
ComponentResponse,
|
55
54
|
ModelRequest,
|
56
55
|
ModelResponse,
|
57
|
-
ModelVersionArtifactRequest,
|
58
|
-
ModelVersionArtifactResponse,
|
59
|
-
ModelVersionPipelineRunRequest,
|
60
|
-
ModelVersionPipelineRunResponse,
|
61
56
|
ModelVersionRequest,
|
62
57
|
ModelVersionResponse,
|
63
58
|
Page,
|
@@ -1442,131 +1437,6 @@ def create_model_version(
|
|
1442
1437
|
)
|
1443
1438
|
|
1444
1439
|
|
1445
|
-
@router.post(
|
1446
|
-
WORKSPACES
|
1447
|
-
+ "/{workspace_name_or_id}"
|
1448
|
-
+ MODEL_VERSIONS
|
1449
|
-
+ "/{model_version_id}"
|
1450
|
-
+ ARTIFACTS,
|
1451
|
-
response_model=ModelVersionArtifactResponse,
|
1452
|
-
responses={401: error_response, 409: error_response, 422: error_response},
|
1453
|
-
)
|
1454
|
-
@handle_exceptions
|
1455
|
-
def create_model_version_artifact_link(
|
1456
|
-
workspace_name_or_id: Union[str, UUID],
|
1457
|
-
model_version_id: UUID,
|
1458
|
-
model_version_artifact_link: ModelVersionArtifactRequest,
|
1459
|
-
auth_context: AuthContext = Security(authorize),
|
1460
|
-
) -> ModelVersionArtifactResponse:
|
1461
|
-
"""Create a new model version to artifact link.
|
1462
|
-
|
1463
|
-
Args:
|
1464
|
-
workspace_name_or_id: Name or ID of the workspace.
|
1465
|
-
model_version_id: ID of the model version.
|
1466
|
-
model_version_artifact_link: The model version to artifact link to create.
|
1467
|
-
auth_context: Authentication context.
|
1468
|
-
|
1469
|
-
Returns:
|
1470
|
-
The created model version to artifact link.
|
1471
|
-
|
1472
|
-
Raises:
|
1473
|
-
IllegalOperationError: If the workspace or user specified in the
|
1474
|
-
model version does not match the current workspace or authenticated
|
1475
|
-
user.
|
1476
|
-
"""
|
1477
|
-
workspace = zen_store().get_workspace(workspace_name_or_id)
|
1478
|
-
if str(model_version_id) != str(model_version_artifact_link.model_version):
|
1479
|
-
raise IllegalOperationError(
|
1480
|
-
f"The model version id in your path `{model_version_id}` does not "
|
1481
|
-
f"match the model version specified in the request model "
|
1482
|
-
f"`{model_version_artifact_link.model_version}`"
|
1483
|
-
)
|
1484
|
-
|
1485
|
-
if model_version_artifact_link.workspace != workspace.id:
|
1486
|
-
raise IllegalOperationError(
|
1487
|
-
"Creating model version to artifact links outside of the workspace scope "
|
1488
|
-
f"of this endpoint `{workspace_name_or_id}` is "
|
1489
|
-
f"not supported."
|
1490
|
-
)
|
1491
|
-
if model_version_artifact_link.user != auth_context.user.id:
|
1492
|
-
raise IllegalOperationError(
|
1493
|
-
"Creating model to artifact links for a user other than yourself "
|
1494
|
-
"is not supported."
|
1495
|
-
)
|
1496
|
-
|
1497
|
-
model_version = zen_store().get_model_version(model_version_id)
|
1498
|
-
verify_permission_for_model(model_version, action=Action.UPDATE)
|
1499
|
-
|
1500
|
-
mv = zen_store().create_model_version_artifact_link(
|
1501
|
-
model_version_artifact_link
|
1502
|
-
)
|
1503
|
-
return mv
|
1504
|
-
|
1505
|
-
|
1506
|
-
@router.post(
|
1507
|
-
WORKSPACES
|
1508
|
-
+ "/{workspace_name_or_id}"
|
1509
|
-
+ MODEL_VERSIONS
|
1510
|
-
+ "/{model_version_id}"
|
1511
|
-
+ RUNS,
|
1512
|
-
response_model=ModelVersionPipelineRunResponse,
|
1513
|
-
responses={401: error_response, 409: error_response, 422: error_response},
|
1514
|
-
)
|
1515
|
-
@handle_exceptions
|
1516
|
-
def create_model_version_pipeline_run_link(
|
1517
|
-
workspace_name_or_id: Union[str, UUID],
|
1518
|
-
model_version_id: UUID,
|
1519
|
-
model_version_pipeline_run_link: ModelVersionPipelineRunRequest,
|
1520
|
-
auth_context: AuthContext = Security(authorize),
|
1521
|
-
) -> ModelVersionPipelineRunResponse:
|
1522
|
-
"""Create a new model version to pipeline run link.
|
1523
|
-
|
1524
|
-
Args:
|
1525
|
-
workspace_name_or_id: Name or ID of the workspace.
|
1526
|
-
model_version_id: ID of the model version.
|
1527
|
-
model_version_pipeline_run_link: The model version to pipeline run link to create.
|
1528
|
-
auth_context: Authentication context.
|
1529
|
-
|
1530
|
-
Returns:
|
1531
|
-
- If Model Version to Pipeline Run Link already exists - returns the existing link.
|
1532
|
-
- Otherwise, returns the newly created model version to pipeline run link.
|
1533
|
-
|
1534
|
-
Raises:
|
1535
|
-
IllegalOperationError: If the workspace or user specified in the
|
1536
|
-
model version does not match the current workspace or authenticated
|
1537
|
-
user.
|
1538
|
-
"""
|
1539
|
-
workspace = zen_store().get_workspace(workspace_name_or_id)
|
1540
|
-
if str(model_version_id) != str(
|
1541
|
-
model_version_pipeline_run_link.model_version
|
1542
|
-
):
|
1543
|
-
raise IllegalOperationError(
|
1544
|
-
f"The model version id in your path `{model_version_id}` does not "
|
1545
|
-
f"match the model version specified in the request model "
|
1546
|
-
f"`{model_version_pipeline_run_link.model_version}`"
|
1547
|
-
)
|
1548
|
-
|
1549
|
-
if model_version_pipeline_run_link.workspace != workspace.id:
|
1550
|
-
raise IllegalOperationError(
|
1551
|
-
"Creating model versions outside of the workspace scope "
|
1552
|
-
f"of this endpoint `{workspace_name_or_id}` is "
|
1553
|
-
f"not supported."
|
1554
|
-
)
|
1555
|
-
if model_version_pipeline_run_link.user != auth_context.user.id:
|
1556
|
-
raise IllegalOperationError(
|
1557
|
-
"Creating models for a user other than yourself "
|
1558
|
-
"is not supported."
|
1559
|
-
)
|
1560
|
-
|
1561
|
-
model_version = zen_store().get_model_version(model_version_id)
|
1562
|
-
verify_permission_for_model(model_version, action=Action.UPDATE)
|
1563
|
-
|
1564
|
-
mv = zen_store().create_model_version_pipeline_run_link(
|
1565
|
-
model_version_pipeline_run_link
|
1566
|
-
)
|
1567
|
-
return mv
|
1568
|
-
|
1569
|
-
|
1570
1440
|
@router.post(
|
1571
1441
|
WORKSPACES + "/{workspace_name_or_id}" + SERVICES,
|
1572
1442
|
response_model=ServiceResponse,
|
@@ -42,6 +42,7 @@ from zenml.enums import (
|
|
42
42
|
SecretsStoreType,
|
43
43
|
StoreType,
|
44
44
|
)
|
45
|
+
from zenml.exceptions import IllegalOperationError
|
45
46
|
from zenml.logger import get_logger
|
46
47
|
from zenml.models import (
|
47
48
|
ServerDatabaseType,
|
@@ -335,7 +336,7 @@ class BaseZenStore(
|
|
335
336
|
# Ensure that the active stack is still valid
|
336
337
|
try:
|
337
338
|
active_stack = self.get_stack(stack_id=active_stack_id)
|
338
|
-
except KeyError:
|
339
|
+
except (KeyError, IllegalOperationError):
|
339
340
|
logger.warning(
|
340
341
|
"The current %s active stack is no longer available. "
|
341
342
|
"Resetting the active stack to default.",
|