sqlspec 0.13.1__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +39 -1
- sqlspec/__main__.py +12 -0
- sqlspec/adapters/adbc/config.py +16 -40
- sqlspec/adapters/adbc/driver.py +43 -16
- sqlspec/adapters/adbc/transformers.py +108 -0
- sqlspec/adapters/aiosqlite/config.py +2 -20
- sqlspec/adapters/aiosqlite/driver.py +36 -18
- sqlspec/adapters/asyncmy/config.py +2 -33
- sqlspec/adapters/asyncmy/driver.py +23 -16
- sqlspec/adapters/asyncpg/config.py +5 -39
- sqlspec/adapters/asyncpg/driver.py +41 -18
- sqlspec/adapters/bigquery/config.py +2 -43
- sqlspec/adapters/bigquery/driver.py +26 -14
- sqlspec/adapters/duckdb/config.py +2 -49
- sqlspec/adapters/duckdb/driver.py +35 -16
- sqlspec/adapters/oracledb/config.py +4 -83
- sqlspec/adapters/oracledb/driver.py +54 -27
- sqlspec/adapters/psqlpy/config.py +2 -55
- sqlspec/adapters/psqlpy/driver.py +28 -8
- sqlspec/adapters/psycopg/config.py +4 -73
- sqlspec/adapters/psycopg/driver.py +69 -24
- sqlspec/adapters/sqlite/config.py +3 -21
- sqlspec/adapters/sqlite/driver.py +50 -26
- sqlspec/cli.py +248 -0
- sqlspec/config.py +18 -20
- sqlspec/driver/_async.py +28 -10
- sqlspec/driver/_common.py +5 -4
- sqlspec/driver/_sync.py +28 -10
- sqlspec/driver/mixins/__init__.py +6 -0
- sqlspec/driver/mixins/_cache.py +114 -0
- sqlspec/driver/mixins/_pipeline.py +0 -4
- sqlspec/{service/base.py → driver/mixins/_query_tools.py} +86 -421
- sqlspec/driver/mixins/_result_utils.py +0 -2
- sqlspec/driver/mixins/_sql_translator.py +0 -2
- sqlspec/driver/mixins/_storage.py +4 -18
- sqlspec/driver/mixins/_type_coercion.py +0 -2
- sqlspec/driver/parameters.py +4 -4
- sqlspec/extensions/aiosql/adapter.py +4 -4
- sqlspec/extensions/litestar/__init__.py +2 -1
- sqlspec/extensions/litestar/cli.py +48 -0
- sqlspec/extensions/litestar/plugin.py +3 -0
- sqlspec/loader.py +1 -1
- sqlspec/migrations/__init__.py +23 -0
- sqlspec/migrations/base.py +390 -0
- sqlspec/migrations/commands.py +525 -0
- sqlspec/migrations/runner.py +215 -0
- sqlspec/migrations/tracker.py +153 -0
- sqlspec/migrations/utils.py +89 -0
- sqlspec/protocols.py +37 -3
- sqlspec/statement/builder/__init__.py +8 -8
- sqlspec/statement/builder/{column.py → _column.py} +82 -52
- sqlspec/statement/builder/{ddl.py → _ddl.py} +5 -5
- sqlspec/statement/builder/_ddl_utils.py +1 -1
- sqlspec/statement/builder/{delete.py → _delete.py} +1 -1
- sqlspec/statement/builder/{insert.py → _insert.py} +1 -1
- sqlspec/statement/builder/{merge.py → _merge.py} +1 -1
- sqlspec/statement/builder/_parsing_utils.py +5 -3
- sqlspec/statement/builder/{select.py → _select.py} +59 -61
- sqlspec/statement/builder/{update.py → _update.py} +2 -2
- sqlspec/statement/builder/mixins/__init__.py +24 -30
- sqlspec/statement/builder/mixins/{_set_ops.py → _cte_and_set_ops.py} +86 -2
- sqlspec/statement/builder/mixins/{_delete_from.py → _delete_operations.py} +2 -0
- sqlspec/statement/builder/mixins/{_insert_values.py → _insert_operations.py} +70 -1
- sqlspec/statement/builder/mixins/{_merge_clauses.py → _merge_operations.py} +2 -0
- sqlspec/statement/builder/mixins/_order_limit_operations.py +123 -0
- sqlspec/statement/builder/mixins/{_pivot.py → _pivot_operations.py} +71 -2
- sqlspec/statement/builder/mixins/_select_operations.py +612 -0
- sqlspec/statement/builder/mixins/{_update_set.py → _update_operations.py} +73 -2
- sqlspec/statement/builder/mixins/_where_clause.py +536 -0
- sqlspec/statement/cache.py +50 -0
- sqlspec/statement/filters.py +37 -8
- sqlspec/statement/parameters.py +143 -54
- sqlspec/statement/pipelines/__init__.py +1 -1
- sqlspec/statement/pipelines/context.py +4 -10
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +3 -3
- sqlspec/statement/pipelines/validators/_parameter_style.py +22 -22
- sqlspec/statement/pipelines/validators/_performance.py +1 -5
- sqlspec/statement/sql.py +246 -176
- sqlspec/utils/__init__.py +2 -1
- sqlspec/utils/statement_hashing.py +203 -0
- sqlspec/utils/type_guards.py +32 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/METADATA +1 -1
- sqlspec-0.14.1.dist-info/RECORD +145 -0
- sqlspec-0.14.1.dist-info/entry_points.txt +2 -0
- sqlspec/service/__init__.py +0 -4
- sqlspec/service/_util.py +0 -147
- sqlspec/service/pagination.py +0 -26
- sqlspec/statement/builder/mixins/_aggregate_functions.py +0 -250
- sqlspec/statement/builder/mixins/_case_builder.py +0 -91
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -90
- sqlspec/statement/builder/mixins/_from.py +0 -63
- sqlspec/statement/builder/mixins/_group_by.py +0 -118
- sqlspec/statement/builder/mixins/_having.py +0 -35
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -47
- sqlspec/statement/builder/mixins/_insert_into.py +0 -36
- sqlspec/statement/builder/mixins/_limit_offset.py +0 -53
- sqlspec/statement/builder/mixins/_order_by.py +0 -46
- sqlspec/statement/builder/mixins/_returning.py +0 -37
- sqlspec/statement/builder/mixins/_select_columns.py +0 -61
- sqlspec/statement/builder/mixins/_unpivot.py +0 -77
- sqlspec/statement/builder/mixins/_update_from.py +0 -55
- sqlspec/statement/builder/mixins/_update_table.py +0 -29
- sqlspec/statement/builder/mixins/_where.py +0 -401
- sqlspec/statement/builder/mixins/_window_functions.py +0 -86
- sqlspec/statement/parameter_manager.py +0 -220
- sqlspec/statement/sql_compiler.py +0 -140
- sqlspec-0.13.1.dist-info/RECORD +0 -150
- /sqlspec/statement/builder/{base.py → _base.py} +0 -0
- /sqlspec/statement/builder/mixins/{_join.py → _join_operations.py} +0 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/statement/filters.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections import abc
|
|
5
|
+
from collections.abc import Sequence
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from datetime import datetime
|
|
7
8
|
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Protocol, Union, runtime_checkable
|
|
@@ -25,6 +26,7 @@ __all__ = (
|
|
|
25
26
|
"NotAnyCollectionFilter",
|
|
26
27
|
"NotInCollectionFilter",
|
|
27
28
|
"NotInSearchFilter",
|
|
29
|
+
"OffsetPagination",
|
|
28
30
|
"OnBeforeAfterFilter",
|
|
29
31
|
"OrderByFilter",
|
|
30
32
|
"PaginationFilter",
|
|
@@ -430,8 +432,7 @@ class LimitOffsetFilter(PaginationFilter):
|
|
|
430
432
|
_, named_params = self.extract_parameters()
|
|
431
433
|
for name, value in named_params.items():
|
|
432
434
|
result = result.add_named_parameter(name, value)
|
|
433
|
-
|
|
434
|
-
return result
|
|
435
|
+
return result.filter(self)
|
|
435
436
|
|
|
436
437
|
|
|
437
438
|
@dataclass
|
|
@@ -449,12 +450,21 @@ class OrderByFilter(StatementFilter):
|
|
|
449
450
|
return [], {}
|
|
450
451
|
|
|
451
452
|
def append_to_statement(self, statement: "SQL") -> "SQL":
|
|
452
|
-
|
|
453
|
-
if
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
453
|
+
converted_sort_order = self.sort_order.lower()
|
|
454
|
+
if converted_sort_order not in {"asc", "desc"}:
|
|
455
|
+
converted_sort_order = "asc"
|
|
456
|
+
|
|
457
|
+
col_expr = exp.column(self.field_name)
|
|
458
|
+
order_expr = col_expr.desc() if converted_sort_order == "desc" else col_expr.asc()
|
|
459
|
+
|
|
460
|
+
# Check if the statement supports ORDER BY directly
|
|
461
|
+
if isinstance(statement._statement, exp.Select):
|
|
462
|
+
new_statement = statement._statement.order_by(order_expr)
|
|
463
|
+
else:
|
|
464
|
+
# Wrap in a SELECT if the statement doesn't support ORDER BY directly
|
|
465
|
+
new_statement = exp.Select().from_(statement._statement).order_by(order_expr)
|
|
466
|
+
|
|
467
|
+
return statement.copy(statement=new_statement)
|
|
458
468
|
|
|
459
469
|
|
|
460
470
|
@dataclass
|
|
@@ -568,6 +578,25 @@ class NotInSearchFilter(SearchFilter):
|
|
|
568
578
|
return result
|
|
569
579
|
|
|
570
580
|
|
|
581
|
+
@dataclass
|
|
582
|
+
class OffsetPagination(Generic[T]):
|
|
583
|
+
"""Container for data returned using limit/offset pagination."""
|
|
584
|
+
|
|
585
|
+
__slots__ = ("items", "limit", "offset", "total")
|
|
586
|
+
|
|
587
|
+
items: Sequence[T]
|
|
588
|
+
"""List of data being sent as part of the response."""
|
|
589
|
+
limit: int
|
|
590
|
+
"""Maximal number of items to send."""
|
|
591
|
+
offset: int
|
|
592
|
+
"""Offset from the beginning of the query.
|
|
593
|
+
|
|
594
|
+
Identical to an index.
|
|
595
|
+
"""
|
|
596
|
+
total: int
|
|
597
|
+
"""Total number of items."""
|
|
598
|
+
|
|
599
|
+
|
|
571
600
|
def apply_filter(statement: "SQL", filter_obj: StatementFilter) -> "SQL":
|
|
572
601
|
"""Apply a statement filter to a SQL query object.
|
|
573
602
|
|
sqlspec/statement/parameters.py
CHANGED
|
@@ -10,22 +10,23 @@ import re
|
|
|
10
10
|
from collections.abc import Mapping, Sequence
|
|
11
11
|
from dataclasses import dataclass, field
|
|
12
12
|
from enum import Enum
|
|
13
|
-
from typing import
|
|
13
|
+
from typing import Any, Final, Optional, Union
|
|
14
14
|
|
|
15
|
+
from sqlglot import exp
|
|
15
16
|
from typing_extensions import TypedDict
|
|
16
17
|
|
|
17
18
|
from sqlspec.exceptions import ExtraParameterError, MissingParameterError, ParameterStyleMismatchError
|
|
18
19
|
from sqlspec.typing import SQLParameterType
|
|
19
20
|
|
|
20
|
-
|
|
21
|
-
|
|
21
|
+
# Constants
|
|
22
|
+
MAX_32BIT_INT: Final[int] = 2147483647
|
|
22
23
|
|
|
23
24
|
__all__ = (
|
|
24
25
|
"ConvertedParameters",
|
|
25
26
|
"ParameterConverter",
|
|
26
27
|
"ParameterInfo",
|
|
27
|
-
"ParameterNormalizationState",
|
|
28
28
|
"ParameterStyle",
|
|
29
|
+
"ParameterStyleConversionState",
|
|
29
30
|
"ParameterValidator",
|
|
30
31
|
"SQLParameterType",
|
|
31
32
|
"TypedParameter",
|
|
@@ -140,40 +141,56 @@ class TypedParameter:
|
|
|
140
141
|
semantic_name: "Optional[str]" = None
|
|
141
142
|
"""Optional semantic name derived from SQL context (e.g., 'user_id', 'email')."""
|
|
142
143
|
|
|
144
|
+
def __hash__(self) -> int:
|
|
145
|
+
"""Make TypedParameter hashable for use in cache keys.
|
|
146
|
+
|
|
147
|
+
We hash based on the value and type_hint, which are the key attributes
|
|
148
|
+
that affect SQL compilation and parameter handling.
|
|
149
|
+
"""
|
|
150
|
+
if isinstance(self.value, (list, dict)):
|
|
151
|
+
value_hash = hash(repr(self.value))
|
|
152
|
+
else:
|
|
153
|
+
try:
|
|
154
|
+
value_hash = hash(self.value)
|
|
155
|
+
except TypeError:
|
|
156
|
+
value_hash = hash(repr(self.value))
|
|
157
|
+
|
|
158
|
+
return hash((value_hash, self.type_hint, self.semantic_name))
|
|
143
159
|
|
|
144
|
-
class NormalizationInfo(TypedDict, total=False):
|
|
145
|
-
"""Information about SQL parameter normalization."""
|
|
146
160
|
|
|
147
|
-
|
|
161
|
+
class ParameterStyleInfo(TypedDict, total=False):
|
|
162
|
+
"""Information about SQL parameter style transformation."""
|
|
163
|
+
|
|
164
|
+
was_converted: bool
|
|
148
165
|
placeholder_map: dict[str, Union[str, int]]
|
|
149
166
|
original_styles: list[ParameterStyle]
|
|
150
167
|
|
|
151
168
|
|
|
152
169
|
@dataclass
|
|
153
|
-
class
|
|
154
|
-
"""Encapsulates all information about parameter
|
|
170
|
+
class ParameterStyleConversionState:
|
|
171
|
+
"""Encapsulates all information about parameter style transformation.
|
|
155
172
|
|
|
156
173
|
This class provides a single source of truth for parameter style conversions,
|
|
157
|
-
making it easier to track and reverse
|
|
174
|
+
making it easier to track and reverse transformations applied for SQLGlot compatibility.
|
|
158
175
|
"""
|
|
159
176
|
|
|
160
|
-
|
|
161
|
-
"""Whether parameter
|
|
177
|
+
was_transformed: bool = False
|
|
178
|
+
"""Whether parameter transformation was applied."""
|
|
162
179
|
|
|
163
180
|
original_styles: list[ParameterStyle] = field(default_factory=list)
|
|
164
181
|
"""Original parameter style(s) detected in the SQL."""
|
|
165
182
|
|
|
166
|
-
|
|
167
|
-
"""Target style used for
|
|
183
|
+
transformation_style: Optional[ParameterStyle] = None
|
|
184
|
+
"""Target style used for transformation (if transformed)."""
|
|
168
185
|
|
|
169
186
|
placeholder_map: dict[str, Union[str, int]] = field(default_factory=dict)
|
|
170
|
-
"""Mapping from
|
|
187
|
+
"""Mapping from transformed names to original names/positions."""
|
|
171
188
|
|
|
172
189
|
reverse_map: dict[Union[str, int], str] = field(default_factory=dict)
|
|
173
190
|
"""Reverse mapping for quick lookups."""
|
|
174
191
|
|
|
175
192
|
original_param_info: list["ParameterInfo"] = field(default_factory=list)
|
|
176
|
-
"""Original parameter info before
|
|
193
|
+
"""Original parameter info before conversion."""
|
|
177
194
|
|
|
178
195
|
def __post_init__(self) -> None:
|
|
179
196
|
"""Build reverse map if not provided."""
|
|
@@ -194,8 +211,8 @@ class ConvertedParameters:
|
|
|
194
211
|
merged_parameters: "SQLParameterType"
|
|
195
212
|
"""Parameters after merging from various sources."""
|
|
196
213
|
|
|
197
|
-
|
|
198
|
-
"""Complete
|
|
214
|
+
conversion_state: ParameterStyleConversionState
|
|
215
|
+
"""Complete conversion state for tracking conversions."""
|
|
199
216
|
|
|
200
217
|
|
|
201
218
|
@dataclass
|
|
@@ -295,17 +312,13 @@ class ParameterValidator:
|
|
|
295
312
|
"""
|
|
296
313
|
if not parameters_info:
|
|
297
314
|
return ParameterStyle.NONE
|
|
298
|
-
|
|
299
|
-
# Note: This logic prioritizes pyformat if present, then named, then positional.
|
|
300
315
|
is_pyformat_named = any(p.style == ParameterStyle.NAMED_PYFORMAT for p in parameters_info)
|
|
301
316
|
is_pyformat_positional = any(p.style == ParameterStyle.POSITIONAL_PYFORMAT for p in parameters_info)
|
|
302
317
|
|
|
303
318
|
if is_pyformat_named:
|
|
304
319
|
return ParameterStyle.NAMED_PYFORMAT
|
|
305
|
-
if is_pyformat_positional:
|
|
320
|
+
if is_pyformat_positional:
|
|
306
321
|
return ParameterStyle.POSITIONAL_PYFORMAT
|
|
307
|
-
|
|
308
|
-
# Simplified logic if not pyformat, checks for any named or any positional
|
|
309
322
|
has_named = any(
|
|
310
323
|
p.style
|
|
311
324
|
in {
|
|
@@ -317,13 +330,7 @@ class ParameterValidator:
|
|
|
317
330
|
for p in parameters_info
|
|
318
331
|
)
|
|
319
332
|
has_positional = any(p.style in {ParameterStyle.QMARK, ParameterStyle.NUMERIC} for p in parameters_info)
|
|
320
|
-
|
|
321
|
-
# If mixed named and positional (non-pyformat), prefer named as dominant.
|
|
322
|
-
# The choice of NAMED_COLON here is somewhat arbitrary if multiple named styles are mixed.
|
|
323
333
|
if has_named:
|
|
324
|
-
# Could refine to return the style of the first named param encountered, or most frequent.
|
|
325
|
-
# For simplicity, returning a general named style like NAMED_COLON is often sufficient.
|
|
326
|
-
# Or, more accurately, find the first named style:
|
|
327
334
|
for p_style in (
|
|
328
335
|
ParameterStyle.NAMED_COLON,
|
|
329
336
|
ParameterStyle.POSITIONAL_COLON,
|
|
@@ -335,12 +342,11 @@ class ParameterValidator:
|
|
|
335
342
|
return ParameterStyle.NAMED_COLON
|
|
336
343
|
|
|
337
344
|
if has_positional:
|
|
338
|
-
# Similarly, could choose QMARK or NUMERIC based on presence.
|
|
339
345
|
if any(p.style == ParameterStyle.NUMERIC for p in parameters_info):
|
|
340
346
|
return ParameterStyle.NUMERIC
|
|
341
|
-
return ParameterStyle.QMARK
|
|
347
|
+
return ParameterStyle.QMARK
|
|
342
348
|
|
|
343
|
-
return ParameterStyle.NONE
|
|
349
|
+
return ParameterStyle.NONE
|
|
344
350
|
|
|
345
351
|
@staticmethod
|
|
346
352
|
def determine_parameter_input_type(parameters_info: "list[ParameterInfo]") -> "Optional[type]":
|
|
@@ -365,9 +371,8 @@ class ParameterValidator:
|
|
|
365
371
|
if any(
|
|
366
372
|
p.name is not None and p.style not in {ParameterStyle.POSITIONAL_COLON, ParameterStyle.NUMERIC}
|
|
367
373
|
for p in parameters_info
|
|
368
|
-
):
|
|
374
|
+
):
|
|
369
375
|
return dict
|
|
370
|
-
# All parameters must have p.name is None or be positional styles (POSITIONAL_COLON, NUMERIC)
|
|
371
376
|
if all(
|
|
372
377
|
p.name is None or p.style in {ParameterStyle.POSITIONAL_COLON, ParameterStyle.NUMERIC}
|
|
373
378
|
for p in parameters_info
|
|
@@ -381,9 +386,7 @@ class ParameterValidator:
|
|
|
381
386
|
"Ambiguous parameter structure for determining input type. "
|
|
382
387
|
"Query might contain a mix of named and unnamed styles not typically supported together."
|
|
383
388
|
)
|
|
384
|
-
|
|
385
|
-
# However, strict validation should ideally prevent such mixed styles from being valid.
|
|
386
|
-
return dict # Or raise an error for unsupported mixed styles.
|
|
389
|
+
return dict
|
|
387
390
|
|
|
388
391
|
def validate_parameters(
|
|
389
392
|
self,
|
|
@@ -402,12 +405,7 @@ class ParameterValidator:
|
|
|
402
405
|
ParameterStyleMismatchError: When style doesn't match
|
|
403
406
|
"""
|
|
404
407
|
expected_input_type = self.determine_parameter_input_type(parameters_info)
|
|
405
|
-
|
|
406
|
-
# Allow creating SQL statements with placeholders but no parameters
|
|
407
|
-
# This enables patterns like SQL("SELECT * FROM users WHERE id = ?").as_many([...])
|
|
408
|
-
# Validation will happen later when parameters are actually provided
|
|
409
408
|
if provided_params is None and parameters_info:
|
|
410
|
-
# Don't raise an error, just return - validation will happen later
|
|
411
409
|
return
|
|
412
410
|
|
|
413
411
|
if (
|
|
@@ -673,7 +671,7 @@ class ParameterConverter:
|
|
|
673
671
|
"""
|
|
674
672
|
parameters_info = self.validator.extract_parameters(sql)
|
|
675
673
|
|
|
676
|
-
|
|
674
|
+
needs_conversion = any(p.style in SQLGLOT_INCOMPATIBLE_STYLES for p in parameters_info)
|
|
677
675
|
|
|
678
676
|
has_positional = any(p.name is None for p in parameters_info)
|
|
679
677
|
has_named = any(p.name is not None for p in parameters_info)
|
|
@@ -686,19 +684,19 @@ class ParameterConverter:
|
|
|
686
684
|
|
|
687
685
|
if validate:
|
|
688
686
|
self.validator.validate_parameters(parameters_info, merged_params, sql)
|
|
689
|
-
if
|
|
687
|
+
if needs_conversion:
|
|
690
688
|
transformed_sql, placeholder_map = self._transform_sql_for_parsing(sql, parameters_info)
|
|
691
|
-
|
|
692
|
-
|
|
689
|
+
conversion_state = ParameterStyleConversionState(
|
|
690
|
+
was_transformed=True,
|
|
693
691
|
original_styles=list({p.style for p in parameters_info}),
|
|
694
|
-
|
|
692
|
+
transformation_style=ParameterStyle.NAMED_COLON,
|
|
695
693
|
placeholder_map=placeholder_map,
|
|
696
694
|
original_param_info=parameters_info,
|
|
697
695
|
)
|
|
698
696
|
else:
|
|
699
697
|
transformed_sql = sql
|
|
700
|
-
|
|
701
|
-
|
|
698
|
+
conversion_state = ParameterStyleConversionState(
|
|
699
|
+
was_transformed=False,
|
|
702
700
|
original_styles=list({p.style for p in parameters_info}),
|
|
703
701
|
original_param_info=parameters_info,
|
|
704
702
|
)
|
|
@@ -707,7 +705,7 @@ class ParameterConverter:
|
|
|
707
705
|
transformed_sql=transformed_sql,
|
|
708
706
|
parameter_info=parameters_info,
|
|
709
707
|
merged_parameters=merged_params,
|
|
710
|
-
|
|
708
|
+
conversion_state=conversion_state,
|
|
711
709
|
)
|
|
712
710
|
|
|
713
711
|
@staticmethod
|
|
@@ -756,10 +754,10 @@ class ParameterConverter:
|
|
|
756
754
|
return parameters
|
|
757
755
|
|
|
758
756
|
if kwargs is not None:
|
|
759
|
-
return dict(kwargs)
|
|
757
|
+
return dict(kwargs)
|
|
760
758
|
|
|
761
759
|
if args is not None:
|
|
762
|
-
return list(args)
|
|
760
|
+
return list(args)
|
|
763
761
|
|
|
764
762
|
return None
|
|
765
763
|
|
|
@@ -781,7 +779,98 @@ class ParameterConverter:
|
|
|
781
779
|
Returns:
|
|
782
780
|
Parameters with TypedParameter wrapping where appropriate
|
|
783
781
|
"""
|
|
784
|
-
|
|
782
|
+
if parameters is None:
|
|
783
|
+
return None
|
|
784
|
+
|
|
785
|
+
# Import here to avoid circular imports
|
|
786
|
+
from datetime import date, datetime, time
|
|
787
|
+
from decimal import Decimal
|
|
788
|
+
|
|
789
|
+
def infer_type_from_value(value: Any) -> tuple[str, "exp.DataType"]:
|
|
790
|
+
"""Infer SQL type hint and SQLGlot DataType from Python value."""
|
|
791
|
+
|
|
792
|
+
# None/NULL
|
|
793
|
+
if value is None:
|
|
794
|
+
return "null", exp.DataType.build("NULL")
|
|
795
|
+
if isinstance(value, bool):
|
|
796
|
+
return "boolean", exp.DataType.build("BOOLEAN")
|
|
797
|
+
if isinstance(value, int) and not isinstance(value, bool):
|
|
798
|
+
if abs(value) > MAX_32BIT_INT:
|
|
799
|
+
return "bigint", exp.DataType.build("BIGINT")
|
|
800
|
+
return "integer", exp.DataType.build("INT")
|
|
801
|
+
if isinstance(value, float):
|
|
802
|
+
return "float", exp.DataType.build("FLOAT")
|
|
803
|
+
if isinstance(value, Decimal):
|
|
804
|
+
return "decimal", exp.DataType.build("DECIMAL")
|
|
805
|
+
if isinstance(value, datetime):
|
|
806
|
+
return "timestamp", exp.DataType.build("TIMESTAMP")
|
|
807
|
+
if isinstance(value, date):
|
|
808
|
+
return "date", exp.DataType.build("DATE")
|
|
809
|
+
if isinstance(value, time):
|
|
810
|
+
return "time", exp.DataType.build("TIME")
|
|
811
|
+
if isinstance(value, dict):
|
|
812
|
+
return "json", exp.DataType.build("JSON")
|
|
813
|
+
if isinstance(value, (list, tuple)):
|
|
814
|
+
return "array", exp.DataType.build("ARRAY")
|
|
815
|
+
if isinstance(value, str):
|
|
816
|
+
return "string", exp.DataType.build("VARCHAR")
|
|
817
|
+
if isinstance(value, bytes):
|
|
818
|
+
return "binary", exp.DataType.build("BINARY")
|
|
819
|
+
return "string", exp.DataType.build("VARCHAR")
|
|
820
|
+
|
|
821
|
+
def wrap_value(value: Any, semantic_name: Optional[str] = None) -> Any:
|
|
822
|
+
"""Wrap a single value with TypedParameter if beneficial."""
|
|
823
|
+
# Don't wrap if already a TypedParameter
|
|
824
|
+
if hasattr(value, "__class__") and value.__class__.__name__ == "TypedParameter":
|
|
825
|
+
return value
|
|
826
|
+
|
|
827
|
+
# Don't wrap simple scalar types unless they need special handling
|
|
828
|
+
if isinstance(value, (str, int, float)) and not isinstance(value, bool):
|
|
829
|
+
# For simple types, only wrap if we have special type needs
|
|
830
|
+
# (e.g., bigint, decimal precision, etc.)
|
|
831
|
+
if isinstance(value, int) and abs(value) > MAX_32BIT_INT:
|
|
832
|
+
# Wrap large integers as bigint
|
|
833
|
+
type_hint, sqlglot_type = infer_type_from_value(value)
|
|
834
|
+
return TypedParameter(
|
|
835
|
+
value=value, sqlglot_type=sqlglot_type, type_hint=type_hint, semantic_name=semantic_name
|
|
836
|
+
)
|
|
837
|
+
# Otherwise, return unwrapped for performance
|
|
838
|
+
return value
|
|
839
|
+
|
|
840
|
+
# Wrap complex types and types needing special handling
|
|
841
|
+
if isinstance(value, (datetime, date, time, Decimal, dict, list, tuple, bytes, bool, type(None))):
|
|
842
|
+
type_hint, sqlglot_type = infer_type_from_value(value)
|
|
843
|
+
return TypedParameter(
|
|
844
|
+
value=value, sqlglot_type=sqlglot_type, type_hint=type_hint, semantic_name=semantic_name
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
# Default: return unwrapped
|
|
848
|
+
return value
|
|
849
|
+
|
|
850
|
+
# Handle different parameter structures
|
|
851
|
+
if isinstance(parameters, dict):
|
|
852
|
+
# Wrap dict values selectively
|
|
853
|
+
wrapped_dict = {}
|
|
854
|
+
for key, value in parameters.items():
|
|
855
|
+
wrapped_dict[key] = wrap_value(value, semantic_name=key)
|
|
856
|
+
return wrapped_dict
|
|
857
|
+
|
|
858
|
+
if isinstance(parameters, (list, tuple)):
|
|
859
|
+
# Wrap list/tuple values selectively
|
|
860
|
+
wrapped_list: list[Any] = []
|
|
861
|
+
for i, value in enumerate(parameters):
|
|
862
|
+
# Try to get semantic name from parameters_info if available
|
|
863
|
+
semantic_name = None
|
|
864
|
+
if parameters_info and i < len(parameters_info) and parameters_info[i].name:
|
|
865
|
+
semantic_name = parameters_info[i].name
|
|
866
|
+
wrapped_list.append(wrap_value(value, semantic_name=semantic_name))
|
|
867
|
+
return wrapped_list if isinstance(parameters, list) else tuple(wrapped_list)
|
|
868
|
+
|
|
869
|
+
# Single scalar parameter
|
|
870
|
+
semantic_name = None
|
|
871
|
+
if parameters_info and parameters_info[0].name:
|
|
872
|
+
semantic_name = parameters_info[0].name
|
|
873
|
+
return wrap_value(parameters, semantic_name=semantic_name)
|
|
785
874
|
|
|
786
875
|
def _convert_sql_placeholders(
|
|
787
876
|
self, rendered_sql: str, final_parameter_info: "list[ParameterInfo]", target_style: "ParameterStyle"
|
|
@@ -816,7 +905,7 @@ class ParameterConverter:
|
|
|
816
905
|
from sqlspec.exceptions import SQLTransformationError
|
|
817
906
|
|
|
818
907
|
msg = (
|
|
819
|
-
f"Parameter count mismatch during
|
|
908
|
+
f"Parameter count mismatch during deconversion. "
|
|
820
909
|
f"Expected at least {len(final_parameter_info)} parameters, "
|
|
821
910
|
f"found {len(canonical_params)} in SQL"
|
|
822
911
|
)
|
|
@@ -155,7 +155,7 @@ class StatementPipeline:
|
|
|
155
155
|
UnsupportedParameterStyleError,
|
|
156
156
|
)
|
|
157
157
|
|
|
158
|
-
if context.config.
|
|
158
|
+
if not context.config.parse_errors_as_warnings and isinstance(
|
|
159
159
|
e, (MissingParameterError, MixedParameterStyleError, UnsupportedParameterStyleError)
|
|
160
160
|
):
|
|
161
161
|
raise
|
|
@@ -8,7 +8,7 @@ from sqlspec.exceptions import RiskLevel
|
|
|
8
8
|
if TYPE_CHECKING:
|
|
9
9
|
from sqlglot.dialects.dialect import DialectType
|
|
10
10
|
|
|
11
|
-
from sqlspec.statement.parameters import ParameterInfo,
|
|
11
|
+
from sqlspec.statement.parameters import ParameterInfo, ParameterStyleConversionState
|
|
12
12
|
from sqlspec.statement.sql import SQLConfig
|
|
13
13
|
from sqlspec.typing import SQLParameterType
|
|
14
14
|
|
|
@@ -66,12 +66,6 @@ class SQLProcessingContext:
|
|
|
66
66
|
# Current state
|
|
67
67
|
current_expression: Optional[exp.Expression] = None
|
|
68
68
|
"""The SQL expression, potentially modified by transformers."""
|
|
69
|
-
|
|
70
|
-
# Parameters
|
|
71
|
-
initial_parameters: "Optional[SQLParameterType]" = None
|
|
72
|
-
"""The initial parameters as provided to the SQL object (before merging with kwargs)."""
|
|
73
|
-
initial_kwargs: "Optional[dict[str, Any]]" = None
|
|
74
|
-
"""The initial keyword arguments as provided to the SQL object."""
|
|
75
69
|
merged_parameters: "SQLParameterType" = field(default_factory=list)
|
|
76
70
|
"""Parameters after merging initial_parameters and initial_kwargs."""
|
|
77
71
|
parameter_info: "list[ParameterInfo]" = field(default_factory=list)
|
|
@@ -97,10 +91,10 @@ class SQLProcessingContext:
|
|
|
97
91
|
statement_type: Optional[str] = None
|
|
98
92
|
"""The detected type of the SQL statement (e.g., SELECT, INSERT, DDL)."""
|
|
99
93
|
extra_info: dict[str, Any] = field(default_factory=dict)
|
|
100
|
-
"""Extra information from parameter processing, including
|
|
94
|
+
"""Extra information from parameter processing, including conversion state."""
|
|
101
95
|
|
|
102
|
-
|
|
103
|
-
"""Single source of truth for parameter
|
|
96
|
+
parameter_conversion: "Optional[ParameterStyleConversionState]" = None
|
|
97
|
+
"""Single source of truth for parameter style conversion tracking."""
|
|
104
98
|
|
|
105
99
|
@property
|
|
106
100
|
def has_errors(self) -> bool:
|
|
@@ -21,7 +21,7 @@ class SimplificationConfig:
|
|
|
21
21
|
enable_literal_folding: bool = True
|
|
22
22
|
enable_boolean_optimization: bool = True
|
|
23
23
|
enable_connector_optimization: bool = True
|
|
24
|
-
|
|
24
|
+
enable_equality_conversion: bool = True
|
|
25
25
|
enable_complement_removal: bool = True
|
|
26
26
|
|
|
27
27
|
|
|
@@ -74,8 +74,8 @@ class ExpressionSimplifier(ProcessorProtocol):
|
|
|
74
74
|
optimizations.append("boolean_optimization")
|
|
75
75
|
if self.config.enable_connector_optimization:
|
|
76
76
|
optimizations.append("connector_optimization")
|
|
77
|
-
if self.config.
|
|
78
|
-
optimizations.append("
|
|
77
|
+
if self.config.enable_equality_conversion:
|
|
78
|
+
optimizations.append("equality_conversion")
|
|
79
79
|
if self.config.enable_complement_removal:
|
|
80
80
|
optimizations.append("complement_removal")
|
|
81
81
|
|
|
@@ -73,13 +73,13 @@ class ParameterStyleValidator(ProcessorProtocol):
|
|
|
73
73
|
config = context.config
|
|
74
74
|
param_info = context.parameter_info
|
|
75
75
|
|
|
76
|
-
# Check if parameters were
|
|
77
|
-
# This happens when Oracle numeric parameters (:1, :2) are
|
|
78
|
-
|
|
76
|
+
# Check if parameters were converted by looking for param_ placeholders
|
|
77
|
+
# This happens when Oracle numeric parameters (:1, :2) are converted
|
|
78
|
+
is_converted = param_info and any(p.name and p.name.startswith("param_") for p in param_info)
|
|
79
79
|
|
|
80
|
-
# First check parameter styles if configured (skip if
|
|
80
|
+
# First check parameter styles if configured (skip if converted)
|
|
81
81
|
has_style_errors = False
|
|
82
|
-
if not
|
|
82
|
+
if not is_converted and config.allowed_parameter_styles is not None and param_info:
|
|
83
83
|
unique_styles = {p.style for p in param_info}
|
|
84
84
|
|
|
85
85
|
if len(unique_styles) > 1 and not config.allow_mixed_parameter_styles:
|
|
@@ -279,11 +279,11 @@ class ParameterStyleValidator(ProcessorProtocol):
|
|
|
279
279
|
"""Handle validation for named parameters."""
|
|
280
280
|
missing: list[str] = []
|
|
281
281
|
|
|
282
|
-
# Check if we have
|
|
283
|
-
|
|
282
|
+
# Check if we have converted parameters (e.g., param_0)
|
|
283
|
+
is_converted = any(p.name and p.name.startswith("param_") for p in param_info)
|
|
284
284
|
|
|
285
|
-
if
|
|
286
|
-
# For
|
|
285
|
+
if is_converted and hasattr(context, "extra_info"):
|
|
286
|
+
# For converted parameters, we need to check against the original placeholder mapping
|
|
287
287
|
placeholder_map = context.extra_info.get("placeholder_map", {})
|
|
288
288
|
|
|
289
289
|
# Check if we have Oracle numeric keys in merged_params
|
|
@@ -291,9 +291,9 @@ class ParameterStyleValidator(ProcessorProtocol):
|
|
|
291
291
|
|
|
292
292
|
if all_numeric_keys:
|
|
293
293
|
# Parameters were provided as list and converted to Oracle numeric dict {"1": val1, "2": val2}
|
|
294
|
-
for i
|
|
295
|
-
|
|
296
|
-
original_key = placeholder_map.get(
|
|
294
|
+
for i in range(len(param_info)):
|
|
295
|
+
converted_name = f"param_{i}"
|
|
296
|
+
original_key = placeholder_map.get(converted_name)
|
|
297
297
|
|
|
298
298
|
if original_key is not None:
|
|
299
299
|
# Check using the original key (e.g., "1", "2" for Oracle)
|
|
@@ -309,11 +309,11 @@ class ParameterStyleValidator(ProcessorProtocol):
|
|
|
309
309
|
|
|
310
310
|
if all_param_keys:
|
|
311
311
|
# This was originally a list converted to dict with param_N keys
|
|
312
|
-
for i
|
|
313
|
-
|
|
314
|
-
if
|
|
312
|
+
for i in range(len(param_info)):
|
|
313
|
+
converted_name = f"param_{i}"
|
|
314
|
+
if converted_name not in merged_params or merged_params[converted_name] is None:
|
|
315
315
|
# Get original parameter style from placeholder map
|
|
316
|
-
original_key = placeholder_map.get(
|
|
316
|
+
original_key = placeholder_map.get(converted_name)
|
|
317
317
|
if original_key is not None:
|
|
318
318
|
original_key_str = str(original_key)
|
|
319
319
|
if original_key_str.isdigit():
|
|
@@ -322,16 +322,16 @@ class ParameterStyleValidator(ProcessorProtocol):
|
|
|
322
322
|
missing.append(f":{original_key}")
|
|
323
323
|
else:
|
|
324
324
|
# Mixed parameter names, check using placeholder map
|
|
325
|
-
for i
|
|
326
|
-
|
|
327
|
-
original_key = placeholder_map.get(
|
|
325
|
+
for i in range(len(param_info)):
|
|
326
|
+
converted_name = f"param_{i}"
|
|
327
|
+
original_key = placeholder_map.get(converted_name)
|
|
328
328
|
|
|
329
329
|
if original_key is not None:
|
|
330
|
-
# For mixed params, check both
|
|
330
|
+
# For mixed params, check both converted and original keys
|
|
331
331
|
original_key_str = str(original_key)
|
|
332
332
|
|
|
333
|
-
# First check with
|
|
334
|
-
found =
|
|
333
|
+
# First check with converted name
|
|
334
|
+
found = converted_name in merged_params and merged_params[converted_name] is not None
|
|
335
335
|
|
|
336
336
|
# If not found, check with original key
|
|
337
337
|
if not found:
|
|
@@ -601,11 +601,7 @@ class PerformanceValidator(ProcessorProtocol):
|
|
|
601
601
|
),
|
|
602
602
|
("join_optimization", optimize_joins.optimize_joins, "Optimize join order and conditions"),
|
|
603
603
|
("simplification", simplify.simplify, "Simplify expressions and conditions"),
|
|
604
|
-
(
|
|
605
|
-
"identifier_normalization",
|
|
606
|
-
normalize_identifiers.normalize_identifiers,
|
|
607
|
-
"Normalize identifier casing",
|
|
608
|
-
),
|
|
604
|
+
("identifier_conversion", normalize_identifiers.normalize_identifiers, "Normalize identifier casing"),
|
|
609
605
|
]
|
|
610
606
|
|
|
611
607
|
best_optimized = expression.copy()
|