sqlspec 0.10.1__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

sqlspec/filters.py CHANGED
@@ -1,16 +1,20 @@
1
1
  """Collection filter datastructures."""
2
2
 
3
- from abc import ABC
3
+ from abc import ABC, abstractmethod
4
4
  from collections import abc
5
5
  from dataclasses import dataclass
6
6
  from datetime import datetime
7
- from typing import Generic, Literal, Optional, Protocol, Union
7
+ from typing import Any, Generic, Literal, Optional, Protocol, Union, cast, runtime_checkable
8
8
 
9
- from typing_extensions import TypeVar
9
+ from sqlglot import exp
10
+ from typing_extensions import TypeAlias, TypeVar
11
+
12
+ from sqlspec.statement import SQLStatement
10
13
 
11
14
  __all__ = (
12
15
  "BeforeAfter",
13
16
  "CollectionFilter",
17
+ "FilterTypes",
14
18
  "InAnyFilter",
15
19
  "LimitOffset",
16
20
  "NotInCollectionFilter",
@@ -20,18 +24,27 @@ __all__ = (
20
24
  "PaginationFilter",
21
25
  "SearchFilter",
22
26
  "StatementFilter",
27
+ "apply_filter",
23
28
  )
24
29
 
25
30
  T = TypeVar("T")
26
- StatementT = TypeVar("StatementT", bound="str")
27
31
 
28
32
 
33
+ @runtime_checkable
29
34
  class StatementFilter(Protocol):
30
35
  """Protocol for filters that can be appended to a statement."""
31
36
 
32
- def append_to_statement(self, statement: StatementT) -> StatementT:
33
- """Append the filter to the statement."""
34
- return statement
37
+ @abstractmethod
38
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
39
+ """Append the filter to the statement.
40
+
41
+ Args:
42
+ statement: The SQL statement to modify.
43
+
44
+ Returns:
45
+ The modified statement.
46
+ """
47
+ raise NotImplementedError
35
48
 
36
49
 
37
50
  @dataclass
@@ -45,6 +58,27 @@ class BeforeAfter(StatementFilter):
45
58
  after: Optional[datetime] = None
46
59
  """Filter results where field later than this."""
47
60
 
61
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
62
+ conditions = []
63
+ params: dict[str, Any] = {}
64
+ col_expr = exp.column(self.field_name)
65
+
66
+ if self.before:
67
+ param_name = statement.generate_param_name(f"{self.field_name}_before")
68
+ conditions.append(exp.LT(this=col_expr, expression=exp.Placeholder(this=param_name)))
69
+ params[param_name] = self.before
70
+ if self.after:
71
+ param_name = statement.generate_param_name(f"{self.field_name}_after")
72
+ conditions.append(exp.GT(this=col_expr, expression=exp.Placeholder(this=param_name))) # type: ignore[arg-type]
73
+ params[param_name] = self.after
74
+
75
+ if conditions:
76
+ final_condition = conditions[0]
77
+ for cond in conditions[1:]:
78
+ final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
79
+ statement.add_condition(final_condition, params)
80
+ return statement
81
+
48
82
 
49
83
  @dataclass
50
84
  class OnBeforeAfter(StatementFilter):
@@ -57,13 +91,38 @@ class OnBeforeAfter(StatementFilter):
57
91
  on_or_after: Optional[datetime] = None
58
92
  """Filter results where field on or later than this."""
59
93
 
94
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
95
+ conditions = []
96
+ params: dict[str, Any] = {}
97
+ col_expr = exp.column(self.field_name)
98
+
99
+ if self.on_or_before:
100
+ param_name = statement.generate_param_name(f"{self.field_name}_on_or_before")
101
+ conditions.append(exp.LTE(this=col_expr, expression=exp.Placeholder(this=param_name)))
102
+ params[param_name] = self.on_or_before
103
+ if self.on_or_after:
104
+ param_name = statement.generate_param_name(f"{self.field_name}_on_or_after")
105
+ conditions.append(exp.GTE(this=col_expr, expression=exp.Placeholder(this=param_name))) # type: ignore[arg-type]
106
+ params[param_name] = self.on_or_after
107
+
108
+ if conditions:
109
+ final_condition = conditions[0]
110
+ for cond in conditions[1:]:
111
+ final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
112
+ statement.add_condition(final_condition, params)
113
+ return statement
60
114
 
61
- class InAnyFilter(StatementFilter, ABC):
115
+
116
+ class InAnyFilter(StatementFilter, ABC, Generic[T]):
62
117
  """Subclass for methods that have a `prefer_any` attribute."""
63
118
 
119
+ @abstractmethod
120
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
121
+ raise NotImplementedError
122
+
64
123
 
65
124
  @dataclass
66
- class CollectionFilter(InAnyFilter, Generic[T]):
125
+ class CollectionFilter(InAnyFilter[T]):
67
126
  """Data required to construct a ``WHERE ... IN (...)`` clause."""
68
127
 
69
128
  field_name: str
@@ -73,9 +132,30 @@ class CollectionFilter(InAnyFilter, Generic[T]):
73
132
 
74
133
  An empty list will return an empty result set, however, if ``None``, the filter is not applied to the query, and all rows are returned. """
75
134
 
135
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
136
+ if self.values is None:
137
+ return statement
138
+
139
+ if not self.values: # Empty collection
140
+ # Add a condition that is always false
141
+ statement.add_condition(exp.false())
142
+ return statement
143
+
144
+ placeholder_expressions: list[exp.Placeholder] = []
145
+ current_params: dict[str, Any] = {}
146
+
147
+ for i, value_item in enumerate(self.values):
148
+ param_key = statement.generate_param_name(f"{self.field_name}_in_{i}")
149
+ placeholder_expressions.append(exp.Placeholder(this=param_key))
150
+ current_params[param_key] = value_item
151
+
152
+ in_condition = exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)
153
+ statement.add_condition(in_condition, current_params)
154
+ return statement
155
+
76
156
 
77
157
  @dataclass
78
- class NotInCollectionFilter(InAnyFilter, Generic[T]):
158
+ class NotInCollectionFilter(InAnyFilter[T]):
79
159
  """Data required to construct a ``WHERE ... NOT IN (...)`` clause."""
80
160
 
81
161
  field_name: str
@@ -85,10 +165,31 @@ class NotInCollectionFilter(InAnyFilter, Generic[T]):
85
165
 
86
166
  An empty list or ``None`` will return all rows."""
87
167
 
168
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
169
+ if self.values is None or not self.values: # Empty list or None, no filter applied
170
+ return statement
171
+
172
+ placeholder_expressions: list[exp.Placeholder] = []
173
+ current_params: dict[str, Any] = {}
174
+
175
+ for i, value_item in enumerate(self.values):
176
+ param_key = statement.generate_param_name(f"{self.field_name}_notin_{i}")
177
+ placeholder_expressions.append(exp.Placeholder(this=param_key))
178
+ current_params[param_key] = value_item
179
+
180
+ in_expr = exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)
181
+ not_in_condition = exp.Not(this=in_expr)
182
+ statement.add_condition(not_in_condition, current_params)
183
+ return statement
184
+
88
185
 
89
186
  class PaginationFilter(StatementFilter, ABC):
90
187
  """Subclass for methods that function as a pagination type."""
91
188
 
189
+ @abstractmethod
190
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
191
+ raise NotImplementedError
192
+
92
193
 
93
194
  @dataclass
94
195
  class LimitOffset(PaginationFilter):
@@ -99,6 +200,16 @@ class LimitOffset(PaginationFilter):
99
200
  offset: int
100
201
  """Value for ``OFFSET`` clause of query."""
101
202
 
203
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
204
+ # Generate parameter names for limit and offset
205
+ limit_param_name = statement.generate_param_name("limit_val")
206
+ offset_param_name = statement.generate_param_name("offset_val")
207
+
208
+ statement.add_limit(self.limit, param_name=limit_param_name)
209
+ statement.add_offset(self.offset, param_name=offset_param_name)
210
+
211
+ return statement
212
+
102
213
 
103
214
  @dataclass
104
215
  class OrderBy(StatementFilter):
@@ -109,6 +220,16 @@ class OrderBy(StatementFilter):
109
220
  sort_order: Literal["asc", "desc"] = "asc"
110
221
  """Sort ascending or descending"""
111
222
 
223
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
224
+ # Basic validation for sort_order, though Literal helps at type checking time
225
+ normalized_sort_order = self.sort_order.lower()
226
+ if normalized_sort_order not in {"asc", "desc"}:
227
+ normalized_sort_order = "asc"
228
+
229
+ statement.add_order_by(self.field_name, direction=cast("Literal['asc', 'desc']", normalized_sort_order))
230
+
231
+ return statement
232
+
112
233
 
113
234
  @dataclass
114
235
  class SearchFilter(StatementFilter):
@@ -121,7 +242,90 @@ class SearchFilter(StatementFilter):
121
242
  ignore_case: Optional[bool] = False
122
243
  """Should the search be case insensitive."""
123
244
 
245
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
246
+ if not self.value:
247
+ return statement
248
+
249
+ search_val_param_name = statement.generate_param_name("search_val")
250
+
251
+ # The pattern %value% needs to be handled carefully.
252
+ params = {search_val_param_name: f"%{self.value}%"}
253
+ pattern_expr = exp.Placeholder(this=search_val_param_name)
254
+
255
+ like_op = exp.ILike if self.ignore_case else exp.Like
256
+
257
+ if isinstance(self.field_name, str):
258
+ condition = like_op(this=exp.column(self.field_name), expression=pattern_expr)
259
+ statement.add_condition(condition, params)
260
+ elif isinstance(self.field_name, set) and self.field_name:
261
+ field_conditions = [like_op(this=exp.column(field), expression=pattern_expr) for field in self.field_name]
262
+ if not field_conditions:
263
+ return statement
264
+
265
+ final_condition = field_conditions[0]
266
+ for cond in field_conditions[1:]:
267
+ final_condition = exp.Or(this=final_condition, expression=cond) # type: ignore[assignment]
268
+ statement.add_condition(final_condition, params)
269
+
270
+ return statement
271
+
124
272
 
125
273
  @dataclass
126
- class NotInSearchFilter(SearchFilter):
274
+ class NotInSearchFilter(SearchFilter): # Inherits field_name, value, ignore_case
127
275
  """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause."""
276
+
277
+ def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
278
+ if not self.value:
279
+ return statement
280
+
281
+ search_val_param_name = statement.generate_param_name("not_search_val")
282
+
283
+ params = {search_val_param_name: f"%{self.value}%"}
284
+ pattern_expr = exp.Placeholder(this=search_val_param_name)
285
+
286
+ like_op = exp.ILike if self.ignore_case else exp.Like
287
+
288
+ if isinstance(self.field_name, str):
289
+ condition = exp.Not(this=like_op(this=exp.column(self.field_name), expression=pattern_expr))
290
+ statement.add_condition(condition, params)
291
+ elif isinstance(self.field_name, set) and self.field_name:
292
+ field_conditions = [
293
+ exp.Not(this=like_op(this=exp.column(field), expression=pattern_expr)) for field in self.field_name
294
+ ]
295
+ if not field_conditions:
296
+ return statement
297
+
298
+ # Combine with AND: (field1 NOT LIKE pattern) AND (field2 NOT LIKE pattern) ...
299
+ final_condition = field_conditions[0]
300
+ for cond in field_conditions[1:]:
301
+ final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
302
+ statement.add_condition(final_condition, params)
303
+
304
+ return statement
305
+
306
+
307
+ # Function to be imported in SQLStatement module
308
+ def apply_filter(statement: SQLStatement, filter_obj: StatementFilter) -> SQLStatement:
309
+ """Apply a statement filter to a SQL statement.
310
+
311
+ Args:
312
+ statement: The SQL statement to modify.
313
+ filter_obj: The filter to apply.
314
+
315
+ Returns:
316
+ The modified statement.
317
+ """
318
+ return filter_obj.append_to_statement(statement)
319
+
320
+
321
+ FilterTypes: TypeAlias = Union[
322
+ BeforeAfter,
323
+ OnBeforeAfter,
324
+ CollectionFilter[Any],
325
+ LimitOffset,
326
+ OrderBy,
327
+ SearchFilter,
328
+ NotInCollectionFilter[Any],
329
+ NotInSearchFilter,
330
+ ]
331
+ """Aggregate type alias of the types supported for collection filtering."""
sqlspec/mixins.py CHANGED
@@ -1,19 +1,40 @@
1
+ import datetime
1
2
  from abc import abstractmethod
3
+ from collections.abc import Sequence
4
+ from enum import Enum
5
+ from functools import partial
6
+ from pathlib import Path, PurePath
2
7
  from typing import (
3
8
  TYPE_CHECKING,
4
9
  Any,
10
+ Callable,
5
11
  ClassVar,
6
12
  Generic,
7
13
  Optional,
14
+ Union,
15
+ cast,
16
+ overload,
8
17
  )
18
+ from uuid import UUID
9
19
 
10
20
  from sqlglot import parse_one
11
21
  from sqlglot.dialects.dialect import DialectType
12
22
 
13
- from sqlspec.exceptions import SQLConversionError, SQLParsingError
14
- from sqlspec.typing import ConnectionT, StatementParameterType
23
+ from sqlspec.exceptions import SQLConversionError, SQLParsingError, SQLSpecError
24
+ from sqlspec.typing import (
25
+ ConnectionT,
26
+ ModelDTOT,
27
+ ModelT,
28
+ StatementParameterType,
29
+ convert,
30
+ get_type_adapter,
31
+ is_dataclass,
32
+ is_msgspec_struct,
33
+ is_pydantic_model,
34
+ )
15
35
 
16
36
  if TYPE_CHECKING:
37
+ from sqlspec.filters import StatementFilter
17
38
  from sqlspec.typing import ArrowTable
18
39
 
19
40
  __all__ = (
@@ -31,12 +52,11 @@ class SyncArrowBulkOperationsMixin(Generic[ConnectionT]):
31
52
  __supports_arrow__: "ClassVar[bool]" = True
32
53
 
33
54
  @abstractmethod
34
- def select_arrow( # pyright: ignore[reportUnknownParameterType]
55
+ def select_arrow(
35
56
  self,
36
57
  sql: str,
37
58
  parameters: "Optional[StatementParameterType]" = None,
38
- /,
39
- *,
59
+ *filters: "StatementFilter",
40
60
  connection: "Optional[ConnectionT]" = None,
41
61
  **kwargs: Any,
42
62
  ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType]
@@ -45,6 +65,7 @@ class SyncArrowBulkOperationsMixin(Generic[ConnectionT]):
45
65
  Args:
46
66
  sql: The SQL query string.
47
67
  parameters: Parameters for the query.
68
+ filters: Optional filters to apply to the query.
48
69
  connection: Optional connection override.
49
70
  **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
50
71
 
@@ -60,12 +81,11 @@ class AsyncArrowBulkOperationsMixin(Generic[ConnectionT]):
60
81
  __supports_arrow__: "ClassVar[bool]" = True
61
82
 
62
83
  @abstractmethod
63
- async def select_arrow( # pyright: ignore[reportUnknownParameterType]
84
+ async def select_arrow(
64
85
  self,
65
86
  sql: str,
66
87
  parameters: "Optional[StatementParameterType]" = None,
67
- /,
68
- *,
88
+ *filters: "StatementFilter",
69
89
  connection: "Optional[ConnectionT]" = None,
70
90
  **kwargs: Any,
71
91
  ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType]
@@ -74,6 +94,7 @@ class AsyncArrowBulkOperationsMixin(Generic[ConnectionT]):
74
94
  Args:
75
95
  sql: The SQL query string.
76
96
  parameters: Parameters for the query.
97
+ filters: Optional filters to apply to the query.
77
98
  connection: Optional connection override.
78
99
  **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
79
100
 
@@ -91,8 +112,7 @@ class SyncParquetExportMixin(Generic[ConnectionT]):
91
112
  self,
92
113
  sql: str,
93
114
  parameters: "Optional[StatementParameterType]" = None,
94
- /,
95
- *,
115
+ *filters: "StatementFilter",
96
116
  connection: "Optional[ConnectionT]" = None,
97
117
  **kwargs: Any,
98
118
  ) -> None:
@@ -108,8 +128,7 @@ class AsyncParquetExportMixin(Generic[ConnectionT]):
108
128
  self,
109
129
  sql: str,
110
130
  parameters: "Optional[StatementParameterType]" = None,
111
- /,
112
- *,
131
+ *filters: "StatementFilter",
113
132
  connection: "Optional[ConnectionT]" = None,
114
133
  **kwargs: Any,
115
134
  ) -> None:
@@ -154,3 +173,133 @@ class SQLTranslatorMixin(Generic[ConnectionT]):
154
173
  except Exception as e:
155
174
  error_msg = f"Failed to convert SQL to {to_dialect}: {e!s}"
156
175
  raise SQLConversionError(error_msg) from e
176
+
177
+
178
+ _DEFAULT_TYPE_DECODERS = [ # pyright: ignore[reportUnknownVariableType]
179
+ (lambda x: x is UUID, lambda t, v: t(v.hex)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
180
+ (lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
181
+ (lambda x: x is datetime.date, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
182
+ (lambda x: x is datetime.time, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
183
+ (lambda x: x is Enum, lambda t, v: t(v.value)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
184
+ ]
185
+
186
+
187
+ def _default_msgspec_deserializer(
188
+ target_type: Any,
189
+ value: Any,
190
+ type_decoders: "Union[Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]], None]" = None,
191
+ ) -> Any: # pragma: no cover
192
+ """Transform values non-natively supported by ``msgspec``
193
+
194
+ Args:
195
+ target_type: Encountered type
196
+ value: Value to coerce
197
+ type_decoders: Optional sequence of type decoders
198
+
199
+ Raises:
200
+ TypeError: If the value cannot be coerced to the target type
201
+
202
+ Returns:
203
+ A ``msgspec``-supported type
204
+ """
205
+
206
+ if isinstance(value, target_type):
207
+ return value
208
+
209
+ if type_decoders:
210
+ for predicate, decoder in type_decoders:
211
+ if predicate(target_type):
212
+ return decoder(target_type, value)
213
+
214
+ if issubclass(target_type, (Path, PurePath, UUID)):
215
+ return target_type(value)
216
+
217
+ try:
218
+ return target_type(value)
219
+ except Exception as e:
220
+ msg = f"Unsupported type: {type(value)!r}"
221
+ raise TypeError(msg) from e
222
+
223
+
224
+ class ResultConverter:
225
+ """Simple mixin to help convert to dictionary or list of dictionaries to specified schema type.
226
+
227
+ Single objects are transformed to the supplied schema type, and lists of objects are transformed into a list of the supplied schema type.
228
+
229
+ Args:
230
+ data: A database model instance or row mapping.
231
+ Type: :class:`~sqlspec.typing.ModelDictT`
232
+
233
+ Returns:
234
+ The converted schema object.
235
+ """
236
+
237
+ @overload
238
+ @staticmethod
239
+ def to_schema(data: "ModelT", *, schema_type: None = None) -> "ModelT": ...
240
+ @overload
241
+ @staticmethod
242
+ def to_schema(data: "dict[str, Any]", *, schema_type: "type[ModelDTOT]") -> "ModelDTOT": ...
243
+ @overload
244
+ @staticmethod
245
+ def to_schema(data: "Sequence[ModelT]", *, schema_type: None = None) -> "Sequence[ModelT]": ...
246
+ @overload
247
+ @staticmethod
248
+ def to_schema(data: "Sequence[dict[str, Any]]", *, schema_type: "type[ModelDTOT]") -> "Sequence[ModelDTOT]": ...
249
+
250
+ @staticmethod
251
+ def to_schema(
252
+ data: "Union[ModelT, Sequence[ModelT], dict[str, Any], Sequence[dict[str, Any]], ModelDTOT, Sequence[ModelDTOT]]",
253
+ *,
254
+ schema_type: "Optional[type[ModelDTOT]]" = None,
255
+ ) -> "Union[ModelT, Sequence[ModelT], ModelDTOT, Sequence[ModelDTOT]]":
256
+ if schema_type is None:
257
+ if not isinstance(data, Sequence):
258
+ return cast("ModelT", data)
259
+ return cast("Sequence[ModelT]", data)
260
+ if is_dataclass(schema_type):
261
+ if not isinstance(data, Sequence):
262
+ # data is assumed to be dict[str, Any] as per the method's overloads
263
+ return cast("ModelDTOT", schema_type(**data)) # type: ignore[operator]
264
+ # data is assumed to be Sequence[dict[str, Any]]
265
+ return cast("Sequence[ModelDTOT]", [schema_type(**item) for item in data]) # type: ignore[operator]
266
+ if is_msgspec_struct(schema_type):
267
+ if not isinstance(data, Sequence):
268
+ return cast(
269
+ "ModelDTOT",
270
+ convert(
271
+ obj=data,
272
+ type=schema_type,
273
+ from_attributes=True,
274
+ dec_hook=partial(
275
+ _default_msgspec_deserializer,
276
+ type_decoders=_DEFAULT_TYPE_DECODERS,
277
+ ),
278
+ ),
279
+ )
280
+ return cast(
281
+ "Sequence[ModelDTOT]",
282
+ convert(
283
+ obj=data,
284
+ type=list[schema_type], # type: ignore[valid-type]
285
+ from_attributes=True,
286
+ dec_hook=partial(
287
+ _default_msgspec_deserializer,
288
+ type_decoders=_DEFAULT_TYPE_DECODERS,
289
+ ),
290
+ ),
291
+ )
292
+
293
+ if schema_type is not None and is_pydantic_model(schema_type):
294
+ if not isinstance(data, Sequence):
295
+ return cast(
296
+ "ModelDTOT",
297
+ get_type_adapter(schema_type).validate_python(data, from_attributes=True), # pyright: ignore
298
+ )
299
+ return cast(
300
+ "Sequence[ModelDTOT]",
301
+ get_type_adapter(list[schema_type]).validate_python(data, from_attributes=True), # type: ignore[valid-type] # pyright: ignore[reportUnknownArgumentType]
302
+ )
303
+
304
+ msg = "`schema_type` should be a valid Dataclass, Pydantic model or Msgspec struct"
305
+ raise SQLSpecError(msg)