sqlspec 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +7 -15
- sqlspec/_serialization.py +256 -24
- sqlspec/_typing.py +71 -52
- sqlspec/adapters/adbc/_types.py +1 -1
- sqlspec/adapters/adbc/adk/__init__.py +5 -0
- sqlspec/adapters/adbc/adk/store.py +870 -0
- sqlspec/adapters/adbc/config.py +69 -12
- sqlspec/adapters/adbc/data_dictionary.py +340 -0
- sqlspec/adapters/adbc/driver.py +266 -58
- sqlspec/adapters/adbc/litestar/__init__.py +5 -0
- sqlspec/adapters/adbc/litestar/store.py +504 -0
- sqlspec/adapters/adbc/type_converter.py +153 -0
- sqlspec/adapters/aiosqlite/_types.py +1 -1
- sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/adk/store.py +527 -0
- sqlspec/adapters/aiosqlite/config.py +88 -15
- sqlspec/adapters/aiosqlite/data_dictionary.py +149 -0
- sqlspec/adapters/aiosqlite/driver.py +143 -40
- sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
- sqlspec/adapters/aiosqlite/pool.py +7 -7
- sqlspec/adapters/asyncmy/__init__.py +7 -1
- sqlspec/adapters/asyncmy/_types.py +2 -2
- sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
- sqlspec/adapters/asyncmy/adk/store.py +493 -0
- sqlspec/adapters/asyncmy/config.py +68 -23
- sqlspec/adapters/asyncmy/data_dictionary.py +161 -0
- sqlspec/adapters/asyncmy/driver.py +313 -58
- sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncmy/litestar/store.py +296 -0
- sqlspec/adapters/asyncpg/__init__.py +2 -1
- sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
- sqlspec/adapters/asyncpg/_types.py +11 -7
- sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
- sqlspec/adapters/asyncpg/adk/store.py +450 -0
- sqlspec/adapters/asyncpg/config.py +59 -35
- sqlspec/adapters/asyncpg/data_dictionary.py +173 -0
- sqlspec/adapters/asyncpg/driver.py +170 -25
- sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncpg/litestar/store.py +253 -0
- sqlspec/adapters/bigquery/_types.py +1 -1
- sqlspec/adapters/bigquery/adk/__init__.py +5 -0
- sqlspec/adapters/bigquery/adk/store.py +576 -0
- sqlspec/adapters/bigquery/config.py +27 -10
- sqlspec/adapters/bigquery/data_dictionary.py +149 -0
- sqlspec/adapters/bigquery/driver.py +368 -142
- sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
- sqlspec/adapters/bigquery/litestar/store.py +327 -0
- sqlspec/adapters/bigquery/type_converter.py +125 -0
- sqlspec/adapters/duckdb/_types.py +1 -1
- sqlspec/adapters/duckdb/adk/__init__.py +14 -0
- sqlspec/adapters/duckdb/adk/store.py +553 -0
- sqlspec/adapters/duckdb/config.py +80 -20
- sqlspec/adapters/duckdb/data_dictionary.py +163 -0
- sqlspec/adapters/duckdb/driver.py +167 -45
- sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
- sqlspec/adapters/duckdb/litestar/store.py +332 -0
- sqlspec/adapters/duckdb/pool.py +4 -4
- sqlspec/adapters/duckdb/type_converter.py +133 -0
- sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
- sqlspec/adapters/oracledb/_types.py +20 -2
- sqlspec/adapters/oracledb/adk/__init__.py +5 -0
- sqlspec/adapters/oracledb/adk/store.py +1745 -0
- sqlspec/adapters/oracledb/config.py +122 -32
- sqlspec/adapters/oracledb/data_dictionary.py +509 -0
- sqlspec/adapters/oracledb/driver.py +353 -91
- sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
- sqlspec/adapters/oracledb/litestar/store.py +767 -0
- sqlspec/adapters/oracledb/migrations.py +348 -73
- sqlspec/adapters/oracledb/type_converter.py +207 -0
- sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
- sqlspec/adapters/psqlpy/_types.py +2 -1
- sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
- sqlspec/adapters/psqlpy/adk/store.py +482 -0
- sqlspec/adapters/psqlpy/config.py +46 -17
- sqlspec/adapters/psqlpy/data_dictionary.py +172 -0
- sqlspec/adapters/psqlpy/driver.py +123 -209
- sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
- sqlspec/adapters/psqlpy/litestar/store.py +272 -0
- sqlspec/adapters/psqlpy/type_converter.py +102 -0
- sqlspec/adapters/psycopg/_type_handlers.py +80 -0
- sqlspec/adapters/psycopg/_types.py +2 -1
- sqlspec/adapters/psycopg/adk/__init__.py +5 -0
- sqlspec/adapters/psycopg/adk/store.py +944 -0
- sqlspec/adapters/psycopg/config.py +69 -35
- sqlspec/adapters/psycopg/data_dictionary.py +331 -0
- sqlspec/adapters/psycopg/driver.py +238 -81
- sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
- sqlspec/adapters/psycopg/litestar/store.py +554 -0
- sqlspec/adapters/sqlite/__init__.py +2 -1
- sqlspec/adapters/sqlite/_type_handlers.py +86 -0
- sqlspec/adapters/sqlite/_types.py +1 -1
- sqlspec/adapters/sqlite/adk/__init__.py +5 -0
- sqlspec/adapters/sqlite/adk/store.py +572 -0
- sqlspec/adapters/sqlite/config.py +87 -15
- sqlspec/adapters/sqlite/data_dictionary.py +149 -0
- sqlspec/adapters/sqlite/driver.py +137 -54
- sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/sqlite/litestar/store.py +318 -0
- sqlspec/adapters/sqlite/pool.py +18 -9
- sqlspec/base.py +45 -26
- sqlspec/builder/__init__.py +73 -4
- sqlspec/builder/_base.py +162 -89
- sqlspec/builder/_column.py +62 -29
- sqlspec/builder/_ddl.py +180 -121
- sqlspec/builder/_delete.py +5 -4
- sqlspec/builder/_dml.py +388 -0
- sqlspec/{_sql.py → builder/_factory.py} +53 -94
- sqlspec/builder/_insert.py +32 -131
- sqlspec/builder/_join.py +375 -0
- sqlspec/builder/_merge.py +446 -11
- sqlspec/builder/_parsing_utils.py +111 -17
- sqlspec/builder/_select.py +1457 -24
- sqlspec/builder/_update.py +11 -42
- sqlspec/cli.py +307 -194
- sqlspec/config.py +252 -67
- sqlspec/core/__init__.py +5 -4
- sqlspec/core/cache.py +17 -17
- sqlspec/core/compiler.py +62 -9
- sqlspec/core/filters.py +37 -37
- sqlspec/core/hashing.py +9 -9
- sqlspec/core/parameters.py +83 -48
- sqlspec/core/result.py +102 -46
- sqlspec/core/splitter.py +16 -17
- sqlspec/core/statement.py +36 -30
- sqlspec/core/type_conversion.py +235 -0
- sqlspec/driver/__init__.py +7 -6
- sqlspec/driver/_async.py +188 -151
- sqlspec/driver/_common.py +285 -80
- sqlspec/driver/_sync.py +188 -152
- sqlspec/driver/mixins/_result_tools.py +20 -236
- sqlspec/driver/mixins/_sql_translator.py +4 -4
- sqlspec/exceptions.py +75 -7
- sqlspec/extensions/adk/__init__.py +53 -0
- sqlspec/extensions/adk/_types.py +51 -0
- sqlspec/extensions/adk/converters.py +172 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
- sqlspec/extensions/adk/migrations/__init__.py +0 -0
- sqlspec/extensions/adk/service.py +181 -0
- sqlspec/extensions/adk/store.py +536 -0
- sqlspec/extensions/aiosql/adapter.py +73 -53
- sqlspec/extensions/litestar/__init__.py +21 -4
- sqlspec/extensions/litestar/cli.py +54 -10
- sqlspec/extensions/litestar/config.py +59 -266
- sqlspec/extensions/litestar/handlers.py +46 -17
- sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
- sqlspec/extensions/litestar/migrations/__init__.py +3 -0
- sqlspec/extensions/litestar/plugin.py +324 -223
- sqlspec/extensions/litestar/providers.py +25 -25
- sqlspec/extensions/litestar/store.py +265 -0
- sqlspec/loader.py +30 -49
- sqlspec/migrations/__init__.py +4 -3
- sqlspec/migrations/base.py +302 -39
- sqlspec/migrations/commands.py +611 -144
- sqlspec/migrations/context.py +142 -0
- sqlspec/migrations/fix.py +199 -0
- sqlspec/migrations/loaders.py +68 -23
- sqlspec/migrations/runner.py +543 -107
- sqlspec/migrations/tracker.py +237 -21
- sqlspec/migrations/utils.py +51 -3
- sqlspec/migrations/validation.py +177 -0
- sqlspec/protocols.py +66 -36
- sqlspec/storage/_utils.py +98 -0
- sqlspec/storage/backends/fsspec.py +134 -106
- sqlspec/storage/backends/local.py +78 -51
- sqlspec/storage/backends/obstore.py +278 -162
- sqlspec/storage/registry.py +75 -39
- sqlspec/typing.py +16 -84
- sqlspec/utils/config_resolver.py +153 -0
- sqlspec/utils/correlation.py +4 -5
- sqlspec/utils/data_transformation.py +3 -2
- sqlspec/utils/deprecation.py +9 -8
- sqlspec/utils/fixtures.py +4 -4
- sqlspec/utils/logging.py +46 -6
- sqlspec/utils/module_loader.py +2 -2
- sqlspec/utils/schema.py +288 -0
- sqlspec/utils/serializers.py +50 -2
- sqlspec/utils/sync_tools.py +21 -17
- sqlspec/utils/text.py +1 -2
- sqlspec/utils/type_guards.py +111 -20
- sqlspec/utils/version.py +433 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/METADATA +40 -21
- sqlspec-0.27.0.dist-info/RECORD +207 -0
- sqlspec/builder/mixins/__init__.py +0 -55
- sqlspec/builder/mixins/_cte_and_set_ops.py +0 -254
- sqlspec/builder/mixins/_delete_operations.py +0 -50
- sqlspec/builder/mixins/_insert_operations.py +0 -282
- sqlspec/builder/mixins/_join_operations.py +0 -389
- sqlspec/builder/mixins/_merge_operations.py +0 -592
- sqlspec/builder/mixins/_order_limit_operations.py +0 -152
- sqlspec/builder/mixins/_pivot_operations.py +0 -157
- sqlspec/builder/mixins/_select_operations.py +0 -936
- sqlspec/builder/mixins/_update_operations.py +0 -218
- sqlspec/builder/mixins/_where_clause.py +0 -1304
- sqlspec-0.25.0.dist-info/RECORD +0 -139
- sqlspec-0.25.0.dist-info/licenses/NOTICE +0 -29
- {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.27.0.dist-info}/licenses/LICENSE +0 -0
sqlspec/core/cache.py
CHANGED
|
@@ -173,8 +173,8 @@ class CacheNode:
|
|
|
173
173
|
"""
|
|
174
174
|
self.key = key
|
|
175
175
|
self.value = value
|
|
176
|
-
self.prev:
|
|
177
|
-
self.next:
|
|
176
|
+
self.prev: CacheNode | None = None
|
|
177
|
+
self.next: CacheNode | None = None
|
|
178
178
|
self.timestamp = time.time()
|
|
179
179
|
self.access_count = 1
|
|
180
180
|
|
|
@@ -190,7 +190,7 @@ class UnifiedCache:
|
|
|
190
190
|
|
|
191
191
|
__slots__ = UNIFIED_CACHE_SLOTS
|
|
192
192
|
|
|
193
|
-
def __init__(self, max_size: int = DEFAULT_MAX_SIZE, ttl_seconds:
|
|
193
|
+
def __init__(self, max_size: int = DEFAULT_MAX_SIZE, ttl_seconds: int | None = DEFAULT_TTL_SECONDS) -> None:
|
|
194
194
|
"""Initialize unified cache.
|
|
195
195
|
|
|
196
196
|
Args:
|
|
@@ -208,7 +208,7 @@ class UnifiedCache:
|
|
|
208
208
|
self._head.next = self._tail
|
|
209
209
|
self._tail.prev = self._head
|
|
210
210
|
|
|
211
|
-
def get(self, key: CacheKey) ->
|
|
211
|
+
def get(self, key: CacheKey) -> Any | None:
|
|
212
212
|
"""Get value from cache.
|
|
213
213
|
|
|
214
214
|
Args:
|
|
@@ -275,7 +275,7 @@ class UnifiedCache:
|
|
|
275
275
|
True if key was found and deleted, False otherwise
|
|
276
276
|
"""
|
|
277
277
|
with self._lock:
|
|
278
|
-
node:
|
|
278
|
+
node: CacheNode | None = self._cache.get(key)
|
|
279
279
|
if node is None:
|
|
280
280
|
return False
|
|
281
281
|
|
|
@@ -306,7 +306,7 @@ class UnifiedCache:
|
|
|
306
306
|
def _add_to_head(self, node: CacheNode) -> None:
|
|
307
307
|
"""Add node to head of list."""
|
|
308
308
|
node.prev = self._head
|
|
309
|
-
head_next:
|
|
309
|
+
head_next: CacheNode | None = self._head.next
|
|
310
310
|
node.next = head_next
|
|
311
311
|
if head_next is not None:
|
|
312
312
|
head_next.prev = node
|
|
@@ -314,8 +314,8 @@ class UnifiedCache:
|
|
|
314
314
|
|
|
315
315
|
def _remove_node(self, node: CacheNode) -> None:
|
|
316
316
|
"""Remove node from linked list."""
|
|
317
|
-
node_prev:
|
|
318
|
-
node_next:
|
|
317
|
+
node_prev: CacheNode | None = node.prev
|
|
318
|
+
node_next: CacheNode | None = node.next
|
|
319
319
|
if node_prev is not None:
|
|
320
320
|
node_prev.next = node_next
|
|
321
321
|
if node_next is not None:
|
|
@@ -341,7 +341,7 @@ class UnifiedCache:
|
|
|
341
341
|
return not (ttl is not None and time.time() - node.timestamp > ttl)
|
|
342
342
|
|
|
343
343
|
|
|
344
|
-
_default_cache:
|
|
344
|
+
_default_cache: UnifiedCache | None = None
|
|
345
345
|
_cache_lock = threading.Lock()
|
|
346
346
|
|
|
347
347
|
|
|
@@ -381,7 +381,7 @@ def get_cache_statistics() -> dict[str, CacheStats]:
|
|
|
381
381
|
return stats
|
|
382
382
|
|
|
383
383
|
|
|
384
|
-
_global_cache_config: "
|
|
384
|
+
_global_cache_config: "CacheConfig | None" = None
|
|
385
385
|
|
|
386
386
|
|
|
387
387
|
@mypyc_attr(allow_interpreted_subclasses=False)
|
|
@@ -558,7 +558,7 @@ class CachedStatement:
|
|
|
558
558
|
"""
|
|
559
559
|
|
|
560
560
|
compiled_sql: str
|
|
561
|
-
parameters:
|
|
561
|
+
parameters: tuple[Any, ...] | dict[str, Any] | None # None allowed for static script compilation
|
|
562
562
|
expression: Optional["exp.Expression"]
|
|
563
563
|
|
|
564
564
|
def get_parameters_view(self) -> "ParametersView":
|
|
@@ -572,7 +572,7 @@ class CachedStatement:
|
|
|
572
572
|
return ParametersView(list(self.parameters), {})
|
|
573
573
|
|
|
574
574
|
|
|
575
|
-
def create_cache_key(level: str, key: str, dialect:
|
|
575
|
+
def create_cache_key(level: str, key: str, dialect: str | None = None) -> str:
|
|
576
576
|
"""Create optimized cache key using string concatenation.
|
|
577
577
|
|
|
578
578
|
Args:
|
|
@@ -592,7 +592,7 @@ class MultiLevelCache:
|
|
|
592
592
|
|
|
593
593
|
__slots__ = ("_cache",)
|
|
594
594
|
|
|
595
|
-
def __init__(self, max_size: int = DEFAULT_MAX_SIZE, ttl_seconds:
|
|
595
|
+
def __init__(self, max_size: int = DEFAULT_MAX_SIZE, ttl_seconds: int | None = DEFAULT_TTL_SECONDS) -> None:
|
|
596
596
|
"""Initialize multi-level cache.
|
|
597
597
|
|
|
598
598
|
Args:
|
|
@@ -601,7 +601,7 @@ class MultiLevelCache:
|
|
|
601
601
|
"""
|
|
602
602
|
self._cache = UnifiedCache(max_size, ttl_seconds)
|
|
603
603
|
|
|
604
|
-
def get(self, level: str, key: str, dialect:
|
|
604
|
+
def get(self, level: str, key: str, dialect: str | None = None) -> Any | None:
|
|
605
605
|
"""Get value from cache with level and dialect namespace.
|
|
606
606
|
|
|
607
607
|
Args:
|
|
@@ -616,7 +616,7 @@ class MultiLevelCache:
|
|
|
616
616
|
cache_key = CacheKey((full_key,))
|
|
617
617
|
return self._cache.get(cache_key)
|
|
618
618
|
|
|
619
|
-
def put(self, level: str, key: str, value: Any, dialect:
|
|
619
|
+
def put(self, level: str, key: str, value: Any, dialect: str | None = None) -> None:
|
|
620
620
|
"""Put value in cache with level and dialect namespace.
|
|
621
621
|
|
|
622
622
|
Args:
|
|
@@ -629,7 +629,7 @@ class MultiLevelCache:
|
|
|
629
629
|
cache_key = CacheKey((full_key,))
|
|
630
630
|
self._cache.put(cache_key, value)
|
|
631
631
|
|
|
632
|
-
def delete(self, level: str, key: str, dialect:
|
|
632
|
+
def delete(self, level: str, key: str, dialect: str | None = None) -> bool:
|
|
633
633
|
"""Delete entry from cache.
|
|
634
634
|
|
|
635
635
|
Args:
|
|
@@ -653,7 +653,7 @@ class MultiLevelCache:
|
|
|
653
653
|
return self._cache.get_stats()
|
|
654
654
|
|
|
655
655
|
|
|
656
|
-
_multi_level_cache:
|
|
656
|
+
_multi_level_cache: MultiLevelCache | None = None
|
|
657
657
|
|
|
658
658
|
|
|
659
659
|
def get_cache() -> MultiLevelCache:
|
sqlspec/core/compiler.py
CHANGED
|
@@ -8,16 +8,15 @@ Components:
|
|
|
8
8
|
|
|
9
9
|
import hashlib
|
|
10
10
|
from collections import OrderedDict
|
|
11
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional
|
|
12
12
|
|
|
13
13
|
import sqlglot
|
|
14
14
|
from mypy_extensions import mypyc_attr
|
|
15
15
|
from sqlglot import expressions as exp
|
|
16
16
|
from sqlglot.errors import ParseError
|
|
17
|
-
from typing_extensions import Literal
|
|
18
17
|
|
|
18
|
+
import sqlspec.exceptions
|
|
19
19
|
from sqlspec.core.parameters import ParameterProcessor
|
|
20
|
-
from sqlspec.exceptions import SQLSpecError
|
|
21
20
|
from sqlspec.utils.logging import get_logger
|
|
22
21
|
|
|
23
22
|
if TYPE_CHECKING:
|
|
@@ -72,6 +71,7 @@ class CompiledSQL:
|
|
|
72
71
|
"execution_parameters",
|
|
73
72
|
"expression",
|
|
74
73
|
"operation_type",
|
|
74
|
+
"parameter_casts",
|
|
75
75
|
"parameter_style",
|
|
76
76
|
"supports_many",
|
|
77
77
|
)
|
|
@@ -84,8 +84,9 @@ class CompiledSQL:
|
|
|
84
84
|
execution_parameters: Any,
|
|
85
85
|
operation_type: "OperationType",
|
|
86
86
|
expression: Optional["exp.Expression"] = None,
|
|
87
|
-
parameter_style:
|
|
87
|
+
parameter_style: str | None = None,
|
|
88
88
|
supports_many: bool = False,
|
|
89
|
+
parameter_casts: Optional["dict[int, str]"] = None,
|
|
89
90
|
) -> None:
|
|
90
91
|
"""Initialize compiled result.
|
|
91
92
|
|
|
@@ -96,6 +97,7 @@ class CompiledSQL:
|
|
|
96
97
|
expression: SQLGlot AST expression
|
|
97
98
|
parameter_style: Parameter style used in compilation
|
|
98
99
|
supports_many: Whether this supports execute_many operations
|
|
100
|
+
parameter_casts: Mapping of parameter positions to cast types
|
|
99
101
|
"""
|
|
100
102
|
self.compiled_sql = compiled_sql
|
|
101
103
|
self.execution_parameters = execution_parameters
|
|
@@ -103,7 +105,8 @@ class CompiledSQL:
|
|
|
103
105
|
self.expression = expression
|
|
104
106
|
self.parameter_style = parameter_style
|
|
105
107
|
self.supports_many = supports_many
|
|
106
|
-
self.
|
|
108
|
+
self.parameter_casts = parameter_casts or {}
|
|
109
|
+
self._hash: int | None = None
|
|
107
110
|
|
|
108
111
|
def __hash__(self) -> int:
|
|
109
112
|
"""Cached hash value."""
|
|
@@ -224,11 +227,13 @@ class SQLProcessor:
|
|
|
224
227
|
ast_was_transformed = False
|
|
225
228
|
expression = None
|
|
226
229
|
operation_type: OperationType = "EXECUTE"
|
|
230
|
+
parameter_casts: dict[int, str] = {}
|
|
227
231
|
|
|
228
232
|
if self._config.enable_parsing:
|
|
229
233
|
try:
|
|
230
234
|
expression = sqlglot.parse_one(sqlglot_sql, dialect=dialect_str)
|
|
231
235
|
operation_type = self._detect_operation_type(expression)
|
|
236
|
+
parameter_casts = self._detect_parameter_casts(expression)
|
|
232
237
|
|
|
233
238
|
ast_transformer = self._config.parameter_config.ast_transformer
|
|
234
239
|
if ast_transformer:
|
|
@@ -238,6 +243,7 @@ class SQLProcessor:
|
|
|
238
243
|
except ParseError:
|
|
239
244
|
expression = None
|
|
240
245
|
operation_type = "EXECUTE"
|
|
246
|
+
parameter_casts = {}
|
|
241
247
|
|
|
242
248
|
if self._config.parameter_config.needs_static_script_compilation and processed_params is None:
|
|
243
249
|
final_sql, final_params = processed_sql, processed_params
|
|
@@ -264,14 +270,16 @@ class SQLProcessor:
|
|
|
264
270
|
expression=expression,
|
|
265
271
|
parameter_style=self._config.parameter_config.default_parameter_style.value,
|
|
266
272
|
supports_many=isinstance(final_params, list) and len(final_params) > 0,
|
|
273
|
+
parameter_casts=parameter_casts,
|
|
267
274
|
)
|
|
268
275
|
|
|
269
|
-
except SQLSpecError:
|
|
270
|
-
# Re-raise SQLSpecError (validation errors, parameter mismatches) - these should fail hard
|
|
276
|
+
except sqlspec.exceptions.SQLSpecError:
|
|
271
277
|
raise
|
|
272
278
|
except Exception as e:
|
|
273
279
|
logger.warning("Compilation failed, using fallback: %s", e)
|
|
274
|
-
return CompiledSQL(
|
|
280
|
+
return CompiledSQL(
|
|
281
|
+
compiled_sql=sql, execution_parameters=parameters, operation_type="UNKNOWN", parameter_casts={}
|
|
282
|
+
)
|
|
275
283
|
|
|
276
284
|
def _make_cache_key(self, sql: str, parameters: Any, is_many: bool = False) -> str:
|
|
277
285
|
"""Generate cache key.
|
|
@@ -326,8 +334,53 @@ class SQLProcessor:
|
|
|
326
334
|
|
|
327
335
|
return "UNKNOWN"
|
|
328
336
|
|
|
337
|
+
def _detect_parameter_casts(self, expression: Optional["exp.Expression"]) -> "dict[int, str]":
|
|
338
|
+
"""Detect explicit type casts on parameters in the AST.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
expression: SQLGlot AST expression to analyze
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Dict mapping parameter positions (1-based) to cast type names
|
|
345
|
+
"""
|
|
346
|
+
if not expression:
|
|
347
|
+
return {}
|
|
348
|
+
|
|
349
|
+
cast_positions = {}
|
|
350
|
+
|
|
351
|
+
# Walk all nodes in order to track parameter positions
|
|
352
|
+
for node in expression.walk():
|
|
353
|
+
# Check for cast nodes with parameter children
|
|
354
|
+
if isinstance(node, exp.Cast):
|
|
355
|
+
cast_target = node.this
|
|
356
|
+
position = None
|
|
357
|
+
|
|
358
|
+
if isinstance(cast_target, exp.Parameter):
|
|
359
|
+
# Handle $1, $2 style parameters
|
|
360
|
+
param_value = cast_target.this
|
|
361
|
+
if isinstance(param_value, exp.Literal):
|
|
362
|
+
position = int(param_value.this)
|
|
363
|
+
elif isinstance(cast_target, exp.Placeholder):
|
|
364
|
+
# For ? style, we need to count position (will implement if needed)
|
|
365
|
+
pass
|
|
366
|
+
elif isinstance(cast_target, exp.Column):
|
|
367
|
+
# Handle cases where $1 gets parsed as a column
|
|
368
|
+
column_name = str(cast_target.this) if cast_target.this else str(cast_target)
|
|
369
|
+
if column_name.startswith("$") and column_name[1:].isdigit():
|
|
370
|
+
position = int(column_name[1:])
|
|
371
|
+
|
|
372
|
+
if position is not None:
|
|
373
|
+
# Extract cast type
|
|
374
|
+
if isinstance(node.to, exp.DataType):
|
|
375
|
+
cast_type = node.to.this.value if hasattr(node.to.this, "value") else str(node.to.this)
|
|
376
|
+
else:
|
|
377
|
+
cast_type = str(node.to)
|
|
378
|
+
cast_positions[position] = cast_type.upper()
|
|
379
|
+
|
|
380
|
+
return cast_positions
|
|
381
|
+
|
|
329
382
|
def _apply_final_transformations(
|
|
330
|
-
self, expression: "
|
|
383
|
+
self, expression: "exp.Expression | None", sql: str, parameters: Any, dialect_str: "str | None"
|
|
331
384
|
) -> "tuple[str, Any]":
|
|
332
385
|
"""Apply final transformations.
|
|
333
386
|
|
sqlspec/core/filters.py
CHANGED
|
@@ -23,11 +23,11 @@ from abc import ABC, abstractmethod
|
|
|
23
23
|
from collections import abc
|
|
24
24
|
from collections.abc import Sequence
|
|
25
25
|
from datetime import datetime
|
|
26
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal,
|
|
26
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias
|
|
27
27
|
|
|
28
28
|
import sqlglot
|
|
29
29
|
from sqlglot import exp
|
|
30
|
-
from typing_extensions import
|
|
30
|
+
from typing_extensions import TypeVar
|
|
31
31
|
|
|
32
32
|
if TYPE_CHECKING:
|
|
33
33
|
from sqlglot.expressions import Condition
|
|
@@ -125,7 +125,7 @@ class BeforeAfterFilter(StatementFilter):
|
|
|
125
125
|
|
|
126
126
|
__slots__ = ("_after", "_before", "_field_name")
|
|
127
127
|
|
|
128
|
-
def __init__(self, field_name: str, before:
|
|
128
|
+
def __init__(self, field_name: str, before: datetime | None = None, after: datetime | None = None) -> None:
|
|
129
129
|
self._field_name = field_name
|
|
130
130
|
self._before = before
|
|
131
131
|
self._after = after
|
|
@@ -135,11 +135,11 @@ class BeforeAfterFilter(StatementFilter):
|
|
|
135
135
|
return self._field_name
|
|
136
136
|
|
|
137
137
|
@property
|
|
138
|
-
def before(self) ->
|
|
138
|
+
def before(self) -> datetime | None:
|
|
139
139
|
return self._before
|
|
140
140
|
|
|
141
141
|
@property
|
|
142
|
-
def after(self) ->
|
|
142
|
+
def after(self) -> datetime | None:
|
|
143
143
|
return self._after
|
|
144
144
|
|
|
145
145
|
def get_param_names(self) -> list[str]:
|
|
@@ -206,7 +206,7 @@ class OnBeforeAfterFilter(StatementFilter):
|
|
|
206
206
|
__slots__ = ("_field_name", "_on_or_after", "_on_or_before")
|
|
207
207
|
|
|
208
208
|
def __init__(
|
|
209
|
-
self, field_name: str, on_or_before:
|
|
209
|
+
self, field_name: str, on_or_before: datetime | None = None, on_or_after: datetime | None = None
|
|
210
210
|
) -> None:
|
|
211
211
|
self._field_name = field_name
|
|
212
212
|
self._on_or_before = on_or_before
|
|
@@ -217,11 +217,11 @@ class OnBeforeAfterFilter(StatementFilter):
|
|
|
217
217
|
return self._field_name
|
|
218
218
|
|
|
219
219
|
@property
|
|
220
|
-
def on_or_before(self) ->
|
|
220
|
+
def on_or_before(self) -> datetime | None:
|
|
221
221
|
return self._on_or_before
|
|
222
222
|
|
|
223
223
|
@property
|
|
224
|
-
def on_or_after(self) ->
|
|
224
|
+
def on_or_after(self) -> datetime | None:
|
|
225
225
|
return self._on_or_after
|
|
226
226
|
|
|
227
227
|
def get_param_names(self) -> list[str]:
|
|
@@ -298,7 +298,7 @@ class InCollectionFilter(InAnyFilter[T]):
|
|
|
298
298
|
|
|
299
299
|
__slots__ = ("_field_name", "_values")
|
|
300
300
|
|
|
301
|
-
def __init__(self, field_name: str, values:
|
|
301
|
+
def __init__(self, field_name: str, values: abc.Collection[T] | None = None) -> None:
|
|
302
302
|
self._field_name = field_name
|
|
303
303
|
self._values = values
|
|
304
304
|
|
|
@@ -307,7 +307,7 @@ class InCollectionFilter(InAnyFilter[T]):
|
|
|
307
307
|
return self._field_name
|
|
308
308
|
|
|
309
309
|
@property
|
|
310
|
-
def values(self) ->
|
|
310
|
+
def values(self) -> abc.Collection[T] | None:
|
|
311
311
|
return self._values
|
|
312
312
|
|
|
313
313
|
def get_param_names(self) -> list[str]:
|
|
@@ -340,7 +340,7 @@ class InCollectionFilter(InAnyFilter[T]):
|
|
|
340
340
|
|
|
341
341
|
result = statement.where(exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions))
|
|
342
342
|
|
|
343
|
-
for resolved_name, value in zip(resolved_names, self.values):
|
|
343
|
+
for resolved_name, value in zip(resolved_names, self.values, strict=False):
|
|
344
344
|
result = result.add_named_parameter(resolved_name, value)
|
|
345
345
|
return result
|
|
346
346
|
|
|
@@ -358,7 +358,7 @@ class NotInCollectionFilter(InAnyFilter[T]):
|
|
|
358
358
|
|
|
359
359
|
__slots__ = ("_field_name", "_values")
|
|
360
360
|
|
|
361
|
-
def __init__(self, field_name: str, values:
|
|
361
|
+
def __init__(self, field_name: str, values: abc.Collection[T] | None = None) -> None:
|
|
362
362
|
self._field_name = field_name
|
|
363
363
|
self._values = values
|
|
364
364
|
|
|
@@ -367,7 +367,7 @@ class NotInCollectionFilter(InAnyFilter[T]):
|
|
|
367
367
|
return self._field_name
|
|
368
368
|
|
|
369
369
|
@property
|
|
370
|
-
def values(self) ->
|
|
370
|
+
def values(self) -> abc.Collection[T] | None:
|
|
371
371
|
return self._values
|
|
372
372
|
|
|
373
373
|
def get_param_names(self) -> list[str]:
|
|
@@ -400,7 +400,7 @@ class NotInCollectionFilter(InAnyFilter[T]):
|
|
|
400
400
|
exp.Not(this=exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions))
|
|
401
401
|
)
|
|
402
402
|
|
|
403
|
-
for resolved_name, value in zip(resolved_names, self.values):
|
|
403
|
+
for resolved_name, value in zip(resolved_names, self.values, strict=False):
|
|
404
404
|
result = result.add_named_parameter(resolved_name, value)
|
|
405
405
|
return result
|
|
406
406
|
|
|
@@ -418,7 +418,7 @@ class AnyCollectionFilter(InAnyFilter[T]):
|
|
|
418
418
|
|
|
419
419
|
__slots__ = ("_field_name", "_values")
|
|
420
420
|
|
|
421
|
-
def __init__(self, field_name: str, values:
|
|
421
|
+
def __init__(self, field_name: str, values: abc.Collection[T] | None = None) -> None:
|
|
422
422
|
self._field_name = field_name
|
|
423
423
|
self._values = values
|
|
424
424
|
|
|
@@ -427,7 +427,7 @@ class AnyCollectionFilter(InAnyFilter[T]):
|
|
|
427
427
|
return self._field_name
|
|
428
428
|
|
|
429
429
|
@property
|
|
430
|
-
def values(self) ->
|
|
430
|
+
def values(self) -> abc.Collection[T] | None:
|
|
431
431
|
return self._values
|
|
432
432
|
|
|
433
433
|
def get_param_names(self) -> list[str]:
|
|
@@ -461,7 +461,7 @@ class AnyCollectionFilter(InAnyFilter[T]):
|
|
|
461
461
|
array_expr = exp.Array(expressions=placeholder_expressions)
|
|
462
462
|
result = statement.where(exp.EQ(this=exp.column(self.field_name), expression=exp.Any(this=array_expr)))
|
|
463
463
|
|
|
464
|
-
for resolved_name, value in zip(resolved_names, self.values):
|
|
464
|
+
for resolved_name, value in zip(resolved_names, self.values, strict=False):
|
|
465
465
|
result = result.add_named_parameter(resolved_name, value)
|
|
466
466
|
return result
|
|
467
467
|
|
|
@@ -479,7 +479,7 @@ class NotAnyCollectionFilter(InAnyFilter[T]):
|
|
|
479
479
|
|
|
480
480
|
__slots__ = ("_field_name", "_values")
|
|
481
481
|
|
|
482
|
-
def __init__(self, field_name: str, values:
|
|
482
|
+
def __init__(self, field_name: str, values: abc.Collection[T] | None = None) -> None:
|
|
483
483
|
self._field_name = field_name
|
|
484
484
|
self._values = values
|
|
485
485
|
|
|
@@ -488,7 +488,7 @@ class NotAnyCollectionFilter(InAnyFilter[T]):
|
|
|
488
488
|
return self._field_name
|
|
489
489
|
|
|
490
490
|
@property
|
|
491
|
-
def values(self) ->
|
|
491
|
+
def values(self) -> abc.Collection[T] | None:
|
|
492
492
|
return self._values
|
|
493
493
|
|
|
494
494
|
def get_param_names(self) -> list[str]:
|
|
@@ -520,7 +520,7 @@ class NotAnyCollectionFilter(InAnyFilter[T]):
|
|
|
520
520
|
condition = exp.EQ(this=exp.column(self.field_name), expression=exp.Any(this=array_expr))
|
|
521
521
|
result = statement.where(exp.Not(this=condition))
|
|
522
522
|
|
|
523
|
-
for resolved_name, value in zip(resolved_names, self.values):
|
|
523
|
+
for resolved_name, value in zip(resolved_names, self.values, strict=False):
|
|
524
524
|
result = result.add_named_parameter(resolved_name, value)
|
|
525
525
|
return result
|
|
526
526
|
|
|
@@ -650,13 +650,13 @@ class SearchFilter(StatementFilter):
|
|
|
650
650
|
|
|
651
651
|
__slots__ = ("_field_name", "_ignore_case", "_value")
|
|
652
652
|
|
|
653
|
-
def __init__(self, field_name:
|
|
653
|
+
def __init__(self, field_name: str | set[str], value: str, ignore_case: bool | None = False) -> None:
|
|
654
654
|
self._field_name = field_name
|
|
655
655
|
self._value = value
|
|
656
656
|
self._ignore_case = ignore_case
|
|
657
657
|
|
|
658
658
|
@property
|
|
659
|
-
def field_name(self) ->
|
|
659
|
+
def field_name(self) -> str | set[str]:
|
|
660
660
|
return self._field_name
|
|
661
661
|
|
|
662
662
|
@property
|
|
@@ -664,10 +664,10 @@ class SearchFilter(StatementFilter):
|
|
|
664
664
|
return self._value
|
|
665
665
|
|
|
666
666
|
@property
|
|
667
|
-
def ignore_case(self) ->
|
|
667
|
+
def ignore_case(self) -> bool | None:
|
|
668
668
|
return self._ignore_case
|
|
669
669
|
|
|
670
|
-
def get_param_name(self) ->
|
|
670
|
+
def get_param_name(self) -> str | None:
|
|
671
671
|
"""Get parameter name without storing it."""
|
|
672
672
|
if not self.value:
|
|
673
673
|
return None
|
|
@@ -726,7 +726,7 @@ class NotInSearchFilter(SearchFilter):
|
|
|
726
726
|
Constructs WHERE field_name NOT LIKE '%value%' clauses.
|
|
727
727
|
"""
|
|
728
728
|
|
|
729
|
-
def get_param_name(self) ->
|
|
729
|
+
def get_param_name(self) -> str | None:
|
|
730
730
|
"""Get parameter name without storing it."""
|
|
731
731
|
if not self.value:
|
|
732
732
|
return None
|
|
@@ -817,18 +817,18 @@ def apply_filter(statement: "SQL", filter_obj: StatementFilter) -> "SQL":
|
|
|
817
817
|
return filter_obj.append_to_statement(statement)
|
|
818
818
|
|
|
819
819
|
|
|
820
|
-
FilterTypes: TypeAlias =
|
|
821
|
-
BeforeAfterFilter
|
|
822
|
-
OnBeforeAfterFilter
|
|
823
|
-
InCollectionFilter[Any]
|
|
824
|
-
LimitOffsetFilter
|
|
825
|
-
OrderByFilter
|
|
826
|
-
SearchFilter
|
|
827
|
-
NotInCollectionFilter[Any]
|
|
828
|
-
NotInSearchFilter
|
|
829
|
-
AnyCollectionFilter[Any]
|
|
830
|
-
NotAnyCollectionFilter[Any]
|
|
831
|
-
|
|
820
|
+
FilterTypes: TypeAlias = (
|
|
821
|
+
BeforeAfterFilter
|
|
822
|
+
| OnBeforeAfterFilter
|
|
823
|
+
| InCollectionFilter[Any]
|
|
824
|
+
| LimitOffsetFilter
|
|
825
|
+
| OrderByFilter
|
|
826
|
+
| SearchFilter
|
|
827
|
+
| NotInCollectionFilter[Any]
|
|
828
|
+
| NotInSearchFilter
|
|
829
|
+
| AnyCollectionFilter[Any]
|
|
830
|
+
| NotAnyCollectionFilter[Any]
|
|
831
|
+
)
|
|
832
832
|
|
|
833
833
|
|
|
834
834
|
def create_filters(filters: "list[StatementFilter]") -> tuple["StatementFilter", ...]:
|
sqlspec/core/hashing.py
CHANGED
|
@@ -4,7 +4,7 @@ Provides hashing functions for SQL statements, expressions, parameters,
|
|
|
4
4
|
filters, and AST sub-expressions.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import TYPE_CHECKING, Any
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
8
|
|
|
9
9
|
from sqlglot import exp
|
|
10
10
|
|
|
@@ -23,7 +23,7 @@ __all__ = (
|
|
|
23
23
|
)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def hash_expression(expr:
|
|
26
|
+
def hash_expression(expr: exp.Expression | None, _seen: set[int] | None = None) -> int:
|
|
27
27
|
"""Generate hash from AST structure.
|
|
28
28
|
|
|
29
29
|
Args:
|
|
@@ -77,9 +77,9 @@ def _hash_value(value: Any, _seen: set[int]) -> int:
|
|
|
77
77
|
|
|
78
78
|
|
|
79
79
|
def hash_parameters(
|
|
80
|
-
positional_parameters:
|
|
81
|
-
named_parameters:
|
|
82
|
-
original_parameters:
|
|
80
|
+
positional_parameters: list[Any] | None = None,
|
|
81
|
+
named_parameters: dict[str, Any] | None = None,
|
|
82
|
+
original_parameters: Any | None = None,
|
|
83
83
|
) -> int:
|
|
84
84
|
"""Generate hash for SQL parameters.
|
|
85
85
|
|
|
@@ -148,7 +148,7 @@ def _hash_filter_value(value: Any) -> int:
|
|
|
148
148
|
return hash(repr(value))
|
|
149
149
|
|
|
150
150
|
|
|
151
|
-
def hash_filters(filters:
|
|
151
|
+
def hash_filters(filters: list["StatementFilter"] | None = None) -> int:
|
|
152
152
|
"""Generate hash for statement filters.
|
|
153
153
|
|
|
154
154
|
Args:
|
|
@@ -208,7 +208,7 @@ def hash_sql_statement(statement: "SQL") -> str:
|
|
|
208
208
|
return f"sql:{hash(tuple(state_components))}"
|
|
209
209
|
|
|
210
210
|
|
|
211
|
-
def hash_expression_node(node: exp.Expression, include_children: bool = True, dialect:
|
|
211
|
+
def hash_expression_node(node: exp.Expression, include_children: bool = True, dialect: str | None = None) -> str:
|
|
212
212
|
"""Generate cache key for an expression node.
|
|
213
213
|
|
|
214
214
|
Args:
|
|
@@ -235,8 +235,8 @@ def hash_expression_node(node: exp.Expression, include_children: bool = True, di
|
|
|
235
235
|
def hash_optimized_expression(
|
|
236
236
|
expr: exp.Expression,
|
|
237
237
|
dialect: str,
|
|
238
|
-
schema:
|
|
239
|
-
optimizer_settings:
|
|
238
|
+
schema: dict[str, Any] | None = None,
|
|
239
|
+
optimizer_settings: dict[str, Any] | None = None,
|
|
240
240
|
) -> str:
|
|
241
241
|
"""Generate cache key for optimized expressions.
|
|
242
242
|
|