zenml-nightly 0.71.0.dev20241213__py3-none-any.whl → 0.71.0.dev20241214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/client.py +44 -2
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +0 -1
- zenml/model/model.py +12 -16
- zenml/models/v2/base/filter.py +26 -30
- zenml/models/v2/base/scoped.py +258 -5
- zenml/models/v2/core/artifact_version.py +15 -26
- zenml/models/v2/core/code_repository.py +1 -12
- zenml/models/v2/core/component.py +5 -46
- zenml/models/v2/core/flavor.py +1 -11
- zenml/models/v2/core/model.py +1 -57
- zenml/models/v2/core/model_version.py +5 -33
- zenml/models/v2/core/model_version_artifact.py +11 -3
- zenml/models/v2/core/model_version_pipeline_run.py +14 -3
- zenml/models/v2/core/pipeline.py +47 -55
- zenml/models/v2/core/pipeline_build.py +19 -12
- zenml/models/v2/core/pipeline_deployment.py +0 -10
- zenml/models/v2/core/pipeline_run.py +91 -29
- zenml/models/v2/core/run_template.py +21 -29
- zenml/models/v2/core/schedule.py +0 -10
- zenml/models/v2/core/secret.py +0 -14
- zenml/models/v2/core/service.py +9 -16
- zenml/models/v2/core/service_connector.py +0 -11
- zenml/models/v2/core/stack.py +21 -30
- zenml/models/v2/core/step_run.py +18 -14
- zenml/models/v2/core/trigger.py +19 -3
- zenml/orchestrators/step_launcher.py +9 -13
- zenml/orchestrators/step_run_utils.py +8 -204
- zenml/zen_server/rbac/rbac_sql_zen_store.py +173 -0
- zenml/zen_server/utils.py +4 -3
- zenml/zen_stores/base_zen_store.py +10 -2
- zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py +68 -0
- zenml/zen_stores/schemas/model_schemas.py +42 -6
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
- zenml/zen_stores/schemas/pipeline_schemas.py +5 -0
- zenml/zen_stores/sql_zen_store.py +322 -86
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/METADATA +1 -1
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/RECORD +41 -39
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,173 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""RBAC SQL Zen Store implementation."""
|
15
|
+
|
16
|
+
from typing import (
|
17
|
+
Optional,
|
18
|
+
Tuple,
|
19
|
+
)
|
20
|
+
from uuid import UUID
|
21
|
+
|
22
|
+
from zenml.logger import get_logger
|
23
|
+
from zenml.models import (
|
24
|
+
ModelRequest,
|
25
|
+
ModelResponse,
|
26
|
+
ModelVersionRequest,
|
27
|
+
ModelVersionResponse,
|
28
|
+
)
|
29
|
+
from zenml.zen_server.feature_gate.endpoint_utils import (
|
30
|
+
check_entitlement,
|
31
|
+
report_usage,
|
32
|
+
)
|
33
|
+
from zenml.zen_server.rbac.models import Action, ResourceType
|
34
|
+
from zenml.zen_server.rbac.utils import (
|
35
|
+
verify_permission,
|
36
|
+
verify_permission_for_model,
|
37
|
+
)
|
38
|
+
from zenml.zen_stores.sql_zen_store import SqlZenStore
|
39
|
+
|
40
|
+
logger = get_logger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
class RBACSqlZenStore(SqlZenStore):
|
44
|
+
"""Wrapper around the SQLZenStore that implements RBAC functionality."""
|
45
|
+
|
46
|
+
def _get_or_create_model(
|
47
|
+
self, model_request: ModelRequest
|
48
|
+
) -> Tuple[bool, ModelResponse]:
|
49
|
+
"""Get or create a model.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
model_request: The model request.
|
53
|
+
|
54
|
+
# noqa: DAR401
|
55
|
+
Raises:
|
56
|
+
Exception: If the user is not allowed to create a model.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
A boolean whether the model was created or not, and the model.
|
60
|
+
"""
|
61
|
+
allow_model_creation = True
|
62
|
+
error = None
|
63
|
+
|
64
|
+
try:
|
65
|
+
verify_permission(
|
66
|
+
resource_type=ResourceType.MODEL, action=Action.CREATE
|
67
|
+
)
|
68
|
+
check_entitlement(resource_type=ResourceType.MODEL)
|
69
|
+
except Exception as e:
|
70
|
+
allow_model_creation = False
|
71
|
+
error = e
|
72
|
+
|
73
|
+
if allow_model_creation:
|
74
|
+
created, model_response = super()._get_or_create_model(
|
75
|
+
model_request
|
76
|
+
)
|
77
|
+
else:
|
78
|
+
try:
|
79
|
+
model_response = self.get_model(model_request.name)
|
80
|
+
created = False
|
81
|
+
except KeyError:
|
82
|
+
# The model does not exist. We now raise the error that
|
83
|
+
# explains why the model could not be created, instead of just
|
84
|
+
# the KeyError that it doesn't exist
|
85
|
+
assert error
|
86
|
+
raise error from None
|
87
|
+
|
88
|
+
if created:
|
89
|
+
report_usage(
|
90
|
+
resource_type=ResourceType.MODEL, resource_id=model_response.id
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
verify_permission_for_model(model_response, action=Action.READ)
|
94
|
+
|
95
|
+
return created, model_response
|
96
|
+
|
97
|
+
def _get_model_version(
|
98
|
+
self,
|
99
|
+
model_id: UUID,
|
100
|
+
version_name: Optional[str] = None,
|
101
|
+
producer_run_id: Optional[UUID] = None,
|
102
|
+
) -> ModelVersionResponse:
|
103
|
+
"""Get a model version.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
model_id: The ID of the model.
|
107
|
+
version_name: The name of the model version.
|
108
|
+
producer_run_id: The ID of the producer pipeline run. If this is
|
109
|
+
set, only numeric versions created as part of the pipeline run
|
110
|
+
will be returned.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
The model version.
|
114
|
+
"""
|
115
|
+
model_version = super()._get_model_version(
|
116
|
+
model_id=model_id,
|
117
|
+
version_name=version_name,
|
118
|
+
producer_run_id=producer_run_id,
|
119
|
+
)
|
120
|
+
verify_permission_for_model(model_version, action=Action.READ)
|
121
|
+
return model_version
|
122
|
+
|
123
|
+
def _get_or_create_model_version(
|
124
|
+
self,
|
125
|
+
model_version_request: ModelVersionRequest,
|
126
|
+
producer_run_id: Optional[UUID] = None,
|
127
|
+
) -> Tuple[bool, ModelVersionResponse]:
|
128
|
+
"""Get or create a model version.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
model_version_request: The model version request.
|
132
|
+
producer_run_id: ID of the producer pipeline run.
|
133
|
+
|
134
|
+
# noqa: DAR401
|
135
|
+
Raises:
|
136
|
+
Exception: If the authenticated user is not allowed to
|
137
|
+
create a model version.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
A boolean whether the model version was created or not, and the
|
141
|
+
model version.
|
142
|
+
"""
|
143
|
+
allow_creation = True
|
144
|
+
error = None
|
145
|
+
|
146
|
+
try:
|
147
|
+
verify_permission(
|
148
|
+
resource_type=ResourceType.MODEL_VERSION, action=Action.CREATE
|
149
|
+
)
|
150
|
+
except Exception as e:
|
151
|
+
allow_creation = False
|
152
|
+
error = e
|
153
|
+
|
154
|
+
if allow_creation:
|
155
|
+
created, model_version_response = (
|
156
|
+
super()._get_or_create_model_version(model_version_request, producer_run_id=producer_run_id)
|
157
|
+
)
|
158
|
+
else:
|
159
|
+
try:
|
160
|
+
model_version_response = self._get_model_version(
|
161
|
+
model_id=model_version_request.model,
|
162
|
+
version_name=model_version_request.name,
|
163
|
+
producer_run_id=producer_run_id,
|
164
|
+
)
|
165
|
+
created = False
|
166
|
+
except KeyError:
|
167
|
+
# The model version does not exist. We now raise the error that
|
168
|
+
# explains why the version could not be created, instead of just
|
169
|
+
# the KeyError that it doesn't exist
|
170
|
+
assert error
|
171
|
+
raise error from None
|
172
|
+
|
173
|
+
return created, model_version_response
|
zenml/zen_server/utils.py
CHANGED
@@ -421,6 +421,8 @@ def make_dependable(cls: Type[BaseModel]) -> Callable[..., Any]:
|
|
421
421
|
"""
|
422
422
|
from fastapi import Query
|
423
423
|
|
424
|
+
from zenml.zen_server.exceptions import error_detail
|
425
|
+
|
424
426
|
def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel:
|
425
427
|
from fastapi import HTTPException
|
426
428
|
|
@@ -428,9 +430,8 @@ def make_dependable(cls: Type[BaseModel]) -> Callable[..., Any]:
|
|
428
430
|
inspect.signature(init_cls_and_handle_errors).bind(*args, **kwargs)
|
429
431
|
return cls(*args, **kwargs)
|
430
432
|
except ValidationError as e:
|
431
|
-
|
432
|
-
|
433
|
-
raise HTTPException(422, detail=e.errors())
|
433
|
+
detail = error_detail(e, exception_type=ValueError)
|
434
|
+
raise HTTPException(422, detail=detail)
|
434
435
|
|
435
436
|
params = {v.name: v for v in inspect.signature(cls).parameters.values()}
|
436
437
|
query_params = getattr(cls, "API_MULTI_INPUT_PARAMS", [])
|
@@ -36,6 +36,7 @@ from zenml.constants import (
|
|
36
36
|
DEFAULT_STACK_AND_COMPONENT_NAME,
|
37
37
|
DEFAULT_WORKSPACE_NAME,
|
38
38
|
ENV_ZENML_DEFAULT_WORKSPACE_NAME,
|
39
|
+
ENV_ZENML_SERVER,
|
39
40
|
IS_DEBUG_ENV,
|
40
41
|
)
|
41
42
|
from zenml.enums import (
|
@@ -155,9 +156,16 @@ class BaseZenStore(
|
|
155
156
|
TypeError: If the store type is unsupported.
|
156
157
|
"""
|
157
158
|
if store_type == StoreType.SQL:
|
158
|
-
|
159
|
+
if os.environ.get(ENV_ZENML_SERVER):
|
160
|
+
from zenml.zen_server.rbac.rbac_sql_zen_store import (
|
161
|
+
RBACSqlZenStore,
|
162
|
+
)
|
163
|
+
|
164
|
+
return RBACSqlZenStore
|
165
|
+
else:
|
166
|
+
from zenml.zen_stores.sql_zen_store import SqlZenStore
|
159
167
|
|
160
|
-
|
168
|
+
return SqlZenStore
|
161
169
|
elif store_type == StoreType.REST:
|
162
170
|
from zenml.zen_stores.rest_zen_store import RestZenStore
|
163
171
|
|
@@ -0,0 +1,68 @@
|
|
1
|
+
"""Add model version producer run unique constraint [a1237ba94fd8].
|
2
|
+
|
3
|
+
Revision ID: a1237ba94fd8
|
4
|
+
Revises: 26351d482b9e
|
5
|
+
Create Date: 2024-12-13 10:28:55.432414
|
6
|
+
|
7
|
+
"""
|
8
|
+
|
9
|
+
import sqlalchemy as sa
|
10
|
+
import sqlmodel
|
11
|
+
from alembic import op
|
12
|
+
|
13
|
+
# revision identifiers, used by Alembic.
|
14
|
+
revision = "a1237ba94fd8"
|
15
|
+
down_revision = "26351d482b9e"
|
16
|
+
branch_labels = None
|
17
|
+
depends_on = None
|
18
|
+
|
19
|
+
|
20
|
+
def upgrade() -> None:
|
21
|
+
"""Upgrade database schema and/or data, creating a new revision."""
|
22
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
23
|
+
with op.batch_alter_table("model_version", schema=None) as batch_op:
|
24
|
+
batch_op.add_column(
|
25
|
+
sa.Column(
|
26
|
+
"producer_run_id_if_numeric",
|
27
|
+
sqlmodel.sql.sqltypes.GUID(),
|
28
|
+
nullable=True,
|
29
|
+
)
|
30
|
+
)
|
31
|
+
|
32
|
+
# Set the producer_run_id_if_numeric column to the model version ID for
|
33
|
+
# existing rows
|
34
|
+
connection = op.get_bind()
|
35
|
+
metadata = sa.MetaData()
|
36
|
+
metadata.reflect(only=("model_version",), bind=connection)
|
37
|
+
model_version_table = sa.Table("model_version", metadata)
|
38
|
+
|
39
|
+
connection.execute(
|
40
|
+
model_version_table.update().values(
|
41
|
+
producer_run_id_if_numeric=model_version_table.c.id
|
42
|
+
)
|
43
|
+
)
|
44
|
+
|
45
|
+
with op.batch_alter_table("model_version", schema=None) as batch_op:
|
46
|
+
batch_op.alter_column(
|
47
|
+
"producer_run_id_if_numeric",
|
48
|
+
existing_type=sqlmodel.sql.sqltypes.GUID(),
|
49
|
+
nullable=False,
|
50
|
+
)
|
51
|
+
batch_op.create_unique_constraint(
|
52
|
+
"unique_numeric_version_for_pipeline_run",
|
53
|
+
["model_id", "producer_run_id_if_numeric"],
|
54
|
+
)
|
55
|
+
|
56
|
+
# ### end Alembic commands ###
|
57
|
+
|
58
|
+
|
59
|
+
def downgrade() -> None:
|
60
|
+
"""Downgrade database schema and/or data back to the previous revision."""
|
61
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
62
|
+
with op.batch_alter_table("model_version", schema=None) as batch_op:
|
63
|
+
batch_op.drop_constraint(
|
64
|
+
"unique_numeric_version_for_pipeline_run", type_="unique"
|
65
|
+
)
|
66
|
+
batch_op.drop_column("producer_run_id_if_numeric")
|
67
|
+
|
68
|
+
# ### end Alembic commands ###
|
@@ -15,10 +15,16 @@
|
|
15
15
|
|
16
16
|
from datetime import datetime
|
17
17
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
18
|
-
from uuid import UUID
|
18
|
+
from uuid import UUID, uuid4
|
19
19
|
|
20
20
|
from pydantic import ConfigDict
|
21
|
-
from sqlalchemy import
|
21
|
+
from sqlalchemy import (
|
22
|
+
BOOLEAN,
|
23
|
+
INTEGER,
|
24
|
+
TEXT,
|
25
|
+
Column,
|
26
|
+
UniqueConstraint,
|
27
|
+
)
|
22
28
|
from sqlmodel import Field, Relationship
|
23
29
|
|
24
30
|
from zenml.enums import (
|
@@ -228,11 +234,13 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
228
234
|
|
229
235
|
__tablename__ = MODEL_VERSION_TABLENAME
|
230
236
|
__table_args__ = (
|
231
|
-
# We need
|
237
|
+
# We need three unique constraints here:
|
232
238
|
# - The first to ensure that each model version for a
|
233
239
|
# model has a unique version number
|
234
240
|
# - The second one to ensure that explicit names given by
|
235
241
|
# users are unique
|
242
|
+
# - The third one to ensure that a pipeline run only produces a single
|
243
|
+
# auto-incremented version per model
|
236
244
|
UniqueConstraint(
|
237
245
|
"number",
|
238
246
|
"model_id",
|
@@ -243,6 +251,11 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
243
251
|
"model_id",
|
244
252
|
name="unique_version_for_model_id",
|
245
253
|
),
|
254
|
+
UniqueConstraint(
|
255
|
+
"model_id",
|
256
|
+
"producer_run_id_if_numeric",
|
257
|
+
name="unique_numeric_version_for_pipeline_run",
|
258
|
+
),
|
246
259
|
)
|
247
260
|
|
248
261
|
workspace_id: UUID = build_foreign_key_field(
|
@@ -312,12 +325,23 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
312
325
|
),
|
313
326
|
)
|
314
327
|
pipeline_runs: List["PipelineRunSchema"] = Relationship(
|
315
|
-
back_populates="model_version"
|
328
|
+
back_populates="model_version",
|
316
329
|
)
|
317
330
|
step_runs: List["StepRunSchema"] = Relationship(
|
318
331
|
back_populates="model_version"
|
319
332
|
)
|
320
333
|
|
334
|
+
# We want to make sure each pipeline run only creates a single numeric
|
335
|
+
# version for each model. To solve this, we need to add a unique constraint.
|
336
|
+
# If a value of a unique constraint is NULL it is ignored and the
|
337
|
+
# remaining values in the unique constraint have to be unique. In
|
338
|
+
# our case however, we only want the unique constraint applied in
|
339
|
+
# case there is a producer run and only for numeric versions. To solve this,
|
340
|
+
# we fall back to the model version ID (which is the primary key and
|
341
|
+
# therefore unique) in case there is no producer run or the version is not
|
342
|
+
# numeric.
|
343
|
+
producer_run_id_if_numeric: UUID
|
344
|
+
|
321
345
|
# TODO: In Pydantic v2, the `model_` is a protected namespaces for all
|
322
346
|
# fields defined under base models. If not handled, this raises a warning.
|
323
347
|
# It is possible to suppress this warning message with the following
|
@@ -328,24 +352,36 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
328
352
|
|
329
353
|
@classmethod
|
330
354
|
def from_request(
|
331
|
-
cls,
|
355
|
+
cls,
|
356
|
+
model_version_request: ModelVersionRequest,
|
357
|
+
model_version_number: int,
|
358
|
+
producer_run_id: Optional[UUID] = None,
|
332
359
|
) -> "ModelVersionSchema":
|
333
360
|
"""Convert an `ModelVersionRequest` to an `ModelVersionSchema`.
|
334
361
|
|
335
362
|
Args:
|
336
363
|
model_version_request: The request model version to convert.
|
364
|
+
model_version_number: The model version number.
|
365
|
+
producer_run_id: The ID of the producer run.
|
337
366
|
|
338
367
|
Returns:
|
339
368
|
The converted schema.
|
340
369
|
"""
|
370
|
+
id_ = uuid4()
|
371
|
+
is_numeric = str(model_version_number) == model_version_request.name
|
372
|
+
|
341
373
|
return cls(
|
374
|
+
id=id_,
|
342
375
|
workspace_id=model_version_request.workspace,
|
343
376
|
user_id=model_version_request.user,
|
344
377
|
model_id=model_version_request.model,
|
345
378
|
name=model_version_request.name,
|
346
|
-
number=
|
379
|
+
number=model_version_number,
|
347
380
|
description=model_version_request.description,
|
348
381
|
stage=model_version_request.stage,
|
382
|
+
producer_run_id_if_numeric=producer_run_id
|
383
|
+
if (producer_run_id and is_numeric)
|
384
|
+
else id_,
|
349
385
|
)
|
350
386
|
|
351
387
|
def to_model(
|
@@ -228,13 +228,6 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
|
|
228
228
|
Returns:
|
229
229
|
The created `PipelineDeploymentResponse`.
|
230
230
|
"""
|
231
|
-
pipeline_configuration = PipelineConfiguration.model_validate_json(
|
232
|
-
self.pipeline_configuration
|
233
|
-
)
|
234
|
-
step_configurations = json.loads(self.step_configurations)
|
235
|
-
for s, c in step_configurations.items():
|
236
|
-
step_configurations[s] = Step.model_validate(c)
|
237
|
-
|
238
231
|
body = PipelineDeploymentResponseBody(
|
239
232
|
user=self.user.to_model() if self.user else None,
|
240
233
|
created=self.created,
|
@@ -242,6 +235,13 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
|
|
242
235
|
)
|
243
236
|
metadata = None
|
244
237
|
if include_metadata:
|
238
|
+
pipeline_configuration = PipelineConfiguration.model_validate_json(
|
239
|
+
self.pipeline_configuration
|
240
|
+
)
|
241
|
+
step_configurations = json.loads(self.step_configurations)
|
242
|
+
for s, c in step_configurations.items():
|
243
|
+
step_configurations[s] = Step.model_validate(c)
|
244
|
+
|
245
245
|
metadata = PipelineDeploymentResponseMetadata(
|
246
246
|
workspace=self.workspace.to_model(),
|
247
247
|
run_name_template=self.run_name_template,
|
@@ -156,7 +156,12 @@ class PipelineSchema(NamedSchema, table=True):
|
|
156
156
|
|
157
157
|
resources = None
|
158
158
|
if include_resources:
|
159
|
+
latest_run_user = self.runs[-1].user if self.runs else None
|
160
|
+
|
159
161
|
resources = PipelineResponseResources(
|
162
|
+
latest_run_user=latest_run_user.to_model()
|
163
|
+
if latest_run_user
|
164
|
+
else None,
|
160
165
|
tags=[t.tag.to_model() for t in self.tags],
|
161
166
|
)
|
162
167
|
|