sqlspec 0.15.0__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +699 -43
- sqlspec/builder/_base.py +77 -44
- sqlspec/builder/_column.py +0 -4
- sqlspec/builder/_ddl.py +15 -52
- sqlspec/builder/_ddl_utils.py +0 -1
- sqlspec/builder/_delete.py +4 -5
- sqlspec/builder/_insert.py +61 -35
- sqlspec/builder/_merge.py +17 -2
- sqlspec/builder/_parsing_utils.py +16 -12
- sqlspec/builder/_select.py +29 -33
- sqlspec/builder/_update.py +4 -2
- sqlspec/builder/mixins/_cte_and_set_ops.py +47 -20
- sqlspec/builder/mixins/_delete_operations.py +6 -1
- sqlspec/builder/mixins/_insert_operations.py +126 -24
- sqlspec/builder/mixins/_join_operations.py +11 -4
- sqlspec/builder/mixins/_merge_operations.py +91 -19
- sqlspec/builder/mixins/_order_limit_operations.py +15 -3
- sqlspec/builder/mixins/_pivot_operations.py +11 -2
- sqlspec/builder/mixins/_select_operations.py +16 -10
- sqlspec/builder/mixins/_update_operations.py +43 -10
- sqlspec/builder/mixins/_where_clause.py +177 -65
- sqlspec/core/cache.py +26 -28
- sqlspec/core/compiler.py +58 -37
- sqlspec/core/filters.py +12 -10
- sqlspec/core/parameters.py +80 -52
- sqlspec/core/result.py +30 -17
- sqlspec/core/statement.py +47 -22
- sqlspec/driver/_async.py +76 -46
- sqlspec/driver/_common.py +25 -6
- sqlspec/driver/_sync.py +73 -43
- sqlspec/driver/mixins/_result_tools.py +62 -37
- sqlspec/driver/mixins/_sql_translator.py +61 -11
- sqlspec/extensions/litestar/cli.py +1 -1
- sqlspec/extensions/litestar/plugin.py +2 -2
- sqlspec/protocols.py +7 -0
- sqlspec/utils/sync_tools.py +1 -1
- sqlspec/utils/type_guards.py +7 -3
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/METADATA +1 -1
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/RECORD +43 -43
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/core/statement.py
CHANGED
|
@@ -220,13 +220,20 @@ class SQL:
|
|
|
220
220
|
if "is_script" in kwargs:
|
|
221
221
|
self._is_script = bool(kwargs.pop("is_script"))
|
|
222
222
|
|
|
223
|
-
|
|
224
|
-
|
|
223
|
+
# Optimize parameter filtering with direct iteration
|
|
224
|
+
filters: list[StatementFilter] = []
|
|
225
|
+
actual_params: list[Any] = []
|
|
226
|
+
for p in parameters:
|
|
227
|
+
if is_statement_filter(p):
|
|
228
|
+
filters.append(p)
|
|
229
|
+
else:
|
|
230
|
+
actual_params.append(p)
|
|
225
231
|
|
|
226
232
|
self._filters.extend(filters)
|
|
227
233
|
|
|
228
234
|
if actual_params:
|
|
229
|
-
|
|
235
|
+
param_count = len(actual_params)
|
|
236
|
+
if param_count == 1:
|
|
230
237
|
param = actual_params[0]
|
|
231
238
|
if isinstance(param, dict):
|
|
232
239
|
self._named_parameters.update(param)
|
|
@@ -339,10 +346,11 @@ class SQL:
|
|
|
339
346
|
"""Explicitly compile the SQL statement."""
|
|
340
347
|
if self._processed_state is Empty:
|
|
341
348
|
try:
|
|
342
|
-
|
|
349
|
+
# Avoid unnecessary variable assignment
|
|
343
350
|
processor = SQLProcessor(self._statement_config)
|
|
344
|
-
|
|
345
|
-
|
|
351
|
+
compiled_result = processor.compile(
|
|
352
|
+
self._raw_sql, self._named_parameters or self._positional_parameters, is_many=self._is_many
|
|
353
|
+
)
|
|
346
354
|
|
|
347
355
|
self._processed_state = ProcessedState(
|
|
348
356
|
compiled_sql=compiled_result.compiled_sql,
|
|
@@ -368,6 +376,10 @@ class SQL:
|
|
|
368
376
|
new_sql = SQL(
|
|
369
377
|
self._raw_sql, *self._original_parameters, statement_config=self._statement_config, is_many=self._is_many
|
|
370
378
|
)
|
|
379
|
+
# Preserve accumulated parameters when marking as script
|
|
380
|
+
new_sql._named_parameters.update(self._named_parameters)
|
|
381
|
+
new_sql._positional_parameters = self._positional_parameters.copy()
|
|
382
|
+
new_sql._filters = self._filters.copy()
|
|
371
383
|
new_sql._is_script = True
|
|
372
384
|
return new_sql
|
|
373
385
|
|
|
@@ -375,13 +387,19 @@ class SQL:
|
|
|
375
387
|
self, statement: "Optional[Union[str, exp.Expression]]" = None, parameters: Optional[Any] = None, **kwargs: Any
|
|
376
388
|
) -> "SQL":
|
|
377
389
|
"""Create copy with modifications."""
|
|
378
|
-
|
|
390
|
+
new_sql = SQL(
|
|
379
391
|
statement or self._raw_sql,
|
|
380
392
|
*(parameters if parameters is not None else self._original_parameters),
|
|
381
393
|
statement_config=self._statement_config,
|
|
382
394
|
is_many=self._is_many,
|
|
383
395
|
**kwargs,
|
|
384
396
|
)
|
|
397
|
+
# Only preserve accumulated parameters when no explicit parameters are provided
|
|
398
|
+
if parameters is None:
|
|
399
|
+
new_sql._named_parameters.update(self._named_parameters)
|
|
400
|
+
new_sql._positional_parameters = self._positional_parameters.copy()
|
|
401
|
+
new_sql._filters = self._filters.copy()
|
|
402
|
+
return new_sql
|
|
385
403
|
|
|
386
404
|
def add_named_parameter(self, name: str, value: Any) -> "SQL":
|
|
387
405
|
"""Add a named parameter and return a new SQL instance.
|
|
@@ -411,6 +429,7 @@ class SQL:
|
|
|
411
429
|
Returns:
|
|
412
430
|
New SQL instance with the WHERE condition applied
|
|
413
431
|
"""
|
|
432
|
+
# Parse current SQL with copy=False optimization
|
|
414
433
|
current_expr = None
|
|
415
434
|
with contextlib.suppress(ParseError):
|
|
416
435
|
current_expr = sqlglot.parse_one(self._raw_sql, dialect=self._dialect)
|
|
@@ -419,8 +438,11 @@ class SQL:
|
|
|
419
438
|
try:
|
|
420
439
|
current_expr = sqlglot.parse_one(self._raw_sql, dialect=self._dialect)
|
|
421
440
|
except ParseError:
|
|
422
|
-
|
|
441
|
+
# Use f-string optimization and copy=False
|
|
442
|
+
subquery_sql = f"SELECT * FROM ({self._raw_sql}) AS subquery"
|
|
443
|
+
current_expr = sqlglot.parse_one(subquery_sql, dialect=self._dialect)
|
|
423
444
|
|
|
445
|
+
# Parse condition with copy=False optimization
|
|
424
446
|
condition_expr: exp.Expression
|
|
425
447
|
if isinstance(condition, str):
|
|
426
448
|
try:
|
|
@@ -430,29 +452,32 @@ class SQL:
|
|
|
430
452
|
else:
|
|
431
453
|
condition_expr = condition
|
|
432
454
|
|
|
455
|
+
# Apply WHERE clause
|
|
433
456
|
if isinstance(current_expr, exp.Select) or supports_where(current_expr):
|
|
434
|
-
new_expr = current_expr.where(condition_expr)
|
|
457
|
+
new_expr = current_expr.where(condition_expr, copy=False)
|
|
435
458
|
else:
|
|
436
|
-
new_expr = exp.Select().from_(current_expr).where(condition_expr)
|
|
459
|
+
new_expr = exp.Select().from_(current_expr).where(condition_expr, copy=False)
|
|
437
460
|
|
|
461
|
+
# Generate SQL and create new instance
|
|
438
462
|
new_sql_text = new_expr.sql(dialect=self._dialect)
|
|
439
|
-
|
|
440
|
-
return SQL(
|
|
463
|
+
new_sql = SQL(
|
|
441
464
|
new_sql_text, *self._original_parameters, statement_config=self._statement_config, is_many=self._is_many
|
|
442
465
|
)
|
|
443
466
|
|
|
467
|
+
# Preserve state efficiently
|
|
468
|
+
new_sql._named_parameters.update(self._named_parameters)
|
|
469
|
+
new_sql._positional_parameters = self._positional_parameters.copy()
|
|
470
|
+
new_sql._filters = self._filters.copy()
|
|
471
|
+
return new_sql
|
|
472
|
+
|
|
444
473
|
def __hash__(self) -> int:
|
|
445
|
-
"""Hash value."""
|
|
474
|
+
"""Hash value with optimized computation."""
|
|
446
475
|
if self._hash is None:
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
self._is_many,
|
|
453
|
-
self._is_script,
|
|
454
|
-
)
|
|
455
|
-
)
|
|
476
|
+
# Pre-compute tuple components to avoid multiple tuple() calls
|
|
477
|
+
positional_tuple = tuple(self._positional_parameters)
|
|
478
|
+
named_tuple = tuple(sorted(self._named_parameters.items())) if self._named_parameters else ()
|
|
479
|
+
|
|
480
|
+
self._hash = hash((self._raw_sql, positional_tuple, named_tuple, self._is_many, self._is_script))
|
|
456
481
|
return self._hash
|
|
457
482
|
|
|
458
483
|
def __eq__(self, other: object) -> bool:
|
sqlspec/driver/_async.py
CHANGED
|
@@ -5,7 +5,7 @@ including connection management, transaction support, and result processing.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from abc import abstractmethod
|
|
8
|
-
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Final, NoReturn, Optional, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from sqlspec.core import SQL, Statement
|
|
11
11
|
from sqlspec.driver._common import CommonDriverAttributesMixin, ExecutionResult
|
|
@@ -20,14 +20,15 @@ if TYPE_CHECKING:
|
|
|
20
20
|
|
|
21
21
|
from sqlspec.builder import QueryBuilder
|
|
22
22
|
from sqlspec.core import SQLResult, StatementConfig, StatementFilter
|
|
23
|
-
from sqlspec.typing import ModelDTOT,
|
|
23
|
+
from sqlspec.typing import ModelDTOT, StatementParameters
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
_LOGGER_NAME: Final[str] = "sqlspec"
|
|
26
|
+
logger = get_logger(_LOGGER_NAME)
|
|
26
27
|
|
|
27
28
|
__all__ = ("AsyncDriverAdapterBase",)
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
EMPTY_FILTERS: "list[StatementFilter]" = []
|
|
31
|
+
EMPTY_FILTERS: Final["list[StatementFilter]"] = []
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToSchemaMixin):
|
|
@@ -128,12 +129,16 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
128
129
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
129
130
|
statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True)
|
|
130
131
|
|
|
132
|
+
statement_count: int = len(statements)
|
|
133
|
+
successful_count: int = 0
|
|
134
|
+
|
|
131
135
|
for stmt in statements:
|
|
132
136
|
single_stmt = statement.copy(statement=stmt, parameters=prepared_parameters)
|
|
133
137
|
await self._execute_statement(cursor, single_stmt)
|
|
138
|
+
successful_count += 1
|
|
134
139
|
|
|
135
140
|
return self.create_execution_result(
|
|
136
|
-
cursor, statement_count=
|
|
141
|
+
cursor, statement_count=statement_count, successful_statements=successful_count, is_script_result=True
|
|
137
142
|
)
|
|
138
143
|
|
|
139
144
|
@abstractmethod
|
|
@@ -214,8 +219,8 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
214
219
|
By default, validates each statement and logs warnings for dangerous
|
|
215
220
|
operations. Use suppress_warnings=True for migrations and admin scripts.
|
|
216
221
|
"""
|
|
217
|
-
|
|
218
|
-
sql_statement = self.prepare_statement(statement, parameters, statement_config=
|
|
222
|
+
config = statement_config or self.statement_config
|
|
223
|
+
sql_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
|
|
219
224
|
|
|
220
225
|
return await self.dispatch_statement_execution(statement=sql_statement.as_script(), connection=self.connection)
|
|
221
226
|
|
|
@@ -239,7 +244,7 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
239
244
|
schema_type: None = None,
|
|
240
245
|
statement_config: "Optional[StatementConfig]" = None,
|
|
241
246
|
**kwargs: Any,
|
|
242
|
-
) -> "
|
|
247
|
+
) -> "dict[str, Any]": ...
|
|
243
248
|
|
|
244
249
|
async def select_one(
|
|
245
250
|
self,
|
|
@@ -249,23 +254,20 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
249
254
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
250
255
|
statement_config: "Optional[StatementConfig]" = None,
|
|
251
256
|
**kwargs: Any,
|
|
252
|
-
) -> "Union[
|
|
257
|
+
) -> "Union[dict[str, Any], ModelDTOT]":
|
|
253
258
|
"""Execute a select statement and return exactly one row.
|
|
254
259
|
|
|
255
260
|
Raises an exception if no rows or more than one row is returned.
|
|
256
261
|
"""
|
|
257
262
|
result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
258
263
|
data = result.get_data()
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
if
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
return
|
|
266
|
-
"Union[ModelT, RowT, ModelDTOT]",
|
|
267
|
-
self.to_schema(data[0], schema_type=schema_type) if schema_type else data[0],
|
|
268
|
-
)
|
|
264
|
+
data_len: int = len(data)
|
|
265
|
+
if data_len == 0:
|
|
266
|
+
self._raise_no_rows_found()
|
|
267
|
+
if data_len > 1:
|
|
268
|
+
self._raise_expected_one_row(data_len)
|
|
269
|
+
first_row = data[0]
|
|
270
|
+
return self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row
|
|
269
271
|
|
|
270
272
|
@overload
|
|
271
273
|
async def select_one_or_none(
|
|
@@ -287,7 +289,7 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
287
289
|
schema_type: None = None,
|
|
288
290
|
statement_config: "Optional[StatementConfig]" = None,
|
|
289
291
|
**kwargs: Any,
|
|
290
|
-
) -> "Optional[
|
|
292
|
+
) -> "Optional[dict[str, Any]]": ...
|
|
291
293
|
|
|
292
294
|
async def select_one_or_none(
|
|
293
295
|
self,
|
|
@@ -297,7 +299,7 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
297
299
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
298
300
|
statement_config: "Optional[StatementConfig]" = None,
|
|
299
301
|
**kwargs: Any,
|
|
300
|
-
) -> "Optional[Union[
|
|
302
|
+
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
|
|
301
303
|
"""Execute a select statement and return at most one row.
|
|
302
304
|
|
|
303
305
|
Returns None if no rows are found.
|
|
@@ -305,12 +307,16 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
305
307
|
"""
|
|
306
308
|
result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
307
309
|
data = result.get_data()
|
|
308
|
-
|
|
310
|
+
data_len: int = len(data)
|
|
311
|
+
if data_len == 0:
|
|
309
312
|
return None
|
|
310
|
-
if
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
return cast(
|
|
313
|
+
if data_len > 1:
|
|
314
|
+
self._raise_expected_at_most_one_row(data_len)
|
|
315
|
+
first_row = data[0]
|
|
316
|
+
return cast(
|
|
317
|
+
"Optional[Union[dict[str, Any], ModelDTOT]]",
|
|
318
|
+
self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row,
|
|
319
|
+
)
|
|
314
320
|
|
|
315
321
|
@overload
|
|
316
322
|
async def select(
|
|
@@ -332,7 +338,8 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
332
338
|
schema_type: None = None,
|
|
333
339
|
statement_config: "Optional[StatementConfig]" = None,
|
|
334
340
|
**kwargs: Any,
|
|
335
|
-
) -> "list[
|
|
341
|
+
) -> "list[dict[str, Any]]": ...
|
|
342
|
+
|
|
336
343
|
async def select(
|
|
337
344
|
self,
|
|
338
345
|
statement: "Union[Statement, QueryBuilder]",
|
|
@@ -341,12 +348,11 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
341
348
|
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
342
349
|
statement_config: "Optional[StatementConfig]" = None,
|
|
343
350
|
**kwargs: Any,
|
|
344
|
-
) -> "Union[list[
|
|
351
|
+
) -> "Union[list[dict[str, Any]], list[ModelDTOT]]":
|
|
345
352
|
"""Execute a select statement and return all rows."""
|
|
346
353
|
result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
347
354
|
return cast(
|
|
348
|
-
"Union[list[
|
|
349
|
-
self.to_schema(cast("list[ModelT]", result.get_data()), schema_type=schema_type),
|
|
355
|
+
"Union[list[dict[str, Any]], list[ModelDTOT]]", self.to_schema(result.get_data(), schema_type=schema_type)
|
|
350
356
|
)
|
|
351
357
|
|
|
352
358
|
async def select_value(
|
|
@@ -366,23 +372,19 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
366
372
|
try:
|
|
367
373
|
row = result.one()
|
|
368
374
|
except ValueError as e:
|
|
369
|
-
|
|
370
|
-
raise NotFoundError(msg) from e
|
|
375
|
+
self._raise_no_rows_found_from_exception(e)
|
|
371
376
|
if not row:
|
|
372
|
-
|
|
373
|
-
raise NotFoundError(msg)
|
|
377
|
+
self._raise_no_rows_found()
|
|
374
378
|
if is_dict_row(row):
|
|
375
379
|
if not row:
|
|
376
|
-
|
|
377
|
-
raise ValueError(msg)
|
|
380
|
+
self._raise_row_no_columns()
|
|
378
381
|
return next(iter(row.values()))
|
|
379
382
|
if is_indexable_row(row):
|
|
380
383
|
if not row:
|
|
381
|
-
|
|
382
|
-
raise ValueError(msg)
|
|
384
|
+
self._raise_row_no_columns()
|
|
383
385
|
return row[0]
|
|
384
|
-
|
|
385
|
-
|
|
386
|
+
self._raise_unexpected_row_type(type(row))
|
|
387
|
+
return None
|
|
386
388
|
|
|
387
389
|
async def select_value_or_none(
|
|
388
390
|
self,
|
|
@@ -400,11 +402,11 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
400
402
|
"""
|
|
401
403
|
result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
402
404
|
data = result.get_data()
|
|
403
|
-
|
|
405
|
+
data_len: int = len(data)
|
|
406
|
+
if data_len == 0:
|
|
404
407
|
return None
|
|
405
|
-
if
|
|
406
|
-
|
|
407
|
-
raise ValueError(msg)
|
|
408
|
+
if data_len > 1:
|
|
409
|
+
self._raise_expected_at_most_one_row(data_len)
|
|
408
410
|
row = data[0]
|
|
409
411
|
if is_dict_row(row):
|
|
410
412
|
if not row:
|
|
@@ -412,8 +414,8 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
412
414
|
return next(iter(row.values()))
|
|
413
415
|
if is_indexable_row(row):
|
|
414
416
|
return row[0]
|
|
415
|
-
|
|
416
|
-
|
|
417
|
+
self._raise_cannot_extract_value_from_row_type(type(row).__name__)
|
|
418
|
+
return None
|
|
417
419
|
|
|
418
420
|
@overload
|
|
419
421
|
async def select_with_total(
|
|
@@ -470,3 +472,31 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, To
|
|
|
470
472
|
select_result = await self.execute(sql_statement)
|
|
471
473
|
|
|
472
474
|
return (self.to_schema(select_result.get_data(), schema_type=schema_type), count_result.scalar())
|
|
475
|
+
|
|
476
|
+
def _raise_no_rows_found(self) -> NoReturn:
|
|
477
|
+
msg = "No rows found"
|
|
478
|
+
raise NotFoundError(msg)
|
|
479
|
+
|
|
480
|
+
def _raise_no_rows_found_from_exception(self, e: ValueError) -> NoReturn:
|
|
481
|
+
msg = "No rows found"
|
|
482
|
+
raise NotFoundError(msg) from e
|
|
483
|
+
|
|
484
|
+
def _raise_expected_one_row(self, data_len: int) -> NoReturn:
|
|
485
|
+
msg = f"Expected exactly one row, found {data_len}"
|
|
486
|
+
raise ValueError(msg)
|
|
487
|
+
|
|
488
|
+
def _raise_expected_at_most_one_row(self, data_len: int) -> NoReturn:
|
|
489
|
+
msg = f"Expected at most one row, found {data_len}"
|
|
490
|
+
raise ValueError(msg)
|
|
491
|
+
|
|
492
|
+
def _raise_row_no_columns(self) -> NoReturn:
|
|
493
|
+
msg = "Row has no columns"
|
|
494
|
+
raise ValueError(msg)
|
|
495
|
+
|
|
496
|
+
def _raise_unexpected_row_type(self, row_type: type) -> NoReturn:
|
|
497
|
+
msg = f"Unexpected row type: {row_type}"
|
|
498
|
+
raise ValueError(msg)
|
|
499
|
+
|
|
500
|
+
def _raise_cannot_extract_value_from_row_type(self, type_name: str) -> NoReturn:
|
|
501
|
+
msg = f"Cannot extract value from row type {type_name}"
|
|
502
|
+
raise TypeError(msg)
|
sqlspec/driver/_common.py
CHANGED
|
@@ -17,7 +17,9 @@ from sqlspec.exceptions import ImproperConfigurationError
|
|
|
17
17
|
from sqlspec.utils.logging import get_logger
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
|
-
from
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
|
|
22
|
+
from sqlspec.core.filters import FilterTypeT, StatementFilter
|
|
21
23
|
from sqlspec.typing import StatementParameters
|
|
22
24
|
|
|
23
25
|
|
|
@@ -424,10 +426,9 @@ class CommonDriverAttributesMixin:
|
|
|
424
426
|
if isinstance(parameters, dict):
|
|
425
427
|
if not parameters:
|
|
426
428
|
return []
|
|
427
|
-
if (
|
|
428
|
-
statement_config.parameter_config.supported_execution_parameter_styles
|
|
429
|
-
|
|
430
|
-
in statement_config.parameter_config.supported_execution_parameter_styles
|
|
429
|
+
if statement_config.parameter_config.supported_execution_parameter_styles and (
|
|
430
|
+
ParameterStyle.NAMED_PYFORMAT in statement_config.parameter_config.supported_execution_parameter_styles
|
|
431
|
+
or ParameterStyle.NAMED_COLON in statement_config.parameter_config.supported_execution_parameter_styles
|
|
431
432
|
):
|
|
432
433
|
return {k: apply_type_coercion(v) for k, v in parameters.items()}
|
|
433
434
|
if statement_config.parameter_config.default_parameter_style in {
|
|
@@ -577,6 +578,24 @@ class CommonDriverAttributesMixin:
|
|
|
577
578
|
|
|
578
579
|
return max(style_counts.keys(), key=lambda style: (style_counts[style], -precedence.get(style, 99)))
|
|
579
580
|
|
|
581
|
+
@staticmethod
|
|
582
|
+
def find_filter(
|
|
583
|
+
filter_type: "type[FilterTypeT]",
|
|
584
|
+
filters: "Sequence[StatementFilter | StatementParameters] | Sequence[StatementFilter]",
|
|
585
|
+
) -> "FilterTypeT | None":
|
|
586
|
+
"""Get the filter specified by filter type from the filters.
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
filter_type: The type of filter to find.
|
|
590
|
+
filters: filter types to apply to the query
|
|
591
|
+
|
|
592
|
+
Returns:
|
|
593
|
+
The match filter instance or None
|
|
594
|
+
"""
|
|
595
|
+
return next(
|
|
596
|
+
(cast("FilterTypeT | None", filter_) for filter_ in filters if isinstance(filter_, filter_type)), None
|
|
597
|
+
)
|
|
598
|
+
|
|
580
599
|
def _create_count_query(self, original_sql: "SQL") -> "SQL":
|
|
581
600
|
"""Create a COUNT query from the original SQL statement.
|
|
582
601
|
|
|
@@ -586,7 +605,7 @@ class CommonDriverAttributesMixin:
|
|
|
586
605
|
if not original_sql.expression:
|
|
587
606
|
msg = "Cannot create COUNT query from empty SQL expression"
|
|
588
607
|
raise ImproperConfigurationError(msg)
|
|
589
|
-
expr = original_sql.expression
|
|
608
|
+
expr = original_sql.expression
|
|
590
609
|
|
|
591
610
|
if isinstance(expr, exp.Select):
|
|
592
611
|
if expr.args.get("group"):
|