zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240925__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/__init__.py +7 -0
- zenml/cli/base.py +2 -2
- zenml/cli/pipeline.py +21 -0
- zenml/cli/utils.py +14 -11
- zenml/client.py +68 -3
- zenml/config/step_configurations.py +0 -5
- zenml/constants.py +3 -0
- zenml/enums.py +2 -0
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +370 -115
- zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
- zenml/integrations/lightning/__init__.py +1 -1
- zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
- zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
- zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
- zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
- zenml/models/v2/base/filter.py +315 -149
- zenml/models/v2/base/scoped.py +5 -2
- zenml/models/v2/core/artifact_version.py +69 -8
- zenml/models/v2/core/model.py +43 -6
- zenml/models/v2/core/model_version.py +49 -1
- zenml/models/v2/core/model_version_artifact.py +18 -3
- zenml/models/v2/core/model_version_pipeline_run.py +18 -4
- zenml/models/v2/core/pipeline.py +108 -1
- zenml/models/v2/core/pipeline_run.py +172 -21
- zenml/models/v2/core/run_template.py +53 -1
- zenml/models/v2/core/stack.py +33 -5
- zenml/models/v2/core/step_run.py +7 -0
- zenml/new/pipelines/pipeline.py +4 -0
- zenml/new/pipelines/run_utils.py +4 -1
- zenml/orchestrators/base_orchestrator.py +41 -12
- zenml/stack/stack.py +11 -2
- zenml/utils/env_utils.py +54 -1
- zenml/utils/string_utils.py +50 -0
- zenml/zen_server/cloud_utils.py +33 -8
- zenml/zen_server/routers/runs_endpoints.py +89 -3
- zenml/zen_stores/sql_zen_store.py +1 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/METADATA +8 -1
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/RECORD +45 -45
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240925.dist-info}/entry_points.txt +0 -0
@@ -194,7 +194,7 @@ def get_step(pipeline_name: str, step_name: str) -> "StepRunResponse":
|
|
194
194
|
Raises:
|
195
195
|
RuntimeError: If the step is not found.
|
196
196
|
"""
|
197
|
-
runs = Client().list_pipeline_runs(
|
197
|
+
runs = Client().list_pipeline_runs(pipeline=pipeline_name)
|
198
198
|
if runs.total == 0:
|
199
199
|
raise RuntimeError(
|
200
200
|
f"No pipeline runs for pipeline `{pipeline_name}` were found"
|
zenml/models/v2/base/filter.py
CHANGED
@@ -54,7 +54,7 @@ from zenml.utils.typing_utils import get_args
|
|
54
54
|
if TYPE_CHECKING:
|
55
55
|
from sqlalchemy.sql.elements import ColumnElement
|
56
56
|
|
57
|
-
from zenml.zen_stores.schemas import BaseSchema
|
57
|
+
from zenml.zen_stores.schemas import BaseSchema, NamedSchema
|
58
58
|
|
59
59
|
AnySchema = TypeVar("AnySchema", bound=BaseSchema)
|
60
60
|
|
@@ -142,7 +142,10 @@ class Filter(BaseModel, ABC):
|
|
142
142
|
class BoolFilter(Filter):
|
143
143
|
"""Filter for all Boolean fields."""
|
144
144
|
|
145
|
-
ALLOWED_OPS: ClassVar[List[str]] = [
|
145
|
+
ALLOWED_OPS: ClassVar[List[str]] = [
|
146
|
+
GenericFilterOps.EQUALS,
|
147
|
+
GenericFilterOps.NOT_EQUALS,
|
148
|
+
]
|
146
149
|
|
147
150
|
def generate_query_conditions_from_column(self, column: Any) -> Any:
|
148
151
|
"""Generate query conditions for a boolean column.
|
@@ -153,6 +156,9 @@ class BoolFilter(Filter):
|
|
153
156
|
Returns:
|
154
157
|
A list of query conditions.
|
155
158
|
"""
|
159
|
+
if self.operation == GenericFilterOps.NOT_EQUALS:
|
160
|
+
return column != self.value
|
161
|
+
|
156
162
|
return column == self.value
|
157
163
|
|
158
164
|
|
@@ -161,6 +167,7 @@ class StrFilter(Filter):
|
|
161
167
|
|
162
168
|
ALLOWED_OPS: ClassVar[List[str]] = [
|
163
169
|
GenericFilterOps.EQUALS,
|
170
|
+
GenericFilterOps.NOT_EQUALS,
|
164
171
|
GenericFilterOps.STARTSWITH,
|
165
172
|
GenericFilterOps.CONTAINS,
|
166
173
|
GenericFilterOps.ENDSWITH,
|
@@ -181,12 +188,31 @@ class StrFilter(Filter):
|
|
181
188
|
return column.startswith(f"{self.value}")
|
182
189
|
if self.operation == GenericFilterOps.ENDSWITH:
|
183
190
|
return column.endswith(f"{self.value}")
|
191
|
+
if self.operation == GenericFilterOps.NOT_EQUALS:
|
192
|
+
return column != self.value
|
193
|
+
|
184
194
|
return column == self.value
|
185
195
|
|
186
196
|
|
187
197
|
class UUIDFilter(StrFilter):
|
188
198
|
"""Filter for all uuid fields which are mostly treated like strings."""
|
189
199
|
|
200
|
+
@field_validator("value", mode="before")
|
201
|
+
@classmethod
|
202
|
+
def _remove_hyphens_from_value(cls, value: Any) -> Any:
|
203
|
+
"""Remove hyphens from the value to enable string comparisons.
|
204
|
+
|
205
|
+
Args:
|
206
|
+
value: The filter value.
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
The filter value with removed hyphens.
|
210
|
+
"""
|
211
|
+
if isinstance(value, str):
|
212
|
+
return value.replace("-", "")
|
213
|
+
|
214
|
+
return value
|
215
|
+
|
190
216
|
def generate_query_conditions_from_column(self, column: Any) -> Any:
|
191
217
|
"""Generate query conditions for a UUID column.
|
192
218
|
|
@@ -203,6 +229,9 @@ class UUIDFilter(StrFilter):
|
|
203
229
|
if self.operation == GenericFilterOps.EQUALS:
|
204
230
|
return column == self.value
|
205
231
|
|
232
|
+
if self.operation == GenericFilterOps.NOT_EQUALS:
|
233
|
+
return column != self.value
|
234
|
+
|
206
235
|
# For all other operations, cast and handle the column as string
|
207
236
|
return super().generate_query_conditions_from_column(
|
208
237
|
column=cast_if(column, sqlalchemy.String)
|
@@ -216,6 +245,7 @@ class NumericFilter(Filter):
|
|
216
245
|
|
217
246
|
ALLOWED_OPS: ClassVar[List[str]] = [
|
218
247
|
GenericFilterOps.EQUALS,
|
248
|
+
GenericFilterOps.NOT_EQUALS,
|
219
249
|
GenericFilterOps.GT,
|
220
250
|
GenericFilterOps.GTE,
|
221
251
|
GenericFilterOps.LT,
|
@@ -223,14 +253,59 @@ class NumericFilter(Filter):
|
|
223
253
|
]
|
224
254
|
|
225
255
|
def generate_query_conditions_from_column(self, column: Any) -> Any:
|
226
|
-
"""Generate query conditions for a
|
256
|
+
"""Generate query conditions for a numeric column.
|
227
257
|
|
228
258
|
Args:
|
229
|
-
column: The
|
259
|
+
column: The numeric column of an SQLModel table on which to filter.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
A list of query conditions.
|
263
|
+
"""
|
264
|
+
if self.operation == GenericFilterOps.GTE:
|
265
|
+
return column >= self.value
|
266
|
+
if self.operation == GenericFilterOps.GT:
|
267
|
+
return column > self.value
|
268
|
+
if self.operation == GenericFilterOps.LTE:
|
269
|
+
return column <= self.value
|
270
|
+
if self.operation == GenericFilterOps.LT:
|
271
|
+
return column < self.value
|
272
|
+
if self.operation == GenericFilterOps.NOT_EQUALS:
|
273
|
+
return column != self.value
|
274
|
+
return column == self.value
|
275
|
+
|
276
|
+
|
277
|
+
class DatetimeFilter(Filter):
|
278
|
+
"""Filter for all datetime fields."""
|
279
|
+
|
280
|
+
value: Union[datetime, Tuple[datetime, datetime]] = Field(
|
281
|
+
union_mode="left_to_right"
|
282
|
+
)
|
283
|
+
|
284
|
+
ALLOWED_OPS: ClassVar[List[str]] = [
|
285
|
+
GenericFilterOps.EQUALS,
|
286
|
+
GenericFilterOps.NOT_EQUALS,
|
287
|
+
GenericFilterOps.GT,
|
288
|
+
GenericFilterOps.GTE,
|
289
|
+
GenericFilterOps.LT,
|
290
|
+
GenericFilterOps.LTE,
|
291
|
+
GenericFilterOps.IN,
|
292
|
+
]
|
293
|
+
|
294
|
+
def generate_query_conditions_from_column(self, column: Any) -> Any:
|
295
|
+
"""Generate query conditions for a datetime column.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
column: The datetime column of an SQLModel table on which to filter.
|
230
299
|
|
231
300
|
Returns:
|
232
301
|
A list of query conditions.
|
233
302
|
"""
|
303
|
+
if self.operation == GenericFilterOps.IN:
|
304
|
+
assert isinstance(self.value, tuple)
|
305
|
+
lower_bound, upper_bound = self.value
|
306
|
+
return column.between(lower_bound, upper_bound)
|
307
|
+
|
308
|
+
assert isinstance(self.value, datetime)
|
234
309
|
if self.operation == GenericFilterOps.GTE:
|
235
310
|
return column >= self.value
|
236
311
|
if self.operation == GenericFilterOps.GT:
|
@@ -239,6 +314,8 @@ class NumericFilter(Filter):
|
|
239
314
|
return column <= self.value
|
240
315
|
if self.operation == GenericFilterOps.LT:
|
241
316
|
return column < self.value
|
317
|
+
if self.operation == GenericFilterOps.NOT_EQUALS:
|
318
|
+
return column != self.value
|
242
319
|
return column == self.value
|
243
320
|
|
244
321
|
|
@@ -490,7 +567,7 @@ class BaseFilter(BaseModel):
|
|
490
567
|
value, operator = cls._resolve_operator(value)
|
491
568
|
|
492
569
|
# Define the filter
|
493
|
-
filter = cls.
|
570
|
+
filter = FilterGenerator(cls).define_filter(
|
494
571
|
column=key, value=value, operator=operator
|
495
572
|
)
|
496
573
|
list_of_filters.append(filter)
|
@@ -523,9 +600,185 @@ class BaseFilter(BaseModel):
|
|
523
600
|
operator = GenericFilterOps(split_value[0])
|
524
601
|
return value, operator
|
525
602
|
|
526
|
-
|
527
|
-
|
528
|
-
|
603
|
+
def generate_name_or_id_query_conditions(
|
604
|
+
self,
|
605
|
+
value: Union[UUID, str],
|
606
|
+
table: Type["NamedSchema"],
|
607
|
+
) -> "ColumnElement[bool]":
|
608
|
+
"""Generate filter conditions for name or id of a table.
|
609
|
+
|
610
|
+
Args:
|
611
|
+
value: The filter value.
|
612
|
+
table: The table to filter.
|
613
|
+
|
614
|
+
Returns:
|
615
|
+
The query conditions.
|
616
|
+
"""
|
617
|
+
from sqlmodel import or_
|
618
|
+
|
619
|
+
value, operator = BaseFilter._resolve_operator(value)
|
620
|
+
value = str(value)
|
621
|
+
|
622
|
+
conditions = []
|
623
|
+
|
624
|
+
try:
|
625
|
+
filter_ = FilterGenerator(table).define_filter(
|
626
|
+
column="id", value=value, operator=operator
|
627
|
+
)
|
628
|
+
conditions.append(filter_.generate_query_conditions(table=table))
|
629
|
+
except ValueError:
|
630
|
+
# UUID filter with equal operators and no full UUID fail with
|
631
|
+
# a ValueError. In this case, we already know that the filter
|
632
|
+
# will not produce any result and can simply ignore it.
|
633
|
+
pass
|
634
|
+
|
635
|
+
filter_ = FilterGenerator(table).define_filter(
|
636
|
+
column="name", value=value, operator=operator
|
637
|
+
)
|
638
|
+
conditions.append(filter_.generate_query_conditions(table=table))
|
639
|
+
|
640
|
+
return or_(*conditions)
|
641
|
+
|
642
|
+
def generate_custom_query_conditions_for_column(
|
643
|
+
self,
|
644
|
+
value: Any,
|
645
|
+
table: Type[SQLModel],
|
646
|
+
column: str,
|
647
|
+
) -> "ColumnElement[bool]":
|
648
|
+
"""Generate custom filter conditions for a column of a table.
|
649
|
+
|
650
|
+
Args:
|
651
|
+
value: The filter value.
|
652
|
+
table: The table which contains the column.
|
653
|
+
column: The column name.
|
654
|
+
|
655
|
+
Returns:
|
656
|
+
The query conditions.
|
657
|
+
"""
|
658
|
+
value, operator = BaseFilter._resolve_operator(value)
|
659
|
+
filter_ = FilterGenerator(table).define_filter(
|
660
|
+
column=column, value=value, operator=operator
|
661
|
+
)
|
662
|
+
return filter_.generate_query_conditions(table=table)
|
663
|
+
|
664
|
+
@property
|
665
|
+
def offset(self) -> int:
|
666
|
+
"""Returns the offset needed for the query on the data persistence layer.
|
667
|
+
|
668
|
+
Returns:
|
669
|
+
The offset for the query.
|
670
|
+
"""
|
671
|
+
return self.size * (self.page - 1)
|
672
|
+
|
673
|
+
def generate_filter(
|
674
|
+
self, table: Type[SQLModel]
|
675
|
+
) -> Union["ColumnElement[bool]"]:
|
676
|
+
"""Generate the filter for the query.
|
677
|
+
|
678
|
+
Args:
|
679
|
+
table: The Table that is being queried from.
|
680
|
+
|
681
|
+
Returns:
|
682
|
+
The filter expression for the query.
|
683
|
+
|
684
|
+
Raises:
|
685
|
+
RuntimeError: If a valid logical operator is not supplied.
|
686
|
+
"""
|
687
|
+
from sqlmodel import and_, or_
|
688
|
+
|
689
|
+
filters = []
|
690
|
+
for column_filter in self.list_of_filters:
|
691
|
+
filters.append(
|
692
|
+
column_filter.generate_query_conditions(table=table)
|
693
|
+
)
|
694
|
+
for custom_filter in self.get_custom_filters():
|
695
|
+
filters.append(custom_filter)
|
696
|
+
if self.logical_operator == LogicalOperators.OR:
|
697
|
+
return or_(False, *filters)
|
698
|
+
elif self.logical_operator == LogicalOperators.AND:
|
699
|
+
return and_(True, *filters)
|
700
|
+
else:
|
701
|
+
raise RuntimeError("No valid logical operator was supplied.")
|
702
|
+
|
703
|
+
def get_custom_filters(self) -> List["ColumnElement[bool]"]:
|
704
|
+
"""Get custom filters.
|
705
|
+
|
706
|
+
This can be overridden by subclasses to define custom filters that are
|
707
|
+
not based on the columns of the underlying table.
|
708
|
+
|
709
|
+
Returns:
|
710
|
+
A list of custom filters.
|
711
|
+
"""
|
712
|
+
return []
|
713
|
+
|
714
|
+
def apply_filter(
|
715
|
+
self,
|
716
|
+
query: AnyQuery,
|
717
|
+
table: Type["AnySchema"],
|
718
|
+
) -> AnyQuery:
|
719
|
+
"""Applies the filter to a query.
|
720
|
+
|
721
|
+
Args:
|
722
|
+
query: The query to which to apply the filter.
|
723
|
+
table: The query table.
|
724
|
+
|
725
|
+
Returns:
|
726
|
+
The query with filter applied.
|
727
|
+
"""
|
728
|
+
rbac_filter = self.generate_rbac_filter(table=table)
|
729
|
+
|
730
|
+
if rbac_filter is not None:
|
731
|
+
query = query.where(rbac_filter)
|
732
|
+
|
733
|
+
filters = self.generate_filter(table=table)
|
734
|
+
|
735
|
+
if filters is not None:
|
736
|
+
query = query.where(filters)
|
737
|
+
|
738
|
+
return query
|
739
|
+
|
740
|
+
def apply_sorting(
|
741
|
+
self,
|
742
|
+
query: AnyQuery,
|
743
|
+
table: Type["AnySchema"],
|
744
|
+
) -> AnyQuery:
|
745
|
+
"""Apply sorting to the query.
|
746
|
+
|
747
|
+
Args:
|
748
|
+
query: The query to which to apply the sorting.
|
749
|
+
table: The query table.
|
750
|
+
|
751
|
+
Returns:
|
752
|
+
The query with sorting applied.
|
753
|
+
"""
|
754
|
+
column, operand = self.sorting_params
|
755
|
+
|
756
|
+
if operand == SorterOps.DESCENDING:
|
757
|
+
sort_clause = desc(getattr(table, column)) # type: ignore[var-annotated]
|
758
|
+
else:
|
759
|
+
sort_clause = asc(getattr(table, column))
|
760
|
+
|
761
|
+
# We always add the `id` column as a tiebreaker to ensure a stable,
|
762
|
+
# repeatable order of items, otherwise subsequent pages might contain
|
763
|
+
# the same items.
|
764
|
+
query = query.order_by(sort_clause, asc(table.id)) # type: ignore[arg-type]
|
765
|
+
|
766
|
+
return query
|
767
|
+
|
768
|
+
|
769
|
+
class FilterGenerator:
|
770
|
+
"""Helper class to define filters for a class."""
|
771
|
+
|
772
|
+
def __init__(self, model_class: Type[BaseModel]) -> None:
|
773
|
+
"""Initialize the object.
|
774
|
+
|
775
|
+
Args:
|
776
|
+
model_class: The model class for which to define filters.
|
777
|
+
"""
|
778
|
+
self._model_class = model_class
|
779
|
+
|
780
|
+
def define_filter(
|
781
|
+
self, column: str, value: Any, operator: GenericFilterOps
|
529
782
|
) -> Filter:
|
530
783
|
"""Define a filter for a given column.
|
531
784
|
|
@@ -538,23 +791,23 @@ class BaseFilter(BaseModel):
|
|
538
791
|
A Filter object.
|
539
792
|
"""
|
540
793
|
# Create datetime filters
|
541
|
-
if
|
542
|
-
return
|
794
|
+
if self.is_datetime_field(column):
|
795
|
+
return self._define_datetime_filter(
|
543
796
|
column=column,
|
544
797
|
value=value,
|
545
798
|
operator=operator,
|
546
799
|
)
|
547
800
|
|
548
801
|
# Create UUID filters
|
549
|
-
if
|
550
|
-
return
|
802
|
+
if self.is_uuid_field(column):
|
803
|
+
return self._define_uuid_filter(
|
551
804
|
column=column,
|
552
805
|
value=value,
|
553
806
|
operator=operator,
|
554
807
|
)
|
555
808
|
|
556
809
|
# Create int filters
|
557
|
-
if
|
810
|
+
if self.is_int_field(column):
|
558
811
|
return NumericFilter(
|
559
812
|
operation=GenericFilterOps(operator),
|
560
813
|
column=column,
|
@@ -562,15 +815,15 @@ class BaseFilter(BaseModel):
|
|
562
815
|
)
|
563
816
|
|
564
817
|
# Create bool filters
|
565
|
-
if
|
566
|
-
return
|
818
|
+
if self.is_bool_field(column):
|
819
|
+
return self._define_bool_filter(
|
567
820
|
column=column,
|
568
821
|
value=value,
|
569
822
|
operator=operator,
|
570
823
|
)
|
571
824
|
|
572
825
|
# Create str filters
|
573
|
-
if
|
826
|
+
if self.is_str_field(column):
|
574
827
|
return StrFilter(
|
575
828
|
operation=GenericFilterOps(operator),
|
576
829
|
column=column,
|
@@ -579,8 +832,8 @@ class BaseFilter(BaseModel):
|
|
579
832
|
|
580
833
|
# Handle unsupported datatypes
|
581
834
|
logger.warning(
|
582
|
-
f"The Datatype {
|
583
|
-
"supported for filtering. Defaulting to a string filter."
|
835
|
+
f"The Datatype {self._model_class.model_fields[column].annotation} might "
|
836
|
+
"not be supported for filtering. Defaulting to a string filter."
|
584
837
|
)
|
585
838
|
return StrFilter(
|
586
839
|
operation=GenericFilterOps(operator),
|
@@ -588,8 +841,7 @@ class BaseFilter(BaseModel):
|
|
588
841
|
value=str(value),
|
589
842
|
)
|
590
843
|
|
591
|
-
|
592
|
-
def check_field_annotation(cls, k: str, type_: Any) -> bool:
|
844
|
+
def check_field_annotation(self, k: str, type_: Any) -> bool:
|
593
845
|
"""Checks whether a model field has a certain annotation.
|
594
846
|
|
595
847
|
Args:
|
@@ -604,7 +856,7 @@ class BaseFilter(BaseModel):
|
|
604
856
|
otherwise.
|
605
857
|
"""
|
606
858
|
try:
|
607
|
-
annotation =
|
859
|
+
annotation = self._model_class.model_fields[k].annotation
|
608
860
|
|
609
861
|
if annotation is not None:
|
610
862
|
return (
|
@@ -613,14 +865,13 @@ class BaseFilter(BaseModel):
|
|
613
865
|
)
|
614
866
|
else:
|
615
867
|
raise ValueError(
|
616
|
-
f"The field '{k}' inside the model {
|
868
|
+
f"The field '{k}' inside the model {self._model_class.__name__} "
|
617
869
|
"does not have an annotation."
|
618
870
|
)
|
619
871
|
except TypeError:
|
620
872
|
return False
|
621
873
|
|
622
|
-
|
623
|
-
def is_datetime_field(cls, k: str) -> bool:
|
874
|
+
def is_datetime_field(self, k: str) -> bool:
|
624
875
|
"""Checks if it's a datetime field.
|
625
876
|
|
626
877
|
Args:
|
@@ -629,10 +880,9 @@ class BaseFilter(BaseModel):
|
|
629
880
|
Returns:
|
630
881
|
True if the field is a datetime field, False otherwise.
|
631
882
|
"""
|
632
|
-
return
|
883
|
+
return self.check_field_annotation(k=k, type_=datetime)
|
633
884
|
|
634
|
-
|
635
|
-
def is_uuid_field(cls, k: str) -> bool:
|
885
|
+
def is_uuid_field(self, k: str) -> bool:
|
636
886
|
"""Checks if it's a UUID field.
|
637
887
|
|
638
888
|
Args:
|
@@ -641,10 +891,9 @@ class BaseFilter(BaseModel):
|
|
641
891
|
Returns:
|
642
892
|
True if the field is a UUID field, False otherwise.
|
643
893
|
"""
|
644
|
-
return
|
894
|
+
return self.check_field_annotation(k=k, type_=UUID)
|
645
895
|
|
646
|
-
|
647
|
-
def is_int_field(cls, k: str) -> bool:
|
896
|
+
def is_int_field(self, k: str) -> bool:
|
648
897
|
"""Checks if it's an int field.
|
649
898
|
|
650
899
|
Args:
|
@@ -653,10 +902,9 @@ class BaseFilter(BaseModel):
|
|
653
902
|
Returns:
|
654
903
|
True if the field is an int field, False otherwise.
|
655
904
|
"""
|
656
|
-
return
|
905
|
+
return self.check_field_annotation(k=k, type_=int)
|
657
906
|
|
658
|
-
|
659
|
-
def is_bool_field(cls, k: str) -> bool:
|
907
|
+
def is_bool_field(self, k: str) -> bool:
|
660
908
|
"""Checks if it's a bool field.
|
661
909
|
|
662
910
|
Args:
|
@@ -665,10 +913,9 @@ class BaseFilter(BaseModel):
|
|
665
913
|
Returns:
|
666
914
|
True if the field is a bool field, False otherwise.
|
667
915
|
"""
|
668
|
-
return
|
916
|
+
return self.check_field_annotation(k=k, type_=bool)
|
669
917
|
|
670
|
-
|
671
|
-
def is_str_field(cls, k: str) -> bool:
|
918
|
+
def is_str_field(self, k: str) -> bool:
|
672
919
|
"""Checks if it's a string field.
|
673
920
|
|
674
921
|
Args:
|
@@ -677,10 +924,9 @@ class BaseFilter(BaseModel):
|
|
677
924
|
Returns:
|
678
925
|
True if the field is a string field, False otherwise.
|
679
926
|
"""
|
680
|
-
return
|
927
|
+
return self.check_field_annotation(k=k, type_=str)
|
681
928
|
|
682
|
-
|
683
|
-
def is_sort_by_field(cls, k: str) -> bool:
|
929
|
+
def is_sort_by_field(self, k: str) -> bool:
|
684
930
|
"""Checks if it's a sort by field.
|
685
931
|
|
686
932
|
Args:
|
@@ -689,12 +935,12 @@ class BaseFilter(BaseModel):
|
|
689
935
|
Returns:
|
690
936
|
True if the field is a sort by field, False otherwise.
|
691
937
|
"""
|
692
|
-
return
|
938
|
+
return self.check_field_annotation(k=k, type_=str) and k == "sort_by"
|
693
939
|
|
694
940
|
@staticmethod
|
695
941
|
def _define_datetime_filter(
|
696
942
|
column: str, value: Any, operator: GenericFilterOps
|
697
|
-
) ->
|
943
|
+
) -> DatetimeFilter:
|
698
944
|
"""Define a datetime filter for a given column.
|
699
945
|
|
700
946
|
Args:
|
@@ -709,10 +955,17 @@ class BaseFilter(BaseModel):
|
|
709
955
|
ValueError: If the value is not a valid datetime.
|
710
956
|
"""
|
711
957
|
try:
|
958
|
+
filter_value: Union[datetime, Tuple[datetime, datetime]]
|
712
959
|
if isinstance(value, datetime):
|
713
|
-
|
960
|
+
filter_value = value
|
961
|
+
elif "," in value:
|
962
|
+
lower_bound, upper_bound = value.split(",", 1)
|
963
|
+
filter_value = (
|
964
|
+
datetime.strptime(lower_bound, FILTERING_DATETIME_FORMAT),
|
965
|
+
datetime.strptime(upper_bound, FILTERING_DATETIME_FORMAT),
|
966
|
+
)
|
714
967
|
else:
|
715
|
-
|
968
|
+
filter_value = datetime.strptime(
|
716
969
|
value, FILTERING_DATETIME_FORMAT
|
717
970
|
)
|
718
971
|
except ValueError as e:
|
@@ -720,10 +973,27 @@ class BaseFilter(BaseModel):
|
|
720
973
|
"The datetime filter only works with values in the following "
|
721
974
|
f"format: {FILTERING_DATETIME_FORMAT}"
|
722
975
|
) from e
|
723
|
-
|
976
|
+
|
977
|
+
if operator == GenericFilterOps.IN and not isinstance(
|
978
|
+
filter_value, tuple
|
979
|
+
):
|
980
|
+
raise ValueError(
|
981
|
+
"Two comma separated datetime values are required for the `in` "
|
982
|
+
"operator."
|
983
|
+
)
|
984
|
+
|
985
|
+
if operator != GenericFilterOps.IN and not isinstance(
|
986
|
+
filter_value, datetime
|
987
|
+
):
|
988
|
+
raise ValueError(
|
989
|
+
"Only a single datetime value is allowed for operator "
|
990
|
+
f"{operator}."
|
991
|
+
)
|
992
|
+
|
993
|
+
datetime_filter = DatetimeFilter(
|
724
994
|
operation=GenericFilterOps(operator),
|
725
995
|
column=column,
|
726
|
-
value=
|
996
|
+
value=filter_value,
|
727
997
|
)
|
728
998
|
return datetime_filter
|
729
999
|
|
@@ -789,107 +1059,3 @@ class BaseFilter(BaseModel):
|
|
789
1059
|
column=column,
|
790
1060
|
value=bool(value),
|
791
1061
|
)
|
792
|
-
|
793
|
-
@property
|
794
|
-
def offset(self) -> int:
|
795
|
-
"""Returns the offset needed for the query on the data persistence layer.
|
796
|
-
|
797
|
-
Returns:
|
798
|
-
The offset for the query.
|
799
|
-
"""
|
800
|
-
return self.size * (self.page - 1)
|
801
|
-
|
802
|
-
def generate_filter(
|
803
|
-
self, table: Type[SQLModel]
|
804
|
-
) -> Union["ColumnElement[bool]"]:
|
805
|
-
"""Generate the filter for the query.
|
806
|
-
|
807
|
-
Args:
|
808
|
-
table: The Table that is being queried from.
|
809
|
-
|
810
|
-
Returns:
|
811
|
-
The filter expression for the query.
|
812
|
-
|
813
|
-
Raises:
|
814
|
-
RuntimeError: If a valid logical operator is not supplied.
|
815
|
-
"""
|
816
|
-
from sqlmodel import and_, or_
|
817
|
-
|
818
|
-
filters = []
|
819
|
-
for column_filter in self.list_of_filters:
|
820
|
-
filters.append(
|
821
|
-
column_filter.generate_query_conditions(table=table)
|
822
|
-
)
|
823
|
-
for custom_filter in self.get_custom_filters():
|
824
|
-
filters.append(custom_filter)
|
825
|
-
if self.logical_operator == LogicalOperators.OR:
|
826
|
-
return or_(False, *filters)
|
827
|
-
elif self.logical_operator == LogicalOperators.AND:
|
828
|
-
return and_(True, *filters)
|
829
|
-
else:
|
830
|
-
raise RuntimeError("No valid logical operator was supplied.")
|
831
|
-
|
832
|
-
def get_custom_filters(self) -> List["ColumnElement[bool]"]:
|
833
|
-
"""Get custom filters.
|
834
|
-
|
835
|
-
This can be overridden by subclasses to define custom filters that are
|
836
|
-
not based on the columns of the underlying table.
|
837
|
-
|
838
|
-
Returns:
|
839
|
-
A list of custom filters.
|
840
|
-
"""
|
841
|
-
return []
|
842
|
-
|
843
|
-
def apply_filter(
|
844
|
-
self,
|
845
|
-
query: AnyQuery,
|
846
|
-
table: Type["AnySchema"],
|
847
|
-
) -> AnyQuery:
|
848
|
-
"""Applies the filter to a query.
|
849
|
-
|
850
|
-
Args:
|
851
|
-
query: The query to which to apply the filter.
|
852
|
-
table: The query table.
|
853
|
-
|
854
|
-
Returns:
|
855
|
-
The query with filter applied.
|
856
|
-
"""
|
857
|
-
rbac_filter = self.generate_rbac_filter(table=table)
|
858
|
-
|
859
|
-
if rbac_filter is not None:
|
860
|
-
query = query.where(rbac_filter)
|
861
|
-
|
862
|
-
filters = self.generate_filter(table=table)
|
863
|
-
|
864
|
-
if filters is not None:
|
865
|
-
query = query.where(filters)
|
866
|
-
|
867
|
-
return query
|
868
|
-
|
869
|
-
def apply_sorting(
|
870
|
-
self,
|
871
|
-
query: AnyQuery,
|
872
|
-
table: Type["AnySchema"],
|
873
|
-
) -> AnyQuery:
|
874
|
-
"""Apply sorting to the query.
|
875
|
-
|
876
|
-
Args:
|
877
|
-
query: The query to which to apply the sorting.
|
878
|
-
table: The query table.
|
879
|
-
|
880
|
-
Returns:
|
881
|
-
The query with sorting applied.
|
882
|
-
"""
|
883
|
-
column, operand = self.sorting_params
|
884
|
-
|
885
|
-
if operand == SorterOps.DESCENDING:
|
886
|
-
sort_clause = desc(getattr(table, column)) # type: ignore[var-annotated]
|
887
|
-
else:
|
888
|
-
sort_clause = asc(getattr(table, column))
|
889
|
-
|
890
|
-
# We always add the `id` column as a tiebreaker to ensure a stable,
|
891
|
-
# repeatable order of items, otherwise subsequent pages might contain
|
892
|
-
# the same items.
|
893
|
-
query = query.order_by(sort_clause, asc(table.id)) # type: ignore[arg-type]
|
894
|
-
|
895
|
-
return query
|