sqlspec 0.17.1__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (75) hide show
  1. sqlspec/__init__.py +1 -1
  2. sqlspec/_sql.py +54 -159
  3. sqlspec/adapters/adbc/config.py +24 -30
  4. sqlspec/adapters/adbc/driver.py +42 -61
  5. sqlspec/adapters/aiosqlite/config.py +5 -10
  6. sqlspec/adapters/aiosqlite/driver.py +9 -25
  7. sqlspec/adapters/aiosqlite/pool.py +43 -35
  8. sqlspec/adapters/asyncmy/config.py +10 -7
  9. sqlspec/adapters/asyncmy/driver.py +18 -39
  10. sqlspec/adapters/asyncpg/config.py +4 -0
  11. sqlspec/adapters/asyncpg/driver.py +32 -79
  12. sqlspec/adapters/bigquery/config.py +12 -65
  13. sqlspec/adapters/bigquery/driver.py +39 -133
  14. sqlspec/adapters/duckdb/config.py +11 -15
  15. sqlspec/adapters/duckdb/driver.py +61 -85
  16. sqlspec/adapters/duckdb/pool.py +2 -5
  17. sqlspec/adapters/oracledb/_types.py +8 -1
  18. sqlspec/adapters/oracledb/config.py +55 -38
  19. sqlspec/adapters/oracledb/driver.py +35 -92
  20. sqlspec/adapters/oracledb/migrations.py +257 -0
  21. sqlspec/adapters/psqlpy/config.py +13 -9
  22. sqlspec/adapters/psqlpy/driver.py +28 -103
  23. sqlspec/adapters/psycopg/config.py +9 -5
  24. sqlspec/adapters/psycopg/driver.py +107 -175
  25. sqlspec/adapters/sqlite/config.py +7 -5
  26. sqlspec/adapters/sqlite/driver.py +37 -73
  27. sqlspec/adapters/sqlite/pool.py +3 -12
  28. sqlspec/base.py +1 -8
  29. sqlspec/builder/__init__.py +1 -1
  30. sqlspec/builder/_base.py +34 -20
  31. sqlspec/builder/_ddl.py +407 -183
  32. sqlspec/builder/_insert.py +1 -1
  33. sqlspec/builder/mixins/_insert_operations.py +26 -6
  34. sqlspec/builder/mixins/_merge_operations.py +1 -1
  35. sqlspec/builder/mixins/_select_operations.py +1 -5
  36. sqlspec/config.py +32 -13
  37. sqlspec/core/__init__.py +89 -14
  38. sqlspec/core/cache.py +57 -104
  39. sqlspec/core/compiler.py +57 -112
  40. sqlspec/core/filters.py +1 -21
  41. sqlspec/core/hashing.py +13 -47
  42. sqlspec/core/parameters.py +272 -261
  43. sqlspec/core/result.py +12 -27
  44. sqlspec/core/splitter.py +17 -21
  45. sqlspec/core/statement.py +150 -159
  46. sqlspec/driver/_async.py +2 -15
  47. sqlspec/driver/_common.py +16 -95
  48. sqlspec/driver/_sync.py +2 -15
  49. sqlspec/driver/mixins/_result_tools.py +8 -29
  50. sqlspec/driver/mixins/_sql_translator.py +6 -8
  51. sqlspec/exceptions.py +1 -2
  52. sqlspec/loader.py +43 -115
  53. sqlspec/migrations/__init__.py +1 -1
  54. sqlspec/migrations/base.py +34 -45
  55. sqlspec/migrations/commands.py +34 -15
  56. sqlspec/migrations/loaders.py +1 -1
  57. sqlspec/migrations/runner.py +104 -19
  58. sqlspec/migrations/tracker.py +49 -2
  59. sqlspec/protocols.py +3 -6
  60. sqlspec/storage/__init__.py +4 -4
  61. sqlspec/storage/backends/fsspec.py +5 -6
  62. sqlspec/storage/backends/obstore.py +7 -8
  63. sqlspec/storage/registry.py +3 -3
  64. sqlspec/utils/__init__.py +2 -2
  65. sqlspec/utils/logging.py +6 -10
  66. sqlspec/utils/sync_tools.py +27 -4
  67. sqlspec/utils/text.py +6 -1
  68. {sqlspec-0.17.1.dist-info → sqlspec-0.18.0.dist-info}/METADATA +1 -1
  69. sqlspec-0.18.0.dist-info/RECORD +138 -0
  70. sqlspec/builder/_ddl_utils.py +0 -103
  71. sqlspec-0.17.1.dist-info/RECORD +0 -138
  72. {sqlspec-0.17.1.dist-info → sqlspec-0.18.0.dist-info}/WHEEL +0 -0
  73. {sqlspec-0.17.1.dist-info → sqlspec-0.18.0.dist-info}/entry_points.txt +0 -0
  74. {sqlspec-0.17.1.dist-info → sqlspec-0.18.0.dist-info}/licenses/LICENSE +0 -0
  75. {sqlspec-0.17.1.dist-info → sqlspec-0.18.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/driver/_common.py CHANGED
@@ -1,8 +1,4 @@
1
- """Common driver attributes and utilities.
2
-
3
- This module provides core driver infrastructure including execution result handling,
4
- common driver attributes, parameter processing, and SQL compilation utilities.
5
- """
1
+ """Common driver attributes and utilities."""
6
2
 
7
3
  from typing import TYPE_CHECKING, Any, Final, NamedTuple, Optional, Union, cast
8
4
 
@@ -10,7 +6,7 @@ from mypy_extensions import trait
10
6
  from sqlglot import exp
11
7
 
12
8
  from sqlspec.builder import QueryBuilder
13
- from sqlspec.core import SQL, OperationType, ParameterStyle, SQLResult, Statement, StatementConfig, TypedParameter
9
+ from sqlspec.core import SQL, ParameterStyle, SQLResult, Statement, StatementConfig, TypedParameter
14
10
  from sqlspec.core.cache import get_cache_config, sql_cache
15
11
  from sqlspec.core.splitter import split_sql_script
16
12
  from sqlspec.exceptions import ImproperConfigurationError
@@ -38,19 +34,7 @@ logger = get_logger("driver")
38
34
 
39
35
 
40
36
  class ScriptExecutionResult(NamedTuple):
41
- """Result from script execution with statement count information.
42
-
43
- This named tuple eliminates the need for redundant script splitting
44
- by providing statement count information during execution rather than
45
- requiring re-parsing after execution.
46
-
47
- Attributes:
48
- cursor_result: The result returned by the database cursor/driver
49
- rowcount_override: Optional override for the number of affected rows
50
- special_data: Any special metadata or additional information
51
- statement_count: Total number of statements in the script
52
- successful_statements: Number of statements that executed successfully
53
- """
37
+ """Result from script execution with statement count information."""
54
38
 
55
39
  cursor_result: Any
56
40
  rowcount_override: Optional[int]
@@ -60,24 +44,7 @@ class ScriptExecutionResult(NamedTuple):
60
44
 
61
45
 
62
46
  class ExecutionResult(NamedTuple):
63
- """Comprehensive execution result containing all data needed for SQLResult building.
64
-
65
- This named tuple consolidates all execution result data to eliminate the need
66
- for additional data extraction calls and script re-parsing in build_statement_result.
67
-
68
- Attributes:
69
- cursor_result: The raw result returned by the database cursor/driver
70
- rowcount_override: Optional override for the number of affected rows
71
- special_data: Any special metadata or additional information from execution
72
- selected_data: For SELECT operations, the extracted row data
73
- column_names: For SELECT operations, the column names
74
- data_row_count: For SELECT operations, the number of rows returned
75
- statement_count: For script operations, total number of statements
76
- successful_statements: For script operations, number of successful statements
77
- is_script_result: Whether this result is from script execution
78
- is_select_result: Whether this result is from a SELECT operation
79
- is_many_result: Whether this result is from an execute_many operation
80
- """
47
+ """Execution result containing all data needed for SQLResult building."""
81
48
 
82
49
  cursor_result: Any
83
50
  rowcount_override: Optional[int]
@@ -93,20 +60,15 @@ class ExecutionResult(NamedTuple):
93
60
  last_inserted_id: Optional[Union[int, str]] = None
94
61
 
95
62
 
96
- EXEC_CURSOR_RESULT = 0
97
- EXEC_ROWCOUNT_OVERRIDE = 1
98
- EXEC_SPECIAL_DATA = 2
63
+ EXEC_CURSOR_RESULT: Final[int] = 0
64
+ EXEC_ROWCOUNT_OVERRIDE: Final[int] = 1
65
+ EXEC_SPECIAL_DATA: Final[int] = 2
99
66
  DEFAULT_EXECUTION_RESULT: Final[tuple[Any, Optional[int], Any]] = (None, None, None)
100
67
 
101
68
 
102
69
  @trait
103
70
  class CommonDriverAttributesMixin:
104
- """Common attributes and methods for driver adapters.
105
-
106
- This mixin provides the foundation for all SQLSpec drivers, including
107
- connection and configuration management, parameter processing, caching,
108
- and SQL compilation.
109
- """
71
+ """Common attributes and methods for driver adapters."""
110
72
 
111
73
  __slots__ = ("connection", "driver_features", "statement_config")
112
74
  connection: "Any"
@@ -180,9 +142,6 @@ class CommonDriverAttributesMixin:
180
142
  def build_statement_result(self, statement: "SQL", execution_result: ExecutionResult) -> "SQLResult":
181
143
  """Build and return the SQLResult from ExecutionResult data.
182
144
 
183
- Creates SQLResult objects from ExecutionResult data without requiring
184
- additional data extraction calls or script re-parsing.
185
-
186
145
  Args:
187
146
  statement: SQL statement that was executed
188
147
  execution_result: ExecutionResult containing all necessary data
@@ -215,51 +174,11 @@ class CommonDriverAttributesMixin:
215
174
  statement=statement,
216
175
  data=[],
217
176
  rows_affected=execution_result.rowcount_override or 0,
218
- operation_type=self._determine_operation_type(statement),
177
+ operation_type=statement.operation_type,
219
178
  last_inserted_id=execution_result.last_inserted_id,
220
179
  metadata=execution_result.special_data or {"status_message": "OK"},
221
180
  )
222
181
 
223
- def _determine_operation_type(self, statement: "Any") -> OperationType:
224
- """Determine operation type from SQL statement expression.
225
-
226
- Examines the statement's expression type to determine if it's
227
- INSERT, UPDATE, DELETE, SELECT, SCRIPT, or generic EXECUTE.
228
-
229
- Args:
230
- statement: SQL statement object with expression attribute
231
-
232
- Returns:
233
- OperationType literal value
234
- """
235
- if statement.is_script:
236
- return "SCRIPT"
237
-
238
- try:
239
- expression = statement.expression
240
- except AttributeError:
241
- return "EXECUTE"
242
-
243
- if not expression:
244
- return "EXECUTE"
245
-
246
- expr_type = type(expression).__name__.upper()
247
-
248
- if "ANONYMOUS" in expr_type and statement.is_script:
249
- return "SCRIPT"
250
-
251
- if "INSERT" in expr_type:
252
- return "INSERT"
253
- if "UPDATE" in expr_type:
254
- return "UPDATE"
255
- if "DELETE" in expr_type:
256
- return "DELETE"
257
- if "SELECT" in expr_type:
258
- return "SELECT"
259
- if "COPY" in expr_type:
260
- return "COPY"
261
- return "EXECUTE"
262
-
263
182
  def prepare_statement(
264
183
  self,
265
184
  statement: "Union[Statement, QueryBuilder]",
@@ -489,7 +408,8 @@ class CommonDriverAttributesMixin:
489
408
  if cached_result is not None:
490
409
  return cached_result
491
410
 
492
- compiled_sql, execution_parameters = statement.compile()
411
+ prepared_statement = self.prepare_statement(statement, statement_config=statement_config)
412
+ compiled_sql, execution_parameters = prepared_statement.compile()
493
413
 
494
414
  prepared_parameters = self.prepare_driver_parameters(
495
415
  execution_parameters, statement_config, is_many=statement.is_many
@@ -590,7 +510,7 @@ class CommonDriverAttributesMixin:
590
510
  def find_filter(
591
511
  filter_type: "type[FilterTypeT]",
592
512
  filters: "Sequence[StatementFilter | StatementParameters] | Sequence[StatementFilter]",
593
- ) -> "FilterTypeT | None":
513
+ ) -> "Optional[FilterTypeT]":
594
514
  """Get the filter specified by filter type from the filters.
595
515
 
596
516
  Args:
@@ -600,9 +520,10 @@ class CommonDriverAttributesMixin:
600
520
  Returns:
601
521
  The match filter instance or None
602
522
  """
603
- return next(
604
- (cast("FilterTypeT | None", filter_) for filter_ in filters if isinstance(filter_, filter_type)), None
605
- )
523
+ for filter_ in filters:
524
+ if isinstance(filter_, filter_type):
525
+ return filter_
526
+ return None
606
527
 
607
528
  def _create_count_query(self, original_sql: "SQL") -> "SQL":
608
529
  """Create a COUNT query from the original SQL statement.
sqlspec/driver/_sync.py CHANGED
@@ -1,8 +1,4 @@
1
- """Synchronous driver protocol implementation.
2
-
3
- This module provides the sync driver infrastructure for database adapters,
4
- including connection management, transaction support, and result processing.
5
- """
1
+ """Synchronous driver protocol implementation."""
6
2
 
7
3
  from abc import abstractmethod
8
4
  from typing import TYPE_CHECKING, Any, Final, NoReturn, Optional, Union, cast, overload
@@ -32,22 +28,13 @@ EMPTY_FILTERS: Final["list[StatementFilter]"] = []
32
28
 
33
29
 
34
30
  class SyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin, ToSchemaMixin):
35
- """Base class for synchronous database drivers.
36
-
37
- Provides the foundation for sync database adapters, including connection management,
38
- transaction support, and SQL execution methods. All database operations are performed
39
- synchronously and support context manager patterns for proper resource cleanup.
40
- """
31
+ """Base class for synchronous database drivers."""
41
32
 
42
33
  __slots__ = ()
43
34
 
44
35
  def dispatch_statement_execution(self, statement: "SQL", connection: "Any") -> "SQLResult":
45
36
  """Central execution dispatcher using the Template Method Pattern.
46
37
 
47
- Orchestrates the common execution flow, delegating database-specific steps
48
- to abstract methods that concrete adapters must implement.
49
- All database operations are wrapped in exception handling.
50
-
51
38
  Args:
52
39
  statement: The SQL statement to execute
53
40
  connection: The database connection to use
@@ -1,4 +1,3 @@
1
- # pyright: reportCallIssue=false, reportAttributeAccessIssue=false, reportArgumentType=false
2
1
  import datetime
3
2
  import logging
4
3
  from collections.abc import Sequence
@@ -28,10 +27,8 @@ __all__ = ("_DEFAULT_TYPE_DECODERS", "_default_msgspec_deserializer")
28
27
 
29
28
  logger = logging.getLogger(__name__)
30
29
 
31
- # Constants for performance optimization
32
- _DATETIME_TYPES: Final[set[type]] = {datetime.datetime, datetime.date, datetime.time}
33
- _PATH_TYPES: Final[tuple[type, ...]] = (Path, PurePath, UUID)
34
30
 
31
+ _DATETIME_TYPES: Final[set[type]] = {datetime.datetime, datetime.date, datetime.time}
35
32
  _DEFAULT_TYPE_DECODERS: Final[list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]] = [
36
33
  (lambda x: x is UUID, lambda t, v: t(v.hex)),
37
34
  (lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())),
@@ -44,21 +41,15 @@ _DEFAULT_TYPE_DECODERS: Final[list[tuple[Callable[[Any], bool], Callable[[Any, A
44
41
  def _default_msgspec_deserializer(
45
42
  target_type: Any, value: Any, type_decoders: "Optional[Sequence[tuple[Any, Any]]]" = None
46
43
  ) -> Any:
47
- """Default msgspec deserializer with type conversion support.
48
-
49
- Converts values to appropriate types for msgspec deserialization, including
50
- UUID, datetime, date, time, Enum, Path, and PurePath types.
51
- """
44
+ """Default msgspec deserializer with type conversion support."""
52
45
  if type_decoders:
53
46
  for predicate, decoder in type_decoders:
54
47
  if predicate(target_type):
55
48
  return decoder(target_type, value)
56
49
 
57
- # Fast path checks using type identity and isinstance
58
50
  if target_type is UUID and isinstance(value, UUID):
59
51
  return value.hex
60
52
 
61
- # Use pre-computed set for faster lookup
62
53
  if target_type in _DATETIME_TYPES:
63
54
  try:
64
55
  return value.isoformat()
@@ -71,7 +62,6 @@ def _default_msgspec_deserializer(
71
62
  if isinstance(value, target_type):
72
63
  return value
73
64
 
74
- # Check for path types using pre-computed tuple
75
65
  if isinstance(target_type, type):
76
66
  try:
77
67
  if issubclass(target_type, (Path, PurePath)) or issubclass(target_type, UUID):
@@ -86,7 +76,6 @@ def _default_msgspec_deserializer(
86
76
  class ToSchemaMixin:
87
77
  __slots__ = ()
88
78
 
89
- # Schema conversion overloads - handle common cases first
90
79
  @overload
91
80
  @staticmethod
92
81
  def to_schema(data: "list[dict[str, Any]]") -> "list[dict[str, Any]]": ...
@@ -125,15 +114,11 @@ class ToSchemaMixin:
125
114
  def to_schema(data: Any, *, schema_type: "Optional[type[ModelDTOT]]" = None) -> Any:
126
115
  """Convert data to a specified schema type.
127
116
 
128
- Supports conversion to dataclasses, msgspec structs, Pydantic models, and attrs classes.
129
- Handles both single objects and sequences.
130
-
131
117
  Raises:
132
118
  SQLSpecError if `schema_type` is not a valid type.
133
119
 
134
120
  Returns:
135
121
  Converted data in the specified schema type.
136
-
137
122
  """
138
123
  if schema_type is None:
139
124
  return data
@@ -152,30 +137,24 @@ class ToSchemaMixin:
152
137
  return schema_type(**data) # type: ignore[operator]
153
138
  return data
154
139
  if is_msgspec_struct(schema_type):
155
- # Cache the deserializer to avoid repeated partial() calls
156
140
  deserializer = partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS)
157
141
  if not isinstance(data, Sequence):
158
142
  return convert(obj=data, type=schema_type, from_attributes=True, dec_hook=deserializer)
159
- return convert(
160
- obj=data,
161
- type=list[schema_type], # type: ignore[valid-type] # pyright: ignore
162
- from_attributes=True,
163
- dec_hook=deserializer,
164
- )
143
+ return convert(obj=data, type=list[schema_type], from_attributes=True, dec_hook=deserializer) # type: ignore[valid-type]
165
144
  if is_pydantic_model(schema_type):
166
145
  if not isinstance(data, Sequence):
167
146
  adapter = get_type_adapter(schema_type)
168
- return adapter.validate_python(data, from_attributes=True) # pyright: ignore
169
- list_adapter = get_type_adapter(list[schema_type]) # type: ignore[valid-type] # pyright: ignore
147
+ return adapter.validate_python(data, from_attributes=True)
148
+ list_adapter = get_type_adapter(list[schema_type]) # type: ignore[valid-type]
170
149
  return list_adapter.validate_python(data, from_attributes=True)
171
150
  if is_attrs_schema(schema_type):
172
151
  if CATTRS_INSTALLED:
173
152
  if isinstance(data, Sequence):
174
- return cattrs_structure(data, list[schema_type]) # type: ignore[valid-type] # pyright: ignore
153
+ return cattrs_structure(data, list[schema_type]) # type: ignore[valid-type]
175
154
  if hasattr(data, "__attrs_attrs__"):
176
155
  unstructured_data = cattrs_unstructure(data)
177
- return cattrs_structure(unstructured_data, schema_type) # pyright: ignore
178
- return cattrs_structure(data, schema_type) # pyright: ignore
156
+ return cattrs_structure(unstructured_data, schema_type)
157
+ return cattrs_structure(data, schema_type)
179
158
  if isinstance(data, list):
180
159
  attrs_result: list[Any] = []
181
160
  for item in data:
@@ -9,7 +9,7 @@ from sqlspec.exceptions import SQLConversionError
9
9
 
10
10
  __all__ = ("SQLTranslatorMixin",)
11
11
 
12
- # Constants for better performance
12
+
13
13
  _DEFAULT_PRETTY: Final[bool] = True
14
14
 
15
15
 
@@ -18,6 +18,7 @@ class SQLTranslatorMixin:
18
18
  """Mixin for drivers supporting SQL translation."""
19
19
 
20
20
  __slots__ = ()
21
+ dialect: "Optional[DialectType]"
21
22
 
22
23
  def convert_to_dialect(
23
24
  self, statement: "Statement", to_dialect: "Optional[DialectType]" = None, pretty: bool = _DEFAULT_PRETTY
@@ -35,7 +36,7 @@ class SQLTranslatorMixin:
35
36
  Raises:
36
37
  SQLConversionError: If parsing or conversion fails
37
38
  """
38
- # Fast path: get the parsed expression with minimal allocations
39
+
39
40
  parsed_expression: Optional[exp.Expression] = None
40
41
 
41
42
  if statement is not None and isinstance(statement, SQL):
@@ -47,19 +48,16 @@ class SQLTranslatorMixin:
47
48
  else:
48
49
  parsed_expression = self._parse_statement_safely(statement)
49
50
 
50
- # Get target dialect with fallback
51
- target_dialect = to_dialect or self.dialect # type: ignore[attr-defined]
51
+ target_dialect = to_dialect or self.dialect
52
52
 
53
- # Generate SQL with error handling
54
53
  return self._generate_sql_safely(parsed_expression, target_dialect, pretty)
55
54
 
56
55
  def _parse_statement_safely(self, statement: "Statement") -> "exp.Expression":
57
56
  """Parse statement with copy=False optimization and proper error handling."""
58
57
  try:
59
- # Convert statement to string if needed
60
58
  sql_string = str(statement)
61
- # Use copy=False for better performance
62
- return parse_one(sql_string, dialect=self.dialect, copy=False) # type: ignore[attr-defined]
59
+
60
+ return parse_one(sql_string, dialect=self.dialect, copy=False)
63
61
  except Exception as e:
64
62
  self._raise_parse_error(e)
65
63
 
sqlspec/exceptions.py CHANGED
@@ -181,9 +181,8 @@ def wrap_exceptions(
181
181
  (isinstance(suppress, type) and isinstance(exc, suppress))
182
182
  or (isinstance(suppress, tuple) and isinstance(exc, suppress))
183
183
  ):
184
- return # Suppress this exception
184
+ return
185
185
 
186
- # If it's already a SQLSpec exception, don't wrap it
187
186
  if isinstance(exc, SQLSpecError):
188
187
  raise
189
188
 
sqlspec/loader.py CHANGED
@@ -1,17 +1,15 @@
1
- """SQL file loader module for managing SQL statements from files.
1
+ """SQL file loader for managing SQL statements from files.
2
2
 
3
- This module provides functionality to load, cache, and manage SQL statements
3
+ Provides functionality to load, cache, and manage SQL statements
4
4
  from files using aiosql-style named queries.
5
5
  """
6
6
 
7
7
  import hashlib
8
8
  import re
9
9
  import time
10
- from dataclasses import dataclass, field
11
10
  from datetime import datetime, timezone
12
- from difflib import get_close_matches
13
11
  from pathlib import Path
14
- from typing import Any, Optional, Union
12
+ from typing import TYPE_CHECKING, Any, Final, Optional, Union
15
13
 
16
14
  from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
17
15
  from sqlspec.core.statement import SQL
@@ -21,11 +19,13 @@ from sqlspec.exceptions import (
21
19
  SQLFileParseError,
22
20
  StorageOperationFailedError,
23
21
  )
24
- from sqlspec.storage import storage_registry
25
- from sqlspec.storage.registry import StorageRegistry
22
+ from sqlspec.storage.registry import storage_registry as default_storage_registry
26
23
  from sqlspec.utils.correlation import CorrelationContext
27
24
  from sqlspec.utils.logging import get_logger
28
25
 
26
+ if TYPE_CHECKING:
27
+ from sqlspec.storage.registry import StorageRegistry
28
+
29
29
  __all__ = ("CachedSQLFile", "NamedStatement", "SQLFile", "SQLFileLoader")
30
30
 
31
31
  logger = get_logger("loader")
@@ -38,48 +38,8 @@ TRIM_SPECIAL_CHARS = re.compile(r"[^\w.-]")
38
38
  # Matches: -- dialect: dialect_name (optional dialect specification)
39
39
  DIALECT_PATTERN = re.compile(r"^\s*--\s*dialect\s*:\s*(?P<dialect>[a-zA-Z0-9_]+)\s*$", re.IGNORECASE | re.MULTILINE)
40
40
 
41
- # Supported SQL dialects (based on SQLGlot's available dialects)
42
- SUPPORTED_DIALECTS = {
43
- # Core databases
44
- "sqlite",
45
- "postgresql",
46
- "postgres",
47
- "mysql",
48
- "oracle",
49
- "mssql",
50
- "tsql",
51
- # Cloud platforms
52
- "bigquery",
53
- "snowflake",
54
- "redshift",
55
- "athena",
56
- "fabric",
57
- # Analytics engines
58
- "clickhouse",
59
- "duckdb",
60
- "databricks",
61
- "spark",
62
- "spark2",
63
- "trino",
64
- "presto",
65
- # Specialized
66
- "hive",
67
- "drill",
68
- "druid",
69
- "materialize",
70
- "teradata",
71
- "dremio",
72
- "doris",
73
- "risingwave",
74
- "singlestore",
75
- "starrocks",
76
- "tableau",
77
- "exasol",
78
- "dune",
79
- }
80
41
 
81
- # Dialect aliases for common variants
82
- DIALECT_ALIASES = {
42
+ DIALECT_ALIASES: Final = {
83
43
  "postgresql": "postgres",
84
44
  "pg": "postgres",
85
45
  "pgplsql": "postgres",
@@ -88,7 +48,7 @@ DIALECT_ALIASES = {
88
48
  "tsql": "mssql",
89
49
  }
90
50
 
91
- MIN_QUERY_PARTS = 3
51
+ MIN_QUERY_PARTS: Final = 3
92
52
 
93
53
 
94
54
  def _normalize_query_name(name: str) -> str:
@@ -129,19 +89,6 @@ def _normalize_dialect_for_sqlglot(dialect: str) -> str:
129
89
  return DIALECT_ALIASES.get(normalized, normalized)
130
90
 
131
91
 
132
- def _get_dialect_suggestions(invalid_dialect: str) -> "list[str]":
133
- """Get dialect suggestions using fuzzy matching.
134
-
135
- Args:
136
- invalid_dialect: Invalid dialect name that was provided
137
-
138
- Returns:
139
- List of suggested dialect names (up to 3 suggestions)
140
- """
141
-
142
- return get_close_matches(invalid_dialect, SUPPORTED_DIALECTS, n=3, cutoff=0.6)
143
-
144
-
145
92
  class NamedStatement:
146
93
  """Represents a parsed SQL statement with metadata.
147
94
 
@@ -159,7 +106,6 @@ class NamedStatement:
159
106
  self.start_line = start_line
160
107
 
161
108
 
162
- @dataclass
163
109
  class SQLFile:
164
110
  """Represents a loaded SQL file with metadata.
165
111
 
@@ -167,28 +113,32 @@ class SQLFile:
167
113
  timestamps, and content hash.
168
114
  """
169
115
 
170
- content: str
171
- """The raw SQL content from the file."""
116
+ __slots__ = ("checksum", "content", "loaded_at", "metadata", "path")
172
117
 
173
- path: str
174
- """Path where the SQL file was loaded from."""
118
+ def __init__(
119
+ self,
120
+ content: str,
121
+ path: str,
122
+ metadata: "Optional[dict[str, Any]]" = None,
123
+ loaded_at: "Optional[datetime]" = None,
124
+ ) -> None:
125
+ """Initialize SQLFile.
175
126
 
176
- metadata: "dict[str, Any]" = field(default_factory=dict)
177
- """Optional metadata associated with the SQL file."""
178
-
179
- checksum: str = field(init=False)
180
- """MD5 checksum of the SQL content for cache invalidation."""
181
-
182
- loaded_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
183
- """Timestamp when the file was loaded."""
184
-
185
- def __post_init__(self) -> None:
186
- """Calculate checksum after initialization."""
127
+ Args:
128
+ content: The raw SQL content from the file.
129
+ path: Path where the SQL file was loaded from.
130
+ metadata: Optional metadata associated with the SQL file.
131
+ loaded_at: Timestamp when the file was loaded.
132
+ """
133
+ self.content = content
134
+ self.path = path
135
+ self.metadata = metadata or {}
136
+ self.loaded_at = loaded_at or datetime.now(timezone.utc)
187
137
  self.checksum = hashlib.md5(self.content.encode(), usedforsecurity=False).hexdigest()
188
138
 
189
139
 
190
140
  class CachedSQLFile:
191
- """Cached SQL file with parsed statements for efficient reloading.
141
+ """Cached SQL file with parsed statements.
192
142
 
193
143
  Stored in the file cache to avoid re-parsing SQL files when their
194
144
  content hasn't changed.
@@ -205,17 +155,19 @@ class CachedSQLFile:
205
155
  """
206
156
  self.sql_file = sql_file
207
157
  self.parsed_statements = parsed_statements
208
- self.statement_names = list(parsed_statements.keys())
158
+ self.statement_names = tuple(parsed_statements.keys())
209
159
 
210
160
 
211
161
  class SQLFileLoader:
212
162
  """Loads and parses SQL files with aiosql-style named queries.
213
163
 
214
- Provides functionality to load SQL files containing named queries
215
- (using -- name: syntax) and retrieve them by name.
164
+ Loads SQL files containing named queries (using -- name: syntax)
165
+ and retrieves them by name.
216
166
  """
217
167
 
218
- def __init__(self, *, encoding: str = "utf-8", storage_registry: StorageRegistry = storage_registry) -> None:
168
+ __slots__ = ("_files", "_queries", "_query_to_file", "encoding", "storage_registry")
169
+
170
+ def __init__(self, *, encoding: str = "utf-8", storage_registry: "Optional[StorageRegistry]" = None) -> None:
219
171
  """Initialize the SQL file loader.
220
172
 
221
173
  Args:
@@ -223,7 +175,8 @@ class SQLFileLoader:
223
175
  storage_registry: Storage registry for handling file URIs.
224
176
  """
225
177
  self.encoding = encoding
226
- self.storage_registry = storage_registry
178
+
179
+ self.storage_registry = storage_registry or default_storage_registry
227
180
  self._queries: dict[str, NamedStatement] = {}
228
181
  self._files: dict[str, SQLFile] = {}
229
182
  self._query_to_file: dict[str, str] = {}
@@ -309,7 +262,6 @@ class SQLFileLoader:
309
262
  except KeyError as e:
310
263
  raise SQLFileNotFoundError(path_str) from e
311
264
  except MissingDependencyError:
312
- # Fall back to standard file reading when no storage backend is available
313
265
  try:
314
266
  return path.read_text(encoding=self.encoding) # type: ignore[union-attr]
315
267
  except FileNotFoundError as e:
@@ -350,7 +302,6 @@ class SQLFileLoader:
350
302
  or invalid dialect names are specified
351
303
  """
352
304
  statements: dict[str, NamedStatement] = {}
353
- content.splitlines()
354
305
 
355
306
  name_matches = list(QUERY_NAME_PATTERN.finditer(content))
356
307
  if not name_matches:
@@ -379,20 +330,7 @@ class SQLFileLoader:
379
330
  if dialect_match:
380
331
  declared_dialect = dialect_match.group("dialect").lower()
381
332
 
382
- normalized_dialect = _normalize_dialect(declared_dialect)
383
-
384
- if normalized_dialect not in SUPPORTED_DIALECTS:
385
- suggestions = _get_dialect_suggestions(normalized_dialect)
386
- warning_msg = f"Unknown dialect '{declared_dialect}' at line {statement_start_line + 1}"
387
- if suggestions:
388
- warning_msg += f". Did you mean: {', '.join(suggestions)}?"
389
- warning_msg += (
390
- f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
391
- )
392
- logger.warning(warning_msg)
393
- dialect = declared_dialect.lower()
394
- else:
395
- dialect = normalized_dialect
333
+ dialect = _normalize_dialect(declared_dialect)
396
334
  remaining_lines = section_lines[1:]
397
335
  statement_sql = "\n".join(remaining_lines)
398
336
 
@@ -473,7 +411,7 @@ class SQLFileLoader:
473
411
  raise
474
412
 
475
413
  def _load_directory(self, dir_path: Path) -> int:
476
- """Load all SQL files from a directory with namespacing."""
414
+ """Load all SQL files from a directory."""
477
415
  sql_files = list(dir_path.rglob("*.sql"))
478
416
  if not sql_files:
479
417
  return 0
@@ -486,7 +424,7 @@ class SQLFileLoader:
486
424
  return len(sql_files)
487
425
 
488
426
  def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
489
- """Load a single SQL file with optional namespace and caching.
427
+ """Load a single SQL file with optional namespace.
490
428
 
491
429
  Args:
492
430
  file_path: Path to the SQL file.
@@ -543,7 +481,7 @@ class SQLFileLoader:
543
481
  unified_cache.put(cache_key, cached_file_data)
544
482
 
545
483
  def _load_file_without_cache(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
546
- """Load a single SQL file without caching.
484
+ """Load a single SQL file without using cache.
547
485
 
548
486
  Args:
549
487
  file_path: Path to the SQL file.
@@ -580,7 +518,7 @@ class SQLFileLoader:
580
518
  Raises:
581
519
  ValueError: If query name already exists.
582
520
  """
583
- # Normalize the name for consistency with file-loaded queries
521
+
584
522
  normalized_name = _normalize_query_name(name)
585
523
 
586
524
  if normalized_name in self._queries:
@@ -589,17 +527,7 @@ class SQLFileLoader:
589
527
  raise ValueError(msg)
590
528
 
591
529
  if dialect is not None:
592
- normalized_dialect = _normalize_dialect(dialect)
593
- if normalized_dialect not in SUPPORTED_DIALECTS:
594
- suggestions = _get_dialect_suggestions(normalized_dialect)
595
- warning_msg = f"Unknown dialect '{dialect}'"
596
- if suggestions:
597
- warning_msg += f". Did you mean: {', '.join(suggestions)}?"
598
- warning_msg += f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
599
- logger.warning(warning_msg)
600
- dialect = dialect.lower()
601
- else:
602
- dialect = normalized_dialect
530
+ dialect = _normalize_dialect(dialect)
603
531
 
604
532
  statement = NamedStatement(name=normalized_name, sql=sql.strip(), dialect=dialect, start_line=0)
605
533
  self._queries[normalized_name] = statement