zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240924__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/base.py +2 -2
- zenml/cli/utils.py +14 -11
- zenml/client.py +68 -3
- zenml/config/step_configurations.py +0 -5
- zenml/enums.py +2 -0
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +81 -43
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
- zenml/models/v2/base/filter.py +315 -149
- zenml/models/v2/base/scoped.py +5 -2
- zenml/models/v2/core/artifact_version.py +69 -8
- zenml/models/v2/core/model.py +43 -6
- zenml/models/v2/core/model_version.py +49 -1
- zenml/models/v2/core/model_version_artifact.py +18 -3
- zenml/models/v2/core/model_version_pipeline_run.py +18 -4
- zenml/models/v2/core/pipeline.py +108 -1
- zenml/models/v2/core/pipeline_run.py +110 -20
- zenml/models/v2/core/run_template.py +53 -1
- zenml/models/v2/core/stack.py +33 -5
- zenml/models/v2/core/step_run.py +7 -0
- zenml/new/pipelines/pipeline.py +4 -0
- zenml/utils/env_utils.py +54 -1
- zenml/utils/string_utils.py +50 -0
- zenml/zen_stores/sql_zen_store.py +1 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240924.dist-info}/METADATA +1 -1
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240924.dist-info}/RECORD +30 -30
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240924.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240924.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240924.dist-info}/entry_points.txt +0 -0
@@ -28,8 +28,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|
28
28
|
|
29
29
|
from zenml.config.pipeline_configurations import PipelineConfiguration
|
30
30
|
from zenml.constants import STR_FIELD_MAX_LENGTH
|
31
|
-
from zenml.enums import ExecutionStatus
|
32
|
-
from zenml.models.v2.base.filter import StrFilter
|
31
|
+
from zenml.enums import ExecutionStatus
|
33
32
|
from zenml.models.v2.base.scoped import (
|
34
33
|
WorkspaceScopedFilter,
|
35
34
|
WorkspaceScopedRequest,
|
@@ -522,6 +521,11 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
|
|
522
521
|
"schedule_id",
|
523
522
|
"stack_id",
|
524
523
|
"template_id",
|
524
|
+
"user",
|
525
|
+
"pipeline",
|
526
|
+
"stack",
|
527
|
+
"code_repository",
|
528
|
+
"model",
|
525
529
|
"pipeline_name",
|
526
530
|
"templatable",
|
527
531
|
]
|
@@ -538,10 +542,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
|
|
538
542
|
description="Pipeline associated with the Pipeline Run",
|
539
543
|
union_mode="left_to_right",
|
540
544
|
)
|
541
|
-
pipeline_name: Optional[str] = Field(
|
542
|
-
default=None,
|
543
|
-
description="Name of the pipeline associated with the run",
|
544
|
-
)
|
545
545
|
workspace_id: Optional[Union[UUID, str]] = Field(
|
546
546
|
default=None,
|
547
547
|
description="Workspace of the Pipeline Run",
|
@@ -582,6 +582,11 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
|
|
582
582
|
description="Template used for the pipeline run.",
|
583
583
|
union_mode="left_to_right",
|
584
584
|
)
|
585
|
+
model_version_id: Optional[Union[UUID, str]] = Field(
|
586
|
+
default=None,
|
587
|
+
description="Model version associated with the pipeline run.",
|
588
|
+
union_mode="left_to_right",
|
589
|
+
)
|
585
590
|
status: Optional[str] = Field(
|
586
591
|
default=None,
|
587
592
|
description="Name of the Pipeline Run",
|
@@ -597,7 +602,37 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
|
|
597
602
|
union_mode="left_to_right",
|
598
603
|
)
|
599
604
|
unlisted: Optional[bool] = None
|
600
|
-
|
605
|
+
user: Optional[Union[UUID, str]] = Field(
|
606
|
+
default=None,
|
607
|
+
description="Name/ID of the user that created the run.",
|
608
|
+
)
|
609
|
+
# TODO: Remove once frontend is ready for it. This is replaced by the more
|
610
|
+
# generic `pipeline` filter below.
|
611
|
+
pipeline_name: Optional[str] = Field(
|
612
|
+
default=None,
|
613
|
+
description="Name of the pipeline associated with the run",
|
614
|
+
)
|
615
|
+
pipeline: Optional[Union[UUID, str]] = Field(
|
616
|
+
default=None,
|
617
|
+
description="Name/ID of the pipeline associated with the run.",
|
618
|
+
)
|
619
|
+
stack: Optional[Union[UUID, str]] = Field(
|
620
|
+
default=None,
|
621
|
+
description="Name/ID of the stack associated with the run.",
|
622
|
+
)
|
623
|
+
code_repository: Optional[Union[UUID, str]] = Field(
|
624
|
+
default=None,
|
625
|
+
description="Name/ID of the code repository associated with the run.",
|
626
|
+
)
|
627
|
+
model: Optional[Union[UUID, str]] = Field(
|
628
|
+
default=None,
|
629
|
+
description="Name/ID of the model associated with the run.",
|
630
|
+
)
|
631
|
+
templatable: Optional[bool] = Field(
|
632
|
+
default=None, description="Whether the run is templatable."
|
633
|
+
)
|
634
|
+
|
635
|
+
model_config = ConfigDict(protected_namespaces=())
|
601
636
|
|
602
637
|
def get_custom_filters(
|
603
638
|
self,
|
@@ -613,12 +648,16 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
|
|
613
648
|
|
614
649
|
from zenml.zen_stores.schemas import (
|
615
650
|
CodeReferenceSchema,
|
651
|
+
CodeRepositorySchema,
|
652
|
+
ModelSchema,
|
653
|
+
ModelVersionSchema,
|
616
654
|
PipelineBuildSchema,
|
617
655
|
PipelineDeploymentSchema,
|
618
656
|
PipelineRunSchema,
|
619
657
|
PipelineSchema,
|
620
658
|
ScheduleSchema,
|
621
659
|
StackSchema,
|
660
|
+
UserSchema,
|
622
661
|
)
|
623
662
|
|
624
663
|
if self.unlisted is not None:
|
@@ -628,19 +667,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
|
|
628
667
|
unlisted_filter = PipelineRunSchema.pipeline_id.is_not(None) # type: ignore[union-attr]
|
629
668
|
custom_filters.append(unlisted_filter)
|
630
669
|
|
631
|
-
if self.pipeline_name is not None:
|
632
|
-
value, filter_operator = self._resolve_operator(self.pipeline_name)
|
633
|
-
filter_ = StrFilter(
|
634
|
-
operation=GenericFilterOps(filter_operator),
|
635
|
-
column="name",
|
636
|
-
value=value,
|
637
|
-
)
|
638
|
-
pipeline_name_filter = and_(
|
639
|
-
PipelineRunSchema.pipeline_id == PipelineSchema.id,
|
640
|
-
filter_.generate_query_conditions(PipelineSchema),
|
641
|
-
)
|
642
|
-
custom_filters.append(pipeline_name_filter)
|
643
|
-
|
644
670
|
if self.code_repository_id:
|
645
671
|
code_repo_filter = and_(
|
646
672
|
PipelineRunSchema.deployment_id == PipelineDeploymentSchema.id,
|
@@ -682,6 +708,70 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
|
|
682
708
|
)
|
683
709
|
custom_filters.append(run_template_filter)
|
684
710
|
|
711
|
+
if self.user:
|
712
|
+
user_filter = and_(
|
713
|
+
PipelineRunSchema.user_id == UserSchema.id,
|
714
|
+
self.generate_name_or_id_query_conditions(
|
715
|
+
value=self.user, table=UserSchema
|
716
|
+
),
|
717
|
+
)
|
718
|
+
custom_filters.append(user_filter)
|
719
|
+
|
720
|
+
if self.pipeline:
|
721
|
+
pipeline_filter = and_(
|
722
|
+
PipelineRunSchema.pipeline_id == PipelineSchema.id,
|
723
|
+
self.generate_name_or_id_query_conditions(
|
724
|
+
value=self.pipeline, table=PipelineSchema
|
725
|
+
),
|
726
|
+
)
|
727
|
+
custom_filters.append(pipeline_filter)
|
728
|
+
|
729
|
+
if self.stack:
|
730
|
+
stack_filter = and_(
|
731
|
+
PipelineRunSchema.deployment_id == PipelineDeploymentSchema.id,
|
732
|
+
PipelineDeploymentSchema.stack_id == StackSchema.id,
|
733
|
+
self.generate_name_or_id_query_conditions(
|
734
|
+
value=self.stack,
|
735
|
+
table=StackSchema,
|
736
|
+
),
|
737
|
+
)
|
738
|
+
custom_filters.append(stack_filter)
|
739
|
+
|
740
|
+
if self.code_repository:
|
741
|
+
code_repo_filter = and_(
|
742
|
+
PipelineRunSchema.deployment_id == PipelineDeploymentSchema.id,
|
743
|
+
PipelineDeploymentSchema.code_reference_id
|
744
|
+
== CodeReferenceSchema.id,
|
745
|
+
CodeReferenceSchema.code_repository_id
|
746
|
+
== CodeRepositorySchema.id,
|
747
|
+
self.generate_name_or_id_query_conditions(
|
748
|
+
value=self.code_repository,
|
749
|
+
table=CodeRepositorySchema,
|
750
|
+
),
|
751
|
+
)
|
752
|
+
custom_filters.append(code_repo_filter)
|
753
|
+
|
754
|
+
if self.model:
|
755
|
+
model_filter = and_(
|
756
|
+
PipelineRunSchema.model_version_id == ModelVersionSchema.id,
|
757
|
+
ModelVersionSchema.model_id == ModelSchema.id,
|
758
|
+
self.generate_name_or_id_query_conditions(
|
759
|
+
value=self.model, table=ModelSchema
|
760
|
+
),
|
761
|
+
)
|
762
|
+
custom_filters.append(model_filter)
|
763
|
+
|
764
|
+
if self.pipeline_name:
|
765
|
+
pipeline_name_filter = and_(
|
766
|
+
PipelineRunSchema.pipeline_id == PipelineSchema.id,
|
767
|
+
self.generate_custom_query_conditions_for_column(
|
768
|
+
value=self.pipeline_name,
|
769
|
+
table=PipelineSchema,
|
770
|
+
column="name",
|
771
|
+
),
|
772
|
+
)
|
773
|
+
custom_filters.append(pipeline_name_filter)
|
774
|
+
|
685
775
|
if self.templatable is not None:
|
686
776
|
if self.templatable is True:
|
687
777
|
templatable_filter = and_(
|
@@ -299,7 +299,11 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter):
|
|
299
299
|
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
|
300
300
|
"code_repository_id",
|
301
301
|
"stack_id",
|
302
|
-
"build_id"
|
302
|
+
"build_id",
|
303
|
+
"pipeline_id",
|
304
|
+
"user",
|
305
|
+
"pipeline",
|
306
|
+
"stack",
|
303
307
|
]
|
304
308
|
|
305
309
|
name: Optional[str] = Field(
|
@@ -336,6 +340,18 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter):
|
|
336
340
|
description="Code repository associated with the template.",
|
337
341
|
union_mode="left_to_right",
|
338
342
|
)
|
343
|
+
user: Optional[Union[UUID, str]] = Field(
|
344
|
+
default=None,
|
345
|
+
description="Name/ID of the user that created the template.",
|
346
|
+
)
|
347
|
+
pipeline: Optional[Union[UUID, str]] = Field(
|
348
|
+
default=None,
|
349
|
+
description="Name/ID of the pipeline associated with the template.",
|
350
|
+
)
|
351
|
+
stack: Optional[Union[UUID, str]] = Field(
|
352
|
+
default=None,
|
353
|
+
description="Name/ID of the stack associated with the template.",
|
354
|
+
)
|
339
355
|
|
340
356
|
def get_custom_filters(
|
341
357
|
self,
|
@@ -352,7 +368,10 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter):
|
|
352
368
|
from zenml.zen_stores.schemas import (
|
353
369
|
CodeReferenceSchema,
|
354
370
|
PipelineDeploymentSchema,
|
371
|
+
PipelineSchema,
|
355
372
|
RunTemplateSchema,
|
373
|
+
StackSchema,
|
374
|
+
UserSchema,
|
356
375
|
)
|
357
376
|
|
358
377
|
if self.code_repository_id:
|
@@ -390,4 +409,37 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter):
|
|
390
409
|
)
|
391
410
|
custom_filters.append(pipeline_filter)
|
392
411
|
|
412
|
+
if self.user:
|
413
|
+
user_filter = and_(
|
414
|
+
RunTemplateSchema.user_id == UserSchema.id,
|
415
|
+
self.generate_name_or_id_query_conditions(
|
416
|
+
value=self.user, table=UserSchema
|
417
|
+
),
|
418
|
+
)
|
419
|
+
custom_filters.append(user_filter)
|
420
|
+
|
421
|
+
if self.pipeline:
|
422
|
+
pipeline_filter = and_(
|
423
|
+
RunTemplateSchema.source_deployment_id
|
424
|
+
== PipelineDeploymentSchema.id,
|
425
|
+
PipelineDeploymentSchema.pipeline_id == PipelineSchema.id,
|
426
|
+
self.generate_name_or_id_query_conditions(
|
427
|
+
value=self.pipeline,
|
428
|
+
table=PipelineSchema,
|
429
|
+
),
|
430
|
+
)
|
431
|
+
custom_filters.append(pipeline_filter)
|
432
|
+
|
433
|
+
if self.stack:
|
434
|
+
stack_filter = and_(
|
435
|
+
RunTemplateSchema.source_deployment_id
|
436
|
+
== PipelineDeploymentSchema.id,
|
437
|
+
PipelineDeploymentSchema.stack_id == StackSchema.id,
|
438
|
+
self.generate_name_or_id_query_conditions(
|
439
|
+
value=self.stack,
|
440
|
+
table=StackSchema,
|
441
|
+
),
|
442
|
+
)
|
443
|
+
custom_filters.append(stack_filter)
|
444
|
+
|
393
445
|
return custom_filters
|
zenml/models/v2/core/stack.py
CHANGED
@@ -318,12 +318,11 @@ class StackFilter(WorkspaceScopedFilter):
|
|
318
318
|
scoping.
|
319
319
|
"""
|
320
320
|
|
321
|
-
# `component_id` refers to a relationship through a link-table
|
322
|
-
# rather than a field in the db, hence it needs to be handled
|
323
|
-
# explicitly
|
324
321
|
FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
325
322
|
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
|
326
|
-
"component_id",
|
323
|
+
"component_id",
|
324
|
+
"user",
|
325
|
+
"component",
|
327
326
|
]
|
328
327
|
|
329
328
|
name: Optional[str] = Field(
|
@@ -348,6 +347,13 @@ class StackFilter(WorkspaceScopedFilter):
|
|
348
347
|
description="Component in the stack",
|
349
348
|
union_mode="left_to_right",
|
350
349
|
)
|
350
|
+
user: Optional[Union[UUID, str]] = Field(
|
351
|
+
default=None,
|
352
|
+
description="Name/ID of the user that created the stack.",
|
353
|
+
)
|
354
|
+
component: Optional[Union[UUID, str]] = Field(
|
355
|
+
default=None, description="Name/ID of a component in the stack."
|
356
|
+
)
|
351
357
|
|
352
358
|
def get_custom_filters(self) -> List["ColumnElement[bool]"]:
|
353
359
|
"""Get custom filters.
|
@@ -357,9 +363,11 @@ class StackFilter(WorkspaceScopedFilter):
|
|
357
363
|
"""
|
358
364
|
custom_filters = super().get_custom_filters()
|
359
365
|
|
360
|
-
from zenml.zen_stores.schemas
|
366
|
+
from zenml.zen_stores.schemas import (
|
367
|
+
StackComponentSchema,
|
361
368
|
StackCompositionSchema,
|
362
369
|
StackSchema,
|
370
|
+
UserSchema,
|
363
371
|
)
|
364
372
|
|
365
373
|
if self.component_id:
|
@@ -369,4 +377,24 @@ class StackFilter(WorkspaceScopedFilter):
|
|
369
377
|
)
|
370
378
|
custom_filters.append(component_id_filter)
|
371
379
|
|
380
|
+
if self.user:
|
381
|
+
user_filter = and_(
|
382
|
+
StackSchema.user_id == UserSchema.id,
|
383
|
+
self.generate_name_or_id_query_conditions(
|
384
|
+
value=self.user, table=UserSchema
|
385
|
+
),
|
386
|
+
)
|
387
|
+
custom_filters.append(user_filter)
|
388
|
+
|
389
|
+
if self.component:
|
390
|
+
component_filter = and_(
|
391
|
+
StackCompositionSchema.stack_id == StackSchema.id,
|
392
|
+
StackCompositionSchema.component_id == StackComponentSchema.id,
|
393
|
+
self.generate_name_or_id_query_conditions(
|
394
|
+
value=self.component,
|
395
|
+
table=StackComponentSchema,
|
396
|
+
),
|
397
|
+
)
|
398
|
+
custom_filters.append(component_filter)
|
399
|
+
|
372
400
|
return custom_filters
|
zenml/models/v2/core/step_run.py
CHANGED
@@ -536,3 +536,10 @@ class StepRunFilter(WorkspaceScopedFilter):
|
|
536
536
|
description="Workspace of this step run",
|
537
537
|
union_mode="left_to_right",
|
538
538
|
)
|
539
|
+
model_version_id: Optional[Union[UUID, str]] = Field(
|
540
|
+
default=None,
|
541
|
+
description="Model version associated with the pipeline run.",
|
542
|
+
union_mode="left_to_right",
|
543
|
+
)
|
544
|
+
|
545
|
+
model_config = ConfigDict(protected_namespaces=())
|
zenml/new/pipelines/pipeline.py
CHANGED
@@ -84,6 +84,7 @@ from zenml.utils import (
|
|
84
84
|
code_utils,
|
85
85
|
dashboard_utils,
|
86
86
|
dict_utils,
|
87
|
+
env_utils,
|
87
88
|
pydantic_utils,
|
88
89
|
settings_utils,
|
89
90
|
source_utils,
|
@@ -1030,12 +1031,14 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
1030
1031
|
|
1031
1032
|
# Update with the values in code so they take precedence
|
1032
1033
|
run_config = pydantic_utils.update_model(run_config, update=update)
|
1034
|
+
run_config = env_utils.substitute_env_variable_placeholders(run_config)
|
1033
1035
|
|
1034
1036
|
deployment = Compiler().compile(
|
1035
1037
|
pipeline=self,
|
1036
1038
|
stack=Client().active_stack,
|
1037
1039
|
run_configuration=run_config,
|
1038
1040
|
)
|
1041
|
+
deployment = env_utils.substitute_env_variable_placeholders(deployment)
|
1039
1042
|
|
1040
1043
|
return deployment, run_config.schedule, run_config.build
|
1041
1044
|
|
@@ -1252,6 +1255,7 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
1252
1255
|
if config_path:
|
1253
1256
|
with open(config_path, "r") as f:
|
1254
1257
|
_from_config_file = yaml.load(f, Loader=yaml.SafeLoader)
|
1258
|
+
|
1255
1259
|
_from_config_file = dict_utils.remove_none_values(
|
1256
1260
|
{k: v for k, v in _from_config_file.items() if k in matcher}
|
1257
1261
|
)
|
zenml/utils/env_utils.py
CHANGED
@@ -14,7 +14,16 @@
|
|
14
14
|
"""Utility functions for handling environment variables."""
|
15
15
|
|
16
16
|
import os
|
17
|
-
|
17
|
+
import re
|
18
|
+
from typing import Any, Dict, List, Match, Optional, TypeVar, cast
|
19
|
+
|
20
|
+
from zenml.logger import get_logger
|
21
|
+
from zenml.utils import string_utils
|
22
|
+
|
23
|
+
logger = get_logger(__name__)
|
24
|
+
|
25
|
+
V = TypeVar("V", bound=Any)
|
26
|
+
ENV_VARIABLE_PLACEHOLDER_PATTERN = re.compile(pattern=r"\$\{([a-zA-Z0-9_]+)\}")
|
18
27
|
|
19
28
|
ENV_VAR_CHUNK_SUFFIX = "_CHUNK_"
|
20
29
|
|
@@ -99,3 +108,47 @@ def reconstruct_environment_variables(
|
|
99
108
|
# Remove the chunk environment variables
|
100
109
|
for key in chunk_keys:
|
101
110
|
env.pop(key)
|
111
|
+
|
112
|
+
|
113
|
+
def substitute_env_variable_placeholders(
|
114
|
+
value: V, raise_when_missing: bool = True
|
115
|
+
) -> V:
|
116
|
+
"""Substitute environment variable placeholders in an object.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
value: The object in which to substitute the placeholders.
|
120
|
+
raise_when_missing: If True, an exception will be raised when an
|
121
|
+
environment variable is missing. Otherwise, a warning will be logged
|
122
|
+
instead.
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
The object with placeholders substituted.
|
126
|
+
"""
|
127
|
+
|
128
|
+
def _replace_with_env_variable_value(match: Match[str]) -> str:
|
129
|
+
key = match.group(1)
|
130
|
+
if key in os.environ:
|
131
|
+
return os.environ[key]
|
132
|
+
else:
|
133
|
+
if raise_when_missing:
|
134
|
+
raise KeyError(
|
135
|
+
"Unable to substitute environment variable placeholder "
|
136
|
+
f"'{key}' because the environment variable is not set."
|
137
|
+
)
|
138
|
+
else:
|
139
|
+
logger.warning(
|
140
|
+
"Unable to substitute environment variable placeholder %s "
|
141
|
+
"because the environment variable is not set, using an "
|
142
|
+
"empty string instead.",
|
143
|
+
key,
|
144
|
+
)
|
145
|
+
return ""
|
146
|
+
|
147
|
+
def _substitution_func(v: str) -> str:
|
148
|
+
return ENV_VARIABLE_PLACEHOLDER_PATTERN.sub(
|
149
|
+
_replace_with_env_variable_value, v
|
150
|
+
)
|
151
|
+
|
152
|
+
return string_utils.substitute_string(
|
153
|
+
value=value, substitution_func=_substitution_func
|
154
|
+
)
|
zenml/utils/string_utils.py
CHANGED
@@ -15,13 +15,17 @@
|
|
15
15
|
|
16
16
|
import base64
|
17
17
|
import datetime
|
18
|
+
import functools
|
18
19
|
import random
|
19
20
|
import string
|
21
|
+
from typing import Any, Callable, Dict, TypeVar, cast
|
20
22
|
|
21
23
|
from pydantic import BaseModel
|
22
24
|
|
23
25
|
from zenml.constants import BANNED_NAME_CHARACTERS
|
24
26
|
|
27
|
+
V = TypeVar("V", bound=Any)
|
28
|
+
|
25
29
|
|
26
30
|
def get_human_readable_time(seconds: float) -> str:
|
27
31
|
"""Convert seconds into a human-readable string.
|
@@ -167,3 +171,49 @@ def format_name_template(
|
|
167
171
|
datetime.datetime.now(datetime.timezone.utc).strftime("%H_%M_%S_%f"),
|
168
172
|
)
|
169
173
|
return name_template.format(**kwargs)
|
174
|
+
|
175
|
+
|
176
|
+
def substitute_string(value: V, substitution_func: Callable[[str], str]) -> V:
|
177
|
+
"""Recursively substitute strings in objects.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
value: An object in which the strings should be recursively substituted.
|
181
|
+
This can be a pydantic model, dict, set, list, tuple or any
|
182
|
+
primitive type.
|
183
|
+
substitution_func: The function that does the actual string
|
184
|
+
substitution.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
The object with the substitution function applied to all string values.
|
188
|
+
"""
|
189
|
+
substitute_ = functools.partial(
|
190
|
+
substitute_string, substitution_func=substitution_func
|
191
|
+
)
|
192
|
+
|
193
|
+
if isinstance(value, BaseModel):
|
194
|
+
model_values = {}
|
195
|
+
|
196
|
+
for k, v in value.__iter__():
|
197
|
+
new_value = substitute_(v)
|
198
|
+
|
199
|
+
if k not in value.model_fields_set and new_value == getattr(
|
200
|
+
value, k
|
201
|
+
):
|
202
|
+
# This is a default value on the model and was not set
|
203
|
+
# explicitly. In this case, we don't include it in the model
|
204
|
+
# values to keep the `exclude_unset` behavior the same
|
205
|
+
continue
|
206
|
+
|
207
|
+
model_values[k] = new_value
|
208
|
+
|
209
|
+
return cast(V, type(value).model_validate(model_values))
|
210
|
+
elif isinstance(value, Dict):
|
211
|
+
return cast(
|
212
|
+
V, {substitute_(k): substitute_(v) for k, v in value.items()}
|
213
|
+
)
|
214
|
+
elif isinstance(value, (list, set, tuple)):
|
215
|
+
return cast(V, type(value)(substitute_(v) for v in value))
|
216
|
+
elif isinstance(value, str):
|
217
|
+
return cast(V, substitution_func(value))
|
218
|
+
|
219
|
+
return value
|
@@ -973,6 +973,7 @@ class SqlZenStore(BaseZenStore):
|
|
973
973
|
ValueError: if the filtered page number is out of bounds.
|
974
974
|
RuntimeError: if the schema does not have a `to_model` method.
|
975
975
|
"""
|
976
|
+
query = query.distinct()
|
976
977
|
query = filter_model.apply_filter(query=query, table=table)
|
977
978
|
query = query.distinct()
|
978
979
|
|