sqlspec 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_serialization.py +223 -21
- sqlspec/_sql.py +12 -50
- sqlspec/_typing.py +9 -0
- sqlspec/adapters/adbc/config.py +8 -1
- sqlspec/adapters/adbc/data_dictionary.py +290 -0
- sqlspec/adapters/adbc/driver.py +127 -18
- sqlspec/adapters/adbc/type_converter.py +159 -0
- sqlspec/adapters/aiosqlite/config.py +3 -0
- sqlspec/adapters/aiosqlite/data_dictionary.py +117 -0
- sqlspec/adapters/aiosqlite/driver.py +17 -3
- sqlspec/adapters/asyncmy/_types.py +1 -1
- sqlspec/adapters/asyncmy/config.py +11 -8
- sqlspec/adapters/asyncmy/data_dictionary.py +122 -0
- sqlspec/adapters/asyncmy/driver.py +31 -7
- sqlspec/adapters/asyncpg/config.py +3 -0
- sqlspec/adapters/asyncpg/data_dictionary.py +134 -0
- sqlspec/adapters/asyncpg/driver.py +19 -4
- sqlspec/adapters/bigquery/config.py +3 -0
- sqlspec/adapters/bigquery/data_dictionary.py +109 -0
- sqlspec/adapters/bigquery/driver.py +21 -3
- sqlspec/adapters/bigquery/type_converter.py +93 -0
- sqlspec/adapters/duckdb/_types.py +1 -1
- sqlspec/adapters/duckdb/config.py +2 -0
- sqlspec/adapters/duckdb/data_dictionary.py +124 -0
- sqlspec/adapters/duckdb/driver.py +32 -5
- sqlspec/adapters/duckdb/pool.py +1 -1
- sqlspec/adapters/duckdb/type_converter.py +103 -0
- sqlspec/adapters/oracledb/config.py +6 -0
- sqlspec/adapters/oracledb/data_dictionary.py +442 -0
- sqlspec/adapters/oracledb/driver.py +63 -9
- sqlspec/adapters/oracledb/migrations.py +51 -67
- sqlspec/adapters/oracledb/type_converter.py +132 -0
- sqlspec/adapters/psqlpy/config.py +3 -0
- sqlspec/adapters/psqlpy/data_dictionary.py +133 -0
- sqlspec/adapters/psqlpy/driver.py +23 -179
- sqlspec/adapters/psqlpy/type_converter.py +73 -0
- sqlspec/adapters/psycopg/config.py +6 -0
- sqlspec/adapters/psycopg/data_dictionary.py +257 -0
- sqlspec/adapters/psycopg/driver.py +40 -5
- sqlspec/adapters/sqlite/config.py +3 -0
- sqlspec/adapters/sqlite/data_dictionary.py +117 -0
- sqlspec/adapters/sqlite/driver.py +18 -3
- sqlspec/adapters/sqlite/pool.py +13 -4
- sqlspec/builder/_base.py +82 -42
- sqlspec/builder/_column.py +57 -24
- sqlspec/builder/_ddl.py +84 -34
- sqlspec/builder/_insert.py +30 -52
- sqlspec/builder/_parsing_utils.py +104 -8
- sqlspec/builder/_select.py +147 -2
- sqlspec/builder/mixins/_cte_and_set_ops.py +1 -2
- sqlspec/builder/mixins/_join_operations.py +14 -30
- sqlspec/builder/mixins/_merge_operations.py +167 -61
- sqlspec/builder/mixins/_order_limit_operations.py +3 -10
- sqlspec/builder/mixins/_select_operations.py +3 -9
- sqlspec/builder/mixins/_update_operations.py +3 -22
- sqlspec/builder/mixins/_where_clause.py +4 -10
- sqlspec/cli.py +246 -140
- sqlspec/config.py +33 -19
- sqlspec/core/cache.py +2 -2
- sqlspec/core/compiler.py +56 -1
- sqlspec/core/parameters.py +7 -3
- sqlspec/core/statement.py +5 -0
- sqlspec/core/type_conversion.py +234 -0
- sqlspec/driver/__init__.py +6 -3
- sqlspec/driver/_async.py +106 -3
- sqlspec/driver/_common.py +156 -4
- sqlspec/driver/_sync.py +106 -3
- sqlspec/exceptions.py +5 -0
- sqlspec/migrations/__init__.py +4 -3
- sqlspec/migrations/base.py +153 -14
- sqlspec/migrations/commands.py +34 -96
- sqlspec/migrations/context.py +145 -0
- sqlspec/migrations/loaders.py +25 -8
- sqlspec/migrations/runner.py +352 -82
- sqlspec/typing.py +2 -0
- sqlspec/utils/config_resolver.py +153 -0
- sqlspec/utils/serializers.py +50 -2
- {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/METADATA +1 -1
- sqlspec-0.26.0.dist-info/RECORD +157 -0
- sqlspec-0.25.0.dist-info/RECORD +0 -139
- {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/config.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, TypeVar, Union, cast
|
|
3
4
|
|
|
4
5
|
from typing_extensions import NotRequired, TypedDict
|
|
5
6
|
|
|
@@ -11,11 +12,10 @@ from sqlspec.utils.logging import get_logger
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
12
13
|
from collections.abc import Awaitable
|
|
13
14
|
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
|
14
|
-
from pathlib import Path
|
|
15
15
|
|
|
16
16
|
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
|
|
17
17
|
from sqlspec.loader import SQLFileLoader
|
|
18
|
-
from sqlspec.migrations.commands import
|
|
18
|
+
from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
__all__ = (
|
|
@@ -89,6 +89,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
89
89
|
__slots__ = (
|
|
90
90
|
"_migration_commands",
|
|
91
91
|
"_migration_loader",
|
|
92
|
+
"bind_key",
|
|
92
93
|
"driver_features",
|
|
93
94
|
"migration_config",
|
|
94
95
|
"pool_instance",
|
|
@@ -96,7 +97,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
96
97
|
)
|
|
97
98
|
|
|
98
99
|
_migration_loader: "SQLFileLoader"
|
|
99
|
-
_migration_commands: "
|
|
100
|
+
_migration_commands: "Union[SyncMigrationCommands, AsyncMigrationCommands]"
|
|
100
101
|
driver_type: "ClassVar[type[Any]]"
|
|
101
102
|
connection_type: "ClassVar[type[Any]]"
|
|
102
103
|
is_async: "ClassVar[bool]" = False
|
|
@@ -105,6 +106,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
105
106
|
supports_native_arrow_export: "ClassVar[bool]" = False
|
|
106
107
|
supports_native_parquet_import: "ClassVar[bool]" = False
|
|
107
108
|
supports_native_parquet_export: "ClassVar[bool]" = False
|
|
109
|
+
bind_key: "Optional[str]"
|
|
108
110
|
statement_config: "StatementConfig"
|
|
109
111
|
pool_instance: "Optional[PoolT]"
|
|
110
112
|
migration_config: "Union[dict[str, Any], MigrationConfig]"
|
|
@@ -176,10 +178,10 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
176
178
|
at runtime when needed.
|
|
177
179
|
"""
|
|
178
180
|
from sqlspec.loader import SQLFileLoader
|
|
179
|
-
from sqlspec.migrations.commands import
|
|
181
|
+
from sqlspec.migrations.commands import create_migration_commands
|
|
180
182
|
|
|
181
183
|
self._migration_loader = SQLFileLoader()
|
|
182
|
-
self._migration_commands =
|
|
184
|
+
self._migration_commands = create_migration_commands(self) # type: ignore[arg-type]
|
|
183
185
|
|
|
184
186
|
def _ensure_migration_loader(self) -> "SQLFileLoader":
|
|
185
187
|
"""Get the migration SQL loader and auto-load files if needed.
|
|
@@ -200,7 +202,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
200
202
|
|
|
201
203
|
return self._migration_loader
|
|
202
204
|
|
|
203
|
-
def _ensure_migration_commands(self) -> "
|
|
205
|
+
def _ensure_migration_commands(self) -> "Union[SyncMigrationCommands, AsyncMigrationCommands]":
|
|
204
206
|
"""Get the migration commands instance.
|
|
205
207
|
|
|
206
208
|
Returns:
|
|
@@ -225,7 +227,6 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
225
227
|
Args:
|
|
226
228
|
*paths: One or more file paths or directory paths to load migration SQL files from.
|
|
227
229
|
"""
|
|
228
|
-
from pathlib import Path
|
|
229
230
|
|
|
230
231
|
loader = self._ensure_migration_loader()
|
|
231
232
|
for path in paths:
|
|
@@ -236,7 +237,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
236
237
|
else:
|
|
237
238
|
logger.warning("Migration path does not exist: %s", path_obj)
|
|
238
239
|
|
|
239
|
-
def get_migration_commands(self) -> "
|
|
240
|
+
def get_migration_commands(self) -> "Union[SyncMigrationCommands, AsyncMigrationCommands]":
|
|
240
241
|
"""Get migration commands for this configuration.
|
|
241
242
|
|
|
242
243
|
Returns:
|
|
@@ -244,25 +245,27 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
244
245
|
"""
|
|
245
246
|
return self._ensure_migration_commands()
|
|
246
247
|
|
|
247
|
-
def migrate_up(self, revision: str = "head") -> None:
|
|
248
|
+
async def migrate_up(self, revision: str = "head") -> None:
|
|
248
249
|
"""Apply migrations up to the specified revision.
|
|
249
250
|
|
|
250
251
|
Args:
|
|
251
252
|
revision: Target revision or "head" for latest. Defaults to "head".
|
|
252
253
|
"""
|
|
253
254
|
commands = self._ensure_migration_commands()
|
|
254
|
-
commands.upgrade(revision)
|
|
255
255
|
|
|
256
|
-
|
|
256
|
+
await cast("AsyncMigrationCommands", commands).upgrade(revision)
|
|
257
|
+
|
|
258
|
+
async def migrate_down(self, revision: str = "-1") -> None:
|
|
257
259
|
"""Apply migrations down to the specified revision.
|
|
258
260
|
|
|
259
261
|
Args:
|
|
260
262
|
revision: Target revision, "-1" for one step back, or "base" for all migrations. Defaults to "-1".
|
|
261
263
|
"""
|
|
262
264
|
commands = self._ensure_migration_commands()
|
|
263
|
-
commands.downgrade(revision)
|
|
264
265
|
|
|
265
|
-
|
|
266
|
+
await cast("AsyncMigrationCommands", commands).downgrade(revision)
|
|
267
|
+
|
|
268
|
+
async def get_current_migration(self, verbose: bool = False) -> "Optional[str]":
|
|
266
269
|
"""Get the current migration version.
|
|
267
270
|
|
|
268
271
|
Args:
|
|
@@ -272,9 +275,10 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
272
275
|
Current migration version or None if no migrations applied.
|
|
273
276
|
"""
|
|
274
277
|
commands = self._ensure_migration_commands()
|
|
275
|
-
return commands.current(verbose=verbose)
|
|
276
278
|
|
|
277
|
-
|
|
279
|
+
return await cast("AsyncMigrationCommands", commands).current(verbose=verbose)
|
|
280
|
+
|
|
281
|
+
async def create_migration(self, message: str, file_type: str = "sql") -> None:
|
|
278
282
|
"""Create a new migration file.
|
|
279
283
|
|
|
280
284
|
Args:
|
|
@@ -282,9 +286,10 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
282
286
|
file_type: Type of migration file to create ('sql' or 'py'). Defaults to 'sql'.
|
|
283
287
|
"""
|
|
284
288
|
commands = self._ensure_migration_commands()
|
|
285
|
-
commands.revision(message, file_type)
|
|
286
289
|
|
|
287
|
-
|
|
290
|
+
await cast("AsyncMigrationCommands", commands).revision(message, file_type)
|
|
291
|
+
|
|
292
|
+
async def init_migrations(self, directory: "Optional[str]" = None, package: bool = True) -> None:
|
|
288
293
|
"""Initialize migration directory structure.
|
|
289
294
|
|
|
290
295
|
Args:
|
|
@@ -297,7 +302,8 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
297
302
|
|
|
298
303
|
commands = self._ensure_migration_commands()
|
|
299
304
|
assert directory is not None
|
|
300
|
-
|
|
305
|
+
|
|
306
|
+
await cast("AsyncMigrationCommands", commands).init(directory, package)
|
|
301
307
|
|
|
302
308
|
|
|
303
309
|
class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
|
|
@@ -315,7 +321,9 @@ class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
|
|
|
315
321
|
migration_config: "Optional[Union[dict[str, Any], MigrationConfig]]" = None,
|
|
316
322
|
statement_config: "Optional[StatementConfig]" = None,
|
|
317
323
|
driver_features: "Optional[dict[str, Any]]" = None,
|
|
324
|
+
bind_key: "Optional[str]" = None,
|
|
318
325
|
) -> None:
|
|
326
|
+
self.bind_key = bind_key
|
|
319
327
|
self.pool_instance = None
|
|
320
328
|
self.connection_config = connection_config or {}
|
|
321
329
|
self.migration_config: Union[dict[str, Any], MigrationConfig] = migration_config or {}
|
|
@@ -369,7 +377,9 @@ class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
|
|
|
369
377
|
migration_config: "Optional[Union[dict[str, Any], MigrationConfig]]" = None,
|
|
370
378
|
statement_config: "Optional[StatementConfig]" = None,
|
|
371
379
|
driver_features: "Optional[dict[str, Any]]" = None,
|
|
380
|
+
bind_key: "Optional[str]" = None,
|
|
372
381
|
) -> None:
|
|
382
|
+
self.bind_key = bind_key
|
|
373
383
|
self.pool_instance = None
|
|
374
384
|
self.connection_config = connection_config or {}
|
|
375
385
|
self.migration_config: Union[dict[str, Any], MigrationConfig] = migration_config or {}
|
|
@@ -424,7 +434,9 @@ class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
|
|
|
424
434
|
migration_config: "Optional[Union[dict[str, Any], MigrationConfig]]" = None,
|
|
425
435
|
statement_config: "Optional[StatementConfig]" = None,
|
|
426
436
|
driver_features: "Optional[dict[str, Any]]" = None,
|
|
437
|
+
bind_key: "Optional[str]" = None,
|
|
427
438
|
) -> None:
|
|
439
|
+
self.bind_key = bind_key
|
|
428
440
|
self.pool_instance = pool_instance
|
|
429
441
|
self.pool_config = pool_config or {}
|
|
430
442
|
self.migration_config: Union[dict[str, Any], MigrationConfig] = migration_config or {}
|
|
@@ -501,7 +513,9 @@ class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
|
|
|
501
513
|
migration_config: "Optional[Union[dict[str, Any], MigrationConfig]]" = None,
|
|
502
514
|
statement_config: "Optional[StatementConfig]" = None,
|
|
503
515
|
driver_features: "Optional[dict[str, Any]]" = None,
|
|
516
|
+
bind_key: "Optional[str]" = None,
|
|
504
517
|
) -> None:
|
|
518
|
+
self.bind_key = bind_key
|
|
505
519
|
self.pool_instance = pool_instance
|
|
506
520
|
self.pool_config = pool_config or {}
|
|
507
521
|
self.migration_config: Union[dict[str, Any], MigrationConfig] = migration_config or {}
|
sqlspec/core/cache.py
CHANGED
|
@@ -14,7 +14,7 @@ Components:
|
|
|
14
14
|
import threading
|
|
15
15
|
import time
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
from typing import TYPE_CHECKING, Any, Final, Optional
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Final, Optional, Union
|
|
18
18
|
|
|
19
19
|
from mypy_extensions import mypyc_attr
|
|
20
20
|
from typing_extensions import TypeVar
|
|
@@ -558,7 +558,7 @@ class CachedStatement:
|
|
|
558
558
|
"""
|
|
559
559
|
|
|
560
560
|
compiled_sql: str
|
|
561
|
-
parameters: Optional[tuple[Any, ...]] # None allowed for static script compilation
|
|
561
|
+
parameters: Optional[Union[tuple[Any, ...], dict[str, Any]]] # None allowed for static script compilation
|
|
562
562
|
expression: Optional["exp.Expression"]
|
|
563
563
|
|
|
564
564
|
def get_parameters_view(self) -> "ParametersView":
|
sqlspec/core/compiler.py
CHANGED
|
@@ -72,6 +72,7 @@ class CompiledSQL:
|
|
|
72
72
|
"execution_parameters",
|
|
73
73
|
"expression",
|
|
74
74
|
"operation_type",
|
|
75
|
+
"parameter_casts",
|
|
75
76
|
"parameter_style",
|
|
76
77
|
"supports_many",
|
|
77
78
|
)
|
|
@@ -86,6 +87,7 @@ class CompiledSQL:
|
|
|
86
87
|
expression: Optional["exp.Expression"] = None,
|
|
87
88
|
parameter_style: Optional[str] = None,
|
|
88
89
|
supports_many: bool = False,
|
|
90
|
+
parameter_casts: Optional["dict[int, str]"] = None,
|
|
89
91
|
) -> None:
|
|
90
92
|
"""Initialize compiled result.
|
|
91
93
|
|
|
@@ -96,6 +98,7 @@ class CompiledSQL:
|
|
|
96
98
|
expression: SQLGlot AST expression
|
|
97
99
|
parameter_style: Parameter style used in compilation
|
|
98
100
|
supports_many: Whether this supports execute_many operations
|
|
101
|
+
parameter_casts: Mapping of parameter positions to cast types
|
|
99
102
|
"""
|
|
100
103
|
self.compiled_sql = compiled_sql
|
|
101
104
|
self.execution_parameters = execution_parameters
|
|
@@ -103,6 +106,7 @@ class CompiledSQL:
|
|
|
103
106
|
self.expression = expression
|
|
104
107
|
self.parameter_style = parameter_style
|
|
105
108
|
self.supports_many = supports_many
|
|
109
|
+
self.parameter_casts = parameter_casts or {}
|
|
106
110
|
self._hash: Optional[int] = None
|
|
107
111
|
|
|
108
112
|
def __hash__(self) -> int:
|
|
@@ -224,11 +228,13 @@ class SQLProcessor:
|
|
|
224
228
|
ast_was_transformed = False
|
|
225
229
|
expression = None
|
|
226
230
|
operation_type: OperationType = "EXECUTE"
|
|
231
|
+
parameter_casts: dict[int, str] = {}
|
|
227
232
|
|
|
228
233
|
if self._config.enable_parsing:
|
|
229
234
|
try:
|
|
230
235
|
expression = sqlglot.parse_one(sqlglot_sql, dialect=dialect_str)
|
|
231
236
|
operation_type = self._detect_operation_type(expression)
|
|
237
|
+
parameter_casts = self._detect_parameter_casts(expression)
|
|
232
238
|
|
|
233
239
|
ast_transformer = self._config.parameter_config.ast_transformer
|
|
234
240
|
if ast_transformer:
|
|
@@ -238,6 +244,7 @@ class SQLProcessor:
|
|
|
238
244
|
except ParseError:
|
|
239
245
|
expression = None
|
|
240
246
|
operation_type = "EXECUTE"
|
|
247
|
+
parameter_casts = {}
|
|
241
248
|
|
|
242
249
|
if self._config.parameter_config.needs_static_script_compilation and processed_params is None:
|
|
243
250
|
final_sql, final_params = processed_sql, processed_params
|
|
@@ -264,6 +271,7 @@ class SQLProcessor:
|
|
|
264
271
|
expression=expression,
|
|
265
272
|
parameter_style=self._config.parameter_config.default_parameter_style.value,
|
|
266
273
|
supports_many=isinstance(final_params, list) and len(final_params) > 0,
|
|
274
|
+
parameter_casts=parameter_casts,
|
|
267
275
|
)
|
|
268
276
|
|
|
269
277
|
except SQLSpecError:
|
|
@@ -271,7 +279,9 @@ class SQLProcessor:
|
|
|
271
279
|
raise
|
|
272
280
|
except Exception as e:
|
|
273
281
|
logger.warning("Compilation failed, using fallback: %s", e)
|
|
274
|
-
return CompiledSQL(
|
|
282
|
+
return CompiledSQL(
|
|
283
|
+
compiled_sql=sql, execution_parameters=parameters, operation_type="UNKNOWN", parameter_casts={}
|
|
284
|
+
)
|
|
275
285
|
|
|
276
286
|
def _make_cache_key(self, sql: str, parameters: Any, is_many: bool = False) -> str:
|
|
277
287
|
"""Generate cache key.
|
|
@@ -326,6 +336,51 @@ class SQLProcessor:
|
|
|
326
336
|
|
|
327
337
|
return "UNKNOWN"
|
|
328
338
|
|
|
339
|
+
def _detect_parameter_casts(self, expression: Optional["exp.Expression"]) -> "dict[int, str]":
|
|
340
|
+
"""Detect explicit type casts on parameters in the AST.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
expression: SQLGlot AST expression to analyze
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
Dict mapping parameter positions (1-based) to cast type names
|
|
347
|
+
"""
|
|
348
|
+
if not expression:
|
|
349
|
+
return {}
|
|
350
|
+
|
|
351
|
+
cast_positions = {}
|
|
352
|
+
|
|
353
|
+
# Walk all nodes in order to track parameter positions
|
|
354
|
+
for node in expression.walk():
|
|
355
|
+
# Check for cast nodes with parameter children
|
|
356
|
+
if isinstance(node, exp.Cast):
|
|
357
|
+
cast_target = node.this
|
|
358
|
+
position = None
|
|
359
|
+
|
|
360
|
+
if isinstance(cast_target, exp.Parameter):
|
|
361
|
+
# Handle $1, $2 style parameters
|
|
362
|
+
param_value = cast_target.this
|
|
363
|
+
if isinstance(param_value, exp.Literal):
|
|
364
|
+
position = int(param_value.this)
|
|
365
|
+
elif isinstance(cast_target, exp.Placeholder):
|
|
366
|
+
# For ? style, we need to count position (will implement if needed)
|
|
367
|
+
pass
|
|
368
|
+
elif isinstance(cast_target, exp.Column):
|
|
369
|
+
# Handle cases where $1 gets parsed as a column
|
|
370
|
+
column_name = str(cast_target.this) if cast_target.this else str(cast_target)
|
|
371
|
+
if column_name.startswith("$") and column_name[1:].isdigit():
|
|
372
|
+
position = int(column_name[1:])
|
|
373
|
+
|
|
374
|
+
if position is not None:
|
|
375
|
+
# Extract cast type
|
|
376
|
+
if isinstance(node.to, exp.DataType):
|
|
377
|
+
cast_type = node.to.this.value if hasattr(node.to.this, "value") else str(node.to.this)
|
|
378
|
+
else:
|
|
379
|
+
cast_type = str(node.to)
|
|
380
|
+
cast_positions[position] = cast_type.upper()
|
|
381
|
+
|
|
382
|
+
return cast_positions
|
|
383
|
+
|
|
329
384
|
def _apply_final_transformations(
|
|
330
385
|
self, expression: "Optional[exp.Expression]", sql: str, parameters: Any, dialect_str: "Optional[str]"
|
|
331
386
|
) -> "tuple[str, Any]":
|
sqlspec/core/parameters.py
CHANGED
|
@@ -619,7 +619,9 @@ class ParameterConverter:
|
|
|
619
619
|
|
|
620
620
|
return converted_sql
|
|
621
621
|
|
|
622
|
-
def _convert_sequence_to_dict(
|
|
622
|
+
def _convert_sequence_to_dict(
|
|
623
|
+
self, parameters: "Sequence[Any]", param_info: "list[ParameterInfo]"
|
|
624
|
+
) -> "dict[str, Any]":
|
|
623
625
|
"""Convert sequence parameters to dictionary for named styles.
|
|
624
626
|
|
|
625
627
|
Args:
|
|
@@ -637,7 +639,7 @@ class ParameterConverter:
|
|
|
637
639
|
return param_dict
|
|
638
640
|
|
|
639
641
|
def _extract_param_value_mixed_styles(
|
|
640
|
-
self, param: ParameterInfo, parameters: Mapping, param_keys: "list[str]"
|
|
642
|
+
self, param: ParameterInfo, parameters: "Mapping[str, Any]", param_keys: "list[str]"
|
|
641
643
|
) -> "tuple[Any, bool]":
|
|
642
644
|
"""Extract parameter value for mixed style parameters.
|
|
643
645
|
|
|
@@ -670,7 +672,9 @@ class ParameterConverter:
|
|
|
670
672
|
|
|
671
673
|
return None, False
|
|
672
674
|
|
|
673
|
-
def _extract_param_value_single_style(
|
|
675
|
+
def _extract_param_value_single_style(
|
|
676
|
+
self, param: ParameterInfo, parameters: "Mapping[str, Any]"
|
|
677
|
+
) -> "tuple[Any, bool]":
|
|
674
678
|
"""Extract parameter value for single style parameters.
|
|
675
679
|
|
|
676
680
|
Args:
|
sqlspec/core/statement.py
CHANGED
|
@@ -59,6 +59,7 @@ PROCESSED_STATE_SLOTS: Final = (
|
|
|
59
59
|
"execution_parameters",
|
|
60
60
|
"parsed_expression",
|
|
61
61
|
"operation_type",
|
|
62
|
+
"parameter_casts",
|
|
62
63
|
"validation_errors",
|
|
63
64
|
"is_many",
|
|
64
65
|
)
|
|
@@ -81,6 +82,7 @@ class ProcessedState:
|
|
|
81
82
|
execution_parameters: Any,
|
|
82
83
|
parsed_expression: "Optional[exp.Expression]" = None,
|
|
83
84
|
operation_type: "OperationType" = "UNKNOWN",
|
|
85
|
+
parameter_casts: "Optional[dict[int, str]]" = None,
|
|
84
86
|
validation_errors: "Optional[list[str]]" = None,
|
|
85
87
|
is_many: bool = False,
|
|
86
88
|
) -> None:
|
|
@@ -88,6 +90,7 @@ class ProcessedState:
|
|
|
88
90
|
self.execution_parameters = execution_parameters
|
|
89
91
|
self.parsed_expression = parsed_expression
|
|
90
92
|
self.operation_type = operation_type
|
|
93
|
+
self.parameter_casts = parameter_casts or {}
|
|
91
94
|
self.validation_errors = validation_errors or []
|
|
92
95
|
self.is_many = is_many
|
|
93
96
|
|
|
@@ -447,6 +450,7 @@ class SQL:
|
|
|
447
450
|
execution_parameters=compiled_result.execution_parameters,
|
|
448
451
|
parsed_expression=compiled_result.expression,
|
|
449
452
|
operation_type=compiled_result.operation_type,
|
|
453
|
+
parameter_casts=compiled_result.parameter_casts,
|
|
450
454
|
validation_errors=[],
|
|
451
455
|
is_many=self._is_many,
|
|
452
456
|
)
|
|
@@ -458,6 +462,7 @@ class SQL:
|
|
|
458
462
|
compiled_sql=self._raw_sql,
|
|
459
463
|
execution_parameters=self._named_parameters or self._positional_parameters,
|
|
460
464
|
operation_type="UNKNOWN",
|
|
465
|
+
parameter_casts={},
|
|
461
466
|
is_many=self._is_many,
|
|
462
467
|
)
|
|
463
468
|
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Centralized type conversion and detection for SQLSpec.
|
|
2
|
+
|
|
3
|
+
Provides unified type detection and conversion utilities for all database
|
|
4
|
+
adapters, with MyPyC-compatible optimizations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
from datetime import date, datetime, time, timezone
|
|
9
|
+
from decimal import Decimal
|
|
10
|
+
from typing import Any, Callable, Final, Optional
|
|
11
|
+
from uuid import UUID
|
|
12
|
+
|
|
13
|
+
from sqlspec._serialization import decode_json
|
|
14
|
+
|
|
15
|
+
# MyPyC-compatible pre-compiled patterns
|
|
16
|
+
SPECIAL_TYPE_REGEX: Final[re.Pattern[str]] = re.compile(
|
|
17
|
+
r"^(?:"
|
|
18
|
+
r"(?P<uuid>[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})|"
|
|
19
|
+
r"(?P<iso_datetime>\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:?\d{2})?)|"
|
|
20
|
+
r"(?P<iso_date>\d{4}-\d{2}-\d{2})|"
|
|
21
|
+
r"(?P<iso_time>\d{2}:\d{2}:\d{2}(?:\.\d+)?)|"
|
|
22
|
+
r"(?P<json>[\[{].*[\]}])|"
|
|
23
|
+
r"(?P<ipv4>(?:\d{1,3}\.){3}\d{1,3})|"
|
|
24
|
+
r"(?P<ipv6>(?:[0-9a-f]{1,4}:){7}[0-9a-f]{1,4})|"
|
|
25
|
+
r"(?P<mac>(?:[0-9a-f]{2}:){5}[0-9a-f]{2})"
|
|
26
|
+
r")$",
|
|
27
|
+
re.IGNORECASE | re.DOTALL,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BaseTypeConverter:
|
|
32
|
+
"""Universal type detection and conversion for all adapters.
|
|
33
|
+
|
|
34
|
+
Provides centralized type detection and conversion functionality
|
|
35
|
+
that can be used across all database adapters to ensure consistent
|
|
36
|
+
behavior. Users can extend this class for custom type conversion needs.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
__slots__ = ()
|
|
40
|
+
|
|
41
|
+
def detect_type(self, value: str) -> Optional[str]:
|
|
42
|
+
"""Detect special types from string values.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
value: String value to analyze.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Type name if detected, None otherwise.
|
|
49
|
+
"""
|
|
50
|
+
if not isinstance(value, str): # pyright: ignore
|
|
51
|
+
return None
|
|
52
|
+
if not value:
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
match = SPECIAL_TYPE_REGEX.match(value)
|
|
56
|
+
if not match:
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
return next((k for k, v in match.groupdict().items() if v), None)
|
|
60
|
+
|
|
61
|
+
def convert_value(self, value: str, detected_type: str) -> Any:
|
|
62
|
+
"""Convert string value to appropriate Python type.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
value: String value to convert.
|
|
66
|
+
detected_type: Detected type name.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Converted value in appropriate Python type.
|
|
70
|
+
"""
|
|
71
|
+
converter = _TYPE_CONVERTERS.get(detected_type)
|
|
72
|
+
if converter:
|
|
73
|
+
return converter(value)
|
|
74
|
+
return value
|
|
75
|
+
|
|
76
|
+
def convert_if_detected(self, value: Any) -> Any:
|
|
77
|
+
"""Convert value only if special type detected, else return original.
|
|
78
|
+
|
|
79
|
+
This method provides performance optimization by avoiding expensive
|
|
80
|
+
regex operations on plain strings that don't contain special characters.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
value: Value to potentially convert.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Converted value if special type detected, original value otherwise.
|
|
87
|
+
"""
|
|
88
|
+
if not isinstance(value, str):
|
|
89
|
+
return value
|
|
90
|
+
|
|
91
|
+
# Quick pre-check for performance - avoid regex on plain strings
|
|
92
|
+
if not any(c in value for c in ["{", "[", "-", ":", "T"]):
|
|
93
|
+
return value # Skip regex entirely for "hello world" etc.
|
|
94
|
+
|
|
95
|
+
detected_type = self.detect_type(value)
|
|
96
|
+
if detected_type:
|
|
97
|
+
try:
|
|
98
|
+
return self.convert_value(value, detected_type)
|
|
99
|
+
except Exception:
|
|
100
|
+
# If conversion fails, return original value
|
|
101
|
+
return value
|
|
102
|
+
return value
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def convert_uuid(value: str) -> UUID:
|
|
106
|
+
"""Convert UUID string to UUID object.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
value: UUID string.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
UUID object.
|
|
113
|
+
"""
|
|
114
|
+
return UUID(value)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def convert_iso_datetime(value: str) -> datetime:
|
|
118
|
+
"""Convert ISO 8601 datetime string to datetime object.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
value: ISO datetime string.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
datetime object.
|
|
125
|
+
"""
|
|
126
|
+
# Handle various ISO formats with timezone
|
|
127
|
+
if value.endswith("Z"):
|
|
128
|
+
value = value[:-1] + "+00:00"
|
|
129
|
+
|
|
130
|
+
# Replace space with T for standard ISO format
|
|
131
|
+
if " " in value and "T" not in value:
|
|
132
|
+
value = value.replace(" ", "T")
|
|
133
|
+
|
|
134
|
+
return datetime.fromisoformat(value)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def convert_iso_date(value: str) -> date:
|
|
138
|
+
"""Convert ISO date string to date object.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
value: ISO date string.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
date object.
|
|
145
|
+
"""
|
|
146
|
+
return date.fromisoformat(value)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def convert_iso_time(value: str) -> time:
|
|
150
|
+
"""Convert ISO time string to time object.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
value: ISO time string.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
time object.
|
|
157
|
+
"""
|
|
158
|
+
return time.fromisoformat(value)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def convert_json(value: str) -> Any:
|
|
162
|
+
"""Convert JSON string to Python object.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
value: JSON string.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Decoded Python object.
|
|
169
|
+
"""
|
|
170
|
+
return decode_json(value)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def convert_decimal(value: str) -> Decimal:
|
|
174
|
+
"""Convert string to Decimal for precise arithmetic.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
value: Decimal string.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Decimal object.
|
|
181
|
+
"""
|
|
182
|
+
return Decimal(value)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# Converter registry
|
|
186
|
+
_TYPE_CONVERTERS: Final[dict[str, Callable[[str], Any]]] = {
|
|
187
|
+
"uuid": convert_uuid,
|
|
188
|
+
"iso_datetime": convert_iso_datetime,
|
|
189
|
+
"iso_date": convert_iso_date,
|
|
190
|
+
"iso_time": convert_iso_time,
|
|
191
|
+
"json": convert_json,
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def format_datetime_rfc3339(dt: datetime) -> str:
|
|
196
|
+
"""Format datetime as RFC 3339 compliant string.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
dt: datetime object.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
RFC 3339 formatted datetime string.
|
|
203
|
+
"""
|
|
204
|
+
if dt.tzinfo is None:
|
|
205
|
+
dt = dt.replace(tzinfo=timezone.utc)
|
|
206
|
+
return dt.isoformat()
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def parse_datetime_rfc3339(dt_str: str) -> datetime:
|
|
210
|
+
"""Parse RFC 3339 datetime string.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
dt_str: RFC 3339 datetime string.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
datetime object.
|
|
217
|
+
"""
|
|
218
|
+
# Handle Z suffix
|
|
219
|
+
if dt_str.endswith("Z"):
|
|
220
|
+
dt_str = dt_str[:-1] + "+00:00"
|
|
221
|
+
return datetime.fromisoformat(dt_str)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
__all__ = (
|
|
225
|
+
"BaseTypeConverter",
|
|
226
|
+
"convert_decimal",
|
|
227
|
+
"convert_iso_date",
|
|
228
|
+
"convert_iso_datetime",
|
|
229
|
+
"convert_iso_time",
|
|
230
|
+
"convert_json",
|
|
231
|
+
"convert_uuid",
|
|
232
|
+
"format_datetime_rfc3339",
|
|
233
|
+
"parse_datetime_rfc3339",
|
|
234
|
+
)
|
sqlspec/driver/__init__.py
CHANGED
|
@@ -3,16 +3,19 @@
|
|
|
3
3
|
from typing import Union
|
|
4
4
|
|
|
5
5
|
from sqlspec.driver import mixins
|
|
6
|
-
from sqlspec.driver._async import AsyncDriverAdapterBase
|
|
7
|
-
from sqlspec.driver._common import CommonDriverAttributesMixin, ExecutionResult
|
|
8
|
-
from sqlspec.driver._sync import SyncDriverAdapterBase
|
|
6
|
+
from sqlspec.driver._async import AsyncDataDictionaryBase, AsyncDriverAdapterBase
|
|
7
|
+
from sqlspec.driver._common import CommonDriverAttributesMixin, ExecutionResult, VersionInfo
|
|
8
|
+
from sqlspec.driver._sync import SyncDataDictionaryBase, SyncDriverAdapterBase
|
|
9
9
|
|
|
10
10
|
__all__ = (
|
|
11
|
+
"AsyncDataDictionaryBase",
|
|
11
12
|
"AsyncDriverAdapterBase",
|
|
12
13
|
"CommonDriverAttributesMixin",
|
|
13
14
|
"DriverAdapterProtocol",
|
|
14
15
|
"ExecutionResult",
|
|
16
|
+
"SyncDataDictionaryBase",
|
|
15
17
|
"SyncDriverAdapterBase",
|
|
18
|
+
"VersionInfo",
|
|
16
19
|
"mixins",
|
|
17
20
|
)
|
|
18
21
|
|