zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240924__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/base.py +2 -2
  3. zenml/cli/utils.py +14 -11
  4. zenml/client.py +68 -3
  5. zenml/config/step_configurations.py +0 -5
  6. zenml/enums.py +2 -0
  7. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
  8. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +81 -43
  9. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
  10. zenml/models/v2/base/filter.py +315 -149
  11. zenml/models/v2/base/scoped.py +5 -2
  12. zenml/models/v2/core/artifact_version.py +69 -8
  13. zenml/models/v2/core/model.py +43 -6
  14. zenml/models/v2/core/model_version.py +49 -1
  15. zenml/models/v2/core/model_version_artifact.py +18 -3
  16. zenml/models/v2/core/model_version_pipeline_run.py +18 -4
  17. zenml/models/v2/core/pipeline.py +108 -1
  18. zenml/models/v2/core/pipeline_run.py +110 -20
  19. zenml/models/v2/core/run_template.py +53 -1
  20. zenml/models/v2/core/stack.py +33 -5
  21. zenml/models/v2/core/step_run.py +7 -0
  22. zenml/new/pipelines/pipeline.py +4 -0
  23. zenml/utils/env_utils.py +54 -1
  24. zenml/utils/string_utils.py +50 -0
  25. zenml/zen_stores/sql_zen_store.py +1 -0
  26. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240924.dist-info}/METADATA +1 -1
  27. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240924.dist-info}/RECORD +30 -30
  28. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240924.dist-info}/LICENSE +0 -0
  29. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240924.dist-info}/WHEEL +0 -0
  30. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240924.dist-info}/entry_points.txt +0 -0
@@ -28,8 +28,7 @@ from pydantic import BaseModel, ConfigDict, Field
28
28
 
29
29
  from zenml.config.pipeline_configurations import PipelineConfiguration
30
30
  from zenml.constants import STR_FIELD_MAX_LENGTH
31
- from zenml.enums import ExecutionStatus, GenericFilterOps
32
- from zenml.models.v2.base.filter import StrFilter
31
+ from zenml.enums import ExecutionStatus
33
32
  from zenml.models.v2.base.scoped import (
34
33
  WorkspaceScopedFilter,
35
34
  WorkspaceScopedRequest,
@@ -522,6 +521,11 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
522
521
  "schedule_id",
523
522
  "stack_id",
524
523
  "template_id",
524
+ "user",
525
+ "pipeline",
526
+ "stack",
527
+ "code_repository",
528
+ "model",
525
529
  "pipeline_name",
526
530
  "templatable",
527
531
  ]
@@ -538,10 +542,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
538
542
  description="Pipeline associated with the Pipeline Run",
539
543
  union_mode="left_to_right",
540
544
  )
541
- pipeline_name: Optional[str] = Field(
542
- default=None,
543
- description="Name of the pipeline associated with the run",
544
- )
545
545
  workspace_id: Optional[Union[UUID, str]] = Field(
546
546
  default=None,
547
547
  description="Workspace of the Pipeline Run",
@@ -582,6 +582,11 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
582
582
  description="Template used for the pipeline run.",
583
583
  union_mode="left_to_right",
584
584
  )
585
+ model_version_id: Optional[Union[UUID, str]] = Field(
586
+ default=None,
587
+ description="Model version associated with the pipeline run.",
588
+ union_mode="left_to_right",
589
+ )
585
590
  status: Optional[str] = Field(
586
591
  default=None,
587
592
  description="Name of the Pipeline Run",
@@ -597,7 +602,37 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
597
602
  union_mode="left_to_right",
598
603
  )
599
604
  unlisted: Optional[bool] = None
600
- templatable: Optional[bool] = None
605
+ user: Optional[Union[UUID, str]] = Field(
606
+ default=None,
607
+ description="Name/ID of the user that created the run.",
608
+ )
609
+ # TODO: Remove once frontend is ready for it. This is replaced by the more
610
+ # generic `pipeline` filter below.
611
+ pipeline_name: Optional[str] = Field(
612
+ default=None,
613
+ description="Name of the pipeline associated with the run",
614
+ )
615
+ pipeline: Optional[Union[UUID, str]] = Field(
616
+ default=None,
617
+ description="Name/ID of the pipeline associated with the run.",
618
+ )
619
+ stack: Optional[Union[UUID, str]] = Field(
620
+ default=None,
621
+ description="Name/ID of the stack associated with the run.",
622
+ )
623
+ code_repository: Optional[Union[UUID, str]] = Field(
624
+ default=None,
625
+ description="Name/ID of the code repository associated with the run.",
626
+ )
627
+ model: Optional[Union[UUID, str]] = Field(
628
+ default=None,
629
+ description="Name/ID of the model associated with the run.",
630
+ )
631
+ templatable: Optional[bool] = Field(
632
+ default=None, description="Whether the run is templatable."
633
+ )
634
+
635
+ model_config = ConfigDict(protected_namespaces=())
601
636
 
602
637
  def get_custom_filters(
603
638
  self,
@@ -613,12 +648,16 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
613
648
 
614
649
  from zenml.zen_stores.schemas import (
615
650
  CodeReferenceSchema,
651
+ CodeRepositorySchema,
652
+ ModelSchema,
653
+ ModelVersionSchema,
616
654
  PipelineBuildSchema,
617
655
  PipelineDeploymentSchema,
618
656
  PipelineRunSchema,
619
657
  PipelineSchema,
620
658
  ScheduleSchema,
621
659
  StackSchema,
660
+ UserSchema,
622
661
  )
623
662
 
624
663
  if self.unlisted is not None:
@@ -628,19 +667,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
628
667
  unlisted_filter = PipelineRunSchema.pipeline_id.is_not(None) # type: ignore[union-attr]
629
668
  custom_filters.append(unlisted_filter)
630
669
 
631
- if self.pipeline_name is not None:
632
- value, filter_operator = self._resolve_operator(self.pipeline_name)
633
- filter_ = StrFilter(
634
- operation=GenericFilterOps(filter_operator),
635
- column="name",
636
- value=value,
637
- )
638
- pipeline_name_filter = and_(
639
- PipelineRunSchema.pipeline_id == PipelineSchema.id,
640
- filter_.generate_query_conditions(PipelineSchema),
641
- )
642
- custom_filters.append(pipeline_name_filter)
643
-
644
670
  if self.code_repository_id:
645
671
  code_repo_filter = and_(
646
672
  PipelineRunSchema.deployment_id == PipelineDeploymentSchema.id,
@@ -682,6 +708,70 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
682
708
  )
683
709
  custom_filters.append(run_template_filter)
684
710
 
711
+ if self.user:
712
+ user_filter = and_(
713
+ PipelineRunSchema.user_id == UserSchema.id,
714
+ self.generate_name_or_id_query_conditions(
715
+ value=self.user, table=UserSchema
716
+ ),
717
+ )
718
+ custom_filters.append(user_filter)
719
+
720
+ if self.pipeline:
721
+ pipeline_filter = and_(
722
+ PipelineRunSchema.pipeline_id == PipelineSchema.id,
723
+ self.generate_name_or_id_query_conditions(
724
+ value=self.pipeline, table=PipelineSchema
725
+ ),
726
+ )
727
+ custom_filters.append(pipeline_filter)
728
+
729
+ if self.stack:
730
+ stack_filter = and_(
731
+ PipelineRunSchema.deployment_id == PipelineDeploymentSchema.id,
732
+ PipelineDeploymentSchema.stack_id == StackSchema.id,
733
+ self.generate_name_or_id_query_conditions(
734
+ value=self.stack,
735
+ table=StackSchema,
736
+ ),
737
+ )
738
+ custom_filters.append(stack_filter)
739
+
740
+ if self.code_repository:
741
+ code_repo_filter = and_(
742
+ PipelineRunSchema.deployment_id == PipelineDeploymentSchema.id,
743
+ PipelineDeploymentSchema.code_reference_id
744
+ == CodeReferenceSchema.id,
745
+ CodeReferenceSchema.code_repository_id
746
+ == CodeRepositorySchema.id,
747
+ self.generate_name_or_id_query_conditions(
748
+ value=self.code_repository,
749
+ table=CodeRepositorySchema,
750
+ ),
751
+ )
752
+ custom_filters.append(code_repo_filter)
753
+
754
+ if self.model:
755
+ model_filter = and_(
756
+ PipelineRunSchema.model_version_id == ModelVersionSchema.id,
757
+ ModelVersionSchema.model_id == ModelSchema.id,
758
+ self.generate_name_or_id_query_conditions(
759
+ value=self.model, table=ModelSchema
760
+ ),
761
+ )
762
+ custom_filters.append(model_filter)
763
+
764
+ if self.pipeline_name:
765
+ pipeline_name_filter = and_(
766
+ PipelineRunSchema.pipeline_id == PipelineSchema.id,
767
+ self.generate_custom_query_conditions_for_column(
768
+ value=self.pipeline_name,
769
+ table=PipelineSchema,
770
+ column="name",
771
+ ),
772
+ )
773
+ custom_filters.append(pipeline_name_filter)
774
+
685
775
  if self.templatable is not None:
686
776
  if self.templatable is True:
687
777
  templatable_filter = and_(
@@ -299,7 +299,11 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter):
299
299
  *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
300
300
  "code_repository_id",
301
301
  "stack_id",
302
- "build_id" "pipeline_id",
302
+ "build_id",
303
+ "pipeline_id",
304
+ "user",
305
+ "pipeline",
306
+ "stack",
303
307
  ]
304
308
 
305
309
  name: Optional[str] = Field(
@@ -336,6 +340,18 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter):
336
340
  description="Code repository associated with the template.",
337
341
  union_mode="left_to_right",
338
342
  )
343
+ user: Optional[Union[UUID, str]] = Field(
344
+ default=None,
345
+ description="Name/ID of the user that created the template.",
346
+ )
347
+ pipeline: Optional[Union[UUID, str]] = Field(
348
+ default=None,
349
+ description="Name/ID of the pipeline associated with the template.",
350
+ )
351
+ stack: Optional[Union[UUID, str]] = Field(
352
+ default=None,
353
+ description="Name/ID of the stack associated with the template.",
354
+ )
339
355
 
340
356
  def get_custom_filters(
341
357
  self,
@@ -352,7 +368,10 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter):
352
368
  from zenml.zen_stores.schemas import (
353
369
  CodeReferenceSchema,
354
370
  PipelineDeploymentSchema,
371
+ PipelineSchema,
355
372
  RunTemplateSchema,
373
+ StackSchema,
374
+ UserSchema,
356
375
  )
357
376
 
358
377
  if self.code_repository_id:
@@ -390,4 +409,37 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter):
390
409
  )
391
410
  custom_filters.append(pipeline_filter)
392
411
 
412
+ if self.user:
413
+ user_filter = and_(
414
+ RunTemplateSchema.user_id == UserSchema.id,
415
+ self.generate_name_or_id_query_conditions(
416
+ value=self.user, table=UserSchema
417
+ ),
418
+ )
419
+ custom_filters.append(user_filter)
420
+
421
+ if self.pipeline:
422
+ pipeline_filter = and_(
423
+ RunTemplateSchema.source_deployment_id
424
+ == PipelineDeploymentSchema.id,
425
+ PipelineDeploymentSchema.pipeline_id == PipelineSchema.id,
426
+ self.generate_name_or_id_query_conditions(
427
+ value=self.pipeline,
428
+ table=PipelineSchema,
429
+ ),
430
+ )
431
+ custom_filters.append(pipeline_filter)
432
+
433
+ if self.stack:
434
+ stack_filter = and_(
435
+ RunTemplateSchema.source_deployment_id
436
+ == PipelineDeploymentSchema.id,
437
+ PipelineDeploymentSchema.stack_id == StackSchema.id,
438
+ self.generate_name_or_id_query_conditions(
439
+ value=self.stack,
440
+ table=StackSchema,
441
+ ),
442
+ )
443
+ custom_filters.append(stack_filter)
444
+
393
445
  return custom_filters
@@ -318,12 +318,11 @@ class StackFilter(WorkspaceScopedFilter):
318
318
  scoping.
319
319
  """
320
320
 
321
- # `component_id` refers to a relationship through a link-table
322
- # rather than a field in the db, hence it needs to be handled
323
- # explicitly
324
321
  FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
325
322
  *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
326
- "component_id", # This is a relationship, not a field
323
+ "component_id",
324
+ "user",
325
+ "component",
327
326
  ]
328
327
 
329
328
  name: Optional[str] = Field(
@@ -348,6 +347,13 @@ class StackFilter(WorkspaceScopedFilter):
348
347
  description="Component in the stack",
349
348
  union_mode="left_to_right",
350
349
  )
350
+ user: Optional[Union[UUID, str]] = Field(
351
+ default=None,
352
+ description="Name/ID of the user that created the stack.",
353
+ )
354
+ component: Optional[Union[UUID, str]] = Field(
355
+ default=None, description="Name/ID of a component in the stack."
356
+ )
351
357
 
352
358
  def get_custom_filters(self) -> List["ColumnElement[bool]"]:
353
359
  """Get custom filters.
@@ -357,9 +363,11 @@ class StackFilter(WorkspaceScopedFilter):
357
363
  """
358
364
  custom_filters = super().get_custom_filters()
359
365
 
360
- from zenml.zen_stores.schemas.stack_schemas import (
366
+ from zenml.zen_stores.schemas import (
367
+ StackComponentSchema,
361
368
  StackCompositionSchema,
362
369
  StackSchema,
370
+ UserSchema,
363
371
  )
364
372
 
365
373
  if self.component_id:
@@ -369,4 +377,24 @@ class StackFilter(WorkspaceScopedFilter):
369
377
  )
370
378
  custom_filters.append(component_id_filter)
371
379
 
380
+ if self.user:
381
+ user_filter = and_(
382
+ StackSchema.user_id == UserSchema.id,
383
+ self.generate_name_or_id_query_conditions(
384
+ value=self.user, table=UserSchema
385
+ ),
386
+ )
387
+ custom_filters.append(user_filter)
388
+
389
+ if self.component:
390
+ component_filter = and_(
391
+ StackCompositionSchema.stack_id == StackSchema.id,
392
+ StackCompositionSchema.component_id == StackComponentSchema.id,
393
+ self.generate_name_or_id_query_conditions(
394
+ value=self.component,
395
+ table=StackComponentSchema,
396
+ ),
397
+ )
398
+ custom_filters.append(component_filter)
399
+
372
400
  return custom_filters
@@ -536,3 +536,10 @@ class StepRunFilter(WorkspaceScopedFilter):
536
536
  description="Workspace of this step run",
537
537
  union_mode="left_to_right",
538
538
  )
539
+ model_version_id: Optional[Union[UUID, str]] = Field(
540
+ default=None,
541
+ description="Model version associated with the pipeline run.",
542
+ union_mode="left_to_right",
543
+ )
544
+
545
+ model_config = ConfigDict(protected_namespaces=())
@@ -84,6 +84,7 @@ from zenml.utils import (
84
84
  code_utils,
85
85
  dashboard_utils,
86
86
  dict_utils,
87
+ env_utils,
87
88
  pydantic_utils,
88
89
  settings_utils,
89
90
  source_utils,
@@ -1030,12 +1031,14 @@ To avoid this consider setting pipeline parameters only in one place (config or
1030
1031
 
1031
1032
  # Update with the values in code so they take precedence
1032
1033
  run_config = pydantic_utils.update_model(run_config, update=update)
1034
+ run_config = env_utils.substitute_env_variable_placeholders(run_config)
1033
1035
 
1034
1036
  deployment = Compiler().compile(
1035
1037
  pipeline=self,
1036
1038
  stack=Client().active_stack,
1037
1039
  run_configuration=run_config,
1038
1040
  )
1041
+ deployment = env_utils.substitute_env_variable_placeholders(deployment)
1039
1042
 
1040
1043
  return deployment, run_config.schedule, run_config.build
1041
1044
 
@@ -1252,6 +1255,7 @@ To avoid this consider setting pipeline parameters only in one place (config or
1252
1255
  if config_path:
1253
1256
  with open(config_path, "r") as f:
1254
1257
  _from_config_file = yaml.load(f, Loader=yaml.SafeLoader)
1258
+
1255
1259
  _from_config_file = dict_utils.remove_none_values(
1256
1260
  {k: v for k, v in _from_config_file.items() if k in matcher}
1257
1261
  )
zenml/utils/env_utils.py CHANGED
@@ -14,7 +14,16 @@
14
14
  """Utility functions for handling environment variables."""
15
15
 
16
16
  import os
17
- from typing import Dict, List, Optional, cast
17
+ import re
18
+ from typing import Any, Dict, List, Match, Optional, TypeVar, cast
19
+
20
+ from zenml.logger import get_logger
21
+ from zenml.utils import string_utils
22
+
23
+ logger = get_logger(__name__)
24
+
25
+ V = TypeVar("V", bound=Any)
26
+ ENV_VARIABLE_PLACEHOLDER_PATTERN = re.compile(pattern=r"\$\{([a-zA-Z0-9_]+)\}")
18
27
 
19
28
  ENV_VAR_CHUNK_SUFFIX = "_CHUNK_"
20
29
 
@@ -99,3 +108,47 @@ def reconstruct_environment_variables(
99
108
  # Remove the chunk environment variables
100
109
  for key in chunk_keys:
101
110
  env.pop(key)
111
+
112
+
113
+ def substitute_env_variable_placeholders(
114
+ value: V, raise_when_missing: bool = True
115
+ ) -> V:
116
+ """Substitute environment variable placeholders in an object.
117
+
118
+ Args:
119
+ value: The object in which to substitute the placeholders.
120
+ raise_when_missing: If True, an exception will be raised when an
121
+ environment variable is missing. Otherwise, a warning will be logged
122
+ instead.
123
+
124
+ Returns:
125
+ The object with placeholders substituted.
126
+ """
127
+
128
+ def _replace_with_env_variable_value(match: Match[str]) -> str:
129
+ key = match.group(1)
130
+ if key in os.environ:
131
+ return os.environ[key]
132
+ else:
133
+ if raise_when_missing:
134
+ raise KeyError(
135
+ "Unable to substitute environment variable placeholder "
136
+ f"'{key}' because the environment variable is not set."
137
+ )
138
+ else:
139
+ logger.warning(
140
+ "Unable to substitute environment variable placeholder %s "
141
+ "because the environment variable is not set, using an "
142
+ "empty string instead.",
143
+ key,
144
+ )
145
+ return ""
146
+
147
+ def _substitution_func(v: str) -> str:
148
+ return ENV_VARIABLE_PLACEHOLDER_PATTERN.sub(
149
+ _replace_with_env_variable_value, v
150
+ )
151
+
152
+ return string_utils.substitute_string(
153
+ value=value, substitution_func=_substitution_func
154
+ )
@@ -15,13 +15,17 @@
15
15
 
16
16
  import base64
17
17
  import datetime
18
+ import functools
18
19
  import random
19
20
  import string
21
+ from typing import Any, Callable, Dict, TypeVar, cast
20
22
 
21
23
  from pydantic import BaseModel
22
24
 
23
25
  from zenml.constants import BANNED_NAME_CHARACTERS
24
26
 
27
+ V = TypeVar("V", bound=Any)
28
+
25
29
 
26
30
  def get_human_readable_time(seconds: float) -> str:
27
31
  """Convert seconds into a human-readable string.
@@ -167,3 +171,49 @@ def format_name_template(
167
171
  datetime.datetime.now(datetime.timezone.utc).strftime("%H_%M_%S_%f"),
168
172
  )
169
173
  return name_template.format(**kwargs)
174
+
175
+
176
+ def substitute_string(value: V, substitution_func: Callable[[str], str]) -> V:
177
+ """Recursively substitute strings in objects.
178
+
179
+ Args:
180
+ value: An object in which the strings should be recursively substituted.
181
+ This can be a pydantic model, dict, set, list, tuple or any
182
+ primitive type.
183
+ substitution_func: The function that does the actual string
184
+ substitution.
185
+
186
+ Returns:
187
+ The object with the substitution function applied to all string values.
188
+ """
189
+ substitute_ = functools.partial(
190
+ substitute_string, substitution_func=substitution_func
191
+ )
192
+
193
+ if isinstance(value, BaseModel):
194
+ model_values = {}
195
+
196
+ for k, v in value.__iter__():
197
+ new_value = substitute_(v)
198
+
199
+ if k not in value.model_fields_set and new_value == getattr(
200
+ value, k
201
+ ):
202
+ # This is a default value on the model and was not set
203
+ # explicitly. In this case, we don't include it in the model
204
+ # values to keep the `exclude_unset` behavior the same
205
+ continue
206
+
207
+ model_values[k] = new_value
208
+
209
+ return cast(V, type(value).model_validate(model_values))
210
+ elif isinstance(value, Dict):
211
+ return cast(
212
+ V, {substitute_(k): substitute_(v) for k, v in value.items()}
213
+ )
214
+ elif isinstance(value, (list, set, tuple)):
215
+ return cast(V, type(value)(substitute_(v) for v in value))
216
+ elif isinstance(value, str):
217
+ return cast(V, substitution_func(value))
218
+
219
+ return value
@@ -973,6 +973,7 @@ class SqlZenStore(BaseZenStore):
973
973
  ValueError: if the filtered page number is out of bounds.
974
974
  RuntimeError: if the schema does not have a `to_model` method.
975
975
  """
976
+ query = query.distinct()
976
977
  query = filter_model.apply_filter(query=query, table=table)
977
978
  query = query.distinct()
978
979
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: zenml-nightly
3
- Version: 0.66.0.dev20240923
3
+ Version: 0.66.0.dev20240924
4
4
  Summary: ZenML: Write production-ready ML code.
5
5
  Home-page: https://zenml.io
6
6
  License: Apache-2.0