zenml-nightly 0.68.0.dev20241027__py3-none-any.whl → 0.68.1.dev20241101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- README.md +17 -11
- RELEASE_NOTES.md +9 -0
- zenml/VERSION +1 -1
- zenml/__init__.py +1 -1
- zenml/analytics/context.py +16 -1
- zenml/analytics/utils.py +18 -7
- zenml/artifacts/utils.py +40 -216
- zenml/cli/__init__.py +63 -90
- zenml/cli/base.py +3 -3
- zenml/cli/login.py +951 -0
- zenml/cli/server.py +462 -353
- zenml/cli/service_accounts.py +4 -4
- zenml/cli/stack.py +77 -2
- zenml/cli/stack_components.py +5 -16
- zenml/cli/user_management.py +0 -12
- zenml/cli/utils.py +24 -77
- zenml/client.py +46 -14
- zenml/config/compiler.py +1 -0
- zenml/config/global_config.py +9 -0
- zenml/config/pipeline_configurations.py +2 -1
- zenml/config/pipeline_run_configuration.py +2 -1
- zenml/constants.py +3 -9
- zenml/enums.py +1 -1
- zenml/exceptions.py +11 -0
- zenml/integrations/github/code_repositories/github_code_repository.py +1 -1
- zenml/login/__init__.py +16 -0
- zenml/login/credentials.py +346 -0
- zenml/login/credentials_store.py +603 -0
- zenml/login/pro/__init__.py +16 -0
- zenml/login/pro/client.py +496 -0
- zenml/login/pro/constants.py +34 -0
- zenml/login/pro/models.py +25 -0
- zenml/login/pro/organization/__init__.py +14 -0
- zenml/login/pro/organization/client.py +79 -0
- zenml/login/pro/organization/models.py +32 -0
- zenml/login/pro/tenant/__init__.py +14 -0
- zenml/login/pro/tenant/client.py +92 -0
- zenml/login/pro/tenant/models.py +174 -0
- zenml/login/pro/utils.py +121 -0
- zenml/{cli → login}/web_login.py +64 -28
- zenml/materializers/base_materializer.py +43 -9
- zenml/materializers/built_in_materializer.py +1 -1
- zenml/metadata/metadata_types.py +49 -0
- zenml/model/model.py +0 -38
- zenml/models/__init__.py +3 -0
- zenml/models/v2/base/base.py +12 -8
- zenml/models/v2/base/filter.py +9 -0
- zenml/models/v2/core/artifact_version.py +49 -10
- zenml/models/v2/core/component.py +54 -19
- zenml/models/v2/core/flavor.py +13 -13
- zenml/models/v2/core/model.py +3 -1
- zenml/models/v2/core/model_version.py +3 -5
- zenml/models/v2/core/model_version_artifact.py +3 -1
- zenml/models/v2/core/model_version_pipeline_run.py +3 -1
- zenml/models/v2/core/pipeline.py +3 -1
- zenml/models/v2/core/pipeline_run.py +23 -1
- zenml/models/v2/core/run_template.py +3 -1
- zenml/models/v2/core/stack.py +7 -3
- zenml/models/v2/core/step_run.py +43 -2
- zenml/models/v2/misc/auth_models.py +11 -2
- zenml/models/v2/misc/server_models.py +2 -0
- zenml/orchestrators/base_orchestrator.py +8 -4
- zenml/orchestrators/step_launcher.py +1 -0
- zenml/orchestrators/step_run_utils.py +10 -2
- zenml/orchestrators/step_runner.py +67 -55
- zenml/orchestrators/utils.py +45 -22
- zenml/pipelines/pipeline_decorator.py +5 -0
- zenml/pipelines/pipeline_definition.py +206 -160
- zenml/pipelines/run_utils.py +11 -10
- zenml/services/local/local_daemon_entrypoint.py +4 -4
- zenml/services/service.py +2 -2
- zenml/stack/stack.py +2 -6
- zenml/stack/stack_component.py +2 -7
- zenml/stack/utils.py +26 -14
- zenml/steps/base_step.py +8 -2
- zenml/steps/step_context.py +0 -3
- zenml/steps/step_invocation.py +14 -5
- zenml/steps/utils.py +1 -0
- zenml/utils/materializer_utils.py +1 -1
- zenml/utils/requirements_utils.py +71 -0
- zenml/utils/singleton.py +15 -3
- zenml/utils/source_utils.py +39 -2
- zenml/utils/visualization_utils.py +1 -1
- zenml/zen_server/auth.py +44 -39
- zenml/zen_server/deploy/__init__.py +7 -7
- zenml/zen_server/deploy/base_provider.py +46 -73
- zenml/zen_server/deploy/{local → daemon}/__init__.py +3 -3
- zenml/zen_server/deploy/{local/local_provider.py → daemon/daemon_provider.py} +44 -63
- zenml/zen_server/deploy/{local/local_zen_server.py → daemon/daemon_zen_server.py} +50 -22
- zenml/zen_server/deploy/deployer.py +90 -171
- zenml/zen_server/deploy/deployment.py +20 -12
- zenml/zen_server/deploy/docker/docker_provider.py +9 -28
- zenml/zen_server/deploy/docker/docker_zen_server.py +19 -3
- zenml/zen_server/deploy/helm/Chart.yaml +1 -1
- zenml/zen_server/deploy/helm/README.md +2 -2
- zenml/zen_server/exceptions.py +11 -0
- zenml/zen_server/jwt.py +9 -9
- zenml/zen_server/routers/auth_endpoints.py +30 -8
- zenml/zen_server/routers/stack_components_endpoints.py +1 -1
- zenml/zen_server/routers/workspaces_endpoints.py +1 -1
- zenml/zen_server/template_execution/runner_entrypoint_configuration.py +7 -4
- zenml/zen_server/template_execution/utils.py +6 -61
- zenml/zen_server/utils.py +64 -36
- zenml/zen_stores/base_zen_store.py +4 -49
- zenml/zen_stores/migrations/versions/0.68.1_release.py +23 -0
- zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py +86 -0
- zenml/zen_stores/rest_zen_store.py +325 -147
- zenml/zen_stores/schemas/api_key_schemas.py +9 -4
- zenml/zen_stores/schemas/artifact_schemas.py +21 -2
- zenml/zen_stores/schemas/artifact_visualization_schemas.py +1 -1
- zenml/zen_stores/schemas/component_schemas.py +49 -6
- zenml/zen_stores/schemas/device_schemas.py +9 -4
- zenml/zen_stores/schemas/flavor_schemas.py +1 -1
- zenml/zen_stores/schemas/model_schemas.py +1 -1
- zenml/zen_stores/schemas/service_schemas.py +1 -1
- zenml/zen_stores/schemas/step_run_schemas.py +1 -1
- zenml/zen_stores/schemas/trigger_schemas.py +1 -1
- zenml/zen_stores/sql_zen_store.py +393 -140
- zenml/zen_stores/template_utils.py +3 -1
- {zenml_nightly-0.68.0.dev20241027.dist-info → zenml_nightly-0.68.1.dev20241101.dist-info}/METADATA +18 -12
- {zenml_nightly-0.68.0.dev20241027.dist-info → zenml_nightly-0.68.1.dev20241101.dist-info}/RECORD +124 -107
- zenml/api.py +0 -60
- {zenml_nightly-0.68.0.dev20241027.dist-info → zenml_nightly-0.68.1.dev20241101.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.68.0.dev20241027.dist-info → zenml_nightly-0.68.1.dev20241101.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.68.0.dev20241027.dist-info → zenml_nightly-0.68.1.dev20241101.dist-info}/entry_points.txt +0 -0
zenml/metadata/metadata_types.py
CHANGED
@@ -13,13 +13,18 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Custom types that can be used as metadata of ZenML artifacts."""
|
15
15
|
|
16
|
+
import json
|
16
17
|
from typing import Any, Dict, List, Set, Tuple, Union
|
17
18
|
|
18
19
|
from pydantic import GetCoreSchemaHandler
|
19
20
|
from pydantic_core import CoreSchema, core_schema
|
20
21
|
|
22
|
+
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
|
23
|
+
from zenml.logger import get_logger
|
21
24
|
from zenml.utils.enum_utils import StrEnum
|
22
25
|
|
26
|
+
logger = get_logger(__name__)
|
27
|
+
|
23
28
|
|
24
29
|
class Uri(str):
|
25
30
|
"""Special string class to indicate a URI."""
|
@@ -203,3 +208,47 @@ def cast_to_metadata_type(
|
|
203
208
|
metadata_type = metadata_enum_to_type_mapping[type_]
|
204
209
|
typed_value = metadata_type(value)
|
205
210
|
return typed_value # type: ignore[no-any-return]
|
211
|
+
|
212
|
+
|
213
|
+
def validate_metadata(
|
214
|
+
metadata: Dict[str, MetadataType],
|
215
|
+
) -> Dict[str, MetadataType]:
|
216
|
+
"""Validate metadata.
|
217
|
+
|
218
|
+
This function excludes and warns about metadata values that are too long
|
219
|
+
or of an unsupported type.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
metadata: The metadata to validate.
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
The validated metadata.
|
226
|
+
"""
|
227
|
+
validated_metadata = {}
|
228
|
+
|
229
|
+
for key, value in metadata.items():
|
230
|
+
if len(key) > STR_FIELD_MAX_LENGTH:
|
231
|
+
logger.warning(
|
232
|
+
f"Metadata key '{key}' is too large to be "
|
233
|
+
"stored in the database. Skipping."
|
234
|
+
)
|
235
|
+
continue
|
236
|
+
|
237
|
+
if len(json.dumps(value)) > TEXT_FIELD_MAX_LENGTH:
|
238
|
+
logger.warning(
|
239
|
+
f"Metadata value for key '{key}' is too large to be "
|
240
|
+
"stored in the database. Skipping."
|
241
|
+
)
|
242
|
+
continue
|
243
|
+
|
244
|
+
try:
|
245
|
+
get_metadata_type(value)
|
246
|
+
except ValueError as e:
|
247
|
+
logger.warning(
|
248
|
+
f"Metadata value for key '{key}' is not of a supported "
|
249
|
+
f"type. Skipping. Full error: {e}"
|
250
|
+
)
|
251
|
+
|
252
|
+
validated_metadata[key] = value
|
253
|
+
|
254
|
+
return validated_metadata
|
zenml/model/model.py
CHANGED
@@ -85,7 +85,6 @@ class Model(BaseModel):
|
|
85
85
|
# technical attributes
|
86
86
|
model_version_id: Optional[UUID] = None
|
87
87
|
suppress_class_validation_warnings: bool = False
|
88
|
-
was_created_in_this_run: bool = False
|
89
88
|
_model_id: UUID = PrivateAttr(None)
|
90
89
|
_number: Optional[int] = PrivateAttr(None)
|
91
90
|
_created_model_version: bool = PrivateAttr(False)
|
@@ -691,42 +690,6 @@ class Model(BaseModel):
|
|
691
690
|
)
|
692
691
|
mv_request = ModelVersionRequest.model_validate(model_version_request)
|
693
692
|
try:
|
694
|
-
if not self.version:
|
695
|
-
try:
|
696
|
-
from zenml import get_step_context
|
697
|
-
|
698
|
-
context = get_step_context()
|
699
|
-
except RuntimeError:
|
700
|
-
pass
|
701
|
-
else:
|
702
|
-
# if inside a step context we loop over all
|
703
|
-
# model version configuration to find, if the
|
704
|
-
# model version for current model was already
|
705
|
-
# created in the current run, not to create
|
706
|
-
# new model versions
|
707
|
-
pipeline_mv = context.pipeline_run.config.model
|
708
|
-
if (
|
709
|
-
pipeline_mv
|
710
|
-
and pipeline_mv.was_created_in_this_run
|
711
|
-
and pipeline_mv.name == self.name
|
712
|
-
and pipeline_mv.version is not None
|
713
|
-
):
|
714
|
-
self.version = pipeline_mv.version
|
715
|
-
self.model_version_id = pipeline_mv.model_version_id
|
716
|
-
else:
|
717
|
-
for step in context.pipeline_run.steps.values():
|
718
|
-
step_mv = step.config.model
|
719
|
-
if (
|
720
|
-
step_mv
|
721
|
-
and step_mv.was_created_in_this_run
|
722
|
-
and step_mv.name == self.name
|
723
|
-
and step_mv.version is not None
|
724
|
-
):
|
725
|
-
self.version = step_mv.version
|
726
|
-
self.model_version_id = (
|
727
|
-
step_mv.model_version_id
|
728
|
-
)
|
729
|
-
break
|
730
693
|
if self.version or self.model_version_id:
|
731
694
|
model_version = self._get_model_version()
|
732
695
|
else:
|
@@ -783,7 +746,6 @@ class Model(BaseModel):
|
|
783
746
|
time.sleep(sleep)
|
784
747
|
retries_made += 1
|
785
748
|
self.version = model_version.name
|
786
|
-
self.was_created_in_this_run = True
|
787
749
|
self._created_model_version = True
|
788
750
|
|
789
751
|
logger.info(
|
zenml/models/__init__.py
CHANGED
@@ -132,6 +132,7 @@ from zenml.models.v2.core.component import (
|
|
132
132
|
ComponentResponse,
|
133
133
|
ComponentResponseBody,
|
134
134
|
ComponentResponseMetadata,
|
135
|
+
ComponentResponseResources
|
135
136
|
)
|
136
137
|
from zenml.models.v2.core.event_source_flavor import (
|
137
138
|
EventSourceFlavorResponse,
|
@@ -410,6 +411,7 @@ CodeRepositoryResponseBody.model_rebuild()
|
|
410
411
|
CodeRepositoryResponseMetadata.model_rebuild()
|
411
412
|
ComponentResponseBody.model_rebuild()
|
412
413
|
ComponentResponseMetadata.model_rebuild()
|
414
|
+
ComponentResponseResources.model_rebuild()
|
413
415
|
EventSourceResponseBody.model_rebuild()
|
414
416
|
EventSourceResponseMetadata.model_rebuild()
|
415
417
|
EventSourceResponseResources.model_rebuild()
|
@@ -557,6 +559,7 @@ __all__ = [
|
|
557
559
|
"ComponentResponse",
|
558
560
|
"ComponentResponseBody",
|
559
561
|
"ComponentResponseMetadata",
|
562
|
+
"ComponentResponseResources",
|
560
563
|
"EventSourceFlavorResponse",
|
561
564
|
"EventSourceFlavorResponseBody",
|
562
565
|
"EventSourceFlavorResponseMetadata",
|
zenml/models/v2/base/base.py
CHANGED
@@ -215,6 +215,14 @@ class BaseResponse(BaseZenModel, Generic[AnyBody, AnyMetadata, AnyResources]):
|
|
215
215
|
f"`{hydrated_value}`"
|
216
216
|
)
|
217
217
|
|
218
|
+
def hydrate(self) -> None:
|
219
|
+
"""Hydrate the response."""
|
220
|
+
hydrated_version = self.get_hydrated_version()
|
221
|
+
self._validate_hydrated_version(hydrated_version)
|
222
|
+
|
223
|
+
self.resources = hydrated_version.resources
|
224
|
+
self.metadata = hydrated_version.metadata
|
225
|
+
|
218
226
|
def get_hydrated_version(
|
219
227
|
self,
|
220
228
|
) -> "BaseResponse[AnyBody, AnyMetadata, AnyResources]":
|
@@ -269,9 +277,7 @@ class BaseResponse(BaseZenModel, Generic[AnyBody, AnyMetadata, AnyResources]):
|
|
269
277
|
if len(metadata_type.model_fields):
|
270
278
|
# If the metadata class defines any fields, fetch the metadata
|
271
279
|
# through the hydrated version.
|
272
|
-
|
273
|
-
self._validate_hydrated_version(hydrated_version)
|
274
|
-
self.metadata = hydrated_version.metadata
|
280
|
+
self.hydrate()
|
275
281
|
else:
|
276
282
|
# Otherwise, use the metadata class to create an empty metadata
|
277
283
|
# object.
|
@@ -298,8 +304,8 @@ class BaseResponse(BaseZenModel, Generic[AnyBody, AnyMetadata, AnyResources]):
|
|
298
304
|
"field should exist."
|
299
305
|
)
|
300
306
|
|
301
|
-
#
|
302
|
-
#
|
307
|
+
# resources is defined as:
|
308
|
+
# resources: Optional[....ResponseResources] = Field(default=None)
|
303
309
|
# We need to find the actual class inside the Optional annotation.
|
304
310
|
from zenml.utils.typing_utils import get_args
|
305
311
|
|
@@ -309,9 +315,7 @@ class BaseResponse(BaseZenModel, Generic[AnyBody, AnyMetadata, AnyResources]):
|
|
309
315
|
if len(resources_type.model_fields):
|
310
316
|
# If the resources class defines any fields, fetch the resources
|
311
317
|
# through the hydrated version.
|
312
|
-
|
313
|
-
self._validate_hydrated_version(hydrated_version)
|
314
|
-
self.resources = hydrated_version.resources
|
318
|
+
self.hydrate()
|
315
319
|
else:
|
316
320
|
# Otherwise, use the resources class to create an empty
|
317
321
|
# resources object.
|
zenml/models/v2/base/filter.py
CHANGED
@@ -604,12 +604,15 @@ class BaseFilter(BaseModel):
|
|
604
604
|
self,
|
605
605
|
value: Union[UUID, str],
|
606
606
|
table: Type["NamedSchema"],
|
607
|
+
additional_columns: Optional[List[str]] = None,
|
607
608
|
) -> "ColumnElement[bool]":
|
608
609
|
"""Generate filter conditions for name or id of a table.
|
609
610
|
|
610
611
|
Args:
|
611
612
|
value: The filter value.
|
612
613
|
table: The table to filter.
|
614
|
+
additional_columns: Additional table columns that should also
|
615
|
+
filtered for the given value as part of the or condition.
|
613
616
|
|
614
617
|
Returns:
|
615
618
|
The query conditions.
|
@@ -637,6 +640,12 @@ class BaseFilter(BaseModel):
|
|
637
640
|
)
|
638
641
|
conditions.append(filter_.generate_query_conditions(table=table))
|
639
642
|
|
643
|
+
for column in additional_columns or []:
|
644
|
+
filter_ = FilterGenerator(table).define_filter(
|
645
|
+
column=column, value=value, operator=operator
|
646
|
+
)
|
647
|
+
conditions.append(filter_.generate_query_conditions(table=table))
|
648
|
+
|
640
649
|
return or_(*conditions)
|
641
650
|
|
642
651
|
def generate_custom_query_conditions_for_column(
|
@@ -24,12 +24,19 @@ from typing import (
|
|
24
24
|
)
|
25
25
|
from uuid import UUID
|
26
26
|
|
27
|
-
from pydantic import
|
27
|
+
from pydantic import (
|
28
|
+
BaseModel,
|
29
|
+
ConfigDict,
|
30
|
+
Field,
|
31
|
+
field_validator,
|
32
|
+
model_validator,
|
33
|
+
)
|
28
34
|
|
29
35
|
from zenml.config.source import Source, SourceWithValidator
|
30
36
|
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
|
31
37
|
from zenml.enums import ArtifactType, GenericFilterOps
|
32
38
|
from zenml.logger import get_logger
|
39
|
+
from zenml.metadata.metadata_types import MetadataType
|
33
40
|
from zenml.models.v2.base.filter import StrFilter
|
34
41
|
from zenml.models.v2.base.scoped import (
|
35
42
|
WorkspaceScopedRequest,
|
@@ -63,11 +70,16 @@ logger = get_logger(__name__)
|
|
63
70
|
class ArtifactVersionRequest(WorkspaceScopedRequest):
|
64
71
|
"""Request model for artifact versions."""
|
65
72
|
|
66
|
-
artifact_id: UUID = Field(
|
73
|
+
artifact_id: Optional[UUID] = Field(
|
74
|
+
default=None,
|
67
75
|
title="ID of the artifact to which this version belongs.",
|
68
76
|
)
|
69
|
-
|
70
|
-
|
77
|
+
artifact_name: Optional[str] = Field(
|
78
|
+
default=None,
|
79
|
+
title="Name of the artifact to which this version belongs.",
|
80
|
+
)
|
81
|
+
version: Optional[Union[int, str]] = Field(
|
82
|
+
default=None, title="Version of the artifact."
|
71
83
|
)
|
72
84
|
has_custom_name: bool = Field(
|
73
85
|
title="Whether the name is custom (True) or auto-generated (False).",
|
@@ -95,6 +107,9 @@ class ArtifactVersionRequest(WorkspaceScopedRequest):
|
|
95
107
|
visualizations: Optional[List["ArtifactVisualizationRequest"]] = Field(
|
96
108
|
default=None, title="Visualizations of the artifact."
|
97
109
|
)
|
110
|
+
metadata: Optional[Dict[str, MetadataType]] = Field(
|
111
|
+
default=None, title="Metadata of the artifact version."
|
112
|
+
)
|
98
113
|
|
99
114
|
@field_validator("version")
|
100
115
|
@classmethod
|
@@ -117,6 +132,28 @@ class ArtifactVersionRequest(WorkspaceScopedRequest):
|
|
117
132
|
)
|
118
133
|
return value
|
119
134
|
|
135
|
+
@model_validator(mode="after")
|
136
|
+
def _validate_request(self) -> "ArtifactVersionRequest":
|
137
|
+
"""Validate the request values.
|
138
|
+
|
139
|
+
Raises:
|
140
|
+
ValueError: If the request is invalid.
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
The validated request.
|
144
|
+
"""
|
145
|
+
if self.artifact_id and self.artifact_name:
|
146
|
+
raise ValueError(
|
147
|
+
"Only one of artifact_name and artifact_id can be set."
|
148
|
+
)
|
149
|
+
|
150
|
+
if not (self.artifact_id or self.artifact_name):
|
151
|
+
raise ValueError(
|
152
|
+
"Either artifact_name or artifact_id must be set."
|
153
|
+
)
|
154
|
+
|
155
|
+
return self
|
156
|
+
|
120
157
|
|
121
158
|
# ------------------ Update Model ------------------
|
122
159
|
|
@@ -156,6 +193,10 @@ class ArtifactVersionResponseBody(WorkspaceScopedResponseBody):
|
|
156
193
|
title="The ID of the pipeline run that generated this artifact version.",
|
157
194
|
default=None,
|
158
195
|
)
|
196
|
+
artifact_store_id: Optional[UUID] = Field(
|
197
|
+
title="ID of the artifact store in which this artifact is stored.",
|
198
|
+
default=None,
|
199
|
+
)
|
159
200
|
|
160
201
|
@field_validator("version")
|
161
202
|
@classmethod
|
@@ -182,10 +223,6 @@ class ArtifactVersionResponseBody(WorkspaceScopedResponseBody):
|
|
182
223
|
class ArtifactVersionResponseMetadata(WorkspaceScopedResponseMetadata):
|
183
224
|
"""Response metadata for artifact versions."""
|
184
225
|
|
185
|
-
artifact_store_id: Optional[UUID] = Field(
|
186
|
-
title="ID of the artifact store in which this artifact is stored.",
|
187
|
-
default=None,
|
188
|
-
)
|
189
226
|
producer_step_run_id: Optional[UUID] = Field(
|
190
227
|
title="ID of the step run that produced this artifact.",
|
191
228
|
default=None,
|
@@ -283,7 +320,7 @@ class ArtifactVersionResponse(
|
|
283
320
|
Returns:
|
284
321
|
the value of the property.
|
285
322
|
"""
|
286
|
-
return self.
|
323
|
+
return self.get_body().artifact_store_id
|
287
324
|
|
288
325
|
@property
|
289
326
|
def producer_step_run_id(self) -> Optional[UUID]:
|
@@ -556,7 +593,9 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
|
|
556
593
|
user_filter = and_(
|
557
594
|
ArtifactVersionSchema.user_id == UserSchema.id,
|
558
595
|
self.generate_name_or_id_query_conditions(
|
559
|
-
value=self.user,
|
596
|
+
value=self.user,
|
597
|
+
table=UserSchema,
|
598
|
+
additional_columns=["full_name"],
|
560
599
|
),
|
561
600
|
)
|
562
601
|
custom_filters.append(user_filter)
|
@@ -44,9 +44,7 @@ if TYPE_CHECKING:
|
|
44
44
|
from sqlalchemy.sql.elements import ColumnElement
|
45
45
|
from sqlmodel import SQLModel
|
46
46
|
|
47
|
-
from zenml.models
|
48
|
-
ServiceConnectorResponse,
|
49
|
-
)
|
47
|
+
from zenml.models import FlavorResponse, ServiceConnectorResponse
|
50
48
|
|
51
49
|
# ------------------ Base Model ------------------
|
52
50
|
|
@@ -139,22 +137,11 @@ class InternalComponentRequest(ComponentRequest):
|
|
139
137
|
class ComponentUpdate(BaseUpdate):
|
140
138
|
"""Update model for stack components."""
|
141
139
|
|
142
|
-
ANALYTICS_FIELDS: ClassVar[List[str]] = ["type", "flavor"]
|
143
|
-
|
144
140
|
name: Optional[str] = Field(
|
145
141
|
title="The name of the stack component.",
|
146
142
|
max_length=STR_FIELD_MAX_LENGTH,
|
147
143
|
default=None,
|
148
144
|
)
|
149
|
-
type: Optional[StackComponentType] = Field(
|
150
|
-
title="The type of the stack component.",
|
151
|
-
default=None,
|
152
|
-
)
|
153
|
-
flavor: Optional[str] = Field(
|
154
|
-
title="The flavor of the stack component.",
|
155
|
-
max_length=STR_FIELD_MAX_LENGTH,
|
156
|
-
default=None,
|
157
|
-
)
|
158
145
|
configuration: Optional[Dict[str, Any]] = Field(
|
159
146
|
title="The stack component configuration.",
|
160
147
|
default=None,
|
@@ -187,7 +174,7 @@ class ComponentResponseBody(WorkspaceScopedResponseBody):
|
|
187
174
|
type: StackComponentType = Field(
|
188
175
|
title="The type of the stack component.",
|
189
176
|
)
|
190
|
-
|
177
|
+
flavor_name: str = Field(
|
191
178
|
title="The flavor of the stack component.",
|
192
179
|
max_length=STR_FIELD_MAX_LENGTH,
|
193
180
|
)
|
@@ -232,6 +219,10 @@ class ComponentResponseMetadata(WorkspaceScopedResponseMetadata):
|
|
232
219
|
class ComponentResponseResources(WorkspaceScopedResponseResources):
|
233
220
|
"""Class for all resource models associated with the component entity."""
|
234
221
|
|
222
|
+
flavor: "FlavorResponse" = Field(
|
223
|
+
title="The flavor of this stack component.",
|
224
|
+
)
|
225
|
+
|
235
226
|
|
236
227
|
class ComponentResponse(
|
237
228
|
WorkspaceScopedResponse[
|
@@ -242,7 +233,7 @@ class ComponentResponse(
|
|
242
233
|
):
|
243
234
|
"""Response model for components."""
|
244
235
|
|
245
|
-
ANALYTICS_FIELDS: ClassVar[List[str]] = ["type"
|
236
|
+
ANALYTICS_FIELDS: ClassVar[List[str]] = ["type"]
|
246
237
|
|
247
238
|
name: str = Field(
|
248
239
|
title="The name of the stack component.",
|
@@ -265,6 +256,8 @@ class ComponentResponse(
|
|
265
256
|
if label.startswith("zenml:")
|
266
257
|
}
|
267
258
|
)
|
259
|
+
metadata["flavor"] = self.flavor_name
|
260
|
+
|
268
261
|
return metadata
|
269
262
|
|
270
263
|
def get_hydrated_version(self) -> "ComponentResponse":
|
@@ -288,13 +281,13 @@ class ComponentResponse(
|
|
288
281
|
return self.get_body().type
|
289
282
|
|
290
283
|
@property
|
291
|
-
def
|
292
|
-
"""The `
|
284
|
+
def flavor_name(self) -> str:
|
285
|
+
"""The `flavor_name` property.
|
293
286
|
|
294
287
|
Returns:
|
295
288
|
the value of the property.
|
296
289
|
"""
|
297
|
-
return self.get_body().
|
290
|
+
return self.get_body().flavor_name
|
298
291
|
|
299
292
|
@property
|
300
293
|
def integration(self) -> Optional[str]:
|
@@ -359,6 +352,15 @@ class ComponentResponse(
|
|
359
352
|
"""
|
360
353
|
return self.get_metadata().connector
|
361
354
|
|
355
|
+
@property
|
356
|
+
def flavor(self) -> "FlavorResponse":
|
357
|
+
"""The `flavor` property.
|
358
|
+
|
359
|
+
Returns:
|
360
|
+
the value of the property.
|
361
|
+
"""
|
362
|
+
return self.get_resources().flavor
|
363
|
+
|
362
364
|
|
363
365
|
# ------------------ Filter Model ------------------
|
364
366
|
|
@@ -376,6 +378,7 @@ class ComponentFilter(WorkspaceScopedFilter):
|
|
376
378
|
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
|
377
379
|
"scope_type",
|
378
380
|
"stack_id",
|
381
|
+
"user",
|
379
382
|
]
|
380
383
|
CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
381
384
|
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
|
@@ -418,6 +421,10 @@ class ComponentFilter(WorkspaceScopedFilter):
|
|
418
421
|
description="Stack of the stack component",
|
419
422
|
union_mode="left_to_right",
|
420
423
|
)
|
424
|
+
user: Optional[Union[UUID, str]] = Field(
|
425
|
+
default=None,
|
426
|
+
description="Name/ID of the user that created the component.",
|
427
|
+
)
|
421
428
|
|
422
429
|
def set_scope_type(self, component_type: str) -> None:
|
423
430
|
"""Set the type of component on which to perform the filtering to scope the response.
|
@@ -464,3 +471,31 @@ class ComponentFilter(WorkspaceScopedFilter):
|
|
464
471
|
base_filter = operator(base_filter, stack_filter)
|
465
472
|
|
466
473
|
return base_filter
|
474
|
+
|
475
|
+
def get_custom_filters(self) -> List["ColumnElement[bool]"]:
|
476
|
+
"""Get custom filters.
|
477
|
+
|
478
|
+
Returns:
|
479
|
+
A list of custom filters.
|
480
|
+
"""
|
481
|
+
from sqlmodel import and_
|
482
|
+
|
483
|
+
from zenml.zen_stores.schemas import (
|
484
|
+
StackComponentSchema,
|
485
|
+
UserSchema,
|
486
|
+
)
|
487
|
+
|
488
|
+
custom_filters = super().get_custom_filters()
|
489
|
+
|
490
|
+
if self.user:
|
491
|
+
user_filter = and_(
|
492
|
+
StackComponentSchema.user_id == UserSchema.id,
|
493
|
+
self.generate_name_or_id_query_conditions(
|
494
|
+
value=self.user,
|
495
|
+
table=UserSchema,
|
496
|
+
additional_columns=["full_name"],
|
497
|
+
),
|
498
|
+
)
|
499
|
+
custom_filters.append(user_filter)
|
500
|
+
|
501
|
+
return custom_filters
|
zenml/models/v2/core/flavor.py
CHANGED
@@ -192,6 +192,10 @@ class FlavorResponseBody(UserScopedResponseBody):
|
|
192
192
|
title="The name of the integration that the Flavor belongs to.",
|
193
193
|
max_length=STR_FIELD_MAX_LENGTH,
|
194
194
|
)
|
195
|
+
source: str = Field(
|
196
|
+
title="The path to the module which contains this Flavor.",
|
197
|
+
max_length=STR_FIELD_MAX_LENGTH,
|
198
|
+
)
|
195
199
|
logo_url: Optional[str] = Field(
|
196
200
|
default=None,
|
197
201
|
title="Optionally, a url pointing to a png,"
|
@@ -225,10 +229,6 @@ class FlavorResponseMetadata(UserScopedResponseMetadata):
|
|
225
229
|
"connector.",
|
226
230
|
max_length=STR_FIELD_MAX_LENGTH,
|
227
231
|
)
|
228
|
-
source: str = Field(
|
229
|
-
title="The path to the module which contains this Flavor.",
|
230
|
-
max_length=STR_FIELD_MAX_LENGTH,
|
231
|
-
)
|
232
232
|
docs_url: Optional[str] = Field(
|
233
233
|
default=None,
|
234
234
|
title="Optionally, a url pointing to docs, within docs.zenml.io.",
|
@@ -319,6 +319,15 @@ class FlavorResponse(
|
|
319
319
|
"""
|
320
320
|
return self.get_body().integration
|
321
321
|
|
322
|
+
@property
|
323
|
+
def source(self) -> str:
|
324
|
+
"""The `source` property.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
the value of the property.
|
328
|
+
"""
|
329
|
+
return self.get_body().source
|
330
|
+
|
322
331
|
@property
|
323
332
|
def logo_url(self) -> Optional[str]:
|
324
333
|
"""The `logo_url` property.
|
@@ -373,15 +382,6 @@ class FlavorResponse(
|
|
373
382
|
"""
|
374
383
|
return self.get_metadata().connector_resource_id_attr
|
375
384
|
|
376
|
-
@property
|
377
|
-
def source(self) -> str:
|
378
|
-
"""The `source` property.
|
379
|
-
|
380
|
-
Returns:
|
381
|
-
the value of the property.
|
382
|
-
"""
|
383
|
-
return self.get_metadata().source
|
384
|
-
|
385
385
|
@property
|
386
386
|
def docs_url(self) -> Optional[str]:
|
387
387
|
"""The `docs_url` property.
|
zenml/models/v2/core/model.py
CHANGED
@@ -368,7 +368,9 @@ class ModelFilter(WorkspaceScopedTaggableFilter):
|
|
368
368
|
user_filter = and_(
|
369
369
|
ModelSchema.user_id == UserSchema.id,
|
370
370
|
self.generate_name_or_id_query_conditions(
|
371
|
-
value=self.user,
|
371
|
+
value=self.user,
|
372
|
+
table=UserSchema,
|
373
|
+
additional_columns=["full_name"],
|
372
374
|
),
|
373
375
|
)
|
374
376
|
custom_filters.append(user_filter)
|
@@ -325,14 +325,11 @@ class ModelVersionResponse(
|
|
325
325
|
# Helper functions
|
326
326
|
def to_model_class(
|
327
327
|
self,
|
328
|
-
was_created_in_this_run: bool = False,
|
329
328
|
suppress_class_validation_warnings: bool = True,
|
330
329
|
) -> "Model":
|
331
330
|
"""Convert response model to Model object.
|
332
331
|
|
333
332
|
Args:
|
334
|
-
was_created_in_this_run: Whether model version was created during
|
335
|
-
the current run.
|
336
333
|
suppress_class_validation_warnings: internally used to suppress
|
337
334
|
repeated warnings.
|
338
335
|
|
@@ -352,7 +349,6 @@ class ModelVersionResponse(
|
|
352
349
|
ethics=self.model.ethics,
|
353
350
|
tags=[t.name for t in self.tags],
|
354
351
|
version=self.name,
|
355
|
-
was_created_in_this_run=was_created_in_this_run,
|
356
352
|
suppress_class_validation_warnings=suppress_class_validation_warnings,
|
357
353
|
model_version_id=self.id,
|
358
354
|
)
|
@@ -664,7 +660,9 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter):
|
|
664
660
|
user_filter = and_(
|
665
661
|
ModelVersionSchema.user_id == UserSchema.id,
|
666
662
|
self.generate_name_or_id_query_conditions(
|
667
|
-
value=self.user,
|
663
|
+
value=self.user,
|
664
|
+
table=UserSchema,
|
665
|
+
additional_columns=["full_name"],
|
668
666
|
),
|
669
667
|
)
|
670
668
|
custom_filters.append(user_filter)
|
@@ -294,7 +294,9 @@ class ModelVersionArtifactFilter(WorkspaceScopedFilter):
|
|
294
294
|
== ArtifactVersionSchema.id,
|
295
295
|
ArtifactVersionSchema.user_id == UserSchema.id,
|
296
296
|
self.generate_name_or_id_query_conditions(
|
297
|
-
value=self.user,
|
297
|
+
value=self.user,
|
298
|
+
table=UserSchema,
|
299
|
+
additional_columns=["full_name"],
|
298
300
|
),
|
299
301
|
)
|
300
302
|
custom_filters.append(user_filter)
|
@@ -218,7 +218,9 @@ class ModelVersionPipelineRunFilter(WorkspaceScopedFilter):
|
|
218
218
|
== PipelineRunSchema.id,
|
219
219
|
PipelineRunSchema.user_id == UserSchema.id,
|
220
220
|
self.generate_name_or_id_query_conditions(
|
221
|
-
value=self.user,
|
221
|
+
value=self.user,
|
222
|
+
table=UserSchema,
|
223
|
+
additional_columns=["full_name"],
|
222
224
|
),
|
223
225
|
)
|
224
226
|
custom_filters.append(user_filter)
|
zenml/models/v2/core/pipeline.py
CHANGED
@@ -364,7 +364,9 @@ class PipelineFilter(WorkspaceScopedTaggableFilter):
|
|
364
364
|
user_filter = and_(
|
365
365
|
PipelineSchema.user_id == UserSchema.id,
|
366
366
|
self.generate_name_or_id_query_conditions(
|
367
|
-
value=self.user,
|
367
|
+
value=self.user,
|
368
|
+
table=UserSchema,
|
369
|
+
additional_columns=["full_name"],
|
368
370
|
),
|
369
371
|
)
|
370
372
|
custom_filters.append(user_filter)
|