sqlspec 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (84) hide show
  1. sqlspec/_serialization.py +223 -21
  2. sqlspec/_sql.py +12 -50
  3. sqlspec/_typing.py +9 -0
  4. sqlspec/adapters/adbc/config.py +8 -1
  5. sqlspec/adapters/adbc/data_dictionary.py +290 -0
  6. sqlspec/adapters/adbc/driver.py +127 -18
  7. sqlspec/adapters/adbc/type_converter.py +159 -0
  8. sqlspec/adapters/aiosqlite/config.py +3 -0
  9. sqlspec/adapters/aiosqlite/data_dictionary.py +117 -0
  10. sqlspec/adapters/aiosqlite/driver.py +17 -3
  11. sqlspec/adapters/asyncmy/_types.py +1 -1
  12. sqlspec/adapters/asyncmy/config.py +11 -8
  13. sqlspec/adapters/asyncmy/data_dictionary.py +122 -0
  14. sqlspec/adapters/asyncmy/driver.py +31 -7
  15. sqlspec/adapters/asyncpg/config.py +3 -0
  16. sqlspec/adapters/asyncpg/data_dictionary.py +134 -0
  17. sqlspec/adapters/asyncpg/driver.py +19 -4
  18. sqlspec/adapters/bigquery/config.py +3 -0
  19. sqlspec/adapters/bigquery/data_dictionary.py +109 -0
  20. sqlspec/adapters/bigquery/driver.py +21 -3
  21. sqlspec/adapters/bigquery/type_converter.py +93 -0
  22. sqlspec/adapters/duckdb/_types.py +1 -1
  23. sqlspec/adapters/duckdb/config.py +2 -0
  24. sqlspec/adapters/duckdb/data_dictionary.py +124 -0
  25. sqlspec/adapters/duckdb/driver.py +32 -5
  26. sqlspec/adapters/duckdb/pool.py +1 -1
  27. sqlspec/adapters/duckdb/type_converter.py +103 -0
  28. sqlspec/adapters/oracledb/config.py +6 -0
  29. sqlspec/adapters/oracledb/data_dictionary.py +442 -0
  30. sqlspec/adapters/oracledb/driver.py +63 -9
  31. sqlspec/adapters/oracledb/migrations.py +51 -67
  32. sqlspec/adapters/oracledb/type_converter.py +132 -0
  33. sqlspec/adapters/psqlpy/config.py +3 -0
  34. sqlspec/adapters/psqlpy/data_dictionary.py +133 -0
  35. sqlspec/adapters/psqlpy/driver.py +23 -179
  36. sqlspec/adapters/psqlpy/type_converter.py +73 -0
  37. sqlspec/adapters/psycopg/config.py +6 -0
  38. sqlspec/adapters/psycopg/data_dictionary.py +257 -0
  39. sqlspec/adapters/psycopg/driver.py +40 -5
  40. sqlspec/adapters/sqlite/config.py +3 -0
  41. sqlspec/adapters/sqlite/data_dictionary.py +117 -0
  42. sqlspec/adapters/sqlite/driver.py +18 -3
  43. sqlspec/adapters/sqlite/pool.py +13 -4
  44. sqlspec/builder/_base.py +82 -42
  45. sqlspec/builder/_column.py +57 -24
  46. sqlspec/builder/_ddl.py +84 -34
  47. sqlspec/builder/_insert.py +30 -52
  48. sqlspec/builder/_parsing_utils.py +104 -8
  49. sqlspec/builder/_select.py +147 -2
  50. sqlspec/builder/mixins/_cte_and_set_ops.py +1 -2
  51. sqlspec/builder/mixins/_join_operations.py +14 -30
  52. sqlspec/builder/mixins/_merge_operations.py +167 -61
  53. sqlspec/builder/mixins/_order_limit_operations.py +3 -10
  54. sqlspec/builder/mixins/_select_operations.py +3 -9
  55. sqlspec/builder/mixins/_update_operations.py +3 -22
  56. sqlspec/builder/mixins/_where_clause.py +4 -10
  57. sqlspec/cli.py +246 -140
  58. sqlspec/config.py +33 -19
  59. sqlspec/core/cache.py +2 -2
  60. sqlspec/core/compiler.py +56 -1
  61. sqlspec/core/parameters.py +7 -3
  62. sqlspec/core/statement.py +5 -0
  63. sqlspec/core/type_conversion.py +234 -0
  64. sqlspec/driver/__init__.py +6 -3
  65. sqlspec/driver/_async.py +106 -3
  66. sqlspec/driver/_common.py +156 -4
  67. sqlspec/driver/_sync.py +106 -3
  68. sqlspec/exceptions.py +5 -0
  69. sqlspec/migrations/__init__.py +4 -3
  70. sqlspec/migrations/base.py +153 -14
  71. sqlspec/migrations/commands.py +34 -96
  72. sqlspec/migrations/context.py +145 -0
  73. sqlspec/migrations/loaders.py +25 -8
  74. sqlspec/migrations/runner.py +352 -82
  75. sqlspec/typing.py +2 -0
  76. sqlspec/utils/config_resolver.py +153 -0
  77. sqlspec/utils/serializers.py +50 -2
  78. {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/METADATA +1 -1
  79. sqlspec-0.26.0.dist-info/RECORD +157 -0
  80. sqlspec-0.25.0.dist-info/RECORD +0 -139
  81. {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/WHEEL +0 -0
  82. {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/entry_points.txt +0 -0
  83. {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/licenses/LICENSE +0 -0
  84. {sqlspec-0.25.0.dist-info → sqlspec-0.26.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/config.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, TypeVar, Union, cast
3
4
 
4
5
  from typing_extensions import NotRequired, TypedDict
5
6
 
@@ -11,11 +12,10 @@ from sqlspec.utils.logging import get_logger
11
12
  if TYPE_CHECKING:
12
13
  from collections.abc import Awaitable
13
14
  from contextlib import AbstractAsyncContextManager, AbstractContextManager
14
- from pathlib import Path
15
15
 
16
16
  from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
17
17
  from sqlspec.loader import SQLFileLoader
18
- from sqlspec.migrations.commands import MigrationCommands
18
+ from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands
19
19
 
20
20
 
21
21
  __all__ = (
@@ -89,6 +89,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
89
89
  __slots__ = (
90
90
  "_migration_commands",
91
91
  "_migration_loader",
92
+ "bind_key",
92
93
  "driver_features",
93
94
  "migration_config",
94
95
  "pool_instance",
@@ -96,7 +97,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
96
97
  )
97
98
 
98
99
  _migration_loader: "SQLFileLoader"
99
- _migration_commands: "MigrationCommands"
100
+ _migration_commands: "Union[SyncMigrationCommands, AsyncMigrationCommands]"
100
101
  driver_type: "ClassVar[type[Any]]"
101
102
  connection_type: "ClassVar[type[Any]]"
102
103
  is_async: "ClassVar[bool]" = False
@@ -105,6 +106,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
105
106
  supports_native_arrow_export: "ClassVar[bool]" = False
106
107
  supports_native_parquet_import: "ClassVar[bool]" = False
107
108
  supports_native_parquet_export: "ClassVar[bool]" = False
109
+ bind_key: "Optional[str]"
108
110
  statement_config: "StatementConfig"
109
111
  pool_instance: "Optional[PoolT]"
110
112
  migration_config: "Union[dict[str, Any], MigrationConfig]"
@@ -176,10 +178,10 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
176
178
  at runtime when needed.
177
179
  """
178
180
  from sqlspec.loader import SQLFileLoader
179
- from sqlspec.migrations.commands import MigrationCommands
181
+ from sqlspec.migrations.commands import create_migration_commands
180
182
 
181
183
  self._migration_loader = SQLFileLoader()
182
- self._migration_commands = MigrationCommands(self) # type: ignore[arg-type]
184
+ self._migration_commands = create_migration_commands(self) # type: ignore[arg-type]
183
185
 
184
186
  def _ensure_migration_loader(self) -> "SQLFileLoader":
185
187
  """Get the migration SQL loader and auto-load files if needed.
@@ -200,7 +202,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
200
202
 
201
203
  return self._migration_loader
202
204
 
203
- def _ensure_migration_commands(self) -> "MigrationCommands":
205
+ def _ensure_migration_commands(self) -> "Union[SyncMigrationCommands, AsyncMigrationCommands]":
204
206
  """Get the migration commands instance.
205
207
 
206
208
  Returns:
@@ -225,7 +227,6 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
225
227
  Args:
226
228
  *paths: One or more file paths or directory paths to load migration SQL files from.
227
229
  """
228
- from pathlib import Path
229
230
 
230
231
  loader = self._ensure_migration_loader()
231
232
  for path in paths:
@@ -236,7 +237,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
236
237
  else:
237
238
  logger.warning("Migration path does not exist: %s", path_obj)
238
239
 
239
- def get_migration_commands(self) -> "MigrationCommands":
240
+ def get_migration_commands(self) -> "Union[SyncMigrationCommands, AsyncMigrationCommands]":
240
241
  """Get migration commands for this configuration.
241
242
 
242
243
  Returns:
@@ -244,25 +245,27 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
244
245
  """
245
246
  return self._ensure_migration_commands()
246
247
 
247
- def migrate_up(self, revision: str = "head") -> None:
248
+ async def migrate_up(self, revision: str = "head") -> None:
248
249
  """Apply migrations up to the specified revision.
249
250
 
250
251
  Args:
251
252
  revision: Target revision or "head" for latest. Defaults to "head".
252
253
  """
253
254
  commands = self._ensure_migration_commands()
254
- commands.upgrade(revision)
255
255
 
256
- def migrate_down(self, revision: str = "-1") -> None:
256
+ await cast("AsyncMigrationCommands", commands).upgrade(revision)
257
+
258
+ async def migrate_down(self, revision: str = "-1") -> None:
257
259
  """Apply migrations down to the specified revision.
258
260
 
259
261
  Args:
260
262
  revision: Target revision, "-1" for one step back, or "base" for all migrations. Defaults to "-1".
261
263
  """
262
264
  commands = self._ensure_migration_commands()
263
- commands.downgrade(revision)
264
265
 
265
- def get_current_migration(self, verbose: bool = False) -> "Optional[str]":
266
+ await cast("AsyncMigrationCommands", commands).downgrade(revision)
267
+
268
+ async def get_current_migration(self, verbose: bool = False) -> "Optional[str]":
266
269
  """Get the current migration version.
267
270
 
268
271
  Args:
@@ -272,9 +275,10 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
272
275
  Current migration version or None if no migrations applied.
273
276
  """
274
277
  commands = self._ensure_migration_commands()
275
- return commands.current(verbose=verbose)
276
278
 
277
- def create_migration(self, message: str, file_type: str = "sql") -> None:
279
+ return await cast("AsyncMigrationCommands", commands).current(verbose=verbose)
280
+
281
+ async def create_migration(self, message: str, file_type: str = "sql") -> None:
278
282
  """Create a new migration file.
279
283
 
280
284
  Args:
@@ -282,9 +286,10 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
282
286
  file_type: Type of migration file to create ('sql' or 'py'). Defaults to 'sql'.
283
287
  """
284
288
  commands = self._ensure_migration_commands()
285
- commands.revision(message, file_type)
286
289
 
287
- def init_migrations(self, directory: "Optional[str]" = None, package: bool = True) -> None:
290
+ await cast("AsyncMigrationCommands", commands).revision(message, file_type)
291
+
292
+ async def init_migrations(self, directory: "Optional[str]" = None, package: bool = True) -> None:
288
293
  """Initialize migration directory structure.
289
294
 
290
295
  Args:
@@ -297,7 +302,8 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
297
302
 
298
303
  commands = self._ensure_migration_commands()
299
304
  assert directory is not None
300
- commands.init(directory, package)
305
+
306
+ await cast("AsyncMigrationCommands", commands).init(directory, package)
301
307
 
302
308
 
303
309
  class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
@@ -315,7 +321,9 @@ class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
315
321
  migration_config: "Optional[Union[dict[str, Any], MigrationConfig]]" = None,
316
322
  statement_config: "Optional[StatementConfig]" = None,
317
323
  driver_features: "Optional[dict[str, Any]]" = None,
324
+ bind_key: "Optional[str]" = None,
318
325
  ) -> None:
326
+ self.bind_key = bind_key
319
327
  self.pool_instance = None
320
328
  self.connection_config = connection_config or {}
321
329
  self.migration_config: Union[dict[str, Any], MigrationConfig] = migration_config or {}
@@ -369,7 +377,9 @@ class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
369
377
  migration_config: "Optional[Union[dict[str, Any], MigrationConfig]]" = None,
370
378
  statement_config: "Optional[StatementConfig]" = None,
371
379
  driver_features: "Optional[dict[str, Any]]" = None,
380
+ bind_key: "Optional[str]" = None,
372
381
  ) -> None:
382
+ self.bind_key = bind_key
373
383
  self.pool_instance = None
374
384
  self.connection_config = connection_config or {}
375
385
  self.migration_config: Union[dict[str, Any], MigrationConfig] = migration_config or {}
@@ -424,7 +434,9 @@ class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
424
434
  migration_config: "Optional[Union[dict[str, Any], MigrationConfig]]" = None,
425
435
  statement_config: "Optional[StatementConfig]" = None,
426
436
  driver_features: "Optional[dict[str, Any]]" = None,
437
+ bind_key: "Optional[str]" = None,
427
438
  ) -> None:
439
+ self.bind_key = bind_key
428
440
  self.pool_instance = pool_instance
429
441
  self.pool_config = pool_config or {}
430
442
  self.migration_config: Union[dict[str, Any], MigrationConfig] = migration_config or {}
@@ -501,7 +513,9 @@ class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
501
513
  migration_config: "Optional[Union[dict[str, Any], MigrationConfig]]" = None,
502
514
  statement_config: "Optional[StatementConfig]" = None,
503
515
  driver_features: "Optional[dict[str, Any]]" = None,
516
+ bind_key: "Optional[str]" = None,
504
517
  ) -> None:
518
+ self.bind_key = bind_key
505
519
  self.pool_instance = pool_instance
506
520
  self.pool_config = pool_config or {}
507
521
  self.migration_config: Union[dict[str, Any], MigrationConfig] = migration_config or {}
sqlspec/core/cache.py CHANGED
@@ -14,7 +14,7 @@ Components:
14
14
  import threading
15
15
  import time
16
16
  from dataclasses import dataclass
17
- from typing import TYPE_CHECKING, Any, Final, Optional
17
+ from typing import TYPE_CHECKING, Any, Final, Optional, Union
18
18
 
19
19
  from mypy_extensions import mypyc_attr
20
20
  from typing_extensions import TypeVar
@@ -558,7 +558,7 @@ class CachedStatement:
558
558
  """
559
559
 
560
560
  compiled_sql: str
561
- parameters: Optional[tuple[Any, ...]] # None allowed for static script compilation
561
+ parameters: Optional[Union[tuple[Any, ...], dict[str, Any]]] # None allowed for static script compilation
562
562
  expression: Optional["exp.Expression"]
563
563
 
564
564
  def get_parameters_view(self) -> "ParametersView":
sqlspec/core/compiler.py CHANGED
@@ -72,6 +72,7 @@ class CompiledSQL:
72
72
  "execution_parameters",
73
73
  "expression",
74
74
  "operation_type",
75
+ "parameter_casts",
75
76
  "parameter_style",
76
77
  "supports_many",
77
78
  )
@@ -86,6 +87,7 @@ class CompiledSQL:
86
87
  expression: Optional["exp.Expression"] = None,
87
88
  parameter_style: Optional[str] = None,
88
89
  supports_many: bool = False,
90
+ parameter_casts: Optional["dict[int, str]"] = None,
89
91
  ) -> None:
90
92
  """Initialize compiled result.
91
93
 
@@ -96,6 +98,7 @@ class CompiledSQL:
96
98
  expression: SQLGlot AST expression
97
99
  parameter_style: Parameter style used in compilation
98
100
  supports_many: Whether this supports execute_many operations
101
+ parameter_casts: Mapping of parameter positions to cast types
99
102
  """
100
103
  self.compiled_sql = compiled_sql
101
104
  self.execution_parameters = execution_parameters
@@ -103,6 +106,7 @@ class CompiledSQL:
103
106
  self.expression = expression
104
107
  self.parameter_style = parameter_style
105
108
  self.supports_many = supports_many
109
+ self.parameter_casts = parameter_casts or {}
106
110
  self._hash: Optional[int] = None
107
111
 
108
112
  def __hash__(self) -> int:
@@ -224,11 +228,13 @@ class SQLProcessor:
224
228
  ast_was_transformed = False
225
229
  expression = None
226
230
  operation_type: OperationType = "EXECUTE"
231
+ parameter_casts: dict[int, str] = {}
227
232
 
228
233
  if self._config.enable_parsing:
229
234
  try:
230
235
  expression = sqlglot.parse_one(sqlglot_sql, dialect=dialect_str)
231
236
  operation_type = self._detect_operation_type(expression)
237
+ parameter_casts = self._detect_parameter_casts(expression)
232
238
 
233
239
  ast_transformer = self._config.parameter_config.ast_transformer
234
240
  if ast_transformer:
@@ -238,6 +244,7 @@ class SQLProcessor:
238
244
  except ParseError:
239
245
  expression = None
240
246
  operation_type = "EXECUTE"
247
+ parameter_casts = {}
241
248
 
242
249
  if self._config.parameter_config.needs_static_script_compilation and processed_params is None:
243
250
  final_sql, final_params = processed_sql, processed_params
@@ -264,6 +271,7 @@ class SQLProcessor:
264
271
  expression=expression,
265
272
  parameter_style=self._config.parameter_config.default_parameter_style.value,
266
273
  supports_many=isinstance(final_params, list) and len(final_params) > 0,
274
+ parameter_casts=parameter_casts,
267
275
  )
268
276
 
269
277
  except SQLSpecError:
@@ -271,7 +279,9 @@ class SQLProcessor:
271
279
  raise
272
280
  except Exception as e:
273
281
  logger.warning("Compilation failed, using fallback: %s", e)
274
- return CompiledSQL(compiled_sql=sql, execution_parameters=parameters, operation_type="UNKNOWN")
282
+ return CompiledSQL(
283
+ compiled_sql=sql, execution_parameters=parameters, operation_type="UNKNOWN", parameter_casts={}
284
+ )
275
285
 
276
286
  def _make_cache_key(self, sql: str, parameters: Any, is_many: bool = False) -> str:
277
287
  """Generate cache key.
@@ -326,6 +336,51 @@ class SQLProcessor:
326
336
 
327
337
  return "UNKNOWN"
328
338
 
339
+ def _detect_parameter_casts(self, expression: Optional["exp.Expression"]) -> "dict[int, str]":
340
+ """Detect explicit type casts on parameters in the AST.
341
+
342
+ Args:
343
+ expression: SQLGlot AST expression to analyze
344
+
345
+ Returns:
346
+ Dict mapping parameter positions (1-based) to cast type names
347
+ """
348
+ if not expression:
349
+ return {}
350
+
351
+ cast_positions = {}
352
+
353
+ # Walk all nodes in order to track parameter positions
354
+ for node in expression.walk():
355
+ # Check for cast nodes with parameter children
356
+ if isinstance(node, exp.Cast):
357
+ cast_target = node.this
358
+ position = None
359
+
360
+ if isinstance(cast_target, exp.Parameter):
361
+ # Handle $1, $2 style parameters
362
+ param_value = cast_target.this
363
+ if isinstance(param_value, exp.Literal):
364
+ position = int(param_value.this)
365
+ elif isinstance(cast_target, exp.Placeholder):
366
+ # For ? style, we need to count position (will implement if needed)
367
+ pass
368
+ elif isinstance(cast_target, exp.Column):
369
+ # Handle cases where $1 gets parsed as a column
370
+ column_name = str(cast_target.this) if cast_target.this else str(cast_target)
371
+ if column_name.startswith("$") and column_name[1:].isdigit():
372
+ position = int(column_name[1:])
373
+
374
+ if position is not None:
375
+ # Extract cast type
376
+ if isinstance(node.to, exp.DataType):
377
+ cast_type = node.to.this.value if hasattr(node.to.this, "value") else str(node.to.this)
378
+ else:
379
+ cast_type = str(node.to)
380
+ cast_positions[position] = cast_type.upper()
381
+
382
+ return cast_positions
383
+
329
384
  def _apply_final_transformations(
330
385
  self, expression: "Optional[exp.Expression]", sql: str, parameters: Any, dialect_str: "Optional[str]"
331
386
  ) -> "tuple[str, Any]":
@@ -619,7 +619,9 @@ class ParameterConverter:
619
619
 
620
620
  return converted_sql
621
621
 
622
- def _convert_sequence_to_dict(self, parameters: Sequence, param_info: "list[ParameterInfo]") -> "dict[str, Any]":
622
+ def _convert_sequence_to_dict(
623
+ self, parameters: "Sequence[Any]", param_info: "list[ParameterInfo]"
624
+ ) -> "dict[str, Any]":
623
625
  """Convert sequence parameters to dictionary for named styles.
624
626
 
625
627
  Args:
@@ -637,7 +639,7 @@ class ParameterConverter:
637
639
  return param_dict
638
640
 
639
641
  def _extract_param_value_mixed_styles(
640
- self, param: ParameterInfo, parameters: Mapping, param_keys: "list[str]"
642
+ self, param: ParameterInfo, parameters: "Mapping[str, Any]", param_keys: "list[str]"
641
643
  ) -> "tuple[Any, bool]":
642
644
  """Extract parameter value for mixed style parameters.
643
645
 
@@ -670,7 +672,9 @@ class ParameterConverter:
670
672
 
671
673
  return None, False
672
674
 
673
- def _extract_param_value_single_style(self, param: ParameterInfo, parameters: Mapping) -> "tuple[Any, bool]":
675
+ def _extract_param_value_single_style(
676
+ self, param: ParameterInfo, parameters: "Mapping[str, Any]"
677
+ ) -> "tuple[Any, bool]":
674
678
  """Extract parameter value for single style parameters.
675
679
 
676
680
  Args:
sqlspec/core/statement.py CHANGED
@@ -59,6 +59,7 @@ PROCESSED_STATE_SLOTS: Final = (
59
59
  "execution_parameters",
60
60
  "parsed_expression",
61
61
  "operation_type",
62
+ "parameter_casts",
62
63
  "validation_errors",
63
64
  "is_many",
64
65
  )
@@ -81,6 +82,7 @@ class ProcessedState:
81
82
  execution_parameters: Any,
82
83
  parsed_expression: "Optional[exp.Expression]" = None,
83
84
  operation_type: "OperationType" = "UNKNOWN",
85
+ parameter_casts: "Optional[dict[int, str]]" = None,
84
86
  validation_errors: "Optional[list[str]]" = None,
85
87
  is_many: bool = False,
86
88
  ) -> None:
@@ -88,6 +90,7 @@ class ProcessedState:
88
90
  self.execution_parameters = execution_parameters
89
91
  self.parsed_expression = parsed_expression
90
92
  self.operation_type = operation_type
93
+ self.parameter_casts = parameter_casts or {}
91
94
  self.validation_errors = validation_errors or []
92
95
  self.is_many = is_many
93
96
 
@@ -447,6 +450,7 @@ class SQL:
447
450
  execution_parameters=compiled_result.execution_parameters,
448
451
  parsed_expression=compiled_result.expression,
449
452
  operation_type=compiled_result.operation_type,
453
+ parameter_casts=compiled_result.parameter_casts,
450
454
  validation_errors=[],
451
455
  is_many=self._is_many,
452
456
  )
@@ -458,6 +462,7 @@ class SQL:
458
462
  compiled_sql=self._raw_sql,
459
463
  execution_parameters=self._named_parameters or self._positional_parameters,
460
464
  operation_type="UNKNOWN",
465
+ parameter_casts={},
461
466
  is_many=self._is_many,
462
467
  )
463
468
 
@@ -0,0 +1,234 @@
1
+ """Centralized type conversion and detection for SQLSpec.
2
+
3
+ Provides unified type detection and conversion utilities for all database
4
+ adapters, with MyPyC-compatible optimizations.
5
+ """
6
+
7
+ import re
8
+ from datetime import date, datetime, time, timezone
9
+ from decimal import Decimal
10
+ from typing import Any, Callable, Final, Optional
11
+ from uuid import UUID
12
+
13
+ from sqlspec._serialization import decode_json
14
+
15
+ # MyPyC-compatible pre-compiled patterns
16
+ SPECIAL_TYPE_REGEX: Final[re.Pattern[str]] = re.compile(
17
+ r"^(?:"
18
+ r"(?P<uuid>[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})|"
19
+ r"(?P<iso_datetime>\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:?\d{2})?)|"
20
+ r"(?P<iso_date>\d{4}-\d{2}-\d{2})|"
21
+ r"(?P<iso_time>\d{2}:\d{2}:\d{2}(?:\.\d+)?)|"
22
+ r"(?P<json>[\[{].*[\]}])|"
23
+ r"(?P<ipv4>(?:\d{1,3}\.){3}\d{1,3})|"
24
+ r"(?P<ipv6>(?:[0-9a-f]{1,4}:){7}[0-9a-f]{1,4})|"
25
+ r"(?P<mac>(?:[0-9a-f]{2}:){5}[0-9a-f]{2})"
26
+ r")$",
27
+ re.IGNORECASE | re.DOTALL,
28
+ )
29
+
30
+
31
+ class BaseTypeConverter:
32
+ """Universal type detection and conversion for all adapters.
33
+
34
+ Provides centralized type detection and conversion functionality
35
+ that can be used across all database adapters to ensure consistent
36
+ behavior. Users can extend this class for custom type conversion needs.
37
+ """
38
+
39
+ __slots__ = ()
40
+
41
+ def detect_type(self, value: str) -> Optional[str]:
42
+ """Detect special types from string values.
43
+
44
+ Args:
45
+ value: String value to analyze.
46
+
47
+ Returns:
48
+ Type name if detected, None otherwise.
49
+ """
50
+ if not isinstance(value, str): # pyright: ignore
51
+ return None
52
+ if not value:
53
+ return None
54
+
55
+ match = SPECIAL_TYPE_REGEX.match(value)
56
+ if not match:
57
+ return None
58
+
59
+ return next((k for k, v in match.groupdict().items() if v), None)
60
+
61
+ def convert_value(self, value: str, detected_type: str) -> Any:
62
+ """Convert string value to appropriate Python type.
63
+
64
+ Args:
65
+ value: String value to convert.
66
+ detected_type: Detected type name.
67
+
68
+ Returns:
69
+ Converted value in appropriate Python type.
70
+ """
71
+ converter = _TYPE_CONVERTERS.get(detected_type)
72
+ if converter:
73
+ return converter(value)
74
+ return value
75
+
76
+ def convert_if_detected(self, value: Any) -> Any:
77
+ """Convert value only if special type detected, else return original.
78
+
79
+ This method provides performance optimization by avoiding expensive
80
+ regex operations on plain strings that don't contain special characters.
81
+
82
+ Args:
83
+ value: Value to potentially convert.
84
+
85
+ Returns:
86
+ Converted value if special type detected, original value otherwise.
87
+ """
88
+ if not isinstance(value, str):
89
+ return value
90
+
91
+ # Quick pre-check for performance - avoid regex on plain strings
92
+ if not any(c in value for c in ["{", "[", "-", ":", "T"]):
93
+ return value # Skip regex entirely for "hello world" etc.
94
+
95
+ detected_type = self.detect_type(value)
96
+ if detected_type:
97
+ try:
98
+ return self.convert_value(value, detected_type)
99
+ except Exception:
100
+ # If conversion fails, return original value
101
+ return value
102
+ return value
103
+
104
+
105
+ def convert_uuid(value: str) -> UUID:
106
+ """Convert UUID string to UUID object.
107
+
108
+ Args:
109
+ value: UUID string.
110
+
111
+ Returns:
112
+ UUID object.
113
+ """
114
+ return UUID(value)
115
+
116
+
117
+ def convert_iso_datetime(value: str) -> datetime:
118
+ """Convert ISO 8601 datetime string to datetime object.
119
+
120
+ Args:
121
+ value: ISO datetime string.
122
+
123
+ Returns:
124
+ datetime object.
125
+ """
126
+ # Handle various ISO formats with timezone
127
+ if value.endswith("Z"):
128
+ value = value[:-1] + "+00:00"
129
+
130
+ # Replace space with T for standard ISO format
131
+ if " " in value and "T" not in value:
132
+ value = value.replace(" ", "T")
133
+
134
+ return datetime.fromisoformat(value)
135
+
136
+
137
+ def convert_iso_date(value: str) -> date:
138
+ """Convert ISO date string to date object.
139
+
140
+ Args:
141
+ value: ISO date string.
142
+
143
+ Returns:
144
+ date object.
145
+ """
146
+ return date.fromisoformat(value)
147
+
148
+
149
+ def convert_iso_time(value: str) -> time:
150
+ """Convert ISO time string to time object.
151
+
152
+ Args:
153
+ value: ISO time string.
154
+
155
+ Returns:
156
+ time object.
157
+ """
158
+ return time.fromisoformat(value)
159
+
160
+
161
+ def convert_json(value: str) -> Any:
162
+ """Convert JSON string to Python object.
163
+
164
+ Args:
165
+ value: JSON string.
166
+
167
+ Returns:
168
+ Decoded Python object.
169
+ """
170
+ return decode_json(value)
171
+
172
+
173
+ def convert_decimal(value: str) -> Decimal:
174
+ """Convert string to Decimal for precise arithmetic.
175
+
176
+ Args:
177
+ value: Decimal string.
178
+
179
+ Returns:
180
+ Decimal object.
181
+ """
182
+ return Decimal(value)
183
+
184
+
185
+ # Converter registry
186
+ _TYPE_CONVERTERS: Final[dict[str, Callable[[str], Any]]] = {
187
+ "uuid": convert_uuid,
188
+ "iso_datetime": convert_iso_datetime,
189
+ "iso_date": convert_iso_date,
190
+ "iso_time": convert_iso_time,
191
+ "json": convert_json,
192
+ }
193
+
194
+
195
+ def format_datetime_rfc3339(dt: datetime) -> str:
196
+ """Format datetime as RFC 3339 compliant string.
197
+
198
+ Args:
199
+ dt: datetime object.
200
+
201
+ Returns:
202
+ RFC 3339 formatted datetime string.
203
+ """
204
+ if dt.tzinfo is None:
205
+ dt = dt.replace(tzinfo=timezone.utc)
206
+ return dt.isoformat()
207
+
208
+
209
+ def parse_datetime_rfc3339(dt_str: str) -> datetime:
210
+ """Parse RFC 3339 datetime string.
211
+
212
+ Args:
213
+ dt_str: RFC 3339 datetime string.
214
+
215
+ Returns:
216
+ datetime object.
217
+ """
218
+ # Handle Z suffix
219
+ if dt_str.endswith("Z"):
220
+ dt_str = dt_str[:-1] + "+00:00"
221
+ return datetime.fromisoformat(dt_str)
222
+
223
+
224
+ __all__ = (
225
+ "BaseTypeConverter",
226
+ "convert_decimal",
227
+ "convert_iso_date",
228
+ "convert_iso_datetime",
229
+ "convert_iso_time",
230
+ "convert_json",
231
+ "convert_uuid",
232
+ "format_datetime_rfc3339",
233
+ "parse_datetime_rfc3339",
234
+ )
@@ -3,16 +3,19 @@
3
3
  from typing import Union
4
4
 
5
5
  from sqlspec.driver import mixins
6
- from sqlspec.driver._async import AsyncDriverAdapterBase
7
- from sqlspec.driver._common import CommonDriverAttributesMixin, ExecutionResult
8
- from sqlspec.driver._sync import SyncDriverAdapterBase
6
+ from sqlspec.driver._async import AsyncDataDictionaryBase, AsyncDriverAdapterBase
7
+ from sqlspec.driver._common import CommonDriverAttributesMixin, ExecutionResult, VersionInfo
8
+ from sqlspec.driver._sync import SyncDataDictionaryBase, SyncDriverAdapterBase
9
9
 
10
10
  __all__ = (
11
+ "AsyncDataDictionaryBase",
11
12
  "AsyncDriverAdapterBase",
12
13
  "CommonDriverAttributesMixin",
13
14
  "DriverAdapterProtocol",
14
15
  "ExecutionResult",
16
+ "SyncDataDictionaryBase",
15
17
  "SyncDriverAdapterBase",
18
+ "VersionInfo",
16
19
  "mixins",
17
20
  )
18
21