sqlspec 0.15.0__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +699 -43
- sqlspec/builder/_base.py +77 -44
- sqlspec/builder/_column.py +0 -4
- sqlspec/builder/_ddl.py +15 -52
- sqlspec/builder/_ddl_utils.py +0 -1
- sqlspec/builder/_delete.py +4 -5
- sqlspec/builder/_insert.py +61 -35
- sqlspec/builder/_merge.py +17 -2
- sqlspec/builder/_parsing_utils.py +16 -12
- sqlspec/builder/_select.py +29 -33
- sqlspec/builder/_update.py +4 -2
- sqlspec/builder/mixins/_cte_and_set_ops.py +47 -20
- sqlspec/builder/mixins/_delete_operations.py +6 -1
- sqlspec/builder/mixins/_insert_operations.py +126 -24
- sqlspec/builder/mixins/_join_operations.py +11 -4
- sqlspec/builder/mixins/_merge_operations.py +91 -19
- sqlspec/builder/mixins/_order_limit_operations.py +15 -3
- sqlspec/builder/mixins/_pivot_operations.py +11 -2
- sqlspec/builder/mixins/_select_operations.py +16 -10
- sqlspec/builder/mixins/_update_operations.py +43 -10
- sqlspec/builder/mixins/_where_clause.py +177 -65
- sqlspec/core/cache.py +26 -28
- sqlspec/core/compiler.py +58 -37
- sqlspec/core/filters.py +12 -10
- sqlspec/core/parameters.py +80 -52
- sqlspec/core/result.py +30 -17
- sqlspec/core/statement.py +47 -22
- sqlspec/driver/_async.py +76 -46
- sqlspec/driver/_common.py +25 -6
- sqlspec/driver/_sync.py +73 -43
- sqlspec/driver/mixins/_result_tools.py +62 -37
- sqlspec/driver/mixins/_sql_translator.py +61 -11
- sqlspec/extensions/litestar/cli.py +1 -1
- sqlspec/extensions/litestar/plugin.py +2 -2
- sqlspec/protocols.py +7 -0
- sqlspec/utils/sync_tools.py +1 -1
- sqlspec/utils/type_guards.py +7 -3
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/METADATA +1 -1
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/RECORD +43 -43
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/driver/_sync.py
CHANGED
|
@@ -5,7 +5,7 @@ including connection management, transaction support, and result processing.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from abc import abstractmethod
|
|
8
|
-
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Final, NoReturn, Optional, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from sqlspec.core import SQL
|
|
11
11
|
from sqlspec.driver._common import CommonDriverAttributesMixin, ExecutionResult
|
|
@@ -20,14 +20,15 @@ if TYPE_CHECKING:
|
|
|
20
20
|
|
|
21
21
|
from sqlspec.builder import QueryBuilder
|
|
22
22
|
from sqlspec.core import SQLResult, Statement, StatementConfig, StatementFilter
|
|
23
|
-
from sqlspec.typing import ModelDTOT,
|
|
23
|
+
from sqlspec.typing import ModelDTOT, StatementParameters
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
_LOGGER_NAME: Final[str] = "sqlspec"
|
|
26
|
+
logger = get_logger(_LOGGER_NAME)
|
|
26
27
|
|
|
27
28
|
__all__ = ("SyncDriverAdapterBase",)
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
EMPTY_FILTERS: "list[StatementFilter]" = []
|
|
31
|
+
EMPTY_FILTERS: Final["list[StatementFilter]"] = []
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToSchemaMixin):
|
|
@@ -128,12 +129,16 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
128
129
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
129
130
|
statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True)
|
|
130
131
|
|
|
132
|
+
statement_count: int = len(statements)
|
|
133
|
+
successful_count: int = 0
|
|
134
|
+
|
|
131
135
|
for stmt in statements:
|
|
132
136
|
single_stmt = statement.copy(statement=stmt, parameters=prepared_parameters)
|
|
133
137
|
self._execute_statement(cursor, single_stmt)
|
|
138
|
+
successful_count += 1
|
|
134
139
|
|
|
135
140
|
return self.create_execution_result(
|
|
136
|
-
cursor, statement_count=
|
|
141
|
+
cursor, statement_count=statement_count, successful_statements=successful_count, is_script_result=True
|
|
137
142
|
)
|
|
138
143
|
|
|
139
144
|
@abstractmethod
|
|
@@ -214,8 +219,8 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
214
219
|
By default, validates each statement and logs warnings for dangerous
|
|
215
220
|
operations. Use suppress_warnings=True for migrations and admin scripts.
|
|
216
221
|
"""
|
|
217
|
-
|
|
218
|
-
sql_statement = self.prepare_statement(statement, parameters, statement_config=
|
|
222
|
+
config = statement_config or self.statement_config
|
|
223
|
+
sql_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
|
|
219
224
|
|
|
220
225
|
return self.dispatch_statement_execution(statement=sql_statement.as_script(), connection=self.connection)
|
|
221
226
|
|
|
@@ -239,7 +244,7 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
239
244
|
schema_type: None = None,
|
|
240
245
|
statement_config: "Optional[StatementConfig]" = None,
|
|
241
246
|
**kwargs: Any,
|
|
242
|
-
) -> "
|
|
247
|
+
) -> "dict[str, Any]": ...
|
|
243
248
|
|
|
244
249
|
def select_one(
|
|
245
250
|
self,
|
|
@@ -249,23 +254,20 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
249
254
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
250
255
|
statement_config: "Optional[StatementConfig]" = None,
|
|
251
256
|
**kwargs: Any,
|
|
252
|
-
) -> "Union[
|
|
257
|
+
) -> "Union[dict[str, Any], ModelDTOT]":
|
|
253
258
|
"""Execute a select statement and return exactly one row.
|
|
254
259
|
|
|
255
260
|
Raises an exception if no rows or more than one row is returned.
|
|
256
261
|
"""
|
|
257
262
|
result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
258
263
|
data = result.get_data()
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
if
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
return
|
|
266
|
-
"Union[ModelT, RowT, ModelDTOT]",
|
|
267
|
-
self.to_schema(data[0], schema_type=schema_type) if schema_type else data[0],
|
|
268
|
-
)
|
|
264
|
+
data_len: int = len(data)
|
|
265
|
+
if data_len == 0:
|
|
266
|
+
self._raise_no_rows_found()
|
|
267
|
+
if data_len > 1:
|
|
268
|
+
self._raise_expected_one_row(data_len)
|
|
269
|
+
first_row = data[0]
|
|
270
|
+
return self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row
|
|
269
271
|
|
|
270
272
|
@overload
|
|
271
273
|
def select_one_or_none(
|
|
@@ -287,7 +289,7 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
287
289
|
schema_type: None = None,
|
|
288
290
|
statement_config: "Optional[StatementConfig]" = None,
|
|
289
291
|
**kwargs: Any,
|
|
290
|
-
) -> "Optional[
|
|
292
|
+
) -> "Optional[dict[str, Any]]": ...
|
|
291
293
|
|
|
292
294
|
def select_one_or_none(
|
|
293
295
|
self,
|
|
@@ -297,7 +299,7 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
297
299
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
298
300
|
statement_config: "Optional[StatementConfig]" = None,
|
|
299
301
|
**kwargs: Any,
|
|
300
|
-
) -> "Optional[Union[
|
|
302
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
301
303
|
"""Execute a select statement and return at most one row.
|
|
302
304
|
|
|
303
305
|
Returns None if no rows are found.
|
|
@@ -305,12 +307,16 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
305
307
|
"""
|
|
306
308
|
result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
307
309
|
data = result.get_data()
|
|
308
|
-
|
|
310
|
+
data_len: int = len(data)
|
|
311
|
+
if data_len == 0:
|
|
309
312
|
return None
|
|
310
|
-
if
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
return cast(
|
|
313
|
+
if data_len > 1:
|
|
314
|
+
self._raise_expected_at_most_one_row(data_len)
|
|
315
|
+
first_row = data[0]
|
|
316
|
+
return cast(
|
|
317
|
+
"Optional[Union[dict[str, Any], ModelDTOT]]",
|
|
318
|
+
self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row,
|
|
319
|
+
)
|
|
314
320
|
|
|
315
321
|
@overload
|
|
316
322
|
def select(
|
|
@@ -332,7 +338,7 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
332
338
|
schema_type: None = None,
|
|
333
339
|
statement_config: "Optional[StatementConfig]" = None,
|
|
334
340
|
**kwargs: Any,
|
|
335
|
-
) -> "list[
|
|
341
|
+
) -> "list[dict[str, Any]]": ...
|
|
336
342
|
|
|
337
343
|
def select(
|
|
338
344
|
self,
|
|
@@ -342,12 +348,11 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
342
348
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
343
349
|
statement_config: "Optional[StatementConfig]" = None,
|
|
344
350
|
**kwargs: Any,
|
|
345
|
-
) -> "Union[list[
|
|
351
|
+
) -> "Union[list[dict[str, Any]], list[ModelDTOT]]":
|
|
346
352
|
"""Execute a select statement and return all rows."""
|
|
347
353
|
result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
348
354
|
return cast(
|
|
349
|
-
"Union[list[
|
|
350
|
-
self.to_schema(cast("list[ModelT]", result.get_data()), schema_type=schema_type),
|
|
355
|
+
"Union[list[dict[str, Any]], list[ModelDTOT]]", self.to_schema(result.get_data(), schema_type=schema_type)
|
|
351
356
|
)
|
|
352
357
|
|
|
353
358
|
def select_value(
|
|
@@ -367,23 +372,19 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
367
372
|
try:
|
|
368
373
|
row = result.one()
|
|
369
374
|
except ValueError as e:
|
|
370
|
-
|
|
371
|
-
raise NotFoundError(msg) from e
|
|
375
|
+
self._raise_no_rows_found_from_exception(e)
|
|
372
376
|
if not row:
|
|
373
|
-
|
|
374
|
-
raise NotFoundError(msg)
|
|
377
|
+
self._raise_no_rows_found()
|
|
375
378
|
if is_dict_row(row):
|
|
376
379
|
if not row:
|
|
377
|
-
|
|
378
|
-
raise ValueError(msg)
|
|
380
|
+
self._raise_row_no_columns()
|
|
379
381
|
return next(iter(row.values()))
|
|
380
382
|
if is_indexable_row(row):
|
|
381
383
|
if not row:
|
|
382
|
-
|
|
383
|
-
raise ValueError(msg)
|
|
384
|
+
self._raise_row_no_columns()
|
|
384
385
|
return row[0]
|
|
385
|
-
|
|
386
|
-
|
|
386
|
+
self._raise_unexpected_row_type(type(row))
|
|
387
|
+
return None
|
|
387
388
|
|
|
388
389
|
def select_value_or_none(
|
|
389
390
|
self,
|
|
@@ -401,10 +402,11 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
401
402
|
"""
|
|
402
403
|
result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
403
404
|
data = result.get_data()
|
|
404
|
-
|
|
405
|
+
data_len: int = len(data)
|
|
406
|
+
if data_len == 0:
|
|
405
407
|
return None
|
|
406
|
-
if
|
|
407
|
-
msg = f"Expected at most one row, found {
|
|
408
|
+
if data_len > 1:
|
|
409
|
+
msg = f"Expected at most one row, found {data_len}"
|
|
408
410
|
raise ValueError(msg)
|
|
409
411
|
row = data[0]
|
|
410
412
|
if isinstance(row, dict):
|
|
@@ -471,3 +473,31 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
|
|
|
471
473
|
select_result = self.execute(sql_statement)
|
|
472
474
|
|
|
473
475
|
return (self.to_schema(select_result.get_data(), schema_type=schema_type), count_result.scalar())
|
|
476
|
+
|
|
477
|
+
def _raise_no_rows_found(self) -> NoReturn:
|
|
478
|
+
msg = "No rows found"
|
|
479
|
+
raise NotFoundError(msg)
|
|
480
|
+
|
|
481
|
+
def _raise_no_rows_found_from_exception(self, e: ValueError) -> NoReturn:
|
|
482
|
+
msg = "No rows found"
|
|
483
|
+
raise NotFoundError(msg) from e
|
|
484
|
+
|
|
485
|
+
def _raise_expected_one_row(self, data_len: int) -> NoReturn:
|
|
486
|
+
msg = f"Expected exactly one row, found {data_len}"
|
|
487
|
+
raise ValueError(msg)
|
|
488
|
+
|
|
489
|
+
def _raise_expected_at_most_one_row(self, data_len: int) -> NoReturn:
|
|
490
|
+
msg = f"Expected at most one row, found {data_len}"
|
|
491
|
+
raise ValueError(msg)
|
|
492
|
+
|
|
493
|
+
def _raise_row_no_columns(self) -> NoReturn:
|
|
494
|
+
msg = "Row has no columns"
|
|
495
|
+
raise ValueError(msg)
|
|
496
|
+
|
|
497
|
+
def _raise_unexpected_row_type(self, row_type: type) -> NoReturn:
|
|
498
|
+
msg = f"Unexpected row type: {row_type}"
|
|
499
|
+
raise ValueError(msg)
|
|
500
|
+
|
|
501
|
+
def _raise_cannot_extract_value_from_row_type(self, type_name: str) -> NoReturn:
|
|
502
|
+
msg = f"Cannot extract value from row type {type_name}"
|
|
503
|
+
raise TypeError(msg)
|
|
@@ -5,16 +5,14 @@ from collections.abc import Sequence
|
|
|
5
5
|
from enum import Enum
|
|
6
6
|
from functools import partial
|
|
7
7
|
from pathlib import Path, PurePath
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Any, Callable, Final, Optional, overload
|
|
9
9
|
from uuid import UUID
|
|
10
10
|
|
|
11
11
|
from mypy_extensions import trait
|
|
12
12
|
|
|
13
|
-
from sqlspec.exceptions import SQLSpecError
|
|
13
|
+
from sqlspec.exceptions import SQLSpecError
|
|
14
14
|
from sqlspec.typing import (
|
|
15
15
|
CATTRS_INSTALLED,
|
|
16
|
-
DataclassProtocol,
|
|
17
|
-
DictLike,
|
|
18
16
|
ModelDTOT,
|
|
19
17
|
ModelT,
|
|
20
18
|
attrs_asdict,
|
|
@@ -25,14 +23,16 @@ from sqlspec.typing import (
|
|
|
25
23
|
)
|
|
26
24
|
from sqlspec.utils.type_guards import is_attrs_schema, is_dataclass, is_msgspec_struct, is_pydantic_model
|
|
27
25
|
|
|
28
|
-
if TYPE_CHECKING:
|
|
29
|
-
from sqlspec._typing import AttrsInstanceStub, BaseModelStub, StructStub
|
|
30
|
-
|
|
31
26
|
__all__ = ("_DEFAULT_TYPE_DECODERS", "_default_msgspec_deserializer")
|
|
32
27
|
|
|
33
28
|
|
|
34
29
|
logger = logging.getLogger(__name__)
|
|
35
|
-
|
|
30
|
+
|
|
31
|
+
# Constants for performance optimization
|
|
32
|
+
_DATETIME_TYPES: Final[set[type]] = {datetime.datetime, datetime.date, datetime.time}
|
|
33
|
+
_PATH_TYPES: Final[tuple[type, ...]] = (Path, PurePath, UUID)
|
|
34
|
+
|
|
35
|
+
_DEFAULT_TYPE_DECODERS: Final[list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]] = [
|
|
36
36
|
(lambda x: x is UUID, lambda t, v: t(v.hex)),
|
|
37
37
|
(lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())),
|
|
38
38
|
(lambda x: x is datetime.date, lambda t, v: t(v.isoformat())),
|
|
@@ -53,17 +53,32 @@ def _default_msgspec_deserializer(
|
|
|
53
53
|
for predicate, decoder in type_decoders:
|
|
54
54
|
if predicate(target_type):
|
|
55
55
|
return decoder(target_type, value)
|
|
56
|
+
|
|
57
|
+
# Fast path checks using type identity and isinstance
|
|
56
58
|
if target_type is UUID and isinstance(value, UUID):
|
|
57
59
|
return value.hex
|
|
58
|
-
|
|
59
|
-
|
|
60
|
+
|
|
61
|
+
# Use pre-computed set for faster lookup
|
|
62
|
+
if target_type in _DATETIME_TYPES:
|
|
63
|
+
try:
|
|
60
64
|
return value.isoformat()
|
|
65
|
+
except AttributeError:
|
|
66
|
+
pass
|
|
67
|
+
|
|
61
68
|
if isinstance(target_type, type) and issubclass(target_type, Enum) and isinstance(value, Enum):
|
|
62
69
|
return value.value
|
|
70
|
+
|
|
63
71
|
if isinstance(value, target_type):
|
|
64
72
|
return value
|
|
65
|
-
|
|
66
|
-
|
|
73
|
+
|
|
74
|
+
# Check for path types using pre-computed tuple
|
|
75
|
+
if isinstance(target_type, type):
|
|
76
|
+
try:
|
|
77
|
+
if issubclass(target_type, (Path, PurePath)) or issubclass(target_type, UUID):
|
|
78
|
+
return target_type(str(value))
|
|
79
|
+
except (TypeError, ValueError):
|
|
80
|
+
pass
|
|
81
|
+
|
|
67
82
|
return value
|
|
68
83
|
|
|
69
84
|
|
|
@@ -74,36 +89,37 @@ class ToSchemaMixin:
|
|
|
74
89
|
# Schema conversion overloads - handle common cases first
|
|
75
90
|
@overload
|
|
76
91
|
@staticmethod
|
|
92
|
+
def to_schema(data: "list[dict[str, Any]]") -> "list[dict[str, Any]]": ...
|
|
93
|
+
@overload
|
|
94
|
+
@staticmethod
|
|
77
95
|
def to_schema(data: "list[dict[str, Any]]", *, schema_type: "type[ModelDTOT]") -> "list[ModelDTOT]": ...
|
|
78
96
|
@overload
|
|
79
97
|
@staticmethod
|
|
80
98
|
def to_schema(data: "list[dict[str, Any]]", *, schema_type: None = None) -> "list[dict[str, Any]]": ...
|
|
81
99
|
@overload
|
|
82
100
|
@staticmethod
|
|
101
|
+
def to_schema(data: "dict[str, Any]") -> "dict[str, Any]": ...
|
|
102
|
+
@overload
|
|
103
|
+
@staticmethod
|
|
83
104
|
def to_schema(data: "dict[str, Any]", *, schema_type: "type[ModelDTOT]") -> "ModelDTOT": ...
|
|
84
105
|
@overload
|
|
85
106
|
@staticmethod
|
|
86
107
|
def to_schema(data: "dict[str, Any]", *, schema_type: None = None) -> "dict[str, Any]": ...
|
|
87
108
|
@overload
|
|
88
109
|
@staticmethod
|
|
110
|
+
def to_schema(data: "list[ModelT]") -> "list[ModelT]": ...
|
|
111
|
+
@overload
|
|
112
|
+
@staticmethod
|
|
89
113
|
def to_schema(data: "list[ModelT]", *, schema_type: "type[ModelDTOT]") -> "list[ModelDTOT]": ...
|
|
90
114
|
@overload
|
|
91
115
|
@staticmethod
|
|
92
116
|
def to_schema(data: "list[ModelT]", *, schema_type: None = None) -> "list[ModelT]": ...
|
|
93
117
|
@overload
|
|
94
118
|
@staticmethod
|
|
95
|
-
def to_schema(
|
|
96
|
-
data: "Union[DictLike, StructStub, BaseModelStub, DataclassProtocol, AttrsInstanceStub]",
|
|
97
|
-
*,
|
|
98
|
-
schema_type: "type[ModelDTOT]",
|
|
99
|
-
) -> "ModelDTOT": ...
|
|
119
|
+
def to_schema(data: "ModelT") -> "ModelT": ...
|
|
100
120
|
@overload
|
|
101
121
|
@staticmethod
|
|
102
|
-
def to_schema(
|
|
103
|
-
data: "Union[DictLike, StructStub, BaseModelStub, DataclassProtocol, AttrsInstanceStub]",
|
|
104
|
-
*,
|
|
105
|
-
schema_type: None = None,
|
|
106
|
-
) -> "Union[DictLike, StructStub, BaseModelStub, DataclassProtocol, AttrsInstanceStub]": ...
|
|
122
|
+
def to_schema(data: Any, *, schema_type: None = None) -> Any: ...
|
|
107
123
|
|
|
108
124
|
@staticmethod
|
|
109
125
|
def to_schema(data: Any, *, schema_type: "Optional[type[ModelDTOT]]" = None) -> Any:
|
|
@@ -123,46 +139,55 @@ class ToSchemaMixin:
|
|
|
123
139
|
return data
|
|
124
140
|
if is_dataclass(schema_type):
|
|
125
141
|
if isinstance(data, list):
|
|
126
|
-
|
|
142
|
+
result: list[Any] = []
|
|
143
|
+
for item in data:
|
|
144
|
+
if hasattr(item, "keys"):
|
|
145
|
+
result.append(schema_type(**dict(item))) # type: ignore[operator]
|
|
146
|
+
else:
|
|
147
|
+
result.append(item)
|
|
148
|
+
return result
|
|
127
149
|
if hasattr(data, "keys"):
|
|
128
150
|
return schema_type(**dict(data)) # type: ignore[operator]
|
|
129
151
|
if isinstance(data, dict):
|
|
130
152
|
return schema_type(**data) # type: ignore[operator]
|
|
131
|
-
# Fallback for other types
|
|
132
153
|
return data
|
|
133
154
|
if is_msgspec_struct(schema_type):
|
|
155
|
+
# Cache the deserializer to avoid repeated partial() calls
|
|
156
|
+
deserializer = partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS)
|
|
134
157
|
if not isinstance(data, Sequence):
|
|
135
|
-
return convert(
|
|
136
|
-
obj=data,
|
|
137
|
-
type=schema_type,
|
|
138
|
-
from_attributes=True,
|
|
139
|
-
dec_hook=partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS),
|
|
140
|
-
)
|
|
158
|
+
return convert(obj=data, type=schema_type, from_attributes=True, dec_hook=deserializer)
|
|
141
159
|
return convert(
|
|
142
160
|
obj=data,
|
|
143
161
|
type=list[schema_type], # type: ignore[valid-type] # pyright: ignore
|
|
144
162
|
from_attributes=True,
|
|
145
|
-
dec_hook=
|
|
163
|
+
dec_hook=deserializer,
|
|
146
164
|
)
|
|
147
165
|
if is_pydantic_model(schema_type):
|
|
148
166
|
if not isinstance(data, Sequence):
|
|
149
|
-
|
|
150
|
-
|
|
167
|
+
adapter = get_type_adapter(schema_type)
|
|
168
|
+
return adapter.validate_python(data, from_attributes=True) # pyright: ignore
|
|
169
|
+
list_adapter = get_type_adapter(list[schema_type]) # type: ignore[valid-type] # pyright: ignore
|
|
170
|
+
return list_adapter.validate_python(data, from_attributes=True)
|
|
151
171
|
if is_attrs_schema(schema_type):
|
|
152
172
|
if CATTRS_INSTALLED:
|
|
153
173
|
if isinstance(data, Sequence):
|
|
154
174
|
return cattrs_structure(data, list[schema_type]) # type: ignore[valid-type] # pyright: ignore
|
|
155
|
-
# If data is already structured (attrs instance), unstructure it first
|
|
156
175
|
if hasattr(data, "__attrs_attrs__"):
|
|
157
|
-
|
|
176
|
+
unstructured_data = cattrs_unstructure(data)
|
|
177
|
+
return cattrs_structure(unstructured_data, schema_type) # pyright: ignore
|
|
158
178
|
return cattrs_structure(data, schema_type) # pyright: ignore
|
|
159
179
|
if isinstance(data, list):
|
|
160
|
-
|
|
180
|
+
attrs_result: list[Any] = []
|
|
181
|
+
for item in data:
|
|
182
|
+
if hasattr(item, "keys"):
|
|
183
|
+
attrs_result.append(schema_type(**dict(item)))
|
|
184
|
+
else:
|
|
185
|
+
attrs_result.append(schema_type(**attrs_asdict(item)))
|
|
186
|
+
return attrs_result
|
|
161
187
|
if hasattr(data, "keys"):
|
|
162
188
|
return schema_type(**dict(data))
|
|
163
189
|
if isinstance(data, dict):
|
|
164
190
|
return schema_type(**data)
|
|
165
|
-
# Fallback for other types
|
|
166
191
|
return data
|
|
167
192
|
msg = "`schema_type` should be a valid Dataclass, Pydantic model, Msgspec struct, or Attrs class"
|
|
168
193
|
raise SQLSpecError(msg)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Final, NoReturn, Optional
|
|
2
|
+
|
|
1
3
|
from mypy_extensions import trait
|
|
2
4
|
from sqlglot import exp, parse_one
|
|
3
5
|
from sqlglot.dialects.dialect import DialectType
|
|
@@ -7,6 +9,9 @@ from sqlspec.exceptions import SQLConversionError
|
|
|
7
9
|
|
|
8
10
|
__all__ = ("SQLTranslatorMixin",)
|
|
9
11
|
|
|
12
|
+
# Constants for better performance
|
|
13
|
+
_DEFAULT_PRETTY: Final[bool] = True
|
|
14
|
+
|
|
10
15
|
|
|
11
16
|
@trait
|
|
12
17
|
class SQLTranslatorMixin:
|
|
@@ -14,23 +19,68 @@ class SQLTranslatorMixin:
|
|
|
14
19
|
|
|
15
20
|
__slots__ = ()
|
|
16
21
|
|
|
17
|
-
def convert_to_dialect(
|
|
22
|
+
def convert_to_dialect(
|
|
23
|
+
self, statement: "Statement", to_dialect: "Optional[DialectType]" = None, pretty: bool = _DEFAULT_PRETTY
|
|
24
|
+
) -> str:
|
|
25
|
+
"""Convert a statement to a target SQL dialect.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
statement: SQL statement to convert
|
|
29
|
+
to_dialect: Target dialect (defaults to current dialect)
|
|
30
|
+
pretty: Whether to format the output SQL
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
SQL string in target dialect
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
SQLConversionError: If parsing or conversion fails
|
|
37
|
+
"""
|
|
38
|
+
# Fast path: get the parsed expression with minimal allocations
|
|
39
|
+
parsed_expression: Optional[exp.Expression] = None
|
|
40
|
+
|
|
18
41
|
if statement is not None and isinstance(statement, SQL):
|
|
19
42
|
if statement.expression is None:
|
|
20
|
-
|
|
21
|
-
raise SQLConversionError(msg)
|
|
43
|
+
self._raise_statement_parse_error()
|
|
22
44
|
parsed_expression = statement.expression
|
|
23
45
|
elif isinstance(statement, exp.Expression):
|
|
24
46
|
parsed_expression = statement
|
|
25
47
|
else:
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
error_msg = f"Failed to parse SQL statement: {e!s}"
|
|
30
|
-
raise SQLConversionError(error_msg) from e
|
|
48
|
+
parsed_expression = self._parse_statement_safely(statement)
|
|
49
|
+
|
|
50
|
+
# Get target dialect with fallback
|
|
31
51
|
target_dialect = to_dialect or self.dialect # type: ignore[attr-defined]
|
|
52
|
+
|
|
53
|
+
# Generate SQL with error handling
|
|
54
|
+
return self._generate_sql_safely(parsed_expression, target_dialect, pretty)
|
|
55
|
+
|
|
56
|
+
def _parse_statement_safely(self, statement: "Statement") -> "exp.Expression":
|
|
57
|
+
"""Parse statement with copy=False optimization and proper error handling."""
|
|
58
|
+
try:
|
|
59
|
+
# Convert statement to string if needed
|
|
60
|
+
sql_string = str(statement)
|
|
61
|
+
# Use copy=False for better performance
|
|
62
|
+
return parse_one(sql_string, dialect=self.dialect, copy=False) # type: ignore[attr-defined]
|
|
63
|
+
except Exception as e:
|
|
64
|
+
self._raise_parse_error(e)
|
|
65
|
+
|
|
66
|
+
def _generate_sql_safely(self, expression: "exp.Expression", dialect: DialectType, pretty: bool) -> str:
|
|
67
|
+
"""Generate SQL with proper error handling."""
|
|
32
68
|
try:
|
|
33
|
-
return
|
|
69
|
+
return expression.sql(dialect=dialect, pretty=pretty)
|
|
34
70
|
except Exception as e:
|
|
35
|
-
|
|
36
|
-
|
|
71
|
+
self._raise_conversion_error(dialect, e)
|
|
72
|
+
|
|
73
|
+
def _raise_statement_parse_error(self) -> NoReturn:
|
|
74
|
+
"""Raise error for unparsable statements."""
|
|
75
|
+
msg = "Statement could not be parsed"
|
|
76
|
+
raise SQLConversionError(msg)
|
|
77
|
+
|
|
78
|
+
def _raise_parse_error(self, e: Exception) -> NoReturn:
|
|
79
|
+
"""Raise error for parsing failures."""
|
|
80
|
+
error_msg = f"Failed to parse SQL statement: {e!s}"
|
|
81
|
+
raise SQLConversionError(error_msg) from e
|
|
82
|
+
|
|
83
|
+
def _raise_conversion_error(self, dialect: DialectType, e: Exception) -> NoReturn:
|
|
84
|
+
"""Raise error for conversion failures."""
|
|
85
|
+
error_msg = f"Failed to convert SQL expression to {dialect}: {e!s}"
|
|
86
|
+
raise SQLConversionError(error_msg) from e
|
|
@@ -39,7 +39,7 @@ def get_database_migration_plugin(app: "Litestar") -> "SQLSpec":
|
|
|
39
39
|
raise ImproperConfigurationError(msg)
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
@click.group(cls=LitestarGroup, name="
|
|
42
|
+
@click.group(cls=LitestarGroup, name="db")
|
|
43
43
|
def database_group(ctx: "click.Context") -> None:
|
|
44
44
|
"""Manage SQLSpec database components."""
|
|
45
45
|
ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config}
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, Any, Union
|
|
2
2
|
|
|
3
3
|
from litestar.di import Provide
|
|
4
|
-
from litestar.plugins import InitPluginProtocol
|
|
4
|
+
from litestar.plugins import CLIPlugin, InitPluginProtocol
|
|
5
5
|
|
|
6
6
|
from sqlspec.base import SQLSpec as SQLSpecBase
|
|
7
7
|
from sqlspec.config import AsyncConfigT, DatabaseConfigProtocol, DriverT, SyncConfigT
|
|
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|
|
17
17
|
logger = get_logger("extensions.litestar")
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class SQLSpec(InitPluginProtocol, SQLSpecBase):
|
|
20
|
+
class SQLSpec(InitPluginProtocol, CLIPlugin, SQLSpecBase):
|
|
21
21
|
"""Litestar plugin for SQLSpec database integration."""
|
|
22
22
|
|
|
23
23
|
__slots__ = ("_config", "_plugin_configs")
|
sqlspec/protocols.py
CHANGED
|
@@ -371,6 +371,9 @@ class SQLBuilderProtocol(Protocol):
|
|
|
371
371
|
_expression: "Optional[exp.Expression]"
|
|
372
372
|
_parameters: dict[str, Any]
|
|
373
373
|
_parameter_counter: int
|
|
374
|
+
_columns: Any # Optional attribute for some builders
|
|
375
|
+
_table: Any # Optional attribute for some builders
|
|
376
|
+
_with_ctes: Any # Optional attribute for some builders
|
|
374
377
|
dialect: Any
|
|
375
378
|
dialect_name: "Optional[str]"
|
|
376
379
|
|
|
@@ -383,6 +386,10 @@ class SQLBuilderProtocol(Protocol):
|
|
|
383
386
|
"""Add a parameter to the builder."""
|
|
384
387
|
...
|
|
385
388
|
|
|
389
|
+
def _generate_unique_parameter_name(self, base_name: str) -> str:
|
|
390
|
+
"""Generate a unique parameter name."""
|
|
391
|
+
...
|
|
392
|
+
|
|
386
393
|
def _parameterize_expression(self, expression: "exp.Expression") -> "exp.Expression":
|
|
387
394
|
"""Replace literal values in an expression with bound parameters."""
|
|
388
395
|
...
|
sqlspec/utils/sync_tools.py
CHANGED
sqlspec/utils/type_guards.py
CHANGED
|
@@ -841,9 +841,13 @@ def has_sql_method(obj: Any) -> "TypeGuard[HasSQLMethodProtocol]":
|
|
|
841
841
|
|
|
842
842
|
def has_query_builder_parameters(obj: Any) -> "TypeGuard[SQLBuilderProtocol]":
|
|
843
843
|
"""Check if an object is a query builder with parameters property."""
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
844
|
+
return (
|
|
845
|
+
hasattr(obj, "build")
|
|
846
|
+
and callable(getattr(obj, "build", None))
|
|
847
|
+
and hasattr(obj, "parameters")
|
|
848
|
+
and hasattr(obj, "add_parameter")
|
|
849
|
+
and callable(getattr(obj, "add_parameter", None))
|
|
850
|
+
)
|
|
847
851
|
|
|
848
852
|
|
|
849
853
|
def is_object_store_item(obj: Any) -> "TypeGuard[ObjectStoreItemProtocol]":
|