sqlspec 0.10.1__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

sqlspec/filters.py CHANGED
@@ -1,16 +1,20 @@
1
1
  """Collection filter datastructures."""
2
2
 
3
- from abc import ABC
3
+ from abc import ABC, abstractmethod
4
4
  from collections import abc
5
5
  from dataclasses import dataclass
6
6
  from datetime import datetime
7
- from typing import Generic, Literal, Optional, Protocol, Union
7
+ from typing import Any, Generic, Literal, Optional, Protocol, Union, cast
8
8
 
9
- from typing_extensions import TypeVar
9
+ from sqlglot import exp
10
+ from typing_extensions import TypeAlias, TypeVar
11
+
12
+ from sqlspec.statement import SQLStatement
10
13
 
11
14
  __all__ = (
12
15
  "BeforeAfter",
13
16
  "CollectionFilter",
17
+ "FilterTypes",
14
18
  "InAnyFilter",
15
19
  "LimitOffset",
16
20
  "NotInCollectionFilter",
@@ -20,18 +24,26 @@ __all__ = (
20
24
  "PaginationFilter",
21
25
  "SearchFilter",
22
26
  "StatementFilter",
27
+ "apply_filter",
23
28
  )
24
29
 
25
30
  T = TypeVar("T")
26
- StatementT = TypeVar("StatementT", bound="str")
27
31
 
28
32
 
29
33
  class StatementFilter(Protocol):
30
34
  """Protocol for filters that can be appended to a statement."""
31
35
 
32
- def append_to_statement(self, statement: StatementT) -> StatementT:
33
- """Append the filter to the statement."""
34
- return statement
36
+ @abstractmethod
37
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
38
+ """Append the filter to the statement.
39
+
40
+ Args:
41
+ statement: The SQL statement to modify.
42
+
43
+ Returns:
44
+ The modified statement.
45
+ """
46
+ raise NotImplementedError
35
47
 
36
48
 
37
49
  @dataclass
@@ -45,6 +57,27 @@ class BeforeAfter(StatementFilter):
45
57
  after: Optional[datetime] = None
46
58
  """Filter results where field later than this."""
47
59
 
60
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
61
+ conditions = []
62
+ params: dict[str, Any] = {}
63
+ col_expr = exp.column(self.field_name)
64
+
65
+ if self.before:
66
+ param_name = statement.generate_param_name(f"{self.field_name}_before")
67
+ conditions.append(exp.LT(this=col_expr, expression=exp.Placeholder(this=param_name)))
68
+ params[param_name] = self.before
69
+ if self.after:
70
+ param_name = statement.generate_param_name(f"{self.field_name}_after")
71
+ conditions.append(exp.GT(this=col_expr, expression=exp.Placeholder(this=param_name))) # type: ignore[arg-type]
72
+ params[param_name] = self.after
73
+
74
+ if conditions:
75
+ final_condition = conditions[0]
76
+ for cond in conditions[1:]:
77
+ final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
78
+ statement.add_condition(final_condition, params)
79
+ return statement
80
+
48
81
 
49
82
  @dataclass
50
83
  class OnBeforeAfter(StatementFilter):
@@ -57,13 +90,38 @@ class OnBeforeAfter(StatementFilter):
57
90
  on_or_after: Optional[datetime] = None
58
91
  """Filter results where field on or later than this."""
59
92
 
93
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
94
+ conditions = []
95
+ params: dict[str, Any] = {}
96
+ col_expr = exp.column(self.field_name)
97
+
98
+ if self.on_or_before:
99
+ param_name = statement.generate_param_name(f"{self.field_name}_on_or_before")
100
+ conditions.append(exp.LTE(this=col_expr, expression=exp.Placeholder(this=param_name)))
101
+ params[param_name] = self.on_or_before
102
+ if self.on_or_after:
103
+ param_name = statement.generate_param_name(f"{self.field_name}_on_or_after")
104
+ conditions.append(exp.GTE(this=col_expr, expression=exp.Placeholder(this=param_name))) # type: ignore[arg-type]
105
+ params[param_name] = self.on_or_after
106
+
107
+ if conditions:
108
+ final_condition = conditions[0]
109
+ for cond in conditions[1:]:
110
+ final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
111
+ statement.add_condition(final_condition, params)
112
+ return statement
60
113
 
61
- class InAnyFilter(StatementFilter, ABC):
114
+
115
+ class InAnyFilter(StatementFilter, ABC, Generic[T]):
62
116
  """Subclass for methods that have a `prefer_any` attribute."""
63
117
 
118
+ @abstractmethod
119
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
120
+ raise NotImplementedError
121
+
64
122
 
65
123
  @dataclass
66
- class CollectionFilter(InAnyFilter, Generic[T]):
124
+ class CollectionFilter(InAnyFilter[T]):
67
125
  """Data required to construct a ``WHERE ... IN (...)`` clause."""
68
126
 
69
127
  field_name: str
@@ -73,9 +131,30 @@ class CollectionFilter(InAnyFilter, Generic[T]):
73
131
 
74
132
  An empty list will return an empty result set, however, if ``None``, the filter is not applied to the query, and all rows are returned. """
75
133
 
134
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
135
+ if self.values is None:
136
+ return statement
137
+
138
+ if not self.values: # Empty collection
139
+ # Add a condition that is always false
140
+ statement.add_condition(exp.false())
141
+ return statement
142
+
143
+ placeholder_expressions: list[exp.Placeholder] = []
144
+ current_params: dict[str, Any] = {}
145
+
146
+ for i, value_item in enumerate(self.values):
147
+ param_key = statement.generate_param_name(f"{self.field_name}_in_{i}")
148
+ placeholder_expressions.append(exp.Placeholder(this=param_key))
149
+ current_params[param_key] = value_item
150
+
151
+ in_condition = exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)
152
+ statement.add_condition(in_condition, current_params)
153
+ return statement
154
+
76
155
 
77
156
  @dataclass
78
- class NotInCollectionFilter(InAnyFilter, Generic[T]):
157
+ class NotInCollectionFilter(InAnyFilter[T]):
79
158
  """Data required to construct a ``WHERE ... NOT IN (...)`` clause."""
80
159
 
81
160
  field_name: str
@@ -85,10 +164,31 @@ class NotInCollectionFilter(InAnyFilter, Generic[T]):
85
164
 
86
165
  An empty list or ``None`` will return all rows."""
87
166
 
167
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
168
+ if self.values is None or not self.values: # Empty list or None, no filter applied
169
+ return statement
170
+
171
+ placeholder_expressions: list[exp.Placeholder] = []
172
+ current_params: dict[str, Any] = {}
173
+
174
+ for i, value_item in enumerate(self.values):
175
+ param_key = statement.generate_param_name(f"{self.field_name}_notin_{i}")
176
+ placeholder_expressions.append(exp.Placeholder(this=param_key))
177
+ current_params[param_key] = value_item
178
+
179
+ in_expr = exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)
180
+ not_in_condition = exp.Not(this=in_expr)
181
+ statement.add_condition(not_in_condition, current_params)
182
+ return statement
183
+
88
184
 
89
185
  class PaginationFilter(StatementFilter, ABC):
90
186
  """Subclass for methods that function as a pagination type."""
91
187
 
188
+ @abstractmethod
189
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
190
+ raise NotImplementedError
191
+
92
192
 
93
193
  @dataclass
94
194
  class LimitOffset(PaginationFilter):
@@ -99,6 +199,16 @@ class LimitOffset(PaginationFilter):
99
199
  offset: int
100
200
  """Value for ``OFFSET`` clause of query."""
101
201
 
202
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
203
+ # Generate parameter names for limit and offset
204
+ limit_param_name = statement.generate_param_name("limit_val")
205
+ offset_param_name = statement.generate_param_name("offset_val")
206
+
207
+ statement.add_limit(self.limit, param_name=limit_param_name)
208
+ statement.add_offset(self.offset, param_name=offset_param_name)
209
+
210
+ return statement
211
+
102
212
 
103
213
  @dataclass
104
214
  class OrderBy(StatementFilter):
@@ -109,6 +219,16 @@ class OrderBy(StatementFilter):
109
219
  sort_order: Literal["asc", "desc"] = "asc"
110
220
  """Sort ascending or descending"""
111
221
 
222
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
223
+ # Basic validation for sort_order, though Literal helps at type checking time
224
+ normalized_sort_order = self.sort_order.lower()
225
+ if normalized_sort_order not in {"asc", "desc"}:
226
+ normalized_sort_order = "asc"
227
+
228
+ statement.add_order_by(self.field_name, direction=cast("Literal['asc', 'desc']", normalized_sort_order))
229
+
230
+ return statement
231
+
112
232
 
113
233
  @dataclass
114
234
  class SearchFilter(StatementFilter):
@@ -121,7 +241,90 @@ class SearchFilter(StatementFilter):
121
241
  ignore_case: Optional[bool] = False
122
242
  """Should the search be case insensitive."""
123
243
 
244
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
245
+ if not self.value:
246
+ return statement
247
+
248
+ search_val_param_name = statement.generate_param_name("search_val")
249
+
250
+ # The pattern %value% needs to be handled carefully.
251
+ params = {search_val_param_name: f"%{self.value}%"}
252
+ pattern_expr = exp.Placeholder(this=search_val_param_name)
253
+
254
+ like_op = exp.ILike if self.ignore_case else exp.Like
255
+
256
+ if isinstance(self.field_name, str):
257
+ condition = like_op(this=exp.column(self.field_name), expression=pattern_expr)
258
+ statement.add_condition(condition, params)
259
+ elif isinstance(self.field_name, set) and self.field_name:
260
+ field_conditions = [like_op(this=exp.column(field), expression=pattern_expr) for field in self.field_name]
261
+ if not field_conditions:
262
+ return statement
263
+
264
+ final_condition = field_conditions[0]
265
+ for cond in field_conditions[1:]:
266
+ final_condition = exp.Or(this=final_condition, expression=cond) # type: ignore[assignment]
267
+ statement.add_condition(final_condition, params)
268
+
269
+ return statement
270
+
124
271
 
125
272
  @dataclass
126
- class NotInSearchFilter(SearchFilter):
273
+ class NotInSearchFilter(SearchFilter): # Inherits field_name, value, ignore_case
127
274
  """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause."""
275
+
276
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
277
+ if not self.value:
278
+ return statement
279
+
280
+ search_val_param_name = statement.generate_param_name("not_search_val")
281
+
282
+ params = {search_val_param_name: f"%{self.value}%"}
283
+ pattern_expr = exp.Placeholder(this=search_val_param_name)
284
+
285
+ like_op = exp.ILike if self.ignore_case else exp.Like
286
+
287
+ if isinstance(self.field_name, str):
288
+ condition = exp.Not(this=like_op(this=exp.column(self.field_name), expression=pattern_expr))
289
+ statement.add_condition(condition, params)
290
+ elif isinstance(self.field_name, set) and self.field_name:
291
+ field_conditions = [
292
+ exp.Not(this=like_op(this=exp.column(field), expression=pattern_expr)) for field in self.field_name
293
+ ]
294
+ if not field_conditions:
295
+ return statement
296
+
297
+ # Combine with AND: (field1 NOT LIKE pattern) AND (field2 NOT LIKE pattern) ...
298
+ final_condition = field_conditions[0]
299
+ for cond in field_conditions[1:]:
300
+ final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
301
+ statement.add_condition(final_condition, params)
302
+
303
+ return statement
304
+
305
+
306
+ # Function to be imported in SQLStatement module
307
+ def apply_filter(statement: SQLStatement, filter_obj: StatementFilter) -> SQLStatement:
308
+ """Apply a statement filter to a SQL statement.
309
+
310
+ Args:
311
+ statement: The SQL statement to modify.
312
+ filter_obj: The filter to apply.
313
+
314
+ Returns:
315
+ The modified statement.
316
+ """
317
+ return filter_obj.append_to_statement(statement)
318
+
319
+
320
+ FilterTypes: TypeAlias = Union[
321
+ BeforeAfter,
322
+ OnBeforeAfter,
323
+ CollectionFilter[Any],
324
+ LimitOffset,
325
+ OrderBy,
326
+ SearchFilter,
327
+ NotInCollectionFilter[Any],
328
+ NotInSearchFilter,
329
+ ]
330
+ """Aggregate type alias of the types supported for collection filtering."""
sqlspec/mixins.py CHANGED
@@ -1,17 +1,37 @@
1
+ import datetime
1
2
  from abc import abstractmethod
3
+ from collections.abc import Sequence
4
+ from enum import Enum
5
+ from functools import partial
6
+ from pathlib import Path, PurePath
2
7
  from typing import (
3
8
  TYPE_CHECKING,
4
9
  Any,
10
+ Callable,
5
11
  ClassVar,
6
12
  Generic,
7
13
  Optional,
14
+ Union,
15
+ cast,
16
+ overload,
8
17
  )
18
+ from uuid import UUID
9
19
 
10
20
  from sqlglot import parse_one
11
21
  from sqlglot.dialects.dialect import DialectType
12
22
 
13
- from sqlspec.exceptions import SQLConversionError, SQLParsingError
14
- from sqlspec.typing import ConnectionT, StatementParameterType
23
+ from sqlspec.exceptions import SQLConversionError, SQLParsingError, SQLSpecError
24
+ from sqlspec.typing import (
25
+ ConnectionT,
26
+ ModelDTOT,
27
+ ModelT,
28
+ StatementParameterType,
29
+ convert,
30
+ get_type_adapter,
31
+ is_dataclass,
32
+ is_msgspec_struct,
33
+ is_pydantic_model,
34
+ )
15
35
 
16
36
  if TYPE_CHECKING:
17
37
  from sqlspec.typing import ArrowTable
@@ -154,3 +174,133 @@ class SQLTranslatorMixin(Generic[ConnectionT]):
154
174
  except Exception as e:
155
175
  error_msg = f"Failed to convert SQL to {to_dialect}: {e!s}"
156
176
  raise SQLConversionError(error_msg) from e
177
+
178
+
179
+ _DEFAULT_TYPE_DECODERS = [ # pyright: ignore[reportUnknownVariableType]
180
+ (lambda x: x is UUID, lambda t, v: t(v.hex)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
181
+ (lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
182
+ (lambda x: x is datetime.date, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
183
+ (lambda x: x is datetime.time, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
184
+ (lambda x: x is Enum, lambda t, v: t(v.value)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
185
+ ]
186
+
187
+
188
+ def _default_msgspec_deserializer(
189
+ target_type: Any,
190
+ value: Any,
191
+ type_decoders: "Union[Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]], None]" = None,
192
+ ) -> Any: # pragma: no cover
193
+ """Transform values non-natively supported by ``msgspec``
194
+
195
+ Args:
196
+ target_type: Encountered type
197
+ value: Value to coerce
198
+ type_decoders: Optional sequence of type decoders
199
+
200
+ Raises:
201
+ TypeError: If the value cannot be coerced to the target type
202
+
203
+ Returns:
204
+ A ``msgspec``-supported type
205
+ """
206
+
207
+ if isinstance(value, target_type):
208
+ return value
209
+
210
+ if type_decoders:
211
+ for predicate, decoder in type_decoders:
212
+ if predicate(target_type):
213
+ return decoder(target_type, value)
214
+
215
+ if issubclass(target_type, (Path, PurePath, UUID)):
216
+ return target_type(value)
217
+
218
+ try:
219
+ return target_type(value)
220
+ except Exception as e:
221
+ msg = f"Unsupported type: {type(value)!r}"
222
+ raise TypeError(msg) from e
223
+
224
+
225
+ class ResultConverter:
226
+ """Simple mixin to help convert to dictionary or list of dictionaries to specified schema type.
227
+
228
+ Single objects are transformed to the supplied schema type, and lists of objects are transformed into a list of the supplied schema type.
229
+
230
+ Args:
231
+ data: A database model instance or row mapping.
232
+ Type: :class:`~sqlspec.typing.ModelDictT`
233
+
234
+ Returns:
235
+ The converted schema object.
236
+ """
237
+
238
+ @overload
239
+ @staticmethod
240
+ def to_schema(data: "ModelT", *, schema_type: None = None) -> "ModelT": ...
241
+ @overload
242
+ @staticmethod
243
+ def to_schema(data: "dict[str, Any]", *, schema_type: "type[ModelDTOT]") -> "ModelDTOT": ...
244
+ @overload
245
+ @staticmethod
246
+ def to_schema(data: "Sequence[ModelT]", *, schema_type: None = None) -> "Sequence[ModelT]": ...
247
+ @overload
248
+ @staticmethod
249
+ def to_schema(data: "Sequence[dict[str, Any]]", *, schema_type: "type[ModelDTOT]") -> "Sequence[ModelDTOT]": ...
250
+
251
+ @staticmethod
252
+ def to_schema(
253
+ data: "Union[ModelT, Sequence[ModelT], dict[str, Any], Sequence[dict[str, Any]], ModelDTOT, Sequence[ModelDTOT]]",
254
+ *,
255
+ schema_type: "Optional[type[ModelDTOT]]" = None,
256
+ ) -> "Union[ModelT, Sequence[ModelT], ModelDTOT, Sequence[ModelDTOT]]":
257
+ if schema_type is None:
258
+ if not isinstance(data, Sequence):
259
+ return cast("ModelT", data)
260
+ return cast("Sequence[ModelT]", data)
261
+ if is_dataclass(schema_type):
262
+ if not isinstance(data, Sequence):
263
+ # data is assumed to be dict[str, Any] as per the method's overloads
264
+ return cast("ModelDTOT", schema_type(**data)) # type: ignore[operator]
265
+ # data is assumed to be Sequence[dict[str, Any]]
266
+ return cast("Sequence[ModelDTOT]", [schema_type(**item) for item in data]) # type: ignore[operator]
267
+ if is_msgspec_struct(schema_type):
268
+ if not isinstance(data, Sequence):
269
+ return cast(
270
+ "ModelDTOT",
271
+ convert(
272
+ obj=data,
273
+ type=schema_type,
274
+ from_attributes=True,
275
+ dec_hook=partial(
276
+ _default_msgspec_deserializer,
277
+ type_decoders=_DEFAULT_TYPE_DECODERS,
278
+ ),
279
+ ),
280
+ )
281
+ return cast(
282
+ "Sequence[ModelDTOT]",
283
+ convert(
284
+ obj=data,
285
+ type=list[schema_type], # type: ignore[valid-type]
286
+ from_attributes=True,
287
+ dec_hook=partial(
288
+ _default_msgspec_deserializer,
289
+ type_decoders=_DEFAULT_TYPE_DECODERS,
290
+ ),
291
+ ),
292
+ )
293
+
294
+ if schema_type is not None and is_pydantic_model(schema_type):
295
+ if not isinstance(data, Sequence):
296
+ return cast(
297
+ "ModelDTOT",
298
+ get_type_adapter(schema_type).validate_python(data, from_attributes=True), # pyright: ignore
299
+ )
300
+ return cast(
301
+ "Sequence[ModelDTOT]",
302
+ get_type_adapter(list[schema_type]).validate_python(data, from_attributes=True), # type: ignore[valid-type] # pyright: ignore[reportUnknownArgumentType]
303
+ )
304
+
305
+ msg = "`schema_type` should be a valid Dataclass, Pydantic model or Msgspec struct"
306
+ raise SQLSpecError(msg)