zenml-nightly 0.71.0.dev20241213__py3-none-any.whl → 0.71.0.dev20241214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. zenml/VERSION +1 -1
  2. zenml/client.py +44 -2
  3. zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +0 -1
  4. zenml/model/model.py +12 -16
  5. zenml/models/v2/base/filter.py +26 -30
  6. zenml/models/v2/base/scoped.py +258 -5
  7. zenml/models/v2/core/artifact_version.py +15 -26
  8. zenml/models/v2/core/code_repository.py +1 -12
  9. zenml/models/v2/core/component.py +5 -46
  10. zenml/models/v2/core/flavor.py +1 -11
  11. zenml/models/v2/core/model.py +1 -57
  12. zenml/models/v2/core/model_version.py +5 -33
  13. zenml/models/v2/core/model_version_artifact.py +11 -3
  14. zenml/models/v2/core/model_version_pipeline_run.py +14 -3
  15. zenml/models/v2/core/pipeline.py +47 -55
  16. zenml/models/v2/core/pipeline_build.py +19 -12
  17. zenml/models/v2/core/pipeline_deployment.py +0 -10
  18. zenml/models/v2/core/pipeline_run.py +91 -29
  19. zenml/models/v2/core/run_template.py +21 -29
  20. zenml/models/v2/core/schedule.py +0 -10
  21. zenml/models/v2/core/secret.py +0 -14
  22. zenml/models/v2/core/service.py +9 -16
  23. zenml/models/v2/core/service_connector.py +0 -11
  24. zenml/models/v2/core/stack.py +21 -30
  25. zenml/models/v2/core/step_run.py +18 -14
  26. zenml/models/v2/core/trigger.py +19 -3
  27. zenml/orchestrators/step_launcher.py +9 -13
  28. zenml/orchestrators/step_run_utils.py +8 -204
  29. zenml/zen_server/rbac/rbac_sql_zen_store.py +173 -0
  30. zenml/zen_server/utils.py +4 -3
  31. zenml/zen_stores/base_zen_store.py +10 -2
  32. zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py +68 -0
  33. zenml/zen_stores/schemas/model_schemas.py +42 -6
  34. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
  35. zenml/zen_stores/schemas/pipeline_schemas.py +5 -0
  36. zenml/zen_stores/sql_zen_store.py +322 -86
  37. {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/METADATA +1 -1
  38. {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/RECORD +41 -39
  39. {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/LICENSE +0 -0
  40. {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/WHEEL +0 -0
  41. {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,173 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """RBAC SQL Zen Store implementation."""
15
+
16
+ from typing import (
17
+ Optional,
18
+ Tuple,
19
+ )
20
+ from uuid import UUID
21
+
22
+ from zenml.logger import get_logger
23
+ from zenml.models import (
24
+ ModelRequest,
25
+ ModelResponse,
26
+ ModelVersionRequest,
27
+ ModelVersionResponse,
28
+ )
29
+ from zenml.zen_server.feature_gate.endpoint_utils import (
30
+ check_entitlement,
31
+ report_usage,
32
+ )
33
+ from zenml.zen_server.rbac.models import Action, ResourceType
34
+ from zenml.zen_server.rbac.utils import (
35
+ verify_permission,
36
+ verify_permission_for_model,
37
+ )
38
+ from zenml.zen_stores.sql_zen_store import SqlZenStore
39
+
40
+ logger = get_logger(__name__)
41
+
42
+
43
+ class RBACSqlZenStore(SqlZenStore):
44
+ """Wrapper around the SQLZenStore that implements RBAC functionality."""
45
+
46
+ def _get_or_create_model(
47
+ self, model_request: ModelRequest
48
+ ) -> Tuple[bool, ModelResponse]:
49
+ """Get or create a model.
50
+
51
+ Args:
52
+ model_request: The model request.
53
+
54
+ # noqa: DAR401
55
+ Raises:
56
+ Exception: If the user is not allowed to create a model.
57
+
58
+ Returns:
59
+ A boolean whether the model was created or not, and the model.
60
+ """
61
+ allow_model_creation = True
62
+ error = None
63
+
64
+ try:
65
+ verify_permission(
66
+ resource_type=ResourceType.MODEL, action=Action.CREATE
67
+ )
68
+ check_entitlement(resource_type=ResourceType.MODEL)
69
+ except Exception as e:
70
+ allow_model_creation = False
71
+ error = e
72
+
73
+ if allow_model_creation:
74
+ created, model_response = super()._get_or_create_model(
75
+ model_request
76
+ )
77
+ else:
78
+ try:
79
+ model_response = self.get_model(model_request.name)
80
+ created = False
81
+ except KeyError:
82
+ # The model does not exist. We now raise the error that
83
+ # explains why the model could not be created, instead of just
84
+ # the KeyError that it doesn't exist
85
+ assert error
86
+ raise error from None
87
+
88
+ if created:
89
+ report_usage(
90
+ resource_type=ResourceType.MODEL, resource_id=model_response.id
91
+ )
92
+ else:
93
+ verify_permission_for_model(model_response, action=Action.READ)
94
+
95
+ return created, model_response
96
+
97
+ def _get_model_version(
98
+ self,
99
+ model_id: UUID,
100
+ version_name: Optional[str] = None,
101
+ producer_run_id: Optional[UUID] = None,
102
+ ) -> ModelVersionResponse:
103
+ """Get a model version.
104
+
105
+ Args:
106
+ model_id: The ID of the model.
107
+ version_name: The name of the model version.
108
+ producer_run_id: The ID of the producer pipeline run. If this is
109
+ set, only numeric versions created as part of the pipeline run
110
+ will be returned.
111
+
112
+ Returns:
113
+ The model version.
114
+ """
115
+ model_version = super()._get_model_version(
116
+ model_id=model_id,
117
+ version_name=version_name,
118
+ producer_run_id=producer_run_id,
119
+ )
120
+ verify_permission_for_model(model_version, action=Action.READ)
121
+ return model_version
122
+
123
+ def _get_or_create_model_version(
124
+ self,
125
+ model_version_request: ModelVersionRequest,
126
+ producer_run_id: Optional[UUID] = None,
127
+ ) -> Tuple[bool, ModelVersionResponse]:
128
+ """Get or create a model version.
129
+
130
+ Args:
131
+ model_version_request: The model version request.
132
+ producer_run_id: ID of the producer pipeline run.
133
+
134
+ # noqa: DAR401
135
+ Raises:
136
+ Exception: If the authenticated user is not allowed to
137
+ create a model version.
138
+
139
+ Returns:
140
+ A boolean whether the model version was created or not, and the
141
+ model version.
142
+ """
143
+ allow_creation = True
144
+ error = None
145
+
146
+ try:
147
+ verify_permission(
148
+ resource_type=ResourceType.MODEL_VERSION, action=Action.CREATE
149
+ )
150
+ except Exception as e:
151
+ allow_creation = False
152
+ error = e
153
+
154
+ if allow_creation:
155
+ created, model_version_response = (
156
+ super()._get_or_create_model_version(model_version_request, producer_run_id=producer_run_id)
157
+ )
158
+ else:
159
+ try:
160
+ model_version_response = self._get_model_version(
161
+ model_id=model_version_request.model,
162
+ version_name=model_version_request.name,
163
+ producer_run_id=producer_run_id,
164
+ )
165
+ created = False
166
+ except KeyError:
167
+ # The model version does not exist. We now raise the error that
168
+ # explains why the version could not be created, instead of just
169
+ # the KeyError that it doesn't exist
170
+ assert error
171
+ raise error from None
172
+
173
+ return created, model_version_response
zenml/zen_server/utils.py CHANGED
@@ -421,6 +421,8 @@ def make_dependable(cls: Type[BaseModel]) -> Callable[..., Any]:
421
421
  """
422
422
  from fastapi import Query
423
423
 
424
+ from zenml.zen_server.exceptions import error_detail
425
+
424
426
  def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel:
425
427
  from fastapi import HTTPException
426
428
 
@@ -428,9 +430,8 @@ def make_dependable(cls: Type[BaseModel]) -> Callable[..., Any]:
428
430
  inspect.signature(init_cls_and_handle_errors).bind(*args, **kwargs)
429
431
  return cls(*args, **kwargs)
430
432
  except ValidationError as e:
431
- for error in e.errors():
432
- error["loc"] = tuple(["query"] + list(error["loc"]))
433
- raise HTTPException(422, detail=e.errors())
433
+ detail = error_detail(e, exception_type=ValueError)
434
+ raise HTTPException(422, detail=detail)
434
435
 
435
436
  params = {v.name: v for v in inspect.signature(cls).parameters.values()}
436
437
  query_params = getattr(cls, "API_MULTI_INPUT_PARAMS", [])
@@ -36,6 +36,7 @@ from zenml.constants import (
36
36
  DEFAULT_STACK_AND_COMPONENT_NAME,
37
37
  DEFAULT_WORKSPACE_NAME,
38
38
  ENV_ZENML_DEFAULT_WORKSPACE_NAME,
39
+ ENV_ZENML_SERVER,
39
40
  IS_DEBUG_ENV,
40
41
  )
41
42
  from zenml.enums import (
@@ -155,9 +156,16 @@ class BaseZenStore(
155
156
  TypeError: If the store type is unsupported.
156
157
  """
157
158
  if store_type == StoreType.SQL:
158
- from zenml.zen_stores.sql_zen_store import SqlZenStore
159
+ if os.environ.get(ENV_ZENML_SERVER):
160
+ from zenml.zen_server.rbac.rbac_sql_zen_store import (
161
+ RBACSqlZenStore,
162
+ )
163
+
164
+ return RBACSqlZenStore
165
+ else:
166
+ from zenml.zen_stores.sql_zen_store import SqlZenStore
159
167
 
160
- return SqlZenStore
168
+ return SqlZenStore
161
169
  elif store_type == StoreType.REST:
162
170
  from zenml.zen_stores.rest_zen_store import RestZenStore
163
171
 
@@ -0,0 +1,68 @@
1
+ """Add model version producer run unique constraint [a1237ba94fd8].
2
+
3
+ Revision ID: a1237ba94fd8
4
+ Revises: 26351d482b9e
5
+ Create Date: 2024-12-13 10:28:55.432414
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ import sqlmodel
11
+ from alembic import op
12
+
13
+ # revision identifiers, used by Alembic.
14
+ revision = "a1237ba94fd8"
15
+ down_revision = "26351d482b9e"
16
+ branch_labels = None
17
+ depends_on = None
18
+
19
+
20
+ def upgrade() -> None:
21
+ """Upgrade database schema and/or data, creating a new revision."""
22
+ # ### commands auto generated by Alembic - please adjust! ###
23
+ with op.batch_alter_table("model_version", schema=None) as batch_op:
24
+ batch_op.add_column(
25
+ sa.Column(
26
+ "producer_run_id_if_numeric",
27
+ sqlmodel.sql.sqltypes.GUID(),
28
+ nullable=True,
29
+ )
30
+ )
31
+
32
+ # Set the producer_run_id_if_numeric column to the model version ID for
33
+ # existing rows
34
+ connection = op.get_bind()
35
+ metadata = sa.MetaData()
36
+ metadata.reflect(only=("model_version",), bind=connection)
37
+ model_version_table = sa.Table("model_version", metadata)
38
+
39
+ connection.execute(
40
+ model_version_table.update().values(
41
+ producer_run_id_if_numeric=model_version_table.c.id
42
+ )
43
+ )
44
+
45
+ with op.batch_alter_table("model_version", schema=None) as batch_op:
46
+ batch_op.alter_column(
47
+ "producer_run_id_if_numeric",
48
+ existing_type=sqlmodel.sql.sqltypes.GUID(),
49
+ nullable=False,
50
+ )
51
+ batch_op.create_unique_constraint(
52
+ "unique_numeric_version_for_pipeline_run",
53
+ ["model_id", "producer_run_id_if_numeric"],
54
+ )
55
+
56
+ # ### end Alembic commands ###
57
+
58
+
59
+ def downgrade() -> None:
60
+ """Downgrade database schema and/or data back to the previous revision."""
61
+ # ### commands auto generated by Alembic - please adjust! ###
62
+ with op.batch_alter_table("model_version", schema=None) as batch_op:
63
+ batch_op.drop_constraint(
64
+ "unique_numeric_version_for_pipeline_run", type_="unique"
65
+ )
66
+ batch_op.drop_column("producer_run_id_if_numeric")
67
+
68
+ # ### end Alembic commands ###
@@ -15,10 +15,16 @@
15
15
 
16
16
  from datetime import datetime
17
17
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
18
- from uuid import UUID
18
+ from uuid import UUID, uuid4
19
19
 
20
20
  from pydantic import ConfigDict
21
- from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column, UniqueConstraint
21
+ from sqlalchemy import (
22
+ BOOLEAN,
23
+ INTEGER,
24
+ TEXT,
25
+ Column,
26
+ UniqueConstraint,
27
+ )
22
28
  from sqlmodel import Field, Relationship
23
29
 
24
30
  from zenml.enums import (
@@ -228,11 +234,13 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
228
234
 
229
235
  __tablename__ = MODEL_VERSION_TABLENAME
230
236
  __table_args__ = (
231
- # We need two unique constraints here:
237
+ # We need three unique constraints here:
232
238
  # - The first to ensure that each model version for a
233
239
  # model has a unique version number
234
240
  # - The second one to ensure that explicit names given by
235
241
  # users are unique
242
+ # - The third one to ensure that a pipeline run only produces a single
243
+ # auto-incremented version per model
236
244
  UniqueConstraint(
237
245
  "number",
238
246
  "model_id",
@@ -243,6 +251,11 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
243
251
  "model_id",
244
252
  name="unique_version_for_model_id",
245
253
  ),
254
+ UniqueConstraint(
255
+ "model_id",
256
+ "producer_run_id_if_numeric",
257
+ name="unique_numeric_version_for_pipeline_run",
258
+ ),
246
259
  )
247
260
 
248
261
  workspace_id: UUID = build_foreign_key_field(
@@ -312,12 +325,23 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
312
325
  ),
313
326
  )
314
327
  pipeline_runs: List["PipelineRunSchema"] = Relationship(
315
- back_populates="model_version"
328
+ back_populates="model_version",
316
329
  )
317
330
  step_runs: List["StepRunSchema"] = Relationship(
318
331
  back_populates="model_version"
319
332
  )
320
333
 
334
+ # We want to make sure each pipeline run only creates a single numeric
335
+ # version for each model. To solve this, we need to add a unique constraint.
336
+ # If a value of a unique constraint is NULL it is ignored and the
337
+ # remaining values in the unique constraint have to be unique. In
338
+ # our case however, we only want the unique constraint applied in
339
+ # case there is a producer run and only for numeric versions. To solve this,
340
+ # we fall back to the model version ID (which is the primary key and
341
+ # therefore unique) in case there is no producer run or the version is not
342
+ # numeric.
343
+ producer_run_id_if_numeric: UUID
344
+
321
345
  # TODO: In Pydantic v2, the `model_` is a protected namespaces for all
322
346
  # fields defined under base models. If not handled, this raises a warning.
323
347
  # It is possible to suppress this warning message with the following
@@ -328,24 +352,36 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
328
352
 
329
353
  @classmethod
330
354
  def from_request(
331
- cls, model_version_request: ModelVersionRequest
355
+ cls,
356
+ model_version_request: ModelVersionRequest,
357
+ model_version_number: int,
358
+ producer_run_id: Optional[UUID] = None,
332
359
  ) -> "ModelVersionSchema":
333
360
  """Convert an `ModelVersionRequest` to an `ModelVersionSchema`.
334
361
 
335
362
  Args:
336
363
  model_version_request: The request model version to convert.
364
+ model_version_number: The model version number.
365
+ producer_run_id: The ID of the producer run.
337
366
 
338
367
  Returns:
339
368
  The converted schema.
340
369
  """
370
+ id_ = uuid4()
371
+ is_numeric = str(model_version_number) == model_version_request.name
372
+
341
373
  return cls(
374
+ id=id_,
342
375
  workspace_id=model_version_request.workspace,
343
376
  user_id=model_version_request.user,
344
377
  model_id=model_version_request.model,
345
378
  name=model_version_request.name,
346
- number=model_version_request.number,
379
+ number=model_version_number,
347
380
  description=model_version_request.description,
348
381
  stage=model_version_request.stage,
382
+ producer_run_id_if_numeric=producer_run_id
383
+ if (producer_run_id and is_numeric)
384
+ else id_,
349
385
  )
350
386
 
351
387
  def to_model(
@@ -228,13 +228,6 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
228
228
  Returns:
229
229
  The created `PipelineDeploymentResponse`.
230
230
  """
231
- pipeline_configuration = PipelineConfiguration.model_validate_json(
232
- self.pipeline_configuration
233
- )
234
- step_configurations = json.loads(self.step_configurations)
235
- for s, c in step_configurations.items():
236
- step_configurations[s] = Step.model_validate(c)
237
-
238
231
  body = PipelineDeploymentResponseBody(
239
232
  user=self.user.to_model() if self.user else None,
240
233
  created=self.created,
@@ -242,6 +235,13 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
242
235
  )
243
236
  metadata = None
244
237
  if include_metadata:
238
+ pipeline_configuration = PipelineConfiguration.model_validate_json(
239
+ self.pipeline_configuration
240
+ )
241
+ step_configurations = json.loads(self.step_configurations)
242
+ for s, c in step_configurations.items():
243
+ step_configurations[s] = Step.model_validate(c)
244
+
245
245
  metadata = PipelineDeploymentResponseMetadata(
246
246
  workspace=self.workspace.to_model(),
247
247
  run_name_template=self.run_name_template,
@@ -156,7 +156,12 @@ class PipelineSchema(NamedSchema, table=True):
156
156
 
157
157
  resources = None
158
158
  if include_resources:
159
+ latest_run_user = self.runs[-1].user if self.runs else None
160
+
159
161
  resources = PipelineResponseResources(
162
+ latest_run_user=latest_run_user.to_model()
163
+ if latest_run_user
164
+ else None,
160
165
  tags=[t.tag.to_model() for t in self.tags],
161
166
  )
162
167