sqlspec 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_typing.py +24 -32
- sqlspec/adapters/adbc/config.py +1 -1
- sqlspec/adapters/adbc/driver.py +336 -165
- sqlspec/adapters/aiosqlite/driver.py +211 -126
- sqlspec/adapters/asyncmy/driver.py +164 -68
- sqlspec/adapters/asyncpg/config.py +3 -1
- sqlspec/adapters/asyncpg/driver.py +190 -231
- sqlspec/adapters/bigquery/driver.py +178 -169
- sqlspec/adapters/duckdb/driver.py +175 -84
- sqlspec/adapters/oracledb/driver.py +224 -90
- sqlspec/adapters/psqlpy/driver.py +267 -187
- sqlspec/adapters/psycopg/driver.py +138 -184
- sqlspec/adapters/sqlite/driver.py +153 -121
- sqlspec/base.py +57 -45
- sqlspec/extensions/litestar/__init__.py +3 -12
- sqlspec/extensions/litestar/config.py +22 -7
- sqlspec/extensions/litestar/handlers.py +142 -85
- sqlspec/extensions/litestar/plugin.py +9 -8
- sqlspec/extensions/litestar/providers.py +521 -0
- sqlspec/filters.py +214 -11
- sqlspec/mixins.py +152 -2
- sqlspec/statement.py +276 -271
- sqlspec/typing.py +18 -1
- sqlspec/utils/__init__.py +2 -2
- sqlspec/utils/singleton.py +35 -0
- sqlspec/utils/sync_tools.py +90 -151
- sqlspec/utils/text.py +68 -5
- {sqlspec-0.10.0.dist-info → sqlspec-0.11.0.dist-info}/METADATA +5 -1
- {sqlspec-0.10.0.dist-info → sqlspec-0.11.0.dist-info}/RECORD +32 -30
- {sqlspec-0.10.0.dist-info → sqlspec-0.11.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.10.0.dist-info → sqlspec-0.11.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.10.0.dist-info → sqlspec-0.11.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/filters.py
CHANGED
|
@@ -1,16 +1,20 @@
|
|
|
1
1
|
"""Collection filter datastructures."""
|
|
2
2
|
|
|
3
|
-
from abc import ABC
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
4
|
from collections import abc
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from typing import Generic, Literal, Optional, Protocol, Union
|
|
7
|
+
from typing import Any, Generic, Literal, Optional, Protocol, Union, cast
|
|
8
8
|
|
|
9
|
-
from
|
|
9
|
+
from sqlglot import exp
|
|
10
|
+
from typing_extensions import TypeAlias, TypeVar
|
|
11
|
+
|
|
12
|
+
from sqlspec.statement import SQLStatement
|
|
10
13
|
|
|
11
14
|
__all__ = (
|
|
12
15
|
"BeforeAfter",
|
|
13
16
|
"CollectionFilter",
|
|
17
|
+
"FilterTypes",
|
|
14
18
|
"InAnyFilter",
|
|
15
19
|
"LimitOffset",
|
|
16
20
|
"NotInCollectionFilter",
|
|
@@ -20,18 +24,26 @@ __all__ = (
|
|
|
20
24
|
"PaginationFilter",
|
|
21
25
|
"SearchFilter",
|
|
22
26
|
"StatementFilter",
|
|
27
|
+
"apply_filter",
|
|
23
28
|
)
|
|
24
29
|
|
|
25
30
|
T = TypeVar("T")
|
|
26
|
-
StatementT = TypeVar("StatementT", bound="str")
|
|
27
31
|
|
|
28
32
|
|
|
29
33
|
class StatementFilter(Protocol):
|
|
30
34
|
"""Protocol for filters that can be appended to a statement."""
|
|
31
35
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
38
|
+
"""Append the filter to the statement.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
statement: The SQL statement to modify.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The modified statement.
|
|
45
|
+
"""
|
|
46
|
+
raise NotImplementedError
|
|
35
47
|
|
|
36
48
|
|
|
37
49
|
@dataclass
|
|
@@ -45,6 +57,27 @@ class BeforeAfter(StatementFilter):
|
|
|
45
57
|
after: Optional[datetime] = None
|
|
46
58
|
"""Filter results where field later than this."""
|
|
47
59
|
|
|
60
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
61
|
+
conditions = []
|
|
62
|
+
params: dict[str, Any] = {}
|
|
63
|
+
col_expr = exp.column(self.field_name)
|
|
64
|
+
|
|
65
|
+
if self.before:
|
|
66
|
+
param_name = statement.generate_param_name(f"{self.field_name}_before")
|
|
67
|
+
conditions.append(exp.LT(this=col_expr, expression=exp.Placeholder(this=param_name)))
|
|
68
|
+
params[param_name] = self.before
|
|
69
|
+
if self.after:
|
|
70
|
+
param_name = statement.generate_param_name(f"{self.field_name}_after")
|
|
71
|
+
conditions.append(exp.GT(this=col_expr, expression=exp.Placeholder(this=param_name))) # type: ignore[arg-type]
|
|
72
|
+
params[param_name] = self.after
|
|
73
|
+
|
|
74
|
+
if conditions:
|
|
75
|
+
final_condition = conditions[0]
|
|
76
|
+
for cond in conditions[1:]:
|
|
77
|
+
final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
|
|
78
|
+
statement.add_condition(final_condition, params)
|
|
79
|
+
return statement
|
|
80
|
+
|
|
48
81
|
|
|
49
82
|
@dataclass
|
|
50
83
|
class OnBeforeAfter(StatementFilter):
|
|
@@ -57,13 +90,38 @@ class OnBeforeAfter(StatementFilter):
|
|
|
57
90
|
on_or_after: Optional[datetime] = None
|
|
58
91
|
"""Filter results where field on or later than this."""
|
|
59
92
|
|
|
93
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
94
|
+
conditions = []
|
|
95
|
+
params: dict[str, Any] = {}
|
|
96
|
+
col_expr = exp.column(self.field_name)
|
|
97
|
+
|
|
98
|
+
if self.on_or_before:
|
|
99
|
+
param_name = statement.generate_param_name(f"{self.field_name}_on_or_before")
|
|
100
|
+
conditions.append(exp.LTE(this=col_expr, expression=exp.Placeholder(this=param_name)))
|
|
101
|
+
params[param_name] = self.on_or_before
|
|
102
|
+
if self.on_or_after:
|
|
103
|
+
param_name = statement.generate_param_name(f"{self.field_name}_on_or_after")
|
|
104
|
+
conditions.append(exp.GTE(this=col_expr, expression=exp.Placeholder(this=param_name))) # type: ignore[arg-type]
|
|
105
|
+
params[param_name] = self.on_or_after
|
|
106
|
+
|
|
107
|
+
if conditions:
|
|
108
|
+
final_condition = conditions[0]
|
|
109
|
+
for cond in conditions[1:]:
|
|
110
|
+
final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
|
|
111
|
+
statement.add_condition(final_condition, params)
|
|
112
|
+
return statement
|
|
60
113
|
|
|
61
|
-
|
|
114
|
+
|
|
115
|
+
class InAnyFilter(StatementFilter, ABC, Generic[T]):
|
|
62
116
|
"""Subclass for methods that have a `prefer_any` attribute."""
|
|
63
117
|
|
|
118
|
+
@abstractmethod
|
|
119
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
120
|
+
raise NotImplementedError
|
|
121
|
+
|
|
64
122
|
|
|
65
123
|
@dataclass
|
|
66
|
-
class CollectionFilter(InAnyFilter
|
|
124
|
+
class CollectionFilter(InAnyFilter[T]):
|
|
67
125
|
"""Data required to construct a ``WHERE ... IN (...)`` clause."""
|
|
68
126
|
|
|
69
127
|
field_name: str
|
|
@@ -73,9 +131,30 @@ class CollectionFilter(InAnyFilter, Generic[T]):
|
|
|
73
131
|
|
|
74
132
|
An empty list will return an empty result set, however, if ``None``, the filter is not applied to the query, and all rows are returned. """
|
|
75
133
|
|
|
134
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
135
|
+
if self.values is None:
|
|
136
|
+
return statement
|
|
137
|
+
|
|
138
|
+
if not self.values: # Empty collection
|
|
139
|
+
# Add a condition that is always false
|
|
140
|
+
statement.add_condition(exp.false())
|
|
141
|
+
return statement
|
|
142
|
+
|
|
143
|
+
placeholder_expressions: list[exp.Placeholder] = []
|
|
144
|
+
current_params: dict[str, Any] = {}
|
|
145
|
+
|
|
146
|
+
for i, value_item in enumerate(self.values):
|
|
147
|
+
param_key = statement.generate_param_name(f"{self.field_name}_in_{i}")
|
|
148
|
+
placeholder_expressions.append(exp.Placeholder(this=param_key))
|
|
149
|
+
current_params[param_key] = value_item
|
|
150
|
+
|
|
151
|
+
in_condition = exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)
|
|
152
|
+
statement.add_condition(in_condition, current_params)
|
|
153
|
+
return statement
|
|
154
|
+
|
|
76
155
|
|
|
77
156
|
@dataclass
|
|
78
|
-
class NotInCollectionFilter(InAnyFilter
|
|
157
|
+
class NotInCollectionFilter(InAnyFilter[T]):
|
|
79
158
|
"""Data required to construct a ``WHERE ... NOT IN (...)`` clause."""
|
|
80
159
|
|
|
81
160
|
field_name: str
|
|
@@ -85,10 +164,31 @@ class NotInCollectionFilter(InAnyFilter, Generic[T]):
|
|
|
85
164
|
|
|
86
165
|
An empty list or ``None`` will return all rows."""
|
|
87
166
|
|
|
167
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
168
|
+
if self.values is None or not self.values: # Empty list or None, no filter applied
|
|
169
|
+
return statement
|
|
170
|
+
|
|
171
|
+
placeholder_expressions: list[exp.Placeholder] = []
|
|
172
|
+
current_params: dict[str, Any] = {}
|
|
173
|
+
|
|
174
|
+
for i, value_item in enumerate(self.values):
|
|
175
|
+
param_key = statement.generate_param_name(f"{self.field_name}_notin_{i}")
|
|
176
|
+
placeholder_expressions.append(exp.Placeholder(this=param_key))
|
|
177
|
+
current_params[param_key] = value_item
|
|
178
|
+
|
|
179
|
+
in_expr = exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)
|
|
180
|
+
not_in_condition = exp.Not(this=in_expr)
|
|
181
|
+
statement.add_condition(not_in_condition, current_params)
|
|
182
|
+
return statement
|
|
183
|
+
|
|
88
184
|
|
|
89
185
|
class PaginationFilter(StatementFilter, ABC):
|
|
90
186
|
"""Subclass for methods that function as a pagination type."""
|
|
91
187
|
|
|
188
|
+
@abstractmethod
|
|
189
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
190
|
+
raise NotImplementedError
|
|
191
|
+
|
|
92
192
|
|
|
93
193
|
@dataclass
|
|
94
194
|
class LimitOffset(PaginationFilter):
|
|
@@ -99,6 +199,16 @@ class LimitOffset(PaginationFilter):
|
|
|
99
199
|
offset: int
|
|
100
200
|
"""Value for ``OFFSET`` clause of query."""
|
|
101
201
|
|
|
202
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
203
|
+
# Generate parameter names for limit and offset
|
|
204
|
+
limit_param_name = statement.generate_param_name("limit_val")
|
|
205
|
+
offset_param_name = statement.generate_param_name("offset_val")
|
|
206
|
+
|
|
207
|
+
statement.add_limit(self.limit, param_name=limit_param_name)
|
|
208
|
+
statement.add_offset(self.offset, param_name=offset_param_name)
|
|
209
|
+
|
|
210
|
+
return statement
|
|
211
|
+
|
|
102
212
|
|
|
103
213
|
@dataclass
|
|
104
214
|
class OrderBy(StatementFilter):
|
|
@@ -109,6 +219,16 @@ class OrderBy(StatementFilter):
|
|
|
109
219
|
sort_order: Literal["asc", "desc"] = "asc"
|
|
110
220
|
"""Sort ascending or descending"""
|
|
111
221
|
|
|
222
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
223
|
+
# Basic validation for sort_order, though Literal helps at type checking time
|
|
224
|
+
normalized_sort_order = self.sort_order.lower()
|
|
225
|
+
if normalized_sort_order not in {"asc", "desc"}:
|
|
226
|
+
normalized_sort_order = "asc"
|
|
227
|
+
|
|
228
|
+
statement.add_order_by(self.field_name, direction=cast("Literal['asc', 'desc']", normalized_sort_order))
|
|
229
|
+
|
|
230
|
+
return statement
|
|
231
|
+
|
|
112
232
|
|
|
113
233
|
@dataclass
|
|
114
234
|
class SearchFilter(StatementFilter):
|
|
@@ -121,7 +241,90 @@ class SearchFilter(StatementFilter):
|
|
|
121
241
|
ignore_case: Optional[bool] = False
|
|
122
242
|
"""Should the search be case insensitive."""
|
|
123
243
|
|
|
244
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
245
|
+
if not self.value:
|
|
246
|
+
return statement
|
|
247
|
+
|
|
248
|
+
search_val_param_name = statement.generate_param_name("search_val")
|
|
249
|
+
|
|
250
|
+
# The pattern %value% needs to be handled carefully.
|
|
251
|
+
params = {search_val_param_name: f"%{self.value}%"}
|
|
252
|
+
pattern_expr = exp.Placeholder(this=search_val_param_name)
|
|
253
|
+
|
|
254
|
+
like_op = exp.ILike if self.ignore_case else exp.Like
|
|
255
|
+
|
|
256
|
+
if isinstance(self.field_name, str):
|
|
257
|
+
condition = like_op(this=exp.column(self.field_name), expression=pattern_expr)
|
|
258
|
+
statement.add_condition(condition, params)
|
|
259
|
+
elif isinstance(self.field_name, set) and self.field_name:
|
|
260
|
+
field_conditions = [like_op(this=exp.column(field), expression=pattern_expr) for field in self.field_name]
|
|
261
|
+
if not field_conditions:
|
|
262
|
+
return statement
|
|
263
|
+
|
|
264
|
+
final_condition = field_conditions[0]
|
|
265
|
+
for cond in field_conditions[1:]:
|
|
266
|
+
final_condition = exp.Or(this=final_condition, expression=cond) # type: ignore[assignment]
|
|
267
|
+
statement.add_condition(final_condition, params)
|
|
268
|
+
|
|
269
|
+
return statement
|
|
270
|
+
|
|
124
271
|
|
|
125
272
|
@dataclass
|
|
126
|
-
class NotInSearchFilter(SearchFilter):
|
|
273
|
+
class NotInSearchFilter(SearchFilter): # Inherits field_name, value, ignore_case
|
|
127
274
|
"""Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause."""
|
|
275
|
+
|
|
276
|
+
def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
|
|
277
|
+
if not self.value:
|
|
278
|
+
return statement
|
|
279
|
+
|
|
280
|
+
search_val_param_name = statement.generate_param_name("not_search_val")
|
|
281
|
+
|
|
282
|
+
params = {search_val_param_name: f"%{self.value}%"}
|
|
283
|
+
pattern_expr = exp.Placeholder(this=search_val_param_name)
|
|
284
|
+
|
|
285
|
+
like_op = exp.ILike if self.ignore_case else exp.Like
|
|
286
|
+
|
|
287
|
+
if isinstance(self.field_name, str):
|
|
288
|
+
condition = exp.Not(this=like_op(this=exp.column(self.field_name), expression=pattern_expr))
|
|
289
|
+
statement.add_condition(condition, params)
|
|
290
|
+
elif isinstance(self.field_name, set) and self.field_name:
|
|
291
|
+
field_conditions = [
|
|
292
|
+
exp.Not(this=like_op(this=exp.column(field), expression=pattern_expr)) for field in self.field_name
|
|
293
|
+
]
|
|
294
|
+
if not field_conditions:
|
|
295
|
+
return statement
|
|
296
|
+
|
|
297
|
+
# Combine with AND: (field1 NOT LIKE pattern) AND (field2 NOT LIKE pattern) ...
|
|
298
|
+
final_condition = field_conditions[0]
|
|
299
|
+
for cond in field_conditions[1:]:
|
|
300
|
+
final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
|
|
301
|
+
statement.add_condition(final_condition, params)
|
|
302
|
+
|
|
303
|
+
return statement
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
# Function to be imported in SQLStatement module
|
|
307
|
+
def apply_filter(statement: SQLStatement, filter_obj: StatementFilter) -> SQLStatement:
|
|
308
|
+
"""Apply a statement filter to a SQL statement.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
statement: The SQL statement to modify.
|
|
312
|
+
filter_obj: The filter to apply.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
The modified statement.
|
|
316
|
+
"""
|
|
317
|
+
return filter_obj.append_to_statement(statement)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
FilterTypes: TypeAlias = Union[
|
|
321
|
+
BeforeAfter,
|
|
322
|
+
OnBeforeAfter,
|
|
323
|
+
CollectionFilter[Any],
|
|
324
|
+
LimitOffset,
|
|
325
|
+
OrderBy,
|
|
326
|
+
SearchFilter,
|
|
327
|
+
NotInCollectionFilter[Any],
|
|
328
|
+
NotInSearchFilter,
|
|
329
|
+
]
|
|
330
|
+
"""Aggregate type alias of the types supported for collection filtering."""
|
sqlspec/mixins.py
CHANGED
|
@@ -1,17 +1,37 @@
|
|
|
1
|
+
import datetime
|
|
1
2
|
from abc import abstractmethod
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from functools import partial
|
|
6
|
+
from pathlib import Path, PurePath
|
|
2
7
|
from typing import (
|
|
3
8
|
TYPE_CHECKING,
|
|
4
9
|
Any,
|
|
10
|
+
Callable,
|
|
5
11
|
ClassVar,
|
|
6
12
|
Generic,
|
|
7
13
|
Optional,
|
|
14
|
+
Union,
|
|
15
|
+
cast,
|
|
16
|
+
overload,
|
|
8
17
|
)
|
|
18
|
+
from uuid import UUID
|
|
9
19
|
|
|
10
20
|
from sqlglot import parse_one
|
|
11
21
|
from sqlglot.dialects.dialect import DialectType
|
|
12
22
|
|
|
13
|
-
from sqlspec.exceptions import SQLConversionError, SQLParsingError
|
|
14
|
-
from sqlspec.typing import
|
|
23
|
+
from sqlspec.exceptions import SQLConversionError, SQLParsingError, SQLSpecError
|
|
24
|
+
from sqlspec.typing import (
|
|
25
|
+
ConnectionT,
|
|
26
|
+
ModelDTOT,
|
|
27
|
+
ModelT,
|
|
28
|
+
StatementParameterType,
|
|
29
|
+
convert,
|
|
30
|
+
get_type_adapter,
|
|
31
|
+
is_dataclass,
|
|
32
|
+
is_msgspec_struct,
|
|
33
|
+
is_pydantic_model,
|
|
34
|
+
)
|
|
15
35
|
|
|
16
36
|
if TYPE_CHECKING:
|
|
17
37
|
from sqlspec.typing import ArrowTable
|
|
@@ -154,3 +174,133 @@ class SQLTranslatorMixin(Generic[ConnectionT]):
|
|
|
154
174
|
except Exception as e:
|
|
155
175
|
error_msg = f"Failed to convert SQL to {to_dialect}: {e!s}"
|
|
156
176
|
raise SQLConversionError(error_msg) from e
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
_DEFAULT_TYPE_DECODERS = [ # pyright: ignore[reportUnknownVariableType]
|
|
180
|
+
(lambda x: x is UUID, lambda t, v: t(v.hex)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
|
|
181
|
+
(lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
|
|
182
|
+
(lambda x: x is datetime.date, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
|
|
183
|
+
(lambda x: x is datetime.time, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
|
|
184
|
+
(lambda x: x is Enum, lambda t, v: t(v.value)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _default_msgspec_deserializer(
|
|
189
|
+
target_type: Any,
|
|
190
|
+
value: Any,
|
|
191
|
+
type_decoders: "Union[Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]], None]" = None,
|
|
192
|
+
) -> Any: # pragma: no cover
|
|
193
|
+
"""Transform values non-natively supported by ``msgspec``
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
target_type: Encountered type
|
|
197
|
+
value: Value to coerce
|
|
198
|
+
type_decoders: Optional sequence of type decoders
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
TypeError: If the value cannot be coerced to the target type
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
A ``msgspec``-supported type
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
if isinstance(value, target_type):
|
|
208
|
+
return value
|
|
209
|
+
|
|
210
|
+
if type_decoders:
|
|
211
|
+
for predicate, decoder in type_decoders:
|
|
212
|
+
if predicate(target_type):
|
|
213
|
+
return decoder(target_type, value)
|
|
214
|
+
|
|
215
|
+
if issubclass(target_type, (Path, PurePath, UUID)):
|
|
216
|
+
return target_type(value)
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
return target_type(value)
|
|
220
|
+
except Exception as e:
|
|
221
|
+
msg = f"Unsupported type: {type(value)!r}"
|
|
222
|
+
raise TypeError(msg) from e
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class ResultConverter:
|
|
226
|
+
"""Simple mixin to help convert to dictionary or list of dictionaries to specified schema type.
|
|
227
|
+
|
|
228
|
+
Single objects are transformed to the supplied schema type, and lists of objects are transformed into a list of the supplied schema type.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
data: A database model instance or row mapping.
|
|
232
|
+
Type: :class:`~sqlspec.typing.ModelDictT`
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
The converted schema object.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
@overload
|
|
239
|
+
@staticmethod
|
|
240
|
+
def to_schema(data: "ModelT", *, schema_type: None = None) -> "ModelT": ...
|
|
241
|
+
@overload
|
|
242
|
+
@staticmethod
|
|
243
|
+
def to_schema(data: "dict[str, Any]", *, schema_type: "type[ModelDTOT]") -> "ModelDTOT": ...
|
|
244
|
+
@overload
|
|
245
|
+
@staticmethod
|
|
246
|
+
def to_schema(data: "Sequence[ModelT]", *, schema_type: None = None) -> "Sequence[ModelT]": ...
|
|
247
|
+
@overload
|
|
248
|
+
@staticmethod
|
|
249
|
+
def to_schema(data: "Sequence[dict[str, Any]]", *, schema_type: "type[ModelDTOT]") -> "Sequence[ModelDTOT]": ...
|
|
250
|
+
|
|
251
|
+
@staticmethod
|
|
252
|
+
def to_schema(
|
|
253
|
+
data: "Union[ModelT, Sequence[ModelT], dict[str, Any], Sequence[dict[str, Any]], ModelDTOT, Sequence[ModelDTOT]]",
|
|
254
|
+
*,
|
|
255
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
256
|
+
) -> "Union[ModelT, Sequence[ModelT], ModelDTOT, Sequence[ModelDTOT]]":
|
|
257
|
+
if schema_type is None:
|
|
258
|
+
if not isinstance(data, Sequence):
|
|
259
|
+
return cast("ModelT", data)
|
|
260
|
+
return cast("Sequence[ModelT]", data)
|
|
261
|
+
if is_dataclass(schema_type):
|
|
262
|
+
if not isinstance(data, Sequence):
|
|
263
|
+
# data is assumed to be dict[str, Any] as per the method's overloads
|
|
264
|
+
return cast("ModelDTOT", schema_type(**data)) # type: ignore[operator]
|
|
265
|
+
# data is assumed to be Sequence[dict[str, Any]]
|
|
266
|
+
return cast("Sequence[ModelDTOT]", [schema_type(**item) for item in data]) # type: ignore[operator]
|
|
267
|
+
if is_msgspec_struct(schema_type):
|
|
268
|
+
if not isinstance(data, Sequence):
|
|
269
|
+
return cast(
|
|
270
|
+
"ModelDTOT",
|
|
271
|
+
convert(
|
|
272
|
+
obj=data,
|
|
273
|
+
type=schema_type,
|
|
274
|
+
from_attributes=True,
|
|
275
|
+
dec_hook=partial(
|
|
276
|
+
_default_msgspec_deserializer,
|
|
277
|
+
type_decoders=_DEFAULT_TYPE_DECODERS,
|
|
278
|
+
),
|
|
279
|
+
),
|
|
280
|
+
)
|
|
281
|
+
return cast(
|
|
282
|
+
"Sequence[ModelDTOT]",
|
|
283
|
+
convert(
|
|
284
|
+
obj=data,
|
|
285
|
+
type=list[schema_type], # type: ignore[valid-type]
|
|
286
|
+
from_attributes=True,
|
|
287
|
+
dec_hook=partial(
|
|
288
|
+
_default_msgspec_deserializer,
|
|
289
|
+
type_decoders=_DEFAULT_TYPE_DECODERS,
|
|
290
|
+
),
|
|
291
|
+
),
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
if schema_type is not None and is_pydantic_model(schema_type):
|
|
295
|
+
if not isinstance(data, Sequence):
|
|
296
|
+
return cast(
|
|
297
|
+
"ModelDTOT",
|
|
298
|
+
get_type_adapter(schema_type).validate_python(data, from_attributes=True), # pyright: ignore
|
|
299
|
+
)
|
|
300
|
+
return cast(
|
|
301
|
+
"Sequence[ModelDTOT]",
|
|
302
|
+
get_type_adapter(list[schema_type]).validate_python(data, from_attributes=True), # type: ignore[valid-type] # pyright: ignore[reportUnknownArgumentType]
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
msg = "`schema_type` should be a valid Dataclass, Pydantic model or Msgspec struct"
|
|
306
|
+
raise SQLSpecError(msg)
|