sqlspec 0.15.0__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (43) hide show
  1. sqlspec/_sql.py +699 -43
  2. sqlspec/builder/_base.py +77 -44
  3. sqlspec/builder/_column.py +0 -4
  4. sqlspec/builder/_ddl.py +15 -52
  5. sqlspec/builder/_ddl_utils.py +0 -1
  6. sqlspec/builder/_delete.py +4 -5
  7. sqlspec/builder/_insert.py +61 -35
  8. sqlspec/builder/_merge.py +17 -2
  9. sqlspec/builder/_parsing_utils.py +16 -12
  10. sqlspec/builder/_select.py +29 -33
  11. sqlspec/builder/_update.py +4 -2
  12. sqlspec/builder/mixins/_cte_and_set_ops.py +47 -20
  13. sqlspec/builder/mixins/_delete_operations.py +6 -1
  14. sqlspec/builder/mixins/_insert_operations.py +126 -24
  15. sqlspec/builder/mixins/_join_operations.py +11 -4
  16. sqlspec/builder/mixins/_merge_operations.py +91 -19
  17. sqlspec/builder/mixins/_order_limit_operations.py +15 -3
  18. sqlspec/builder/mixins/_pivot_operations.py +11 -2
  19. sqlspec/builder/mixins/_select_operations.py +16 -10
  20. sqlspec/builder/mixins/_update_operations.py +43 -10
  21. sqlspec/builder/mixins/_where_clause.py +177 -65
  22. sqlspec/core/cache.py +26 -28
  23. sqlspec/core/compiler.py +58 -37
  24. sqlspec/core/filters.py +12 -10
  25. sqlspec/core/parameters.py +80 -52
  26. sqlspec/core/result.py +30 -17
  27. sqlspec/core/statement.py +47 -22
  28. sqlspec/driver/_async.py +76 -46
  29. sqlspec/driver/_common.py +25 -6
  30. sqlspec/driver/_sync.py +73 -43
  31. sqlspec/driver/mixins/_result_tools.py +62 -37
  32. sqlspec/driver/mixins/_sql_translator.py +61 -11
  33. sqlspec/extensions/litestar/cli.py +1 -1
  34. sqlspec/extensions/litestar/plugin.py +2 -2
  35. sqlspec/protocols.py +7 -0
  36. sqlspec/utils/sync_tools.py +1 -1
  37. sqlspec/utils/type_guards.py +7 -3
  38. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/METADATA +1 -1
  39. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/RECORD +43 -43
  40. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/WHEEL +0 -0
  41. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/entry_points.txt +0 -0
  42. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/LICENSE +0 -0
  43. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/driver/_sync.py CHANGED
@@ -5,7 +5,7 @@ including connection management, transaction support, and result processing.
5
5
  """
6
6
 
7
7
  from abc import abstractmethod
8
- from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
8
+ from typing import TYPE_CHECKING, Any, Final, NoReturn, Optional, Union, cast, overload
9
9
 
10
10
  from sqlspec.core import SQL
11
11
  from sqlspec.driver._common import CommonDriverAttributesMixin, ExecutionResult
@@ -20,14 +20,15 @@ if TYPE_CHECKING:
20
20
 
21
21
  from sqlspec.builder import QueryBuilder
22
22
  from sqlspec.core import SQLResult, Statement, StatementConfig, StatementFilter
23
- from sqlspec.typing import ModelDTOT, ModelT, RowT, StatementParameters
23
+ from sqlspec.typing import ModelDTOT, StatementParameters
24
24
 
25
- logger = get_logger("sqlspec")
25
+ _LOGGER_NAME: Final[str] = "sqlspec"
26
+ logger = get_logger(_LOGGER_NAME)
26
27
 
27
28
  __all__ = ("SyncDriverAdapterBase",)
28
29
 
29
30
 
30
- EMPTY_FILTERS: "list[StatementFilter]" = []
31
+ EMPTY_FILTERS: Final["list[StatementFilter]"] = []
31
32
 
32
33
 
33
34
  class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToSchemaMixin):
@@ -128,12 +129,16 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
128
129
  sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
129
130
  statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True)
130
131
 
132
+ statement_count: int = len(statements)
133
+ successful_count: int = 0
134
+
131
135
  for stmt in statements:
132
136
  single_stmt = statement.copy(statement=stmt, parameters=prepared_parameters)
133
137
  self._execute_statement(cursor, single_stmt)
138
+ successful_count += 1
134
139
 
135
140
  return self.create_execution_result(
136
- cursor, statement_count=len(statements), successful_statements=len(statements), is_script_result=True
141
+ cursor, statement_count=statement_count, successful_statements=successful_count, is_script_result=True
137
142
  )
138
143
 
139
144
  @abstractmethod
@@ -214,8 +219,8 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
214
219
  By default, validates each statement and logs warnings for dangerous
215
220
  operations. Use suppress_warnings=True for migrations and admin scripts.
216
221
  """
217
- script_config = statement_config or self.statement_config
218
- sql_statement = self.prepare_statement(statement, parameters, statement_config=script_config, kwargs=kwargs)
222
+ config = statement_config or self.statement_config
223
+ sql_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
219
224
 
220
225
  return self.dispatch_statement_execution(statement=sql_statement.as_script(), connection=self.connection)
221
226
 
@@ -239,7 +244,7 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
239
244
  schema_type: None = None,
240
245
  statement_config: "Optional[StatementConfig]" = None,
241
246
  **kwargs: Any,
242
- ) -> "Union[ModelT, RowT, dict[str, Any]]": ... # pyright: ignore[reportInvalidTypeVarUse]
247
+ ) -> "dict[str, Any]": ...
243
248
 
244
249
  def select_one(
245
250
  self,
@@ -249,23 +254,20 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
249
254
  schema_type: "Optional[type[ModelDTOT]]" = None,
250
255
  statement_config: "Optional[StatementConfig]" = None,
251
256
  **kwargs: Any,
252
- ) -> "Union[ModelT, RowT, ModelDTOT]": # pyright: ignore[reportInvalidTypeVarUse]
257
+ ) -> "Union[dict[str, Any], ModelDTOT]":
253
258
  """Execute a select statement and return exactly one row.
254
259
 
255
260
  Raises an exception if no rows or more than one row is returned.
256
261
  """
257
262
  result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
258
263
  data = result.get_data()
259
- if not data:
260
- msg = "No rows found"
261
- raise NotFoundError(msg)
262
- if len(data) > 1:
263
- msg = f"Expected exactly one row, found {len(data)}"
264
- raise ValueError(msg)
265
- return cast(
266
- "Union[ModelT, RowT, ModelDTOT]",
267
- self.to_schema(data[0], schema_type=schema_type) if schema_type else data[0],
268
- )
264
+ data_len: int = len(data)
265
+ if data_len == 0:
266
+ self._raise_no_rows_found()
267
+ if data_len > 1:
268
+ self._raise_expected_one_row(data_len)
269
+ first_row = data[0]
270
+ return self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row
269
271
 
270
272
  @overload
271
273
  def select_one_or_none(
@@ -287,7 +289,7 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
287
289
  schema_type: None = None,
288
290
  statement_config: "Optional[StatementConfig]" = None,
289
291
  **kwargs: Any,
290
- ) -> "Optional[ModelT]": ... # pyright: ignore[reportInvalidTypeVarUse]
292
+ ) -> "Optional[dict[str, Any]]": ...
291
293
 
292
294
  def select_one_or_none(
293
295
  self,
@@ -297,7 +299,7 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
297
299
  schema_type: "Optional[type[ModelDTOT]]" = None,
298
300
  statement_config: "Optional[StatementConfig]" = None,
299
301
  **kwargs: Any,
300
- ) -> "Optional[Union[ModelT, ModelDTOT]]": # pyright: ignore[reportInvalidTypeVarUse]
302
+ ) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
301
303
  """Execute a select statement and return at most one row.
302
304
 
303
305
  Returns None if no rows are found.
@@ -305,12 +307,16 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
305
307
  """
306
308
  result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
307
309
  data = result.get_data()
308
- if not data:
310
+ data_len: int = len(data)
311
+ if data_len == 0:
309
312
  return None
310
- if len(data) > 1:
311
- msg = f"Expected at most one row, found {len(data)}"
312
- raise ValueError(msg)
313
- return cast("Optional[Union[ModelT, ModelDTOT]]", self.to_schema(data[0], schema_type=schema_type))
313
+ if data_len > 1:
314
+ self._raise_expected_at_most_one_row(data_len)
315
+ first_row = data[0]
316
+ return cast(
317
+ "Optional[Union[dict[str, Any], ModelDTOT]]",
318
+ self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row,
319
+ )
314
320
 
315
321
  @overload
316
322
  def select(
@@ -332,7 +338,7 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
332
338
  schema_type: None = None,
333
339
  statement_config: "Optional[StatementConfig]" = None,
334
340
  **kwargs: Any,
335
- ) -> "list[ModelT]": ... # pyright: ignore[reportInvalidTypeVarUse]
341
+ ) -> "list[dict[str, Any]]": ...
336
342
 
337
343
  def select(
338
344
  self,
@@ -342,12 +348,11 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
342
348
  schema_type: "Optional[type[ModelDTOT]]" = None,
343
349
  statement_config: "Optional[StatementConfig]" = None,
344
350
  **kwargs: Any,
345
- ) -> "Union[list[ModelT], list[ModelDTOT]]": # pyright: ignore[reportInvalidTypeVarUse]
351
+ ) -> "Union[list[dict[str, Any]], list[ModelDTOT]]":
346
352
  """Execute a select statement and return all rows."""
347
353
  result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
348
354
  return cast(
349
- "Union[list[ModelT], list[ModelDTOT]]",
350
- self.to_schema(cast("list[ModelT]", result.get_data()), schema_type=schema_type),
355
+ "Union[list[dict[str, Any]], list[ModelDTOT]]", self.to_schema(result.get_data(), schema_type=schema_type)
351
356
  )
352
357
 
353
358
  def select_value(
@@ -367,23 +372,19 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
367
372
  try:
368
373
  row = result.one()
369
374
  except ValueError as e:
370
- msg = "No rows found"
371
- raise NotFoundError(msg) from e
375
+ self._raise_no_rows_found_from_exception(e)
372
376
  if not row:
373
- msg = "No rows found"
374
- raise NotFoundError(msg)
377
+ self._raise_no_rows_found()
375
378
  if is_dict_row(row):
376
379
  if not row:
377
- msg = "Row has no columns"
378
- raise ValueError(msg)
380
+ self._raise_row_no_columns()
379
381
  return next(iter(row.values()))
380
382
  if is_indexable_row(row):
381
383
  if not row:
382
- msg = "Row has no columns"
383
- raise ValueError(msg)
384
+ self._raise_row_no_columns()
384
385
  return row[0]
385
- msg = f"Unexpected row type: {type(row)}"
386
- raise ValueError(msg)
386
+ self._raise_unexpected_row_type(type(row))
387
+ return None
387
388
 
388
389
  def select_value_or_none(
389
390
  self,
@@ -401,10 +402,11 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
401
402
  """
402
403
  result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
403
404
  data = result.get_data()
404
- if not data:
405
+ data_len: int = len(data)
406
+ if data_len == 0:
405
407
  return None
406
- if len(data) > 1:
407
- msg = f"Expected at most one row, found {len(data)}"
408
+ if data_len > 1:
409
+ msg = f"Expected at most one row, found {data_len}"
408
410
  raise ValueError(msg)
409
411
  row = data[0]
410
412
  if isinstance(row, dict):
@@ -471,3 +473,31 @@ class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToS
471
473
  select_result = self.execute(sql_statement)
472
474
 
473
475
  return (self.to_schema(select_result.get_data(), schema_type=schema_type), count_result.scalar())
476
+
477
+ def _raise_no_rows_found(self) -> NoReturn:
478
+ msg = "No rows found"
479
+ raise NotFoundError(msg)
480
+
481
+ def _raise_no_rows_found_from_exception(self, e: ValueError) -> NoReturn:
482
+ msg = "No rows found"
483
+ raise NotFoundError(msg) from e
484
+
485
+ def _raise_expected_one_row(self, data_len: int) -> NoReturn:
486
+ msg = f"Expected exactly one row, found {data_len}"
487
+ raise ValueError(msg)
488
+
489
+ def _raise_expected_at_most_one_row(self, data_len: int) -> NoReturn:
490
+ msg = f"Expected at most one row, found {data_len}"
491
+ raise ValueError(msg)
492
+
493
+ def _raise_row_no_columns(self) -> NoReturn:
494
+ msg = "Row has no columns"
495
+ raise ValueError(msg)
496
+
497
+ def _raise_unexpected_row_type(self, row_type: type) -> NoReturn:
498
+ msg = f"Unexpected row type: {row_type}"
499
+ raise ValueError(msg)
500
+
501
+ def _raise_cannot_extract_value_from_row_type(self, type_name: str) -> NoReturn:
502
+ msg = f"Cannot extract value from row type {type_name}"
503
+ raise TypeError(msg)
@@ -5,16 +5,14 @@ from collections.abc import Sequence
5
5
  from enum import Enum
6
6
  from functools import partial
7
7
  from pathlib import Path, PurePath
8
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload
8
+ from typing import Any, Callable, Final, Optional, overload
9
9
  from uuid import UUID
10
10
 
11
11
  from mypy_extensions import trait
12
12
 
13
- from sqlspec.exceptions import SQLSpecError, wrap_exceptions
13
+ from sqlspec.exceptions import SQLSpecError
14
14
  from sqlspec.typing import (
15
15
  CATTRS_INSTALLED,
16
- DataclassProtocol,
17
- DictLike,
18
16
  ModelDTOT,
19
17
  ModelT,
20
18
  attrs_asdict,
@@ -25,14 +23,16 @@ from sqlspec.typing import (
25
23
  )
26
24
  from sqlspec.utils.type_guards import is_attrs_schema, is_dataclass, is_msgspec_struct, is_pydantic_model
27
25
 
28
- if TYPE_CHECKING:
29
- from sqlspec._typing import AttrsInstanceStub, BaseModelStub, StructStub
30
-
31
26
  __all__ = ("_DEFAULT_TYPE_DECODERS", "_default_msgspec_deserializer")
32
27
 
33
28
 
34
29
  logger = logging.getLogger(__name__)
35
- _DEFAULT_TYPE_DECODERS: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = [
30
+
31
+ # Constants for performance optimization
32
+ _DATETIME_TYPES: Final[set[type]] = {datetime.datetime, datetime.date, datetime.time}
33
+ _PATH_TYPES: Final[tuple[type, ...]] = (Path, PurePath, UUID)
34
+
35
+ _DEFAULT_TYPE_DECODERS: Final[list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]] = [
36
36
  (lambda x: x is UUID, lambda t, v: t(v.hex)),
37
37
  (lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())),
38
38
  (lambda x: x is datetime.date, lambda t, v: t(v.isoformat())),
@@ -53,17 +53,32 @@ def _default_msgspec_deserializer(
53
53
  for predicate, decoder in type_decoders:
54
54
  if predicate(target_type):
55
55
  return decoder(target_type, value)
56
+
57
+ # Fast path checks using type identity and isinstance
56
58
  if target_type is UUID and isinstance(value, UUID):
57
59
  return value.hex
58
- if target_type in {datetime.datetime, datetime.date, datetime.time}:
59
- with wrap_exceptions(suppress=AttributeError):
60
+
61
+ # Use pre-computed set for faster lookup
62
+ if target_type in _DATETIME_TYPES:
63
+ try:
60
64
  return value.isoformat()
65
+ except AttributeError:
66
+ pass
67
+
61
68
  if isinstance(target_type, type) and issubclass(target_type, Enum) and isinstance(value, Enum):
62
69
  return value.value
70
+
63
71
  if isinstance(value, target_type):
64
72
  return value
65
- if issubclass(target_type, (Path, PurePath, UUID)):
66
- return target_type(value)
73
+
74
+ # Check for path types using pre-computed tuple
75
+ if isinstance(target_type, type):
76
+ try:
77
+ if issubclass(target_type, (Path, PurePath)) or issubclass(target_type, UUID):
78
+ return target_type(str(value))
79
+ except (TypeError, ValueError):
80
+ pass
81
+
67
82
  return value
68
83
 
69
84
 
@@ -74,36 +89,37 @@ class ToSchemaMixin:
74
89
  # Schema conversion overloads - handle common cases first
75
90
  @overload
76
91
  @staticmethod
92
+ def to_schema(data: "list[dict[str, Any]]") -> "list[dict[str, Any]]": ...
93
+ @overload
94
+ @staticmethod
77
95
  def to_schema(data: "list[dict[str, Any]]", *, schema_type: "type[ModelDTOT]") -> "list[ModelDTOT]": ...
78
96
  @overload
79
97
  @staticmethod
80
98
  def to_schema(data: "list[dict[str, Any]]", *, schema_type: None = None) -> "list[dict[str, Any]]": ...
81
99
  @overload
82
100
  @staticmethod
101
+ def to_schema(data: "dict[str, Any]") -> "dict[str, Any]": ...
102
+ @overload
103
+ @staticmethod
83
104
  def to_schema(data: "dict[str, Any]", *, schema_type: "type[ModelDTOT]") -> "ModelDTOT": ...
84
105
  @overload
85
106
  @staticmethod
86
107
  def to_schema(data: "dict[str, Any]", *, schema_type: None = None) -> "dict[str, Any]": ...
87
108
  @overload
88
109
  @staticmethod
110
+ def to_schema(data: "list[ModelT]") -> "list[ModelT]": ...
111
+ @overload
112
+ @staticmethod
89
113
  def to_schema(data: "list[ModelT]", *, schema_type: "type[ModelDTOT]") -> "list[ModelDTOT]": ...
90
114
  @overload
91
115
  @staticmethod
92
116
  def to_schema(data: "list[ModelT]", *, schema_type: None = None) -> "list[ModelT]": ...
93
117
  @overload
94
118
  @staticmethod
95
- def to_schema(
96
- data: "Union[DictLike, StructStub, BaseModelStub, DataclassProtocol, AttrsInstanceStub]",
97
- *,
98
- schema_type: "type[ModelDTOT]",
99
- ) -> "ModelDTOT": ...
119
+ def to_schema(data: "ModelT") -> "ModelT": ...
100
120
  @overload
101
121
  @staticmethod
102
- def to_schema(
103
- data: "Union[DictLike, StructStub, BaseModelStub, DataclassProtocol, AttrsInstanceStub]",
104
- *,
105
- schema_type: None = None,
106
- ) -> "Union[DictLike, StructStub, BaseModelStub, DataclassProtocol, AttrsInstanceStub]": ...
122
+ def to_schema(data: Any, *, schema_type: None = None) -> Any: ...
107
123
 
108
124
  @staticmethod
109
125
  def to_schema(data: Any, *, schema_type: "Optional[type[ModelDTOT]]" = None) -> Any:
@@ -123,46 +139,55 @@ class ToSchemaMixin:
123
139
  return data
124
140
  if is_dataclass(schema_type):
125
141
  if isinstance(data, list):
126
- return [schema_type(**dict(item) if hasattr(item, "keys") else item) for item in data] # type: ignore[operator]
142
+ result: list[Any] = []
143
+ for item in data:
144
+ if hasattr(item, "keys"):
145
+ result.append(schema_type(**dict(item))) # type: ignore[operator]
146
+ else:
147
+ result.append(item)
148
+ return result
127
149
  if hasattr(data, "keys"):
128
150
  return schema_type(**dict(data)) # type: ignore[operator]
129
151
  if isinstance(data, dict):
130
152
  return schema_type(**data) # type: ignore[operator]
131
- # Fallback for other types
132
153
  return data
133
154
  if is_msgspec_struct(schema_type):
155
+ # Cache the deserializer to avoid repeated partial() calls
156
+ deserializer = partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS)
134
157
  if not isinstance(data, Sequence):
135
- return convert(
136
- obj=data,
137
- type=schema_type,
138
- from_attributes=True,
139
- dec_hook=partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS),
140
- )
158
+ return convert(obj=data, type=schema_type, from_attributes=True, dec_hook=deserializer)
141
159
  return convert(
142
160
  obj=data,
143
161
  type=list[schema_type], # type: ignore[valid-type] # pyright: ignore
144
162
  from_attributes=True,
145
- dec_hook=partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS),
163
+ dec_hook=deserializer,
146
164
  )
147
165
  if is_pydantic_model(schema_type):
148
166
  if not isinstance(data, Sequence):
149
- return get_type_adapter(schema_type).validate_python(data, from_attributes=True) # pyright: ignore
150
- return get_type_adapter(list[schema_type]).validate_python(data, from_attributes=True) # type: ignore[valid-type] # pyright: ignore
167
+ adapter = get_type_adapter(schema_type)
168
+ return adapter.validate_python(data, from_attributes=True) # pyright: ignore
169
+ list_adapter = get_type_adapter(list[schema_type]) # type: ignore[valid-type] # pyright: ignore
170
+ return list_adapter.validate_python(data, from_attributes=True)
151
171
  if is_attrs_schema(schema_type):
152
172
  if CATTRS_INSTALLED:
153
173
  if isinstance(data, Sequence):
154
174
  return cattrs_structure(data, list[schema_type]) # type: ignore[valid-type] # pyright: ignore
155
- # If data is already structured (attrs instance), unstructure it first
156
175
  if hasattr(data, "__attrs_attrs__"):
157
- data = cattrs_unstructure(data)
176
+ unstructured_data = cattrs_unstructure(data)
177
+ return cattrs_structure(unstructured_data, schema_type) # pyright: ignore
158
178
  return cattrs_structure(data, schema_type) # pyright: ignore
159
179
  if isinstance(data, list):
160
- return [schema_type(**dict(item) if hasattr(item, "keys") else attrs_asdict(item)) for item in data]
180
+ attrs_result: list[Any] = []
181
+ for item in data:
182
+ if hasattr(item, "keys"):
183
+ attrs_result.append(schema_type(**dict(item)))
184
+ else:
185
+ attrs_result.append(schema_type(**attrs_asdict(item)))
186
+ return attrs_result
161
187
  if hasattr(data, "keys"):
162
188
  return schema_type(**dict(data))
163
189
  if isinstance(data, dict):
164
190
  return schema_type(**data)
165
- # Fallback for other types
166
191
  return data
167
192
  msg = "`schema_type` should be a valid Dataclass, Pydantic model, Msgspec struct, or Attrs class"
168
193
  raise SQLSpecError(msg)
@@ -1,3 +1,5 @@
1
+ from typing import Final, NoReturn, Optional
2
+
1
3
  from mypy_extensions import trait
2
4
  from sqlglot import exp, parse_one
3
5
  from sqlglot.dialects.dialect import DialectType
@@ -7,6 +9,9 @@ from sqlspec.exceptions import SQLConversionError
7
9
 
8
10
  __all__ = ("SQLTranslatorMixin",)
9
11
 
12
+ # Constants for better performance
13
+ _DEFAULT_PRETTY: Final[bool] = True
14
+
10
15
 
11
16
  @trait
12
17
  class SQLTranslatorMixin:
@@ -14,23 +19,68 @@ class SQLTranslatorMixin:
14
19
 
15
20
  __slots__ = ()
16
21
 
17
- def convert_to_dialect(self, statement: "Statement", to_dialect: DialectType = None, pretty: bool = True) -> str:
22
+ def convert_to_dialect(
23
+ self, statement: "Statement", to_dialect: "Optional[DialectType]" = None, pretty: bool = _DEFAULT_PRETTY
24
+ ) -> str:
25
+ """Convert a statement to a target SQL dialect.
26
+
27
+ Args:
28
+ statement: SQL statement to convert
29
+ to_dialect: Target dialect (defaults to current dialect)
30
+ pretty: Whether to format the output SQL
31
+
32
+ Returns:
33
+ SQL string in target dialect
34
+
35
+ Raises:
36
+ SQLConversionError: If parsing or conversion fails
37
+ """
38
+ # Fast path: get the parsed expression with minimal allocations
39
+ parsed_expression: Optional[exp.Expression] = None
40
+
18
41
  if statement is not None and isinstance(statement, SQL):
19
42
  if statement.expression is None:
20
- msg = "Statement could not be parsed"
21
- raise SQLConversionError(msg)
43
+ self._raise_statement_parse_error()
22
44
  parsed_expression = statement.expression
23
45
  elif isinstance(statement, exp.Expression):
24
46
  parsed_expression = statement
25
47
  else:
26
- try:
27
- parsed_expression = parse_one(statement, dialect=self.dialect) # type: ignore[attr-defined]
28
- except Exception as e:
29
- error_msg = f"Failed to parse SQL statement: {e!s}"
30
- raise SQLConversionError(error_msg) from e
48
+ parsed_expression = self._parse_statement_safely(statement)
49
+
50
+ # Get target dialect with fallback
31
51
  target_dialect = to_dialect or self.dialect # type: ignore[attr-defined]
52
+
53
+ # Generate SQL with error handling
54
+ return self._generate_sql_safely(parsed_expression, target_dialect, pretty)
55
+
56
+ def _parse_statement_safely(self, statement: "Statement") -> "exp.Expression":
57
+ """Parse statement with copy=False optimization and proper error handling."""
58
+ try:
59
+ # Convert statement to string if needed
60
+ sql_string = str(statement)
61
+ # Use copy=False for better performance
62
+ return parse_one(sql_string, dialect=self.dialect, copy=False) # type: ignore[attr-defined]
63
+ except Exception as e:
64
+ self._raise_parse_error(e)
65
+
66
+ def _generate_sql_safely(self, expression: "exp.Expression", dialect: DialectType, pretty: bool) -> str:
67
+ """Generate SQL with proper error handling."""
32
68
  try:
33
- return parsed_expression.sql(dialect=target_dialect, pretty=pretty)
69
+ return expression.sql(dialect=dialect, pretty=pretty)
34
70
  except Exception as e:
35
- error_msg = f"Failed to convert SQL expression to {target_dialect}: {e!s}"
36
- raise SQLConversionError(error_msg) from e
71
+ self._raise_conversion_error(dialect, e)
72
+
73
+ def _raise_statement_parse_error(self) -> NoReturn:
74
+ """Raise error for unparsable statements."""
75
+ msg = "Statement could not be parsed"
76
+ raise SQLConversionError(msg)
77
+
78
+ def _raise_parse_error(self, e: Exception) -> NoReturn:
79
+ """Raise error for parsing failures."""
80
+ error_msg = f"Failed to parse SQL statement: {e!s}"
81
+ raise SQLConversionError(error_msg) from e
82
+
83
+ def _raise_conversion_error(self, dialect: DialectType, e: Exception) -> NoReturn:
84
+ """Raise error for conversion failures."""
85
+ error_msg = f"Failed to convert SQL expression to {dialect}: {e!s}"
86
+ raise SQLConversionError(error_msg) from e
@@ -39,7 +39,7 @@ def get_database_migration_plugin(app: "Litestar") -> "SQLSpec":
39
39
  raise ImproperConfigurationError(msg)
40
40
 
41
41
 
42
- @click.group(cls=LitestarGroup, name="database")
42
+ @click.group(cls=LitestarGroup, name="db")
43
43
  def database_group(ctx: "click.Context") -> None:
44
44
  """Manage SQLSpec database components."""
45
45
  ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config}
@@ -1,7 +1,7 @@
1
1
  from typing import TYPE_CHECKING, Any, Union
2
2
 
3
3
  from litestar.di import Provide
4
- from litestar.plugins import InitPluginProtocol
4
+ from litestar.plugins import CLIPlugin, InitPluginProtocol
5
5
 
6
6
  from sqlspec.base import SQLSpec as SQLSpecBase
7
7
  from sqlspec.config import AsyncConfigT, DatabaseConfigProtocol, DriverT, SyncConfigT
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
17
17
  logger = get_logger("extensions.litestar")
18
18
 
19
19
 
20
- class SQLSpec(InitPluginProtocol, SQLSpecBase):
20
+ class SQLSpec(InitPluginProtocol, CLIPlugin, SQLSpecBase):
21
21
  """Litestar plugin for SQLSpec database integration."""
22
22
 
23
23
  __slots__ = ("_config", "_plugin_configs")
sqlspec/protocols.py CHANGED
@@ -371,6 +371,9 @@ class SQLBuilderProtocol(Protocol):
371
371
  _expression: "Optional[exp.Expression]"
372
372
  _parameters: dict[str, Any]
373
373
  _parameter_counter: int
374
+ _columns: Any # Optional attribute for some builders
375
+ _table: Any # Optional attribute for some builders
376
+ _with_ctes: Any # Optional attribute for some builders
374
377
  dialect: Any
375
378
  dialect_name: "Optional[str]"
376
379
 
@@ -383,6 +386,10 @@ class SQLBuilderProtocol(Protocol):
383
386
  """Add a parameter to the builder."""
384
387
  ...
385
388
 
389
+ def _generate_unique_parameter_name(self, base_name: str) -> str:
390
+ """Generate a unique parameter name."""
391
+ ...
392
+
386
393
  def _parameterize_expression(self, expression: "exp.Expression") -> "exp.Expression":
387
394
  """Replace literal values in an expression with bound parameters."""
388
395
  ...
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
14
14
  try:
15
15
  import uvloop # pyright: ignore[reportMissingImports]
16
16
  except ImportError:
17
- uvloop = None
17
+ uvloop = None # type: ignore[assignment,unused-ignore]
18
18
 
19
19
 
20
20
  ReturnT = TypeVar("ReturnT")
@@ -841,9 +841,13 @@ def has_sql_method(obj: Any) -> "TypeGuard[HasSQLMethodProtocol]":
841
841
 
842
842
  def has_query_builder_parameters(obj: Any) -> "TypeGuard[SQLBuilderProtocol]":
843
843
  """Check if an object is a query builder with parameters property."""
844
- from sqlspec.protocols import SQLBuilderProtocol
845
-
846
- return isinstance(obj, SQLBuilderProtocol)
844
+ return (
845
+ hasattr(obj, "build")
846
+ and callable(getattr(obj, "build", None))
847
+ and hasattr(obj, "parameters")
848
+ and hasattr(obj, "add_parameter")
849
+ and callable(getattr(obj, "add_parameter", None))
850
+ )
847
851
 
848
852
 
849
853
  def is_object_store_item(obj: Any) -> "TypeGuard[ObjectStoreItemProtocol]":
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sqlspec
3
- Version: 0.15.0
3
+ Version: 0.16.1
4
4
  Summary: SQL Experiments in Python
5
5
  Project-URL: Discord, https://discord.gg/litestar
6
6
  Project-URL: Issue, https://github.com/litestar-org/sqlspec/issues/