zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240928__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/__init__.py +7 -0
- zenml/cli/base.py +2 -2
- zenml/cli/pipeline.py +21 -0
- zenml/cli/utils.py +14 -11
- zenml/client.py +68 -3
- zenml/config/step_configurations.py +0 -5
- zenml/constants.py +3 -0
- zenml/enums.py +2 -0
- zenml/integrations/__init__.py +1 -0
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +370 -115
- zenml/integrations/azure/__init__.py +6 -2
- zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
- zenml/integrations/constants.py +1 -0
- zenml/integrations/deepchecks/__init__.py +1 -1
- zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +55 -14
- zenml/integrations/deepchecks/validation_checks.py +62 -5
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
- zenml/integrations/lightning/__init__.py +1 -1
- zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
- zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
- zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
- zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
- zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py +38 -26
- zenml/integrations/skypilot_kubernetes/__init__.py +52 -0
- zenml/integrations/skypilot_kubernetes/flavors/__init__.py +26 -0
- zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py +125 -0
- zenml/integrations/skypilot_kubernetes/orchestrators/__init__.py +25 -0
- zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py +74 -0
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
- zenml/models/v2/base/filter.py +315 -149
- zenml/models/v2/base/scoped.py +5 -2
- zenml/models/v2/core/artifact_version.py +69 -8
- zenml/models/v2/core/model.py +43 -6
- zenml/models/v2/core/model_version.py +49 -1
- zenml/models/v2/core/model_version_artifact.py +18 -3
- zenml/models/v2/core/model_version_pipeline_run.py +18 -4
- zenml/models/v2/core/pipeline.py +108 -1
- zenml/models/v2/core/pipeline_run.py +172 -21
- zenml/models/v2/core/run_template.py +53 -1
- zenml/models/v2/core/stack.py +33 -5
- zenml/models/v2/core/step_run.py +7 -0
- zenml/new/pipelines/pipeline.py +4 -0
- zenml/new/pipelines/run_utils.py +4 -1
- zenml/orchestrators/base_orchestrator.py +41 -12
- zenml/stack/stack.py +11 -2
- zenml/utils/env_utils.py +54 -1
- zenml/utils/string_utils.py +50 -0
- zenml/zen_server/cloud_utils.py +33 -8
- zenml/zen_server/dashboard/assets/{404-iO8vpun1.js → 404-Y50hSt65.js} +1 -1
- zenml/zen_server/dashboard/assets/{@reactflow-B6kq9fJZ.js → @reactflow-ytavUpwh.js} +1 -1
- zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-xLR9a1iw.js +1 -0
- zenml/zen_server/dashboard/assets/{CodeSnippet-DNWdQmbo.js → CodeSnippet-IxXNxUDa.js} +2 -2
- zenml/zen_server/dashboard/assets/{CollapsibleCard-B2OVjWYE.js → CollapsibleCard-BhutZbBL.js} +1 -1
- zenml/zen_server/dashboard/assets/{Commands-DsoaVElZ.js → Commands-Bf-rd1z8.js} +1 -1
- zenml/zen_server/dashboard/assets/ComponentBadge-gKR1OIwG.js +1 -0
- zenml/zen_server/dashboard/assets/{CopyButton-BqE_-PHO.js → CopyButton-DcFHidFJ.js} +1 -1
- zenml/zen_server/dashboard/assets/{CsvVizualization-Dyasr2jU.js → CsvVizualization-QSbjrfxw.js} +1 -1
- zenml/zen_server/dashboard/assets/{DialogItem-Cz1VLRwa.js → DialogItem-Cd3HqST4.js} +1 -1
- zenml/zen_server/dashboard/assets/{Error-DorJD_va.js → Error-BhwdmqK7.js} +1 -1
- zenml/zen_server/dashboard/assets/{ExecutionStatus-CIfQTutR.js → ExecutionStatus-D6r6aK8J.js} +1 -1
- zenml/zen_server/dashboard/assets/{Helpbox-CmfvtNeq.js → Helpbox-0pBpTwTm.js} +1 -1
- zenml/zen_server/dashboard/assets/Infobox-BTK_EUKT.js +1 -0
- zenml/zen_server/dashboard/assets/{InlineAvatar-Ds2ZFHPc.js → InlineAvatar-CA3DFMcM.js} +1 -1
- zenml/zen_server/dashboard/assets/Partials-QLOZw624.js +1 -0
- zenml/zen_server/dashboard/assets/{ProviderIcon-BOQJgapd.js → ProviderIcon-C16CCIN4.js} +1 -1
- zenml/zen_server/dashboard/assets/{ProviderRadio-BsYBw9YA.js → ProviderRadio-D3FuCHf3.js} +1 -1
- zenml/zen_server/dashboard/assets/{SearchField-W3GXpLlI.js → SearchField-BzmfxS0L.js} +1 -1
- zenml/zen_server/dashboard/assets/SecretTooltip-BaMwHF-Q.js +1 -0
- zenml/zen_server/dashboard/assets/{SetPassword-B-0a8UCj.js → SetPassword-DuIC65H9.js} +1 -1
- zenml/zen_server/dashboard/assets/{Tick-i1DYsVcX.js → Tick-DJTCF0Re.js} +1 -1
- zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-C6Zb7ASL.js → UpdatePasswordSchemas-CUm-DMpw.js} +1 -1
- zenml/zen_server/dashboard/assets/UsageReason-CKw0juLF.js +1 -0
- zenml/zen_server/dashboard/assets/{WizardFooter-BHbO7zOa.js → WizardFooter-Cv9ApYWU.js} +1 -1
- zenml/zen_server/dashboard/assets/{all-pipeline-runs-query-BBEe6I9-.js → all-pipeline-runs-query-BA3R2Sey.js} +1 -1
- zenml/zen_server/dashboard/assets/{cloud-only-BuP4Kt_7.js → cloud-only-BB4BVa6E.js} +1 -1
- zenml/zen_server/dashboard/assets/{create-stack-B2x2d4r1.js → create-stack-F29xAUEx.js} +1 -1
- zenml/zen_server/dashboard/assets/delete-run-CP0pcJ3U.js +1 -0
- zenml/zen_server/dashboard/assets/{form-schemas-Bap0f854.js → form-schemas-BKXwSDK2.js} +1 -1
- zenml/zen_server/dashboard/assets/index-BhJ6ZJxv.css +1 -0
- zenml/zen_server/dashboard/assets/{index-B9wVwe7u.js → index-Ci0nJ8EZ.js} +5 -5
- zenml/zen_server/dashboard/assets/{index-DFi8BroH.js → index-D-mtoBj3.js} +1 -1
- zenml/zen_server/dashboard/assets/{login-mutation-DwxUz8VA.js → login-mutation-ax6iL2Mb.js} +1 -1
- zenml/zen_server/dashboard/assets/{not-found-D5i9DunU.js → not-found-DbjllLY_.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-oS4hqS8M.js → page-3qPX9WYH.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-iwoJnwPv.js → page-6mfzecin.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-DGMa3ZQL.js → page-8kYmrh0B.js} +1 -1
- zenml/zen_server/dashboard/assets/page-B1n7_W7z.js +1 -0
- zenml/zen_server/dashboard/assets/page-BDg1F-Ug.js +6 -0
- zenml/zen_server/dashboard/assets/{page-xQG6GmFJ.js → page-BXarY9K2.js} +1 -1
- zenml/zen_server/dashboard/assets/page-BZZhLo2u.js +1 -0
- zenml/zen_server/dashboard/assets/page-Bbf_oBjn.js +1 -0
- zenml/zen_server/dashboard/assets/page-BjjuBvZG.js +9 -0
- zenml/zen_server/dashboard/assets/{page-J0s8Sq3N.js → page-BukXK1Aa.js} +1 -1
- zenml/zen_server/dashboard/assets/page-CHaQkFK5.js +1 -0
- zenml/zen_server/dashboard/assets/{page-BitfWsiW.js → page-CKHNAq7z.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-DE03uZZR.js → page-CS0SYFK8.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-WCQ659by.js → page-CvKnNK1S.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-CrSdkteO.js → page-DGM1CbYT.js} +2 -2
- zenml/zen_server/dashboard/assets/{page-DQGCHKrQ.js → page-DMSLXKGT.js} +1 -1
- zenml/zen_server/dashboard/assets/page-DOmIZ2ra.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DgM-N9RL.js → page-DRfcRK1w.js} +1 -1
- zenml/zen_server/dashboard/assets/page-DYVmJ9_w.js +3 -0
- zenml/zen_server/dashboard/assets/{page-BiF8hLbO.js → page-DcTjHmYZ.js} +1 -1
- zenml/zen_server/dashboard/assets/page-DuqYMYmH.js +1 -0
- zenml/zen_server/dashboard/assets/page-Dwow2doB.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DQdwZZ9x.js → page-HkVBdZl6.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-bimkItOg.js → page-MAXyfXBq.js} +1 -1
- zenml/zen_server/dashboard/assets/page-miU2rhYG.js +1 -0
- zenml/zen_server/dashboard/assets/page-p0BhSAWx.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DFCK65G9.js → page-uORspyRu.js} +1 -1
- zenml/zen_server/dashboard/assets/persist-BxIR2XZs.js +1 -0
- zenml/zen_server/dashboard/assets/{persist-xsYgVtR1.js → persist-CfJMar_k.js} +1 -1
- zenml/zen_server/dashboard/assets/sharedSchema-vub0rii3.js +14 -0
- zenml/zen_server/dashboard/assets/stack-detail-query-DQcyzG-2.js +1 -0
- zenml/zen_server/dashboard/assets/tick-circle-m-hJG8i9.js +1 -0
- zenml/zen_server/dashboard/assets/{update-server-settings-mutation-DNqmQXDM.js → update-server-settings-mutation-FGVP7X2U.js} +1 -1
- zenml/zen_server/dashboard/assets/{url-DwbuKk1b.js → url-CbAPzsmT.js} +1 -1
- zenml/zen_server/dashboard/index.html +4 -4
- zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
- zenml/zen_server/dashboard_legacy/index.html +1 -1
- zenml/zen_server/dashboard_legacy/{precache-manifest.290b95d5b43efa3368b3dc63d20c4782.js → precache-manifest.6d320abb70db612019dda6c4948e7a90.js} +4 -4
- zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
- zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js → main.fa9299d5.chunk.js} +2 -2
- zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js.map → main.fa9299d5.chunk.js.map} +1 -1
- zenml/zen_server/routers/runs_endpoints.py +89 -3
- zenml/zen_stores/sql_zen_store.py +1 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/METADATA +8 -1
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/RECORD +133 -125
- zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-BXeSvmMY.js +0 -1
- zenml/zen_server/dashboard/assets/EditSecretDialog-Du423_3U.js +0 -1
- zenml/zen_server/dashboard/assets/Infobox-BL9NOS37.js +0 -1
- zenml/zen_server/dashboard/assets/Partials-DX-8iEa1.js +0 -1
- zenml/zen_server/dashboard/assets/UsageReason-CCnzmwS8.js +0 -1
- zenml/zen_server/dashboard/assets/index-6DYjZgDn.css +0 -1
- zenml/zen_server/dashboard/assets/page-BFuJICXM.js +0 -9
- zenml/zen_server/dashboard/assets/page-CDOQLrPC.js +0 -1
- zenml/zen_server/dashboard/assets/page-CEJWu1YO.js +0 -1
- zenml/zen_server/dashboard/assets/page-CIbehp7V.js +0 -1
- zenml/zen_server/dashboard/assets/page-CLiRGfWo.js +0 -1
- zenml/zen_server/dashboard/assets/page-CV44mQn9.js +0 -1
- zenml/zen_server/dashboard/assets/page-D5F3DJjm.js +0 -1
- zenml/zen_server/dashboard/assets/page-DI-qTWrm.js +0 -1
- zenml/zen_server/dashboard/assets/page-Dt8VgzbE.js +0 -1
- zenml/zen_server/dashboard/assets/page-oSqx9dkH.js +0 -1
- zenml/zen_server/dashboard/assets/page-p3GqEAUW.js +0 -1
- zenml/zen_server/dashboard/assets/page-qvcUVPE-.js +0 -1
- zenml/zen_server/dashboard/assets/persist-mEZN_fgH.js +0 -1
- zenml/zen_server/dashboard/assets/sharedSchema-BfZcy7aP.js +0 -14
- zenml/zen_server/dashboard/assets/stack-detail-query-CU4egfhp.js +0 -1
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/entry_points.txt +0 -0
@@ -24,7 +24,7 @@ from typing import (
|
|
24
24
|
)
|
25
25
|
from uuid import UUID
|
26
26
|
|
27
|
-
from pydantic import BaseModel, Field, field_validator
|
27
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
28
28
|
|
29
29
|
from zenml.config.source import Source, SourceWithValidator
|
30
30
|
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
|
@@ -430,14 +430,14 @@ class ArtifactVersionResponse(
|
|
430
430
|
class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
|
431
431
|
"""Model to enable advanced filtering of artifact versions."""
|
432
432
|
|
433
|
-
# `name` and `only_unused` refer to properties related to other entities
|
434
|
-
# rather than a field in the db, hence they need to be handled
|
435
|
-
# explicitly
|
436
433
|
FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
437
434
|
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
|
438
435
|
"name",
|
439
436
|
"only_unused",
|
440
437
|
"has_custom_name",
|
438
|
+
"user",
|
439
|
+
"model",
|
440
|
+
"pipeline_run",
|
441
441
|
]
|
442
442
|
artifact_id: Optional[Union[UUID, str]] = Field(
|
443
443
|
default=None,
|
@@ -495,6 +495,22 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
|
|
495
495
|
default=None,
|
496
496
|
description="Filter only artifacts with/without custom names.",
|
497
497
|
)
|
498
|
+
user: Optional[Union[UUID, str]] = Field(
|
499
|
+
default=None,
|
500
|
+
description="Name/ID of the user that created the artifact version.",
|
501
|
+
)
|
502
|
+
model: Optional[Union[UUID, str]] = Field(
|
503
|
+
default=None,
|
504
|
+
description="Name/ID of the model that is associated with this "
|
505
|
+
"artifact version.",
|
506
|
+
)
|
507
|
+
pipeline_run: Optional[Union[UUID, str]] = Field(
|
508
|
+
default=None,
|
509
|
+
description="Name/ID of a pipeline run that is associated with this "
|
510
|
+
"artifact version.",
|
511
|
+
)
|
512
|
+
|
513
|
+
model_config = ConfigDict(protected_namespaces=())
|
498
514
|
|
499
515
|
def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]:
|
500
516
|
"""Get custom filters.
|
@@ -504,15 +520,18 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
|
|
504
520
|
"""
|
505
521
|
custom_filters = super().get_custom_filters()
|
506
522
|
|
507
|
-
from sqlmodel import and_, select
|
523
|
+
from sqlmodel import and_, or_, select
|
508
524
|
|
509
|
-
from zenml.zen_stores.schemas
|
525
|
+
from zenml.zen_stores.schemas import (
|
510
526
|
ArtifactSchema,
|
511
527
|
ArtifactVersionSchema,
|
512
|
-
|
513
|
-
|
528
|
+
ModelSchema,
|
529
|
+
ModelVersionArtifactSchema,
|
530
|
+
PipelineRunSchema,
|
514
531
|
StepRunInputArtifactSchema,
|
515
532
|
StepRunOutputArtifactSchema,
|
533
|
+
StepRunSchema,
|
534
|
+
UserSchema,
|
516
535
|
)
|
517
536
|
|
518
537
|
if self.name:
|
@@ -546,6 +565,48 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
|
|
546
565
|
)
|
547
566
|
custom_filters.append(custom_name_filter)
|
548
567
|
|
568
|
+
if self.user:
|
569
|
+
user_filter = and_(
|
570
|
+
ArtifactVersionSchema.user_id == UserSchema.id,
|
571
|
+
self.generate_name_or_id_query_conditions(
|
572
|
+
value=self.user, table=UserSchema
|
573
|
+
),
|
574
|
+
)
|
575
|
+
custom_filters.append(user_filter)
|
576
|
+
|
577
|
+
if self.model:
|
578
|
+
model_filter = and_(
|
579
|
+
ArtifactVersionSchema.id
|
580
|
+
== ModelVersionArtifactSchema.artifact_version_id,
|
581
|
+
ModelVersionArtifactSchema.model_id == ModelSchema.id,
|
582
|
+
self.generate_name_or_id_query_conditions(
|
583
|
+
value=self.model, table=ModelSchema
|
584
|
+
),
|
585
|
+
)
|
586
|
+
custom_filters.append(model_filter)
|
587
|
+
|
588
|
+
if self.pipeline_run:
|
589
|
+
pipeline_run_filter = and_(
|
590
|
+
or_(
|
591
|
+
and_(
|
592
|
+
ArtifactVersionSchema.id
|
593
|
+
== StepRunOutputArtifactSchema.artifact_id,
|
594
|
+
StepRunOutputArtifactSchema.step_id
|
595
|
+
== StepRunSchema.id,
|
596
|
+
),
|
597
|
+
and_(
|
598
|
+
ArtifactVersionSchema.id
|
599
|
+
== StepRunInputArtifactSchema.artifact_id,
|
600
|
+
StepRunInputArtifactSchema.step_id == StepRunSchema.id,
|
601
|
+
),
|
602
|
+
),
|
603
|
+
StepRunSchema.pipeline_run_id == PipelineRunSchema.id,
|
604
|
+
self.generate_name_or_id_query_conditions(
|
605
|
+
value=self.pipeline_run, table=PipelineRunSchema
|
606
|
+
),
|
607
|
+
)
|
608
|
+
custom_filters.append(pipeline_run_filter)
|
609
|
+
|
549
610
|
return custom_filters
|
550
611
|
|
551
612
|
|
zenml/models/v2/core/model.py
CHANGED
@@ -30,10 +30,11 @@ from zenml.models.v2.base.scoped import (
|
|
30
30
|
from zenml.utils.pagination_utils import depaginate
|
31
31
|
|
32
32
|
if TYPE_CHECKING:
|
33
|
+
from sqlalchemy.sql.elements import ColumnElement
|
34
|
+
|
33
35
|
from zenml.model.model import Model
|
34
36
|
from zenml.models.v2.core.tag import TagResponse
|
35
37
|
|
36
|
-
|
37
38
|
# ------------------ Request Model ------------------
|
38
39
|
|
39
40
|
|
@@ -316,6 +317,16 @@ class ModelResponse(
|
|
316
317
|
class ModelFilter(WorkspaceScopedTaggableFilter):
|
317
318
|
"""Model to enable advanced filtering of all Workspaces."""
|
318
319
|
|
320
|
+
CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
321
|
+
*WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS,
|
322
|
+
"workspace_id",
|
323
|
+
"user_id",
|
324
|
+
]
|
325
|
+
FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
326
|
+
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
|
327
|
+
"user",
|
328
|
+
]
|
329
|
+
|
319
330
|
name: Optional[str] = Field(
|
320
331
|
default=None,
|
321
332
|
description="Name of the Model",
|
@@ -330,9 +341,35 @@ class ModelFilter(WorkspaceScopedTaggableFilter):
|
|
330
341
|
description="User of the Model",
|
331
342
|
union_mode="left_to_right",
|
332
343
|
)
|
344
|
+
user: Optional[Union[UUID, str]] = Field(
|
345
|
+
default=None,
|
346
|
+
description="Name/ID of the user that created the model.",
|
347
|
+
)
|
333
348
|
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
"
|
338
|
-
|
349
|
+
def get_custom_filters(
|
350
|
+
self,
|
351
|
+
) -> List["ColumnElement[bool]"]:
|
352
|
+
"""Get custom filters.
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
A list of custom filters.
|
356
|
+
"""
|
357
|
+
custom_filters = super().get_custom_filters()
|
358
|
+
|
359
|
+
from sqlmodel import and_
|
360
|
+
|
361
|
+
from zenml.zen_stores.schemas import (
|
362
|
+
ModelSchema,
|
363
|
+
UserSchema,
|
364
|
+
)
|
365
|
+
|
366
|
+
if self.user:
|
367
|
+
user_filter = and_(
|
368
|
+
ModelSchema.user_id == UserSchema.id,
|
369
|
+
self.generate_name_or_id_query_conditions(
|
370
|
+
value=self.user, table=UserSchema
|
371
|
+
),
|
372
|
+
)
|
373
|
+
custom_filters.append(user_filter)
|
374
|
+
|
375
|
+
return custom_filters
|
@@ -13,7 +13,16 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Models representing model versions."""
|
15
15
|
|
16
|
-
from typing import
|
16
|
+
from typing import (
|
17
|
+
TYPE_CHECKING,
|
18
|
+
ClassVar,
|
19
|
+
Dict,
|
20
|
+
List,
|
21
|
+
Optional,
|
22
|
+
Type,
|
23
|
+
TypeVar,
|
24
|
+
Union,
|
25
|
+
)
|
17
26
|
from uuid import UUID
|
18
27
|
|
19
28
|
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator
|
@@ -34,6 +43,8 @@ from zenml.models.v2.core.service import ServiceResponse
|
|
34
43
|
from zenml.models.v2.core.tag import TagResponse
|
35
44
|
|
36
45
|
if TYPE_CHECKING:
|
46
|
+
from sqlalchemy.sql.elements import ColumnElement
|
47
|
+
|
37
48
|
from zenml.model.model import Model
|
38
49
|
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
|
39
50
|
from zenml.models.v2.core.model import ModelResponse
|
@@ -582,6 +593,11 @@ class ModelVersionResponse(
|
|
582
593
|
class ModelVersionFilter(WorkspaceScopedTaggableFilter):
|
583
594
|
"""Filter model for model versions."""
|
584
595
|
|
596
|
+
FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
597
|
+
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
|
598
|
+
"user",
|
599
|
+
]
|
600
|
+
|
585
601
|
name: Optional[str] = Field(
|
586
602
|
default=None,
|
587
603
|
description="The name of the Model Version",
|
@@ -605,6 +621,10 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter):
|
|
605
621
|
default=None,
|
606
622
|
union_mode="left_to_right",
|
607
623
|
)
|
624
|
+
user: Optional[Union[UUID, str]] = Field(
|
625
|
+
default=None,
|
626
|
+
description="Name/ID of the user that created the model version.",
|
627
|
+
)
|
608
628
|
|
609
629
|
_model_id: UUID = PrivateAttr(None)
|
610
630
|
|
@@ -623,6 +643,34 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter):
|
|
623
643
|
|
624
644
|
self._model_id = model_id
|
625
645
|
|
646
|
+
def get_custom_filters(
|
647
|
+
self,
|
648
|
+
) -> List["ColumnElement[bool]"]:
|
649
|
+
"""Get custom filters.
|
650
|
+
|
651
|
+
Returns:
|
652
|
+
A list of custom filters.
|
653
|
+
"""
|
654
|
+
custom_filters = super().get_custom_filters()
|
655
|
+
|
656
|
+
from sqlmodel import and_
|
657
|
+
|
658
|
+
from zenml.zen_stores.schemas import (
|
659
|
+
ModelVersionSchema,
|
660
|
+
UserSchema,
|
661
|
+
)
|
662
|
+
|
663
|
+
if self.user:
|
664
|
+
user_filter = and_(
|
665
|
+
ModelVersionSchema.user_id == UserSchema.id,
|
666
|
+
self.generate_name_or_id_query_conditions(
|
667
|
+
value=self.user, table=UserSchema
|
668
|
+
),
|
669
|
+
)
|
670
|
+
custom_filters.append(user_filter)
|
671
|
+
|
672
|
+
return custom_filters
|
673
|
+
|
626
674
|
def apply_filter(
|
627
675
|
self,
|
628
676
|
query: AnyQuery,
|
@@ -166,6 +166,7 @@ class ModelVersionArtifactFilter(WorkspaceScopedFilter):
|
|
166
166
|
"only_model_artifacts",
|
167
167
|
"only_deployment_artifacts",
|
168
168
|
"has_custom_name",
|
169
|
+
"user",
|
169
170
|
]
|
170
171
|
CLI_EXCLUDE_FIELDS = [
|
171
172
|
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
|
@@ -214,6 +215,10 @@ class ModelVersionArtifactFilter(WorkspaceScopedFilter):
|
|
214
215
|
only_model_artifacts: Optional[bool] = False
|
215
216
|
only_deployment_artifacts: Optional[bool] = False
|
216
217
|
has_custom_name: Optional[bool] = None
|
218
|
+
user: Optional[Union[UUID, str]] = Field(
|
219
|
+
default=None,
|
220
|
+
description="Name/ID of the user that created the artifact.",
|
221
|
+
)
|
217
222
|
|
218
223
|
# TODO: In Pydantic v2, the `model_` is a protected namespaces for all
|
219
224
|
# fields defined under base models. If not handled, this raises a warning.
|
@@ -233,12 +238,11 @@ class ModelVersionArtifactFilter(WorkspaceScopedFilter):
|
|
233
238
|
|
234
239
|
from sqlmodel import and_
|
235
240
|
|
236
|
-
from zenml.zen_stores.schemas
|
241
|
+
from zenml.zen_stores.schemas import (
|
237
242
|
ArtifactSchema,
|
238
243
|
ArtifactVersionSchema,
|
239
|
-
)
|
240
|
-
from zenml.zen_stores.schemas.model_schemas import (
|
241
244
|
ModelVersionArtifactSchema,
|
245
|
+
UserSchema,
|
242
246
|
)
|
243
247
|
|
244
248
|
if self.artifact_name:
|
@@ -284,4 +288,15 @@ class ModelVersionArtifactFilter(WorkspaceScopedFilter):
|
|
284
288
|
)
|
285
289
|
custom_filters.append(custom_name_filter)
|
286
290
|
|
291
|
+
if self.user:
|
292
|
+
user_filter = and_(
|
293
|
+
ModelVersionArtifactSchema.artifact_version_id
|
294
|
+
== ArtifactVersionSchema.id,
|
295
|
+
ArtifactVersionSchema.user_id == UserSchema.id,
|
296
|
+
self.generate_name_or_id_query_conditions(
|
297
|
+
value=self.user, table=UserSchema
|
298
|
+
),
|
299
|
+
)
|
300
|
+
custom_filters.append(user_filter)
|
301
|
+
|
287
302
|
return custom_filters
|
@@ -123,10 +123,10 @@ class ModelVersionPipelineRunResponse(
|
|
123
123
|
class ModelVersionPipelineRunFilter(WorkspaceScopedFilter):
|
124
124
|
"""Model version pipeline run links filter model."""
|
125
125
|
|
126
|
-
# Pipeline run name is not a DB field and needs to be handled separately
|
127
126
|
FILTER_EXCLUDE_FIELDS = [
|
128
127
|
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
|
129
128
|
"pipeline_run_name",
|
129
|
+
"user",
|
130
130
|
]
|
131
131
|
CLI_EXCLUDE_FIELDS = [
|
132
132
|
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
|
@@ -167,6 +167,10 @@ class ModelVersionPipelineRunFilter(WorkspaceScopedFilter):
|
|
167
167
|
default=None,
|
168
168
|
description="Name of the pipeline run",
|
169
169
|
)
|
170
|
+
user: Optional[Union[UUID, str]] = Field(
|
171
|
+
default=None,
|
172
|
+
description="Name/ID of the user that created the pipeline run.",
|
173
|
+
)
|
170
174
|
|
171
175
|
# TODO: In Pydantic v2, the `model_` is a protected namespaces for all
|
172
176
|
# fields defined under base models. If not handled, this raises a warning.
|
@@ -186,11 +190,10 @@ class ModelVersionPipelineRunFilter(WorkspaceScopedFilter):
|
|
186
190
|
|
187
191
|
from sqlmodel import and_
|
188
192
|
|
189
|
-
from zenml.zen_stores.schemas
|
193
|
+
from zenml.zen_stores.schemas import (
|
190
194
|
ModelVersionPipelineRunSchema,
|
191
|
-
)
|
192
|
-
from zenml.zen_stores.schemas.pipeline_run_schemas import (
|
193
195
|
PipelineRunSchema,
|
196
|
+
UserSchema,
|
194
197
|
)
|
195
198
|
|
196
199
|
if self.pipeline_run_name:
|
@@ -209,4 +212,15 @@ class ModelVersionPipelineRunFilter(WorkspaceScopedFilter):
|
|
209
212
|
)
|
210
213
|
custom_filters.append(pipeline_run_name_filter)
|
211
214
|
|
215
|
+
if self.user:
|
216
|
+
user_filter = and_(
|
217
|
+
ModelVersionPipelineRunSchema.pipeline_run_id
|
218
|
+
== PipelineRunSchema.id,
|
219
|
+
PipelineRunSchema.user_id == UserSchema.id,
|
220
|
+
self.generate_name_or_id_query_conditions(
|
221
|
+
value=self.user, table=UserSchema
|
222
|
+
),
|
223
|
+
)
|
224
|
+
custom_filters.append(user_filter)
|
225
|
+
|
212
226
|
return custom_filters
|
zenml/models/v2/core/pipeline.py
CHANGED
@@ -13,7 +13,16 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Models representing pipelines."""
|
15
15
|
|
16
|
-
from typing import
|
16
|
+
from typing import (
|
17
|
+
TYPE_CHECKING,
|
18
|
+
Any,
|
19
|
+
ClassVar,
|
20
|
+
List,
|
21
|
+
Optional,
|
22
|
+
Type,
|
23
|
+
TypeVar,
|
24
|
+
Union,
|
25
|
+
)
|
17
26
|
from uuid import UUID
|
18
27
|
|
19
28
|
from pydantic import Field
|
@@ -36,6 +45,8 @@ from zenml.models.v2.base.scoped import (
|
|
36
45
|
from zenml.models.v2.core.tag import TagResponse
|
37
46
|
|
38
47
|
if TYPE_CHECKING:
|
48
|
+
from sqlalchemy.sql.elements import ColumnElement
|
49
|
+
|
39
50
|
from zenml.models.v2.core.pipeline_run import PipelineRunResponse
|
40
51
|
from zenml.zen_stores.schemas import BaseSchema
|
41
52
|
|
@@ -248,11 +259,21 @@ class PipelineFilter(WorkspaceScopedTaggableFilter):
|
|
248
259
|
"""Pipeline filter model."""
|
249
260
|
|
250
261
|
CUSTOM_SORTING_OPTIONS = [SORT_PIPELINES_BY_LATEST_RUN_KEY]
|
262
|
+
FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
263
|
+
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
|
264
|
+
"user",
|
265
|
+
"latest_run_status",
|
266
|
+
]
|
251
267
|
|
252
268
|
name: Optional[str] = Field(
|
253
269
|
default=None,
|
254
270
|
description="Name of the Pipeline",
|
255
271
|
)
|
272
|
+
latest_run_status: Optional[str] = Field(
|
273
|
+
default=None,
|
274
|
+
description="Filter by the status of the latest run of a pipeline. "
|
275
|
+
"This will always be applied as an `AND` filter for now.",
|
276
|
+
)
|
256
277
|
workspace_id: Optional[Union[UUID, str]] = Field(
|
257
278
|
default=None,
|
258
279
|
description="Workspace of the Pipeline",
|
@@ -263,6 +284,92 @@ class PipelineFilter(WorkspaceScopedTaggableFilter):
|
|
263
284
|
description="User of the Pipeline",
|
264
285
|
union_mode="left_to_right",
|
265
286
|
)
|
287
|
+
user: Optional[Union[UUID, str]] = Field(
|
288
|
+
default=None,
|
289
|
+
description="Name/ID of the user that created the pipeline.",
|
290
|
+
)
|
291
|
+
|
292
|
+
def apply_filter(
|
293
|
+
self, query: AnyQuery, table: Type["AnySchema"]
|
294
|
+
) -> AnyQuery:
|
295
|
+
"""Applies the filter to a query.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
query: The query to which to apply the filter.
|
299
|
+
table: The query table.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
The query with filter applied.
|
303
|
+
"""
|
304
|
+
query = super().apply_filter(query, table)
|
305
|
+
|
306
|
+
from sqlmodel import and_, col, func, select
|
307
|
+
|
308
|
+
from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema
|
309
|
+
|
310
|
+
if self.latest_run_status:
|
311
|
+
latest_pipeline_run_subquery = (
|
312
|
+
select(
|
313
|
+
PipelineRunSchema.pipeline_id,
|
314
|
+
func.max(PipelineRunSchema.created).label("created"),
|
315
|
+
)
|
316
|
+
.where(col(PipelineRunSchema.pipeline_id).is_not(None))
|
317
|
+
.group_by(col(PipelineRunSchema.pipeline_id))
|
318
|
+
.subquery()
|
319
|
+
)
|
320
|
+
|
321
|
+
query = (
|
322
|
+
query.join(
|
323
|
+
PipelineRunSchema,
|
324
|
+
PipelineSchema.id == PipelineRunSchema.pipeline_id,
|
325
|
+
)
|
326
|
+
.join(
|
327
|
+
latest_pipeline_run_subquery,
|
328
|
+
and_(
|
329
|
+
PipelineRunSchema.pipeline_id
|
330
|
+
== latest_pipeline_run_subquery.c.pipeline_id,
|
331
|
+
PipelineRunSchema.created
|
332
|
+
== latest_pipeline_run_subquery.c.created,
|
333
|
+
),
|
334
|
+
)
|
335
|
+
.where(
|
336
|
+
self.generate_custom_query_conditions_for_column(
|
337
|
+
value=self.latest_run_status,
|
338
|
+
table=PipelineRunSchema,
|
339
|
+
column="status",
|
340
|
+
)
|
341
|
+
)
|
342
|
+
)
|
343
|
+
|
344
|
+
return query
|
345
|
+
|
346
|
+
def get_custom_filters(
|
347
|
+
self,
|
348
|
+
) -> List["ColumnElement[bool]"]:
|
349
|
+
"""Get custom filters.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
A list of custom filters.
|
353
|
+
"""
|
354
|
+
custom_filters = super().get_custom_filters()
|
355
|
+
|
356
|
+
from sqlmodel import and_
|
357
|
+
|
358
|
+
from zenml.zen_stores.schemas import (
|
359
|
+
PipelineSchema,
|
360
|
+
UserSchema,
|
361
|
+
)
|
362
|
+
|
363
|
+
if self.user:
|
364
|
+
user_filter = and_(
|
365
|
+
PipelineSchema.user_id == UserSchema.id,
|
366
|
+
self.generate_name_or_id_query_conditions(
|
367
|
+
value=self.user, table=UserSchema
|
368
|
+
),
|
369
|
+
)
|
370
|
+
custom_filters.append(user_filter)
|
371
|
+
|
372
|
+
return custom_filters
|
266
373
|
|
267
374
|
def apply_sorting(
|
268
375
|
self,
|