sqlspec 0.24.1__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (95) hide show
  1. sqlspec/_serialization.py +223 -21
  2. sqlspec/_sql.py +20 -62
  3. sqlspec/_typing.py +11 -0
  4. sqlspec/adapters/adbc/config.py +8 -1
  5. sqlspec/adapters/adbc/data_dictionary.py +290 -0
  6. sqlspec/adapters/adbc/driver.py +129 -20
  7. sqlspec/adapters/adbc/type_converter.py +159 -0
  8. sqlspec/adapters/aiosqlite/config.py +3 -0
  9. sqlspec/adapters/aiosqlite/data_dictionary.py +117 -0
  10. sqlspec/adapters/aiosqlite/driver.py +17 -3
  11. sqlspec/adapters/asyncmy/_types.py +1 -1
  12. sqlspec/adapters/asyncmy/config.py +11 -8
  13. sqlspec/adapters/asyncmy/data_dictionary.py +122 -0
  14. sqlspec/adapters/asyncmy/driver.py +31 -7
  15. sqlspec/adapters/asyncpg/config.py +3 -0
  16. sqlspec/adapters/asyncpg/data_dictionary.py +134 -0
  17. sqlspec/adapters/asyncpg/driver.py +19 -4
  18. sqlspec/adapters/bigquery/config.py +3 -0
  19. sqlspec/adapters/bigquery/data_dictionary.py +109 -0
  20. sqlspec/adapters/bigquery/driver.py +21 -3
  21. sqlspec/adapters/bigquery/type_converter.py +93 -0
  22. sqlspec/adapters/duckdb/_types.py +1 -1
  23. sqlspec/adapters/duckdb/config.py +2 -0
  24. sqlspec/adapters/duckdb/data_dictionary.py +124 -0
  25. sqlspec/adapters/duckdb/driver.py +32 -5
  26. sqlspec/adapters/duckdb/pool.py +1 -1
  27. sqlspec/adapters/duckdb/type_converter.py +103 -0
  28. sqlspec/adapters/oracledb/config.py +6 -0
  29. sqlspec/adapters/oracledb/data_dictionary.py +442 -0
  30. sqlspec/adapters/oracledb/driver.py +68 -9
  31. sqlspec/adapters/oracledb/migrations.py +51 -67
  32. sqlspec/adapters/oracledb/type_converter.py +132 -0
  33. sqlspec/adapters/psqlpy/config.py +3 -0
  34. sqlspec/adapters/psqlpy/data_dictionary.py +133 -0
  35. sqlspec/adapters/psqlpy/driver.py +23 -179
  36. sqlspec/adapters/psqlpy/type_converter.py +73 -0
  37. sqlspec/adapters/psycopg/config.py +8 -4
  38. sqlspec/adapters/psycopg/data_dictionary.py +257 -0
  39. sqlspec/adapters/psycopg/driver.py +40 -5
  40. sqlspec/adapters/sqlite/config.py +3 -0
  41. sqlspec/adapters/sqlite/data_dictionary.py +117 -0
  42. sqlspec/adapters/sqlite/driver.py +18 -3
  43. sqlspec/adapters/sqlite/pool.py +13 -4
  44. sqlspec/base.py +3 -4
  45. sqlspec/builder/_base.py +130 -48
  46. sqlspec/builder/_column.py +66 -24
  47. sqlspec/builder/_ddl.py +91 -41
  48. sqlspec/builder/_insert.py +40 -58
  49. sqlspec/builder/_parsing_utils.py +127 -12
  50. sqlspec/builder/_select.py +147 -2
  51. sqlspec/builder/_update.py +1 -1
  52. sqlspec/builder/mixins/_cte_and_set_ops.py +31 -23
  53. sqlspec/builder/mixins/_delete_operations.py +12 -7
  54. sqlspec/builder/mixins/_insert_operations.py +50 -36
  55. sqlspec/builder/mixins/_join_operations.py +15 -30
  56. sqlspec/builder/mixins/_merge_operations.py +210 -78
  57. sqlspec/builder/mixins/_order_limit_operations.py +4 -10
  58. sqlspec/builder/mixins/_pivot_operations.py +1 -0
  59. sqlspec/builder/mixins/_select_operations.py +44 -22
  60. sqlspec/builder/mixins/_update_operations.py +30 -37
  61. sqlspec/builder/mixins/_where_clause.py +52 -70
  62. sqlspec/cli.py +246 -140
  63. sqlspec/config.py +33 -19
  64. sqlspec/core/__init__.py +3 -2
  65. sqlspec/core/cache.py +298 -352
  66. sqlspec/core/compiler.py +61 -4
  67. sqlspec/core/filters.py +246 -213
  68. sqlspec/core/hashing.py +9 -11
  69. sqlspec/core/parameters.py +27 -10
  70. sqlspec/core/statement.py +72 -12
  71. sqlspec/core/type_conversion.py +234 -0
  72. sqlspec/driver/__init__.py +6 -3
  73. sqlspec/driver/_async.py +108 -5
  74. sqlspec/driver/_common.py +186 -17
  75. sqlspec/driver/_sync.py +108 -5
  76. sqlspec/driver/mixins/_result_tools.py +60 -7
  77. sqlspec/exceptions.py +5 -0
  78. sqlspec/loader.py +8 -9
  79. sqlspec/migrations/__init__.py +4 -3
  80. sqlspec/migrations/base.py +153 -14
  81. sqlspec/migrations/commands.py +34 -96
  82. sqlspec/migrations/context.py +145 -0
  83. sqlspec/migrations/loaders.py +25 -8
  84. sqlspec/migrations/runner.py +352 -82
  85. sqlspec/storage/backends/fsspec.py +1 -0
  86. sqlspec/typing.py +4 -0
  87. sqlspec/utils/config_resolver.py +153 -0
  88. sqlspec/utils/serializers.py +50 -2
  89. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/METADATA +1 -1
  90. sqlspec-0.26.0.dist-info/RECORD +157 -0
  91. sqlspec-0.24.1.dist-info/RECORD +0 -139
  92. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/WHEEL +0 -0
  93. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/entry_points.txt +0 -0
  94. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/LICENSE +0 -0
  95. {sqlspec-0.24.1.dist-info → sqlspec-0.26.0.dist-info}/licenses/NOTICE +0 -0
@@ -14,6 +14,7 @@ PostgreSQL Features:
14
14
  - PostgreSQL-specific error handling
15
15
  """
16
16
 
17
+ import datetime
17
18
  import io
18
19
  from typing import TYPE_CHECKING, Any, Optional
19
20
 
@@ -32,7 +33,9 @@ from sqlspec.utils.serializers import to_json
32
33
  if TYPE_CHECKING:
33
34
  from contextlib import AbstractAsyncContextManager, AbstractContextManager
34
35
 
36
+ from sqlspec.driver._async import AsyncDataDictionaryBase
35
37
  from sqlspec.driver._common import ExecutionResult
38
+ from sqlspec.driver._sync import SyncDataDictionaryBase
36
39
 
37
40
  logger = get_logger("adapters.psycopg")
38
41
 
@@ -94,7 +97,12 @@ psycopg_statement_config = StatementConfig(
94
97
  ParameterStyle.NAMED_PYFORMAT,
95
98
  ParameterStyle.NUMERIC,
96
99
  },
97
- type_coercion_map={dict: to_json},
100
+ type_coercion_map={
101
+ dict: to_json,
102
+ datetime.datetime: lambda x: x,
103
+ datetime.date: lambda x: x,
104
+ datetime.time: lambda x: x,
105
+ },
98
106
  has_native_list_expansion=True,
99
107
  needs_static_script_compilation=False,
100
108
  preserve_parameter_format=True,
@@ -125,8 +133,7 @@ class PsycopgSyncCursor:
125
133
  self.cursor = self.connection.cursor()
126
134
  return self.cursor
127
135
 
128
- def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
129
- _ = (exc_type, exc_val, exc_tb)
136
+ def __exit__(self, *_: Any) -> None:
130
137
  if self.cursor is not None:
131
138
  self.cursor.close()
132
139
 
@@ -187,7 +194,7 @@ class PsycopgSyncDriver(SyncDriverAdapterBase):
187
194
  bulk data transfer, and PostgreSQL-specific error handling.
188
195
  """
189
196
 
190
- __slots__ = ()
197
+ __slots__ = ("_data_dictionary",)
191
198
  dialect = "postgres"
192
199
 
193
200
  def __init__(
@@ -207,6 +214,7 @@ class PsycopgSyncDriver(SyncDriverAdapterBase):
207
214
  statement_config = default_config
208
215
 
209
216
  super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
217
+ self._data_dictionary: Optional[SyncDataDictionaryBase] = None
210
218
 
211
219
  def with_cursor(self, connection: PsycopgSyncConnection) -> PsycopgSyncCursor:
212
220
  """Create context manager for PostgreSQL cursor."""
@@ -411,6 +419,19 @@ class PsycopgSyncDriver(SyncDriverAdapterBase):
411
419
  affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
412
420
  return self.create_execution_result(cursor, rowcount_override=affected_rows)
413
421
 
422
+ @property
423
+ def data_dictionary(self) -> "SyncDataDictionaryBase":
424
+ """Get the data dictionary for this driver.
425
+
426
+ Returns:
427
+ Data dictionary instance for metadata queries
428
+ """
429
+ if self._data_dictionary is None:
430
+ from sqlspec.adapters.psycopg.data_dictionary import PostgresSyncDataDictionary
431
+
432
+ self._data_dictionary = PostgresSyncDataDictionary()
433
+ return self._data_dictionary
434
+
414
435
 
415
436
  class PsycopgAsyncCursor:
416
437
  """Async context manager for PostgreSQL psycopg cursor management."""
@@ -488,7 +509,7 @@ class PsycopgAsyncDriver(AsyncDriverAdapterBase):
488
509
  and async pub/sub support.
489
510
  """
490
511
 
491
- __slots__ = ()
512
+ __slots__ = ("_data_dictionary",)
492
513
  dialect = "postgres"
493
514
 
494
515
  def __init__(
@@ -508,6 +529,7 @@ class PsycopgAsyncDriver(AsyncDriverAdapterBase):
508
529
  statement_config = default_config
509
530
 
510
531
  super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
532
+ self._data_dictionary: Optional[AsyncDataDictionaryBase] = None
511
533
 
512
534
  def with_cursor(self, connection: "PsycopgAsyncConnection") -> "PsycopgAsyncCursor":
513
535
  """Create async context manager for PostgreSQL cursor."""
@@ -714,3 +736,16 @@ class PsycopgAsyncDriver(AsyncDriverAdapterBase):
714
736
 
715
737
  affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
716
738
  return self.create_execution_result(cursor, rowcount_override=affected_rows)
739
+
740
+ @property
741
+ def data_dictionary(self) -> "AsyncDataDictionaryBase":
742
+ """Get the data dictionary for this driver.
743
+
744
+ Returns:
745
+ Data dictionary instance for metadata queries
746
+ """
747
+ if self._data_dictionary is None:
748
+ from sqlspec.adapters.psycopg.data_dictionary import PostgresAsyncDataDictionary
749
+
750
+ self._data_dictionary = PostgresAsyncDataDictionary()
751
+ return self._data_dictionary
@@ -47,6 +47,7 @@ class SqliteConfig(SyncDatabaseConfig[SqliteConnection, SqliteConnectionPool, Sq
47
47
  migration_config: "Optional[dict[str, Any]]" = None,
48
48
  statement_config: "Optional[StatementConfig]" = None,
49
49
  driver_features: "Optional[dict[str, Any]]" = None,
50
+ bind_key: "Optional[str]" = None,
50
51
  ) -> None:
51
52
  """Initialize SQLite configuration.
52
53
 
@@ -56,6 +57,7 @@ class SqliteConfig(SyncDatabaseConfig[SqliteConnection, SqliteConnectionPool, Sq
56
57
  migration_config: Migration configuration
57
58
  statement_config: Default SQL statement configuration
58
59
  driver_features: Optional driver feature configuration
60
+ bind_key: Optional bind key for the configuration
59
61
  """
60
62
  if pool_config is None:
61
63
  pool_config = {}
@@ -64,6 +66,7 @@ class SqliteConfig(SyncDatabaseConfig[SqliteConnection, SqliteConnectionPool, Sq
64
66
  pool_config["uri"] = True
65
67
 
66
68
  super().__init__(
69
+ bind_key=bind_key,
67
70
  pool_instance=pool_instance,
68
71
  pool_config=cast("dict[str, Any]", pool_config),
69
72
  migration_config=migration_config,
@@ -0,0 +1,117 @@
1
+ """SQLite-specific data dictionary for metadata queries."""
2
+
3
+ import re
4
+ from typing import TYPE_CHECKING, Callable, Optional, cast
5
+
6
+ from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo
7
+ from sqlspec.utils.logging import get_logger
8
+
9
+ if TYPE_CHECKING:
10
+ from sqlspec.adapters.sqlite.driver import SqliteDriver
11
+
12
+ logger = get_logger("adapters.sqlite.data_dictionary")
13
+
14
+ # Compiled regex patterns
15
+ SQLITE_VERSION_PATTERN = re.compile(r"(\d+)\.(\d+)\.(\d+)")
16
+
17
+ __all__ = ("SqliteSyncDataDictionary",)
18
+
19
+
20
+ class SqliteSyncDataDictionary(SyncDataDictionaryBase):
21
+ """SQLite-specific sync data dictionary."""
22
+
23
+ def get_version(self, driver: SyncDriverAdapterBase) -> "Optional[VersionInfo]":
24
+ """Get SQLite database version information.
25
+
26
+ Args:
27
+ driver: Sync database driver instance
28
+
29
+ Returns:
30
+ SQLite version information or None if detection fails
31
+ """
32
+ version_str = cast("SqliteDriver", driver).select_value("SELECT sqlite_version()")
33
+ if not version_str:
34
+ logger.warning("No SQLite version information found")
35
+ return None
36
+
37
+ # Parse version like "3.45.0"
38
+ version_match = SQLITE_VERSION_PATTERN.match(str(version_str))
39
+ if not version_match:
40
+ logger.warning("Could not parse SQLite version: %s", version_str)
41
+ return None
42
+
43
+ major, minor, patch = map(int, version_match.groups())
44
+ version_info = VersionInfo(major, minor, patch)
45
+ logger.debug("Detected SQLite version: %s", version_info)
46
+ return version_info
47
+
48
+ def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
49
+ """Check if SQLite database supports a specific feature.
50
+
51
+ Args:
52
+ driver: SQLite driver instance
53
+ feature: Feature name to check
54
+
55
+ Returns:
56
+ True if feature is supported, False otherwise
57
+ """
58
+ version_info = self.get_version(driver)
59
+ if not version_info:
60
+ return False
61
+
62
+ feature_checks: dict[str, Callable[[VersionInfo], bool]] = {
63
+ "supports_json": lambda v: v >= VersionInfo(3, 38, 0),
64
+ "supports_returning": lambda v: v >= VersionInfo(3, 35, 0),
65
+ "supports_upsert": lambda v: v >= VersionInfo(3, 24, 0),
66
+ "supports_window_functions": lambda v: v >= VersionInfo(3, 25, 0),
67
+ "supports_cte": lambda v: v >= VersionInfo(3, 8, 3),
68
+ "supports_transactions": lambda _: True,
69
+ "supports_prepared_statements": lambda _: True,
70
+ "supports_schemas": lambda _: False, # SQLite has ATTACH but not schemas
71
+ "supports_arrays": lambda _: False,
72
+ "supports_uuid": lambda _: False,
73
+ }
74
+
75
+ if feature in feature_checks:
76
+ return bool(feature_checks[feature](version_info))
77
+
78
+ return False
79
+
80
+ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> str:
81
+ """Get optimal SQLite type for a category.
82
+
83
+ Args:
84
+ driver: SQLite driver instance
85
+ type_category: Type category
86
+
87
+ Returns:
88
+ SQLite-specific type name
89
+ """
90
+ version_info = self.get_version(driver)
91
+
92
+ if type_category == "json":
93
+ if version_info and version_info >= VersionInfo(3, 38, 0):
94
+ return "JSON"
95
+ return "TEXT"
96
+
97
+ type_map = {"uuid": "TEXT", "boolean": "INTEGER", "timestamp": "TIMESTAMP", "text": "TEXT", "blob": "BLOB"}
98
+ return type_map.get(type_category, "TEXT")
99
+
100
+ def list_available_features(self) -> "list[str]":
101
+ """List available SQLite feature flags.
102
+
103
+ Returns:
104
+ List of supported feature names
105
+ """
106
+ return [
107
+ "supports_json",
108
+ "supports_returning",
109
+ "supports_upsert",
110
+ "supports_window_functions",
111
+ "supports_cte",
112
+ "supports_transactions",
113
+ "supports_prepared_statements",
114
+ "supports_schemas",
115
+ "supports_arrays",
116
+ "supports_uuid",
117
+ ]
@@ -20,6 +20,7 @@ if TYPE_CHECKING:
20
20
  from sqlspec.core.result import SQLResult
21
21
  from sqlspec.core.statement import SQL
22
22
  from sqlspec.driver import ExecutionResult
23
+ from sqlspec.driver._sync import SyncDataDictionaryBase
23
24
 
24
25
  __all__ = ("SqliteCursor", "SqliteDriver", "SqliteExceptionHandler", "sqlite_statement_config")
25
26
 
@@ -36,6 +37,7 @@ sqlite_statement_config = StatementConfig(
36
37
  datetime.datetime: lambda v: v.isoformat(),
37
38
  datetime.date: lambda v: v.isoformat(),
38
39
  Decimal: str,
40
+ dict: to_json,
39
41
  list: to_json,
40
42
  },
41
43
  has_native_list_expansion=False,
@@ -75,7 +77,7 @@ class SqliteCursor:
75
77
  self.cursor = self.connection.cursor()
76
78
  return self.cursor
77
79
 
78
- def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
80
+ def __exit__(self, *_: Any) -> None:
79
81
  """Clean up cursor resources.
80
82
 
81
83
  Args:
@@ -83,7 +85,6 @@ class SqliteCursor:
83
85
  exc_val: Exception value if an exception occurred
84
86
  exc_tb: Exception traceback if an exception occurred
85
87
  """
86
- _ = (exc_type, exc_val, exc_tb)
87
88
  if self.cursor is not None:
88
89
  with contextlib.suppress(Exception):
89
90
  self.cursor.close()
@@ -159,7 +160,7 @@ class SqliteDriver(SyncDriverAdapterBase):
159
160
  for SQLite databases using the standard sqlite3 module.
160
161
  """
161
162
 
162
- __slots__ = ()
163
+ __slots__ = ("_data_dictionary",)
163
164
  dialect = "sqlite"
164
165
 
165
166
  def __init__(
@@ -185,6 +186,7 @@ class SqliteDriver(SyncDriverAdapterBase):
185
186
  )
186
187
 
187
188
  super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
189
+ self._data_dictionary: Optional[SyncDataDictionaryBase] = None
188
190
 
189
191
  def with_cursor(self, connection: "SqliteConnection") -> "SqliteCursor":
190
192
  """Create context manager for SQLite cursor.
@@ -325,3 +327,16 @@ class SqliteDriver(SyncDriverAdapterBase):
325
327
  except sqlite3.Error as e:
326
328
  msg = f"Failed to commit transaction: {e}"
327
329
  raise SQLSpecError(msg) from e
330
+
331
+ @property
332
+ def data_dictionary(self) -> "SyncDataDictionaryBase":
333
+ """Get the data dictionary for this driver.
334
+
335
+ Returns:
336
+ Data dictionary instance for metadata queries
337
+ """
338
+ if self._data_dictionary is None:
339
+ from sqlspec.adapters.sqlite.data_dictionary import SqliteSyncDataDictionary
340
+
341
+ self._data_dictionary = SqliteSyncDataDictionary()
342
+ return self._data_dictionary
@@ -1,5 +1,6 @@
1
1
  """SQLite database configuration with thread-local connections."""
2
2
 
3
+ import contextlib
3
4
  import sqlite3
4
5
  import threading
5
6
  from contextlib import contextmanager
@@ -49,6 +50,8 @@ class SqliteConnectionPool:
49
50
  enable_optimizations: Whether to apply performance PRAGMAs
50
51
  **kwargs: Ignored pool parameters for compatibility
51
52
  """
53
+ if "check_same_thread" not in connection_parameters:
54
+ connection_parameters = {**connection_parameters, "check_same_thread": False}
52
55
  self._connection_parameters = connection_parameters
53
56
  self._thread_local = threading.local()
54
57
  self._enable_optimizations = enable_optimizations
@@ -62,8 +65,7 @@ class SqliteConnectionPool:
62
65
  is_memory = database == ":memory:" or database.startswith("file::memory:")
63
66
 
64
67
  if not is_memory:
65
- connection.execute("PRAGMA journal_mode = WAL")
66
-
68
+ connection.execute("PRAGMA journal_mode = DELETE")
67
69
  connection.execute("PRAGMA busy_timeout = 5000")
68
70
  connection.execute("PRAGMA optimize")
69
71
 
@@ -97,7 +99,13 @@ class SqliteConnectionPool:
97
99
  Yields:
98
100
  SqliteConnection: A thread-local connection.
99
101
  """
100
- yield self._get_thread_connection()
102
+ connection = self._get_thread_connection()
103
+ try:
104
+ yield connection
105
+ finally:
106
+ with contextlib.suppress(Exception):
107
+ if connection.in_transaction:
108
+ connection.commit()
101
109
 
102
110
  def close(self) -> None:
103
111
  """Close the thread-local connection if it exists."""
@@ -124,7 +132,8 @@ class SqliteConnectionPool:
124
132
  _ = self._thread_local.connection
125
133
  except AttributeError:
126
134
  return 0
127
- return 1
135
+ else:
136
+ return 1
128
137
 
129
138
  def checked_out(self) -> int:
130
139
  """Get number of checked out connections (always 0)."""
sqlspec/base.py CHANGED
@@ -15,9 +15,8 @@ from sqlspec.config import (
15
15
  )
16
16
  from sqlspec.core.cache import (
17
17
  CacheConfig,
18
- CacheStatsAggregate,
19
18
  get_cache_config,
20
- get_cache_stats,
19
+ get_cache_statistics,
21
20
  log_cache_stats,
22
21
  reset_cache_stats,
23
22
  update_cache_config,
@@ -532,13 +531,13 @@ class SQLSpec:
532
531
  update_cache_config(config)
533
532
 
534
533
  @staticmethod
535
- def get_cache_stats() -> CacheStatsAggregate:
534
+ def get_cache_stats() -> "dict[str, Any]":
536
535
  """Get current cache statistics.
537
536
 
538
537
  Returns:
539
538
  Cache statistics object with detailed metrics.
540
539
  """
541
- return get_cache_stats()
540
+ return get_cache_statistics()
542
541
 
543
542
  @staticmethod
544
543
  def reset_cache_stats() -> None:
sqlspec/builder/_base.py CHANGED
@@ -3,6 +3,8 @@
3
3
  Provides abstract base classes and core functionality for SQL query builders.
4
4
  """
5
5
 
6
+ import hashlib
7
+ import uuid
6
8
  from abc import ABC, abstractmethod
7
9
  from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union, cast
8
10
 
@@ -13,19 +15,21 @@ from sqlglot.errors import ParseError as SQLGlotParseError
13
15
  from sqlglot.optimizer import optimize
14
16
  from typing_extensions import Self
15
17
 
16
- from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
18
+ from sqlspec.core.cache import get_cache, get_cache_config
17
19
  from sqlspec.core.hashing import hash_optimized_expression
18
20
  from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
19
21
  from sqlspec.core.statement import SQL, StatementConfig
20
22
  from sqlspec.exceptions import SQLBuilderError
21
23
  from sqlspec.utils.logging import get_logger
22
- from sqlspec.utils.type_guards import has_expression_and_parameters, has_sql_method, has_with_method
24
+ from sqlspec.utils.type_guards import has_expression_and_parameters, has_sql_method, has_with_method, is_expression
23
25
 
24
26
  if TYPE_CHECKING:
25
27
  from sqlspec.core.result import SQLResult
26
28
 
27
29
  __all__ = ("QueryBuilder", "SafeQuery")
28
30
 
31
+ MAX_PARAMETER_COLLISION_ATTEMPTS = 1000
32
+
29
33
  logger = get_logger(__name__)
30
34
 
31
35
 
@@ -91,6 +95,32 @@ class QueryBuilder(ABC):
91
95
  "QueryBuilder._create_base_expression must return a valid sqlglot expression."
92
96
  )
93
97
 
98
+ def get_expression(self) -> Optional[exp.Expression]:
99
+ """Get expression reference (no copy).
100
+
101
+ Returns:
102
+ The current SQLGlot expression or None if not set
103
+ """
104
+ return self._expression
105
+
106
+ def set_expression(self, expression: exp.Expression) -> None:
107
+ """Set expression with validation.
108
+
109
+ Args:
110
+ expression: SQLGlot expression to set
111
+ """
112
+ if not is_expression(expression):
113
+ self._raise_invalid_expression_type(expression)
114
+ self._expression = expression
115
+
116
+ def has_expression(self) -> bool:
117
+ """Check if expression exists.
118
+
119
+ Returns:
120
+ True if expression is set, False otherwise
121
+ """
122
+ return self._expression is not None
123
+
94
124
  @abstractmethod
95
125
  def _create_base_expression(self) -> exp.Expression:
96
126
  """Create the base sqlglot expression for the specific query type.
@@ -121,6 +151,46 @@ class QueryBuilder(ABC):
121
151
  """
122
152
  raise SQLBuilderError(message) from cause
123
153
 
154
+ @staticmethod
155
+ def _raise_invalid_expression_type(expression: Any) -> NoReturn:
156
+ """Raise error for invalid expression type.
157
+
158
+ Args:
159
+ expression: The invalid expression object
160
+
161
+ Raises:
162
+ TypeError: Always raised for type mismatch
163
+ """
164
+ msg = f"Expected Expression, got {type(expression)}"
165
+ raise TypeError(msg)
166
+
167
+ @staticmethod
168
+ def _raise_cte_query_error(alias: str, message: str) -> NoReturn:
169
+ """Raise error for CTE query issues.
170
+
171
+ Args:
172
+ alias: CTE alias name
173
+ message: Specific error message
174
+
175
+ Raises:
176
+ SQLBuilderError: Always raised for CTE errors
177
+ """
178
+ msg = f"CTE '{alias}': {message}"
179
+ raise SQLBuilderError(msg)
180
+
181
+ @staticmethod
182
+ def _raise_cte_parse_error(cause: BaseException) -> NoReturn:
183
+ """Raise error for CTE parsing failures.
184
+
185
+ Args:
186
+ cause: The original parsing exception
187
+
188
+ Raises:
189
+ SQLBuilderError: Always raised with chained cause
190
+ """
191
+ msg = f"Failed to parse CTE query: {cause!s}"
192
+ raise SQLBuilderError(msg) from cause
193
+
124
194
  def _add_parameter(self, value: Any, context: Optional[str] = None) -> str:
125
195
  """Adds a parameter to the query and returns its placeholder name.
126
196
 
@@ -199,13 +269,11 @@ class QueryBuilder(ABC):
199
269
  if base_name not in self._parameters:
200
270
  return base_name
201
271
 
202
- for i in range(1, 1000):
272
+ for i in range(1, MAX_PARAMETER_COLLISION_ATTEMPTS):
203
273
  name = f"{base_name}_{i}"
204
274
  if name not in self._parameters:
205
275
  return name
206
276
 
207
- import uuid
208
-
209
277
  return f"{base_name}_{uuid.uuid4().hex[:8]}"
210
278
 
211
279
  def _merge_cte_parameters(self, cte_name: str, parameters: dict[str, Any]) -> dict[str, str]:
@@ -254,8 +322,6 @@ class QueryBuilder(ABC):
254
322
  Returns:
255
323
  A unique cache key representing the builder state and configuration
256
324
  """
257
- import hashlib
258
-
259
325
  dialect_name: str = self.dialect_name or "default"
260
326
 
261
327
  if self._expression is None:
@@ -307,36 +373,31 @@ class QueryBuilder(ABC):
307
373
  cte_select_expression: exp.Select
308
374
 
309
375
  if isinstance(query, QueryBuilder):
310
- if query._expression is None:
311
- self._raise_sql_builder_error("CTE query builder has no expression.")
312
- if not isinstance(query._expression, exp.Select):
313
- msg = f"CTE query builder expression must be a Select, got {type(query._expression).__name__}."
314
- self._raise_sql_builder_error(msg)
315
- cte_select_expression = query._expression
376
+ query_expr = query.get_expression()
377
+ if query_expr is None:
378
+ self._raise_cte_query_error(alias, "query builder has no expression")
379
+ if not isinstance(query_expr, exp.Select):
380
+ self._raise_cte_query_error(alias, f"expression must be a Select, got {type(query_expr).__name__}")
381
+ cte_select_expression = query_expr
316
382
  param_mapping = self._merge_cte_parameters(alias, query.parameters)
317
- updated_expression = self._update_placeholders_in_expression(cte_select_expression, param_mapping)
318
- if not isinstance(updated_expression, exp.Select):
319
- msg = f"Updated CTE expression must be a Select, got {type(updated_expression).__name__}."
320
- self._raise_sql_builder_error(msg)
321
- cte_select_expression = updated_expression
383
+ cte_select_expression = cast(
384
+ "exp.Select", self._update_placeholders_in_expression(cte_select_expression, param_mapping)
385
+ )
322
386
 
323
387
  elif isinstance(query, str):
324
388
  try:
325
389
  parsed_expression = sqlglot.parse_one(query, read=self.dialect_name)
326
390
  if not isinstance(parsed_expression, exp.Select):
327
- msg = f"CTE query string must parse to a SELECT statement, got {type(parsed_expression).__name__}."
328
- self._raise_sql_builder_error(msg)
391
+ self._raise_cte_query_error(
392
+ alias, f"query string must parse to SELECT, got {type(parsed_expression).__name__}"
393
+ )
329
394
  cte_select_expression = parsed_expression
330
395
  except SQLGlotParseError as e:
331
- self._raise_sql_builder_error(f"Failed to parse CTE query string: {e!s}", e)
332
- except Exception as e:
333
- msg = f"An unexpected error occurred while parsing CTE query string: {e!s}"
334
- self._raise_sql_builder_error(msg, e)
396
+ self._raise_cte_parse_error(e)
335
397
  elif isinstance(query, exp.Select):
336
398
  cte_select_expression = query
337
399
  else:
338
- msg = f"Invalid query type for CTE: {type(query).__name__}"
339
- self._raise_sql_builder_error(msg)
400
+ self._raise_cte_query_error(alias, f"invalid query type: {type(query).__name__}")
340
401
 
341
402
  self._with_ctes[alias] = exp.CTE(this=cte_select_expression, alias=exp.to_table(alias))
342
403
  return self
@@ -398,9 +459,8 @@ class QueryBuilder(ABC):
398
459
  expression, dialect=dialect_name, schema=self.schema, optimizer_settings=optimizer_settings
399
460
  )
400
461
 
401
- cache_key_obj = CacheKey((cache_key,))
402
- unified_cache = get_default_cache()
403
- cached_optimized = unified_cache.get(cache_key_obj)
462
+ cache = get_cache()
463
+ cached_optimized = cache.get("optimized", cache_key)
404
464
  if cached_optimized:
405
465
  return cast("exp.Expression", cached_optimized)
406
466
 
@@ -408,10 +468,9 @@ class QueryBuilder(ABC):
408
468
  optimized = optimize(
409
469
  expression, schema=self.schema, dialect=self.dialect_name, optimizer_settings=optimizer_settings
410
470
  )
411
-
412
- unified_cache.put(cache_key_obj, optimized)
413
-
471
+ cache.put("optimized", cache_key, optimized)
414
472
  except Exception:
473
+ logger.debug("Expression optimization failed, using original expression")
415
474
  return expression
416
475
  else:
417
476
  return optimized
@@ -430,15 +489,14 @@ class QueryBuilder(ABC):
430
489
  return self._to_statement(config)
431
490
 
432
491
  cache_key_str = self._generate_builder_cache_key(config)
433
- cache_key = CacheKey((cache_key_str,))
434
492
 
435
- unified_cache = get_default_cache()
436
- cached_sql = unified_cache.get(cache_key)
493
+ cache = get_cache()
494
+ cached_sql = cache.get("builder", cache_key_str)
437
495
  if cached_sql is not None:
438
496
  return cast("SQL", cached_sql)
439
497
 
440
498
  sql_statement = self._to_statement(config)
441
- unified_cache.put(cache_key, sql_statement)
499
+ cache.put("builder", cache_key_str, sql_statement)
442
500
 
443
501
  return sql_statement
444
502
 
@@ -453,18 +511,7 @@ class QueryBuilder(ABC):
453
511
  """
454
512
  safe_query = self.build()
455
513
 
456
- if isinstance(safe_query.parameters, dict):
457
- kwargs = safe_query.parameters
458
- parameters: Optional[tuple[Any, ...]] = None
459
- else:
460
- kwargs = None
461
- parameters = (
462
- safe_query.parameters
463
- if isinstance(safe_query.parameters, tuple)
464
- else tuple(safe_query.parameters)
465
- if safe_query.parameters
466
- else None
467
- )
514
+ kwargs, parameters = self._extract_statement_parameters(safe_query.parameters)
468
515
 
469
516
  if config is None:
470
517
  config = StatementConfig(
@@ -492,6 +539,28 @@ class QueryBuilder(ABC):
492
539
  return SQL(sql_string, *parameters, statement_config=config)
493
540
  return SQL(sql_string, statement_config=config)
494
541
 
542
+ def _extract_statement_parameters(
543
+ self, raw_parameters: Any
544
+ ) -> "tuple[Optional[dict[str, Any]], Optional[tuple[Any, ...]]]":
545
+ """Extract parameters for SQL statement creation.
546
+
547
+ Args:
548
+ raw_parameters: Raw parameter data from SafeQuery
549
+
550
+ Returns:
551
+ Tuple of (kwargs, parameters) for SQL statement construction
552
+ """
553
+ if isinstance(raw_parameters, dict):
554
+ return raw_parameters, None
555
+
556
+ if isinstance(raw_parameters, tuple):
557
+ return None, raw_parameters
558
+
559
+ if raw_parameters:
560
+ return None, tuple(raw_parameters)
561
+
562
+ return None, None
563
+
495
564
  def __str__(self) -> str:
496
565
  """Return the SQL string representation of the query.
497
566
 
@@ -531,3 +600,16 @@ class QueryBuilder(ABC):
531
600
  def parameters(self) -> dict[str, Any]:
532
601
  """Public access to query parameters."""
533
602
  return self._parameters
603
+
604
+ def set_parameters(self, parameters: dict[str, Any]) -> None:
605
+ """Set query parameters (public API)."""
606
+ self._parameters = parameters.copy()
607
+
608
+ @property
609
+ def with_ctes(self) -> "dict[str, exp.CTE]":
610
+ """Get WITH clause CTEs (public API)."""
611
+ return dict(self._with_ctes)
612
+
613
+ def generate_unique_parameter_name(self, base_name: str) -> str:
614
+ """Generate unique parameter name (public API)."""
615
+ return self._generate_unique_parameter_name(base_name)