zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240927__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/__init__.py +7 -0
  3. zenml/cli/base.py +2 -2
  4. zenml/cli/pipeline.py +21 -0
  5. zenml/cli/utils.py +14 -11
  6. zenml/client.py +68 -3
  7. zenml/config/step_configurations.py +0 -5
  8. zenml/constants.py +3 -0
  9. zenml/enums.py +2 -0
  10. zenml/integrations/__init__.py +1 -0
  11. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
  12. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +370 -115
  13. zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
  14. zenml/integrations/constants.py +1 -0
  15. zenml/integrations/deepchecks/__init__.py +1 -1
  16. zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +55 -14
  17. zenml/integrations/deepchecks/validation_checks.py +62 -5
  18. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
  19. zenml/integrations/lightning/__init__.py +1 -1
  20. zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
  21. zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
  22. zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
  23. zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
  24. zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py +38 -26
  25. zenml/integrations/skypilot_kubernetes/__init__.py +52 -0
  26. zenml/integrations/skypilot_kubernetes/flavors/__init__.py +26 -0
  27. zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py +125 -0
  28. zenml/integrations/skypilot_kubernetes/orchestrators/__init__.py +25 -0
  29. zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py +74 -0
  30. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
  31. zenml/models/v2/base/filter.py +315 -149
  32. zenml/models/v2/base/scoped.py +5 -2
  33. zenml/models/v2/core/artifact_version.py +69 -8
  34. zenml/models/v2/core/model.py +43 -6
  35. zenml/models/v2/core/model_version.py +49 -1
  36. zenml/models/v2/core/model_version_artifact.py +18 -3
  37. zenml/models/v2/core/model_version_pipeline_run.py +18 -4
  38. zenml/models/v2/core/pipeline.py +108 -1
  39. zenml/models/v2/core/pipeline_run.py +172 -21
  40. zenml/models/v2/core/run_template.py +53 -1
  41. zenml/models/v2/core/stack.py +33 -5
  42. zenml/models/v2/core/step_run.py +7 -0
  43. zenml/new/pipelines/pipeline.py +4 -0
  44. zenml/new/pipelines/run_utils.py +4 -1
  45. zenml/orchestrators/base_orchestrator.py +41 -12
  46. zenml/stack/stack.py +11 -2
  47. zenml/utils/env_utils.py +54 -1
  48. zenml/utils/string_utils.py +50 -0
  49. zenml/zen_server/cloud_utils.py +33 -8
  50. zenml/zen_server/dashboard/assets/{404-iO8vpun1.js → 404-Y50hSt65.js} +1 -1
  51. zenml/zen_server/dashboard/assets/{@reactflow-B6kq9fJZ.js → @reactflow-ytavUpwh.js} +1 -1
  52. zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-xLR9a1iw.js +1 -0
  53. zenml/zen_server/dashboard/assets/{CodeSnippet-DNWdQmbo.js → CodeSnippet-IxXNxUDa.js} +2 -2
  54. zenml/zen_server/dashboard/assets/{CollapsibleCard-B2OVjWYE.js → CollapsibleCard-BhutZbBL.js} +1 -1
  55. zenml/zen_server/dashboard/assets/{Commands-DsoaVElZ.js → Commands-Bf-rd1z8.js} +1 -1
  56. zenml/zen_server/dashboard/assets/ComponentBadge-gKR1OIwG.js +1 -0
  57. zenml/zen_server/dashboard/assets/{CopyButton-BqE_-PHO.js → CopyButton-DcFHidFJ.js} +1 -1
  58. zenml/zen_server/dashboard/assets/{CsvVizualization-Dyasr2jU.js → CsvVizualization-QSbjrfxw.js} +1 -1
  59. zenml/zen_server/dashboard/assets/{DialogItem-Cz1VLRwa.js → DialogItem-Cd3HqST4.js} +1 -1
  60. zenml/zen_server/dashboard/assets/{Error-DorJD_va.js → Error-BhwdmqK7.js} +1 -1
  61. zenml/zen_server/dashboard/assets/{ExecutionStatus-CIfQTutR.js → ExecutionStatus-D6r6aK8J.js} +1 -1
  62. zenml/zen_server/dashboard/assets/{Helpbox-CmfvtNeq.js → Helpbox-0pBpTwTm.js} +1 -1
  63. zenml/zen_server/dashboard/assets/Infobox-BTK_EUKT.js +1 -0
  64. zenml/zen_server/dashboard/assets/{InlineAvatar-Ds2ZFHPc.js → InlineAvatar-CA3DFMcM.js} +1 -1
  65. zenml/zen_server/dashboard/assets/Partials-QLOZw624.js +1 -0
  66. zenml/zen_server/dashboard/assets/{ProviderIcon-BOQJgapd.js → ProviderIcon-C16CCIN4.js} +1 -1
  67. zenml/zen_server/dashboard/assets/{ProviderRadio-BsYBw9YA.js → ProviderRadio-D3FuCHf3.js} +1 -1
  68. zenml/zen_server/dashboard/assets/{SearchField-W3GXpLlI.js → SearchField-BzmfxS0L.js} +1 -1
  69. zenml/zen_server/dashboard/assets/SecretTooltip-BaMwHF-Q.js +1 -0
  70. zenml/zen_server/dashboard/assets/{SetPassword-B-0a8UCj.js → SetPassword-DuIC65H9.js} +1 -1
  71. zenml/zen_server/dashboard/assets/{Tick-i1DYsVcX.js → Tick-DJTCF0Re.js} +1 -1
  72. zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-C6Zb7ASL.js → UpdatePasswordSchemas-CUm-DMpw.js} +1 -1
  73. zenml/zen_server/dashboard/assets/UsageReason-CKw0juLF.js +1 -0
  74. zenml/zen_server/dashboard/assets/{WizardFooter-BHbO7zOa.js → WizardFooter-Cv9ApYWU.js} +1 -1
  75. zenml/zen_server/dashboard/assets/{all-pipeline-runs-query-BBEe6I9-.js → all-pipeline-runs-query-BA3R2Sey.js} +1 -1
  76. zenml/zen_server/dashboard/assets/{cloud-only-BuP4Kt_7.js → cloud-only-BB4BVa6E.js} +1 -1
  77. zenml/zen_server/dashboard/assets/{create-stack-B2x2d4r1.js → create-stack-F29xAUEx.js} +1 -1
  78. zenml/zen_server/dashboard/assets/delete-run-CP0pcJ3U.js +1 -0
  79. zenml/zen_server/dashboard/assets/{form-schemas-Bap0f854.js → form-schemas-BKXwSDK2.js} +1 -1
  80. zenml/zen_server/dashboard/assets/index-BhJ6ZJxv.css +1 -0
  81. zenml/zen_server/dashboard/assets/{index-B9wVwe7u.js → index-Ci0nJ8EZ.js} +5 -5
  82. zenml/zen_server/dashboard/assets/{index-DFi8BroH.js → index-D-mtoBj3.js} +1 -1
  83. zenml/zen_server/dashboard/assets/{login-mutation-DwxUz8VA.js → login-mutation-ax6iL2Mb.js} +1 -1
  84. zenml/zen_server/dashboard/assets/{not-found-D5i9DunU.js → not-found-DbjllLY_.js} +1 -1
  85. zenml/zen_server/dashboard/assets/{page-oS4hqS8M.js → page-3qPX9WYH.js} +1 -1
  86. zenml/zen_server/dashboard/assets/{page-iwoJnwPv.js → page-6mfzecin.js} +1 -1
  87. zenml/zen_server/dashboard/assets/{page-DGMa3ZQL.js → page-8kYmrh0B.js} +1 -1
  88. zenml/zen_server/dashboard/assets/page-B1n7_W7z.js +1 -0
  89. zenml/zen_server/dashboard/assets/page-BDg1F-Ug.js +6 -0
  90. zenml/zen_server/dashboard/assets/{page-xQG6GmFJ.js → page-BXarY9K2.js} +1 -1
  91. zenml/zen_server/dashboard/assets/page-BZZhLo2u.js +1 -0
  92. zenml/zen_server/dashboard/assets/page-Bbf_oBjn.js +1 -0
  93. zenml/zen_server/dashboard/assets/page-BjjuBvZG.js +9 -0
  94. zenml/zen_server/dashboard/assets/{page-J0s8Sq3N.js → page-BukXK1Aa.js} +1 -1
  95. zenml/zen_server/dashboard/assets/page-CHaQkFK5.js +1 -0
  96. zenml/zen_server/dashboard/assets/{page-BitfWsiW.js → page-CKHNAq7z.js} +1 -1
  97. zenml/zen_server/dashboard/assets/{page-DE03uZZR.js → page-CS0SYFK8.js} +1 -1
  98. zenml/zen_server/dashboard/assets/{page-WCQ659by.js → page-CvKnNK1S.js} +1 -1
  99. zenml/zen_server/dashboard/assets/{page-CrSdkteO.js → page-DGM1CbYT.js} +2 -2
  100. zenml/zen_server/dashboard/assets/{page-DQGCHKrQ.js → page-DMSLXKGT.js} +1 -1
  101. zenml/zen_server/dashboard/assets/page-DOmIZ2ra.js +1 -0
  102. zenml/zen_server/dashboard/assets/{page-DgM-N9RL.js → page-DRfcRK1w.js} +1 -1
  103. zenml/zen_server/dashboard/assets/page-DYVmJ9_w.js +3 -0
  104. zenml/zen_server/dashboard/assets/{page-BiF8hLbO.js → page-DcTjHmYZ.js} +1 -1
  105. zenml/zen_server/dashboard/assets/page-DuqYMYmH.js +1 -0
  106. zenml/zen_server/dashboard/assets/page-Dwow2doB.js +1 -0
  107. zenml/zen_server/dashboard/assets/{page-DQdwZZ9x.js → page-HkVBdZl6.js} +1 -1
  108. zenml/zen_server/dashboard/assets/{page-bimkItOg.js → page-MAXyfXBq.js} +1 -1
  109. zenml/zen_server/dashboard/assets/page-miU2rhYG.js +1 -0
  110. zenml/zen_server/dashboard/assets/page-p0BhSAWx.js +1 -0
  111. zenml/zen_server/dashboard/assets/{page-DFCK65G9.js → page-uORspyRu.js} +1 -1
  112. zenml/zen_server/dashboard/assets/persist-BxIR2XZs.js +1 -0
  113. zenml/zen_server/dashboard/assets/{persist-xsYgVtR1.js → persist-CfJMar_k.js} +1 -1
  114. zenml/zen_server/dashboard/assets/sharedSchema-vub0rii3.js +14 -0
  115. zenml/zen_server/dashboard/assets/stack-detail-query-DQcyzG-2.js +1 -0
  116. zenml/zen_server/dashboard/assets/tick-circle-m-hJG8i9.js +1 -0
  117. zenml/zen_server/dashboard/assets/{update-server-settings-mutation-DNqmQXDM.js → update-server-settings-mutation-FGVP7X2U.js} +1 -1
  118. zenml/zen_server/dashboard/assets/{url-DwbuKk1b.js → url-CbAPzsmT.js} +1 -1
  119. zenml/zen_server/dashboard/index.html +4 -4
  120. zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
  121. zenml/zen_server/dashboard_legacy/index.html +1 -1
  122. zenml/zen_server/dashboard_legacy/{precache-manifest.290b95d5b43efa3368b3dc63d20c4782.js → precache-manifest.6d320abb70db612019dda6c4948e7a90.js} +4 -4
  123. zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
  124. zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js → main.fa9299d5.chunk.js} +2 -2
  125. zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js.map → main.fa9299d5.chunk.js.map} +1 -1
  126. zenml/zen_server/routers/runs_endpoints.py +89 -3
  127. zenml/zen_stores/sql_zen_store.py +1 -0
  128. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/METADATA +8 -1
  129. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/RECORD +132 -124
  130. zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-BXeSvmMY.js +0 -1
  131. zenml/zen_server/dashboard/assets/EditSecretDialog-Du423_3U.js +0 -1
  132. zenml/zen_server/dashboard/assets/Infobox-BL9NOS37.js +0 -1
  133. zenml/zen_server/dashboard/assets/Partials-DX-8iEa1.js +0 -1
  134. zenml/zen_server/dashboard/assets/UsageReason-CCnzmwS8.js +0 -1
  135. zenml/zen_server/dashboard/assets/index-6DYjZgDn.css +0 -1
  136. zenml/zen_server/dashboard/assets/page-BFuJICXM.js +0 -9
  137. zenml/zen_server/dashboard/assets/page-CDOQLrPC.js +0 -1
  138. zenml/zen_server/dashboard/assets/page-CEJWu1YO.js +0 -1
  139. zenml/zen_server/dashboard/assets/page-CIbehp7V.js +0 -1
  140. zenml/zen_server/dashboard/assets/page-CLiRGfWo.js +0 -1
  141. zenml/zen_server/dashboard/assets/page-CV44mQn9.js +0 -1
  142. zenml/zen_server/dashboard/assets/page-D5F3DJjm.js +0 -1
  143. zenml/zen_server/dashboard/assets/page-DI-qTWrm.js +0 -1
  144. zenml/zen_server/dashboard/assets/page-Dt8VgzbE.js +0 -1
  145. zenml/zen_server/dashboard/assets/page-oSqx9dkH.js +0 -1
  146. zenml/zen_server/dashboard/assets/page-p3GqEAUW.js +0 -1
  147. zenml/zen_server/dashboard/assets/page-qvcUVPE-.js +0 -1
  148. zenml/zen_server/dashboard/assets/persist-mEZN_fgH.js +0 -1
  149. zenml/zen_server/dashboard/assets/sharedSchema-BfZcy7aP.js +0 -14
  150. zenml/zen_server/dashboard/assets/stack-detail-query-CU4egfhp.js +0 -1
  151. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/LICENSE +0 -0
  152. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/WHEEL +0 -0
  153. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/entry_points.txt +0 -0
@@ -54,7 +54,7 @@ from zenml.utils.typing_utils import get_args
54
54
  if TYPE_CHECKING:
55
55
  from sqlalchemy.sql.elements import ColumnElement
56
56
 
57
- from zenml.zen_stores.schemas import BaseSchema
57
+ from zenml.zen_stores.schemas import BaseSchema, NamedSchema
58
58
 
59
59
  AnySchema = TypeVar("AnySchema", bound=BaseSchema)
60
60
 
@@ -142,7 +142,10 @@ class Filter(BaseModel, ABC):
142
142
  class BoolFilter(Filter):
143
143
  """Filter for all Boolean fields."""
144
144
 
145
- ALLOWED_OPS: ClassVar[List[str]] = [GenericFilterOps.EQUALS]
145
+ ALLOWED_OPS: ClassVar[List[str]] = [
146
+ GenericFilterOps.EQUALS,
147
+ GenericFilterOps.NOT_EQUALS,
148
+ ]
146
149
 
147
150
  def generate_query_conditions_from_column(self, column: Any) -> Any:
148
151
  """Generate query conditions for a boolean column.
@@ -153,6 +156,9 @@ class BoolFilter(Filter):
153
156
  Returns:
154
157
  A list of query conditions.
155
158
  """
159
+ if self.operation == GenericFilterOps.NOT_EQUALS:
160
+ return column != self.value
161
+
156
162
  return column == self.value
157
163
 
158
164
 
@@ -161,6 +167,7 @@ class StrFilter(Filter):
161
167
 
162
168
  ALLOWED_OPS: ClassVar[List[str]] = [
163
169
  GenericFilterOps.EQUALS,
170
+ GenericFilterOps.NOT_EQUALS,
164
171
  GenericFilterOps.STARTSWITH,
165
172
  GenericFilterOps.CONTAINS,
166
173
  GenericFilterOps.ENDSWITH,
@@ -181,12 +188,31 @@ class StrFilter(Filter):
181
188
  return column.startswith(f"{self.value}")
182
189
  if self.operation == GenericFilterOps.ENDSWITH:
183
190
  return column.endswith(f"{self.value}")
191
+ if self.operation == GenericFilterOps.NOT_EQUALS:
192
+ return column != self.value
193
+
184
194
  return column == self.value
185
195
 
186
196
 
187
197
  class UUIDFilter(StrFilter):
188
198
  """Filter for all uuid fields which are mostly treated like strings."""
189
199
 
200
+ @field_validator("value", mode="before")
201
+ @classmethod
202
+ def _remove_hyphens_from_value(cls, value: Any) -> Any:
203
+ """Remove hyphens from the value to enable string comparisons.
204
+
205
+ Args:
206
+ value: The filter value.
207
+
208
+ Returns:
209
+ The filter value with removed hyphens.
210
+ """
211
+ if isinstance(value, str):
212
+ return value.replace("-", "")
213
+
214
+ return value
215
+
190
216
  def generate_query_conditions_from_column(self, column: Any) -> Any:
191
217
  """Generate query conditions for a UUID column.
192
218
 
@@ -203,6 +229,9 @@ class UUIDFilter(StrFilter):
203
229
  if self.operation == GenericFilterOps.EQUALS:
204
230
  return column == self.value
205
231
 
232
+ if self.operation == GenericFilterOps.NOT_EQUALS:
233
+ return column != self.value
234
+
206
235
  # For all other operations, cast and handle the column as string
207
236
  return super().generate_query_conditions_from_column(
208
237
  column=cast_if(column, sqlalchemy.String)
@@ -216,6 +245,7 @@ class NumericFilter(Filter):
216
245
 
217
246
  ALLOWED_OPS: ClassVar[List[str]] = [
218
247
  GenericFilterOps.EQUALS,
248
+ GenericFilterOps.NOT_EQUALS,
219
249
  GenericFilterOps.GT,
220
250
  GenericFilterOps.GTE,
221
251
  GenericFilterOps.LT,
@@ -223,14 +253,59 @@ class NumericFilter(Filter):
223
253
  ]
224
254
 
225
255
  def generate_query_conditions_from_column(self, column: Any) -> Any:
226
- """Generate query conditions for a UUID column.
256
+ """Generate query conditions for a numeric column.
227
257
 
228
258
  Args:
229
- column: The UUID column of an SQLModel table on which to filter.
259
+ column: The numeric column of an SQLModel table on which to filter.
260
+
261
+ Returns:
262
+ A list of query conditions.
263
+ """
264
+ if self.operation == GenericFilterOps.GTE:
265
+ return column >= self.value
266
+ if self.operation == GenericFilterOps.GT:
267
+ return column > self.value
268
+ if self.operation == GenericFilterOps.LTE:
269
+ return column <= self.value
270
+ if self.operation == GenericFilterOps.LT:
271
+ return column < self.value
272
+ if self.operation == GenericFilterOps.NOT_EQUALS:
273
+ return column != self.value
274
+ return column == self.value
275
+
276
+
277
+ class DatetimeFilter(Filter):
278
+ """Filter for all datetime fields."""
279
+
280
+ value: Union[datetime, Tuple[datetime, datetime]] = Field(
281
+ union_mode="left_to_right"
282
+ )
283
+
284
+ ALLOWED_OPS: ClassVar[List[str]] = [
285
+ GenericFilterOps.EQUALS,
286
+ GenericFilterOps.NOT_EQUALS,
287
+ GenericFilterOps.GT,
288
+ GenericFilterOps.GTE,
289
+ GenericFilterOps.LT,
290
+ GenericFilterOps.LTE,
291
+ GenericFilterOps.IN,
292
+ ]
293
+
294
+ def generate_query_conditions_from_column(self, column: Any) -> Any:
295
+ """Generate query conditions for a datetime column.
296
+
297
+ Args:
298
+ column: The datetime column of an SQLModel table on which to filter.
230
299
 
231
300
  Returns:
232
301
  A list of query conditions.
233
302
  """
303
+ if self.operation == GenericFilterOps.IN:
304
+ assert isinstance(self.value, tuple)
305
+ lower_bound, upper_bound = self.value
306
+ return column.between(lower_bound, upper_bound)
307
+
308
+ assert isinstance(self.value, datetime)
234
309
  if self.operation == GenericFilterOps.GTE:
235
310
  return column >= self.value
236
311
  if self.operation == GenericFilterOps.GT:
@@ -239,6 +314,8 @@ class NumericFilter(Filter):
239
314
  return column <= self.value
240
315
  if self.operation == GenericFilterOps.LT:
241
316
  return column < self.value
317
+ if self.operation == GenericFilterOps.NOT_EQUALS:
318
+ return column != self.value
242
319
  return column == self.value
243
320
 
244
321
 
@@ -490,7 +567,7 @@ class BaseFilter(BaseModel):
490
567
  value, operator = cls._resolve_operator(value)
491
568
 
492
569
  # Define the filter
493
- filter = cls._define_filter(
570
+ filter = FilterGenerator(cls).define_filter(
494
571
  column=key, value=value, operator=operator
495
572
  )
496
573
  list_of_filters.append(filter)
@@ -523,9 +600,185 @@ class BaseFilter(BaseModel):
523
600
  operator = GenericFilterOps(split_value[0])
524
601
  return value, operator
525
602
 
526
- @classmethod
527
- def _define_filter(
528
- cls, column: str, value: Any, operator: GenericFilterOps
603
+ def generate_name_or_id_query_conditions(
604
+ self,
605
+ value: Union[UUID, str],
606
+ table: Type["NamedSchema"],
607
+ ) -> "ColumnElement[bool]":
608
+ """Generate filter conditions for name or id of a table.
609
+
610
+ Args:
611
+ value: The filter value.
612
+ table: The table to filter.
613
+
614
+ Returns:
615
+ The query conditions.
616
+ """
617
+ from sqlmodel import or_
618
+
619
+ value, operator = BaseFilter._resolve_operator(value)
620
+ value = str(value)
621
+
622
+ conditions = []
623
+
624
+ try:
625
+ filter_ = FilterGenerator(table).define_filter(
626
+ column="id", value=value, operator=operator
627
+ )
628
+ conditions.append(filter_.generate_query_conditions(table=table))
629
+ except ValueError:
630
+ # UUID filter with equal operators and no full UUID fail with
631
+ # a ValueError. In this case, we already know that the filter
632
+ # will not produce any result and can simply ignore it.
633
+ pass
634
+
635
+ filter_ = FilterGenerator(table).define_filter(
636
+ column="name", value=value, operator=operator
637
+ )
638
+ conditions.append(filter_.generate_query_conditions(table=table))
639
+
640
+ return or_(*conditions)
641
+
642
+ def generate_custom_query_conditions_for_column(
643
+ self,
644
+ value: Any,
645
+ table: Type[SQLModel],
646
+ column: str,
647
+ ) -> "ColumnElement[bool]":
648
+ """Generate custom filter conditions for a column of a table.
649
+
650
+ Args:
651
+ value: The filter value.
652
+ table: The table which contains the column.
653
+ column: The column name.
654
+
655
+ Returns:
656
+ The query conditions.
657
+ """
658
+ value, operator = BaseFilter._resolve_operator(value)
659
+ filter_ = FilterGenerator(table).define_filter(
660
+ column=column, value=value, operator=operator
661
+ )
662
+ return filter_.generate_query_conditions(table=table)
663
+
664
+ @property
665
+ def offset(self) -> int:
666
+ """Returns the offset needed for the query on the data persistence layer.
667
+
668
+ Returns:
669
+ The offset for the query.
670
+ """
671
+ return self.size * (self.page - 1)
672
+
673
+ def generate_filter(
674
+ self, table: Type[SQLModel]
675
+ ) -> Union["ColumnElement[bool]"]:
676
+ """Generate the filter for the query.
677
+
678
+ Args:
679
+ table: The Table that is being queried from.
680
+
681
+ Returns:
682
+ The filter expression for the query.
683
+
684
+ Raises:
685
+ RuntimeError: If a valid logical operator is not supplied.
686
+ """
687
+ from sqlmodel import and_, or_
688
+
689
+ filters = []
690
+ for column_filter in self.list_of_filters:
691
+ filters.append(
692
+ column_filter.generate_query_conditions(table=table)
693
+ )
694
+ for custom_filter in self.get_custom_filters():
695
+ filters.append(custom_filter)
696
+ if self.logical_operator == LogicalOperators.OR:
697
+ return or_(False, *filters)
698
+ elif self.logical_operator == LogicalOperators.AND:
699
+ return and_(True, *filters)
700
+ else:
701
+ raise RuntimeError("No valid logical operator was supplied.")
702
+
703
+ def get_custom_filters(self) -> List["ColumnElement[bool]"]:
704
+ """Get custom filters.
705
+
706
+ This can be overridden by subclasses to define custom filters that are
707
+ not based on the columns of the underlying table.
708
+
709
+ Returns:
710
+ A list of custom filters.
711
+ """
712
+ return []
713
+
714
+ def apply_filter(
715
+ self,
716
+ query: AnyQuery,
717
+ table: Type["AnySchema"],
718
+ ) -> AnyQuery:
719
+ """Applies the filter to a query.
720
+
721
+ Args:
722
+ query: The query to which to apply the filter.
723
+ table: The query table.
724
+
725
+ Returns:
726
+ The query with filter applied.
727
+ """
728
+ rbac_filter = self.generate_rbac_filter(table=table)
729
+
730
+ if rbac_filter is not None:
731
+ query = query.where(rbac_filter)
732
+
733
+ filters = self.generate_filter(table=table)
734
+
735
+ if filters is not None:
736
+ query = query.where(filters)
737
+
738
+ return query
739
+
740
+ def apply_sorting(
741
+ self,
742
+ query: AnyQuery,
743
+ table: Type["AnySchema"],
744
+ ) -> AnyQuery:
745
+ """Apply sorting to the query.
746
+
747
+ Args:
748
+ query: The query to which to apply the sorting.
749
+ table: The query table.
750
+
751
+ Returns:
752
+ The query with sorting applied.
753
+ """
754
+ column, operand = self.sorting_params
755
+
756
+ if operand == SorterOps.DESCENDING:
757
+ sort_clause = desc(getattr(table, column)) # type: ignore[var-annotated]
758
+ else:
759
+ sort_clause = asc(getattr(table, column))
760
+
761
+ # We always add the `id` column as a tiebreaker to ensure a stable,
762
+ # repeatable order of items, otherwise subsequent pages might contain
763
+ # the same items.
764
+ query = query.order_by(sort_clause, asc(table.id)) # type: ignore[arg-type]
765
+
766
+ return query
767
+
768
+
769
+ class FilterGenerator:
770
+ """Helper class to define filters for a class."""
771
+
772
+ def __init__(self, model_class: Type[BaseModel]) -> None:
773
+ """Initialize the object.
774
+
775
+ Args:
776
+ model_class: The model class for which to define filters.
777
+ """
778
+ self._model_class = model_class
779
+
780
+ def define_filter(
781
+ self, column: str, value: Any, operator: GenericFilterOps
529
782
  ) -> Filter:
530
783
  """Define a filter for a given column.
531
784
 
@@ -538,23 +791,23 @@ class BaseFilter(BaseModel):
538
791
  A Filter object.
539
792
  """
540
793
  # Create datetime filters
541
- if cls.is_datetime_field(column):
542
- return cls._define_datetime_filter(
794
+ if self.is_datetime_field(column):
795
+ return self._define_datetime_filter(
543
796
  column=column,
544
797
  value=value,
545
798
  operator=operator,
546
799
  )
547
800
 
548
801
  # Create UUID filters
549
- if cls.is_uuid_field(column):
550
- return cls._define_uuid_filter(
802
+ if self.is_uuid_field(column):
803
+ return self._define_uuid_filter(
551
804
  column=column,
552
805
  value=value,
553
806
  operator=operator,
554
807
  )
555
808
 
556
809
  # Create int filters
557
- if cls.is_int_field(column):
810
+ if self.is_int_field(column):
558
811
  return NumericFilter(
559
812
  operation=GenericFilterOps(operator),
560
813
  column=column,
@@ -562,15 +815,15 @@ class BaseFilter(BaseModel):
562
815
  )
563
816
 
564
817
  # Create bool filters
565
- if cls.is_bool_field(column):
566
- return cls._define_bool_filter(
818
+ if self.is_bool_field(column):
819
+ return self._define_bool_filter(
567
820
  column=column,
568
821
  value=value,
569
822
  operator=operator,
570
823
  )
571
824
 
572
825
  # Create str filters
573
- if cls.is_str_field(column):
826
+ if self.is_str_field(column):
574
827
  return StrFilter(
575
828
  operation=GenericFilterOps(operator),
576
829
  column=column,
@@ -579,8 +832,8 @@ class BaseFilter(BaseModel):
579
832
 
580
833
  # Handle unsupported datatypes
581
834
  logger.warning(
582
- f"The Datatype {cls.model_fields[column].annotation} might not be "
583
- "supported for filtering. Defaulting to a string filter."
835
+ f"The Datatype {self._model_class.model_fields[column].annotation} might "
836
+ "not be supported for filtering. Defaulting to a string filter."
584
837
  )
585
838
  return StrFilter(
586
839
  operation=GenericFilterOps(operator),
@@ -588,8 +841,7 @@ class BaseFilter(BaseModel):
588
841
  value=str(value),
589
842
  )
590
843
 
591
- @classmethod
592
- def check_field_annotation(cls, k: str, type_: Any) -> bool:
844
+ def check_field_annotation(self, k: str, type_: Any) -> bool:
593
845
  """Checks whether a model field has a certain annotation.
594
846
 
595
847
  Args:
@@ -604,7 +856,7 @@ class BaseFilter(BaseModel):
604
856
  otherwise.
605
857
  """
606
858
  try:
607
- annotation = cls.model_fields[k].annotation
859
+ annotation = self._model_class.model_fields[k].annotation
608
860
 
609
861
  if annotation is not None:
610
862
  return (
@@ -613,14 +865,13 @@ class BaseFilter(BaseModel):
613
865
  )
614
866
  else:
615
867
  raise ValueError(
616
- f"The field '{k}' inside the model {cls.__name__} "
868
+ f"The field '{k}' inside the model {self._model_class.__name__} "
617
869
  "does not have an annotation."
618
870
  )
619
871
  except TypeError:
620
872
  return False
621
873
 
622
- @classmethod
623
- def is_datetime_field(cls, k: str) -> bool:
874
+ def is_datetime_field(self, k: str) -> bool:
624
875
  """Checks if it's a datetime field.
625
876
 
626
877
  Args:
@@ -629,10 +880,9 @@ class BaseFilter(BaseModel):
629
880
  Returns:
630
881
  True if the field is a datetime field, False otherwise.
631
882
  """
632
- return cls.check_field_annotation(k=k, type_=datetime)
883
+ return self.check_field_annotation(k=k, type_=datetime)
633
884
 
634
- @classmethod
635
- def is_uuid_field(cls, k: str) -> bool:
885
+ def is_uuid_field(self, k: str) -> bool:
636
886
  """Checks if it's a UUID field.
637
887
 
638
888
  Args:
@@ -641,10 +891,9 @@ class BaseFilter(BaseModel):
641
891
  Returns:
642
892
  True if the field is a UUID field, False otherwise.
643
893
  """
644
- return cls.check_field_annotation(k=k, type_=UUID)
894
+ return self.check_field_annotation(k=k, type_=UUID)
645
895
 
646
- @classmethod
647
- def is_int_field(cls, k: str) -> bool:
896
+ def is_int_field(self, k: str) -> bool:
648
897
  """Checks if it's an int field.
649
898
 
650
899
  Args:
@@ -653,10 +902,9 @@ class BaseFilter(BaseModel):
653
902
  Returns:
654
903
  True if the field is an int field, False otherwise.
655
904
  """
656
- return cls.check_field_annotation(k=k, type_=int)
905
+ return self.check_field_annotation(k=k, type_=int)
657
906
 
658
- @classmethod
659
- def is_bool_field(cls, k: str) -> bool:
907
+ def is_bool_field(self, k: str) -> bool:
660
908
  """Checks if it's a bool field.
661
909
 
662
910
  Args:
@@ -665,10 +913,9 @@ class BaseFilter(BaseModel):
665
913
  Returns:
666
914
  True if the field is a bool field, False otherwise.
667
915
  """
668
- return cls.check_field_annotation(k=k, type_=bool)
916
+ return self.check_field_annotation(k=k, type_=bool)
669
917
 
670
- @classmethod
671
- def is_str_field(cls, k: str) -> bool:
918
+ def is_str_field(self, k: str) -> bool:
672
919
  """Checks if it's a string field.
673
920
 
674
921
  Args:
@@ -677,10 +924,9 @@ class BaseFilter(BaseModel):
677
924
  Returns:
678
925
  True if the field is a string field, False otherwise.
679
926
  """
680
- return cls.check_field_annotation(k=k, type_=str)
927
+ return self.check_field_annotation(k=k, type_=str)
681
928
 
682
- @classmethod
683
- def is_sort_by_field(cls, k: str) -> bool:
929
+ def is_sort_by_field(self, k: str) -> bool:
684
930
  """Checks if it's a sort by field.
685
931
 
686
932
  Args:
@@ -689,12 +935,12 @@ class BaseFilter(BaseModel):
689
935
  Returns:
690
936
  True if the field is a sort by field, False otherwise.
691
937
  """
692
- return cls.check_field_annotation(k=k, type_=str) and k == "sort_by"
938
+ return self.check_field_annotation(k=k, type_=str) and k == "sort_by"
693
939
 
694
940
  @staticmethod
695
941
  def _define_datetime_filter(
696
942
  column: str, value: Any, operator: GenericFilterOps
697
- ) -> NumericFilter:
943
+ ) -> DatetimeFilter:
698
944
  """Define a datetime filter for a given column.
699
945
 
700
946
  Args:
@@ -709,10 +955,17 @@ class BaseFilter(BaseModel):
709
955
  ValueError: If the value is not a valid datetime.
710
956
  """
711
957
  try:
958
+ filter_value: Union[datetime, Tuple[datetime, datetime]]
712
959
  if isinstance(value, datetime):
713
- datetime_value = value
960
+ filter_value = value
961
+ elif "," in value:
962
+ lower_bound, upper_bound = value.split(",", 1)
963
+ filter_value = (
964
+ datetime.strptime(lower_bound, FILTERING_DATETIME_FORMAT),
965
+ datetime.strptime(upper_bound, FILTERING_DATETIME_FORMAT),
966
+ )
714
967
  else:
715
- datetime_value = datetime.strptime(
968
+ filter_value = datetime.strptime(
716
969
  value, FILTERING_DATETIME_FORMAT
717
970
  )
718
971
  except ValueError as e:
@@ -720,10 +973,27 @@ class BaseFilter(BaseModel):
720
973
  "The datetime filter only works with values in the following "
721
974
  f"format: {FILTERING_DATETIME_FORMAT}"
722
975
  ) from e
723
- datetime_filter = NumericFilter(
976
+
977
+ if operator == GenericFilterOps.IN and not isinstance(
978
+ filter_value, tuple
979
+ ):
980
+ raise ValueError(
981
+ "Two comma separated datetime values are required for the `in` "
982
+ "operator."
983
+ )
984
+
985
+ if operator != GenericFilterOps.IN and not isinstance(
986
+ filter_value, datetime
987
+ ):
988
+ raise ValueError(
989
+ "Only a single datetime value is allowed for operator "
990
+ f"{operator}."
991
+ )
992
+
993
+ datetime_filter = DatetimeFilter(
724
994
  operation=GenericFilterOps(operator),
725
995
  column=column,
726
- value=datetime_value,
996
+ value=filter_value,
727
997
  )
728
998
  return datetime_filter
729
999
 
@@ -789,107 +1059,3 @@ class BaseFilter(BaseModel):
789
1059
  column=column,
790
1060
  value=bool(value),
791
1061
  )
792
-
793
- @property
794
- def offset(self) -> int:
795
- """Returns the offset needed for the query on the data persistence layer.
796
-
797
- Returns:
798
- The offset for the query.
799
- """
800
- return self.size * (self.page - 1)
801
-
802
- def generate_filter(
803
- self, table: Type[SQLModel]
804
- ) -> Union["ColumnElement[bool]"]:
805
- """Generate the filter for the query.
806
-
807
- Args:
808
- table: The Table that is being queried from.
809
-
810
- Returns:
811
- The filter expression for the query.
812
-
813
- Raises:
814
- RuntimeError: If a valid logical operator is not supplied.
815
- """
816
- from sqlmodel import and_, or_
817
-
818
- filters = []
819
- for column_filter in self.list_of_filters:
820
- filters.append(
821
- column_filter.generate_query_conditions(table=table)
822
- )
823
- for custom_filter in self.get_custom_filters():
824
- filters.append(custom_filter)
825
- if self.logical_operator == LogicalOperators.OR:
826
- return or_(False, *filters)
827
- elif self.logical_operator == LogicalOperators.AND:
828
- return and_(True, *filters)
829
- else:
830
- raise RuntimeError("No valid logical operator was supplied.")
831
-
832
- def get_custom_filters(self) -> List["ColumnElement[bool]"]:
833
- """Get custom filters.
834
-
835
- This can be overridden by subclasses to define custom filters that are
836
- not based on the columns of the underlying table.
837
-
838
- Returns:
839
- A list of custom filters.
840
- """
841
- return []
842
-
843
- def apply_filter(
844
- self,
845
- query: AnyQuery,
846
- table: Type["AnySchema"],
847
- ) -> AnyQuery:
848
- """Applies the filter to a query.
849
-
850
- Args:
851
- query: The query to which to apply the filter.
852
- table: The query table.
853
-
854
- Returns:
855
- The query with filter applied.
856
- """
857
- rbac_filter = self.generate_rbac_filter(table=table)
858
-
859
- if rbac_filter is not None:
860
- query = query.where(rbac_filter)
861
-
862
- filters = self.generate_filter(table=table)
863
-
864
- if filters is not None:
865
- query = query.where(filters)
866
-
867
- return query
868
-
869
- def apply_sorting(
870
- self,
871
- query: AnyQuery,
872
- table: Type["AnySchema"],
873
- ) -> AnyQuery:
874
- """Apply sorting to the query.
875
-
876
- Args:
877
- query: The query to which to apply the sorting.
878
- table: The query table.
879
-
880
- Returns:
881
- The query with sorting applied.
882
- """
883
- column, operand = self.sorting_params
884
-
885
- if operand == SorterOps.DESCENDING:
886
- sort_clause = desc(getattr(table, column)) # type: ignore[var-annotated]
887
- else:
888
- sort_clause = asc(getattr(table, column))
889
-
890
- # We always add the `id` column as a tiebreaker to ensure a stable,
891
- # repeatable order of items, otherwise subsequent pages might contain
892
- # the same items.
893
- query = query.order_by(sort_clause, asc(table.id)) # type: ignore[arg-type]
894
-
895
- return query
@@ -27,7 +27,6 @@ from typing import (
27
27
  from uuid import UUID
28
28
 
29
29
  from pydantic import Field
30
- from sqlmodel import col
31
30
 
32
31
  from zenml.models.v2.base.base import (
33
32
  BaseDatedResponseBody,
@@ -341,6 +340,10 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter):
341
340
 
342
341
  custom_filters = super().get_custom_filters()
343
342
  if self.tag:
344
- custom_filters.append(col(TagSchema.name) == self.tag)
343
+ custom_filters.append(
344
+ self.generate_custom_query_conditions_for_column(
345
+ value=self.tag, table=TagSchema, column="name"
346
+ )
347
+ )
345
348
 
346
349
  return custom_filters