sqlspec 0.14.1__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (159) hide show
  1. sqlspec/__init__.py +50 -25
  2. sqlspec/__main__.py +1 -1
  3. sqlspec/__metadata__.py +1 -3
  4. sqlspec/_serialization.py +1 -2
  5. sqlspec/_sql.py +480 -121
  6. sqlspec/_typing.py +278 -142
  7. sqlspec/adapters/adbc/__init__.py +4 -3
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/config.py +115 -260
  10. sqlspec/adapters/adbc/driver.py +462 -367
  11. sqlspec/adapters/aiosqlite/__init__.py +18 -3
  12. sqlspec/adapters/aiosqlite/_types.py +13 -0
  13. sqlspec/adapters/aiosqlite/config.py +199 -129
  14. sqlspec/adapters/aiosqlite/driver.py +230 -269
  15. sqlspec/adapters/asyncmy/__init__.py +18 -3
  16. sqlspec/adapters/asyncmy/_types.py +12 -0
  17. sqlspec/adapters/asyncmy/config.py +80 -168
  18. sqlspec/adapters/asyncmy/driver.py +260 -225
  19. sqlspec/adapters/asyncpg/__init__.py +19 -4
  20. sqlspec/adapters/asyncpg/_types.py +17 -0
  21. sqlspec/adapters/asyncpg/config.py +82 -181
  22. sqlspec/adapters/asyncpg/driver.py +285 -383
  23. sqlspec/adapters/bigquery/__init__.py +17 -3
  24. sqlspec/adapters/bigquery/_types.py +12 -0
  25. sqlspec/adapters/bigquery/config.py +191 -258
  26. sqlspec/adapters/bigquery/driver.py +474 -646
  27. sqlspec/adapters/duckdb/__init__.py +14 -3
  28. sqlspec/adapters/duckdb/_types.py +12 -0
  29. sqlspec/adapters/duckdb/config.py +415 -351
  30. sqlspec/adapters/duckdb/driver.py +343 -413
  31. sqlspec/adapters/oracledb/__init__.py +19 -5
  32. sqlspec/adapters/oracledb/_types.py +14 -0
  33. sqlspec/adapters/oracledb/config.py +123 -379
  34. sqlspec/adapters/oracledb/driver.py +507 -560
  35. sqlspec/adapters/psqlpy/__init__.py +13 -3
  36. sqlspec/adapters/psqlpy/_types.py +11 -0
  37. sqlspec/adapters/psqlpy/config.py +93 -254
  38. sqlspec/adapters/psqlpy/driver.py +505 -234
  39. sqlspec/adapters/psycopg/__init__.py +19 -5
  40. sqlspec/adapters/psycopg/_types.py +17 -0
  41. sqlspec/adapters/psycopg/config.py +143 -403
  42. sqlspec/adapters/psycopg/driver.py +706 -872
  43. sqlspec/adapters/sqlite/__init__.py +14 -3
  44. sqlspec/adapters/sqlite/_types.py +11 -0
  45. sqlspec/adapters/sqlite/config.py +202 -118
  46. sqlspec/adapters/sqlite/driver.py +264 -303
  47. sqlspec/base.py +105 -9
  48. sqlspec/{statement/builder → builder}/__init__.py +12 -14
  49. sqlspec/{statement/builder → builder}/_base.py +120 -55
  50. sqlspec/{statement/builder → builder}/_column.py +17 -6
  51. sqlspec/{statement/builder → builder}/_ddl.py +46 -79
  52. sqlspec/{statement/builder → builder}/_ddl_utils.py +5 -10
  53. sqlspec/{statement/builder → builder}/_delete.py +6 -25
  54. sqlspec/{statement/builder → builder}/_insert.py +18 -65
  55. sqlspec/builder/_merge.py +56 -0
  56. sqlspec/{statement/builder → builder}/_parsing_utils.py +8 -11
  57. sqlspec/{statement/builder → builder}/_select.py +11 -56
  58. sqlspec/{statement/builder → builder}/_update.py +12 -18
  59. sqlspec/{statement/builder → builder}/mixins/__init__.py +10 -14
  60. sqlspec/{statement/builder → builder}/mixins/_cte_and_set_ops.py +48 -59
  61. sqlspec/{statement/builder → builder}/mixins/_insert_operations.py +34 -18
  62. sqlspec/{statement/builder → builder}/mixins/_join_operations.py +1 -3
  63. sqlspec/{statement/builder → builder}/mixins/_merge_operations.py +19 -9
  64. sqlspec/{statement/builder → builder}/mixins/_order_limit_operations.py +3 -3
  65. sqlspec/{statement/builder → builder}/mixins/_pivot_operations.py +4 -8
  66. sqlspec/{statement/builder → builder}/mixins/_select_operations.py +25 -38
  67. sqlspec/{statement/builder → builder}/mixins/_update_operations.py +15 -16
  68. sqlspec/{statement/builder → builder}/mixins/_where_clause.py +210 -137
  69. sqlspec/cli.py +4 -5
  70. sqlspec/config.py +180 -133
  71. sqlspec/core/__init__.py +63 -0
  72. sqlspec/core/cache.py +873 -0
  73. sqlspec/core/compiler.py +396 -0
  74. sqlspec/core/filters.py +830 -0
  75. sqlspec/core/hashing.py +310 -0
  76. sqlspec/core/parameters.py +1209 -0
  77. sqlspec/core/result.py +664 -0
  78. sqlspec/{statement → core}/splitter.py +321 -191
  79. sqlspec/core/statement.py +666 -0
  80. sqlspec/driver/__init__.py +7 -10
  81. sqlspec/driver/_async.py +387 -176
  82. sqlspec/driver/_common.py +527 -289
  83. sqlspec/driver/_sync.py +390 -172
  84. sqlspec/driver/mixins/__init__.py +2 -19
  85. sqlspec/driver/mixins/_result_tools.py +164 -0
  86. sqlspec/driver/mixins/_sql_translator.py +6 -3
  87. sqlspec/exceptions.py +5 -252
  88. sqlspec/extensions/aiosql/adapter.py +93 -96
  89. sqlspec/extensions/litestar/cli.py +1 -1
  90. sqlspec/extensions/litestar/config.py +0 -1
  91. sqlspec/extensions/litestar/handlers.py +15 -26
  92. sqlspec/extensions/litestar/plugin.py +18 -16
  93. sqlspec/extensions/litestar/providers.py +17 -52
  94. sqlspec/loader.py +424 -105
  95. sqlspec/migrations/__init__.py +12 -0
  96. sqlspec/migrations/base.py +92 -68
  97. sqlspec/migrations/commands.py +24 -106
  98. sqlspec/migrations/loaders.py +402 -0
  99. sqlspec/migrations/runner.py +49 -51
  100. sqlspec/migrations/tracker.py +31 -44
  101. sqlspec/migrations/utils.py +64 -24
  102. sqlspec/protocols.py +7 -183
  103. sqlspec/storage/__init__.py +1 -1
  104. sqlspec/storage/backends/base.py +37 -40
  105. sqlspec/storage/backends/fsspec.py +136 -112
  106. sqlspec/storage/backends/obstore.py +138 -160
  107. sqlspec/storage/capabilities.py +5 -4
  108. sqlspec/storage/registry.py +57 -106
  109. sqlspec/typing.py +136 -115
  110. sqlspec/utils/__init__.py +2 -3
  111. sqlspec/utils/correlation.py +0 -3
  112. sqlspec/utils/deprecation.py +6 -6
  113. sqlspec/utils/fixtures.py +6 -6
  114. sqlspec/utils/logging.py +0 -2
  115. sqlspec/utils/module_loader.py +7 -12
  116. sqlspec/utils/singleton.py +0 -1
  117. sqlspec/utils/sync_tools.py +17 -38
  118. sqlspec/utils/text.py +12 -51
  119. sqlspec/utils/type_guards.py +443 -232
  120. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/METADATA +7 -2
  121. sqlspec-0.16.0.dist-info/RECORD +134 -0
  122. sqlspec/adapters/adbc/transformers.py +0 -108
  123. sqlspec/driver/connection.py +0 -207
  124. sqlspec/driver/mixins/_cache.py +0 -114
  125. sqlspec/driver/mixins/_csv_writer.py +0 -91
  126. sqlspec/driver/mixins/_pipeline.py +0 -508
  127. sqlspec/driver/mixins/_query_tools.py +0 -796
  128. sqlspec/driver/mixins/_result_utils.py +0 -138
  129. sqlspec/driver/mixins/_storage.py +0 -912
  130. sqlspec/driver/mixins/_type_coercion.py +0 -128
  131. sqlspec/driver/parameters.py +0 -138
  132. sqlspec/statement/__init__.py +0 -21
  133. sqlspec/statement/builder/_merge.py +0 -95
  134. sqlspec/statement/cache.py +0 -50
  135. sqlspec/statement/filters.py +0 -625
  136. sqlspec/statement/parameters.py +0 -956
  137. sqlspec/statement/pipelines/__init__.py +0 -210
  138. sqlspec/statement/pipelines/analyzers/__init__.py +0 -9
  139. sqlspec/statement/pipelines/analyzers/_analyzer.py +0 -646
  140. sqlspec/statement/pipelines/context.py +0 -109
  141. sqlspec/statement/pipelines/transformers/__init__.py +0 -7
  142. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +0 -88
  143. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +0 -1247
  144. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +0 -76
  145. sqlspec/statement/pipelines/validators/__init__.py +0 -23
  146. sqlspec/statement/pipelines/validators/_dml_safety.py +0 -290
  147. sqlspec/statement/pipelines/validators/_parameter_style.py +0 -370
  148. sqlspec/statement/pipelines/validators/_performance.py +0 -714
  149. sqlspec/statement/pipelines/validators/_security.py +0 -967
  150. sqlspec/statement/result.py +0 -435
  151. sqlspec/statement/sql.py +0 -1774
  152. sqlspec/utils/cached_property.py +0 -25
  153. sqlspec/utils/statement_hashing.py +0 -203
  154. sqlspec-0.14.1.dist-info/RECORD +0 -145
  155. /sqlspec/{statement/builder → builder}/mixins/_delete_operations.py +0 -0
  156. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/WHEEL +0 -0
  157. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/entry_points.txt +0 -0
  158. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/LICENSE +0 -0
  159. {sqlspec-0.14.1.dist-info → sqlspec-0.16.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/loader.py CHANGED
@@ -9,49 +9,158 @@ import re
9
9
  import time
10
10
  from dataclasses import dataclass, field
11
11
  from datetime import datetime, timezone
12
+ from difflib import get_close_matches
12
13
  from pathlib import Path
13
14
  from typing import Any, Optional, Union
14
15
 
15
- from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError
16
- from sqlspec.statement.sql import SQL
16
+ from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
17
+ from sqlspec.core.parameters import ParameterStyleConfig, ParameterValidator
18
+ from sqlspec.core.statement import SQL, StatementConfig
19
+ from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError
17
20
  from sqlspec.storage import storage_registry
18
21
  from sqlspec.storage.registry import StorageRegistry
19
22
  from sqlspec.utils.correlation import CorrelationContext
20
23
  from sqlspec.utils.logging import get_logger
21
24
 
22
- __all__ = ("SQLFile", "SQLFileLoader")
25
+ __all__ = ("CachedSQLFile", "NamedStatement", "SQLFile", "SQLFileLoader")
23
26
 
24
27
  logger = get_logger("loader")
25
28
 
26
29
  # Matches: -- name: query_name (supports hyphens and special suffixes)
27
30
  # We capture the name plus any trailing special characters
28
31
  QUERY_NAME_PATTERN = re.compile(r"^\s*--\s*name\s*:\s*([\w-]+[^\w\s]*)\s*$", re.MULTILINE | re.IGNORECASE)
29
- TRIM_TRAILING_SPECIAL_CHARS = re.compile(r"[^\w-]+$")
32
+ TRIM_SPECIAL_CHARS = re.compile(r"[^\w-]")
33
+
34
+ # Matches: -- dialect: dialect_name (optional dialect specification)
35
+ DIALECT_PATTERN = re.compile(r"^\s*--\s*dialect\s*:\s*(?P<dialect>[a-zA-Z0-9_]+)\s*$", re.IGNORECASE | re.MULTILINE)
36
+
37
+ # Supported SQL dialects (based on SQLGlot's available dialects)
38
+ SUPPORTED_DIALECTS = {
39
+ # Core databases
40
+ "sqlite",
41
+ "postgresql",
42
+ "postgres",
43
+ "mysql",
44
+ "oracle",
45
+ "mssql",
46
+ "tsql",
47
+ # Cloud platforms
48
+ "bigquery",
49
+ "snowflake",
50
+ "redshift",
51
+ "athena",
52
+ "fabric",
53
+ # Analytics engines
54
+ "clickhouse",
55
+ "duckdb",
56
+ "databricks",
57
+ "spark",
58
+ "spark2",
59
+ "trino",
60
+ "presto",
61
+ # Specialized
62
+ "hive",
63
+ "drill",
64
+ "druid",
65
+ "materialize",
66
+ "teradata",
67
+ "dremio",
68
+ "doris",
69
+ "risingwave",
70
+ "singlestore",
71
+ "starrocks",
72
+ "tableau",
73
+ "exasol",
74
+ "dune",
75
+ }
76
+
77
+ # Dialect aliases for common variants
78
+ DIALECT_ALIASES = {
79
+ "postgresql": "postgres",
80
+ "pg": "postgres",
81
+ "pgplsql": "postgres",
82
+ "plsql": "oracle",
83
+ "oracledb": "oracle",
84
+ "tsql": "mssql",
85
+ }
86
+
30
87
  MIN_QUERY_PARTS = 3
31
88
 
32
89
 
33
90
  def _normalize_query_name(name: str) -> str:
34
91
  """Normalize query name to be a valid Python identifier.
35
92
 
36
- - Strips trailing special characters (like $, !, etc from aiosql)
37
- - Replaces hyphens with underscores
38
-
39
93
  Args:
40
94
  name: Raw query name from SQL file
41
95
 
42
96
  Returns:
43
- converted query name suitable as Python identifier
97
+ Normalized query name suitable as Python identifier
98
+ """
99
+ return TRIM_SPECIAL_CHARS.sub("", name).replace("-", "_")
100
+
101
+
102
+ def _normalize_dialect(dialect: str) -> str:
103
+ """Normalize dialect name with aliases.
104
+
105
+ Args:
106
+ dialect: Raw dialect name from SQL file
107
+
108
+ Returns:
109
+ Normalized dialect name
110
+ """
111
+ normalized = dialect.lower().strip()
112
+ return DIALECT_ALIASES.get(normalized, normalized)
113
+
114
+
115
+ def _normalize_dialect_for_sqlglot(dialect: str) -> str:
116
+ """Normalize dialect name for SQLGlot compatibility.
117
+
118
+ Args:
119
+ dialect: Dialect name from SQL file or parameter
120
+
121
+ Returns:
122
+ SQLGlot-compatible dialect name
123
+ """
124
+ normalized = dialect.lower().strip()
125
+ return DIALECT_ALIASES.get(normalized, normalized)
126
+
127
+
128
+ def _get_dialect_suggestions(invalid_dialect: str) -> "list[str]":
129
+ """Get dialect suggestions using fuzzy matching.
130
+
131
+ Args:
132
+ invalid_dialect: Invalid dialect name that was provided
133
+
134
+ Returns:
135
+ List of suggested dialect names (up to 3 suggestions)
136
+ """
137
+
138
+ return get_close_matches(invalid_dialect, SUPPORTED_DIALECTS, n=3, cutoff=0.6)
139
+
140
+
141
+ class NamedStatement:
142
+ """Represents a parsed SQL statement with metadata.
143
+
144
+ Contains individual SQL statements extracted from files with their
145
+ normalized names, SQL content, optional dialect specifications,
146
+ and line position for error reporting.
44
147
  """
45
- # Strip trailing non-alphanumeric characters (excluding underscore) and replace hyphens
46
- return TRIM_TRAILING_SPECIAL_CHARS.sub("", name).replace("-", "_")
148
+
149
+ __slots__ = ("dialect", "name", "sql", "start_line")
150
+
151
+ def __init__(self, name: str, sql: str, dialect: "Optional[str]" = None, start_line: int = 0) -> None:
152
+ self.name = name
153
+ self.sql = sql
154
+ self.dialect = dialect
155
+ self.start_line = start_line
47
156
 
48
157
 
49
158
  @dataclass
50
159
  class SQLFile:
51
160
  """Represents a loaded SQL file with metadata.
52
161
 
53
- This class holds the SQL content along with metadata about the file
54
- such as its location, timestamps, and content hash.
162
+ Contains SQL content and associated metadata including file location,
163
+ timestamps, and content hash.
55
164
  """
56
165
 
57
166
  content: str
@@ -74,26 +183,32 @@ class SQLFile:
74
183
  self.checksum = hashlib.md5(self.content.encode(), usedforsecurity=False).hexdigest()
75
184
 
76
185
 
77
- class SQLFileLoader:
78
- """Loads and parses SQL files with aiosql-style named queries.
186
+ class CachedSQLFile:
187
+ """Cached SQL file with parsed statements for efficient reloading.
188
+
189
+ Stored in the file cache to avoid re-parsing SQL files when their
190
+ content hasn't changed.
191
+ """
79
192
 
80
- This class provides functionality to load SQL files containing
81
- named queries (using -- name: syntax) and retrieve them by name.
193
+ __slots__ = ("parsed_statements", "sql_file", "statement_names")
82
194
 
83
- Example:
84
- ```python
85
- # Initialize loader
86
- loader = SQLFileLoader()
195
+ def __init__(self, sql_file: SQLFile, parsed_statements: "dict[str, NamedStatement]") -> None:
196
+ """Initialize cached SQL file.
87
197
 
88
- # Load SQL files
89
- loader.load_sql("queries/users.sql")
90
- loader.load_sql(
91
- "queries/products.sql", "queries/orders.sql"
92
- )
198
+ Args:
199
+ sql_file: The original SQLFile with content and metadata.
200
+ parsed_statements: Named statements from the file.
201
+ """
202
+ self.sql_file = sql_file
203
+ self.parsed_statements = parsed_statements
204
+ self.statement_names = list(parsed_statements.keys())
93
205
 
94
- # Get SQL by query name
95
- sql = loader.get_sql("get_user_by_id", user_id=123)
96
- ```
206
+
207
+ class SQLFileLoader:
208
+ """Loads and parses SQL files with aiosql-style named queries.
209
+
210
+ Provides functionality to load SQL files containing named queries
211
+ (using -- name: syntax) and retrieve them by name.
97
212
  """
98
213
 
99
214
  def __init__(self, *, encoding: str = "utf-8", storage_registry: StorageRegistry = storage_registry) -> None:
@@ -105,10 +220,68 @@ class SQLFileLoader:
105
220
  """
106
221
  self.encoding = encoding
107
222
  self.storage_registry = storage_registry
108
- # Instance-level storage for loaded queries and files
109
- self._queries: dict[str, str] = {}
223
+ self._queries: dict[str, NamedStatement] = {}
110
224
  self._files: dict[str, SQLFile] = {}
111
- self._query_to_file: dict[str, str] = {} # Maps query name to file path
225
+ self._query_to_file: dict[str, str] = {}
226
+
227
+ def _raise_file_not_found(self, path: str) -> None:
228
+ """Raise SQLFileNotFoundError for nonexistent file.
229
+
230
+ Args:
231
+ path: File path that was not found.
232
+
233
+ Raises:
234
+ SQLFileNotFoundError: Always raised.
235
+ """
236
+ raise SQLFileNotFoundError(path)
237
+
238
+ def _generate_file_cache_key(self, path: Union[str, Path]) -> str:
239
+ """Generate cache key for a file path.
240
+
241
+ Args:
242
+ path: File path to generate key for.
243
+
244
+ Returns:
245
+ Cache key string for the file.
246
+ """
247
+ path_str = str(path)
248
+ path_hash = hashlib.md5(path_str.encode(), usedforsecurity=False).hexdigest()
249
+ return f"file:{path_hash[:16]}"
250
+
251
+ def _calculate_file_checksum(self, path: Union[str, Path]) -> str:
252
+ """Calculate checksum for file content validation.
253
+
254
+ Args:
255
+ path: File path to calculate checksum for.
256
+
257
+ Returns:
258
+ MD5 checksum of file content.
259
+
260
+ Raises:
261
+ SQLFileParseError: If file cannot be read.
262
+ """
263
+ try:
264
+ content = self._read_file_content(path)
265
+ return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
266
+ except Exception as e:
267
+ raise SQLFileParseError(str(path), str(path), e) from e
268
+
269
+ def _is_file_unchanged(self, path: Union[str, Path], cached_file: CachedSQLFile) -> bool:
270
+ """Check if file has changed since caching.
271
+
272
+ Args:
273
+ path: File path to check.
274
+ cached_file: Cached file data.
275
+
276
+ Returns:
277
+ True if file is unchanged, False otherwise.
278
+ """
279
+ try:
280
+ current_checksum = self._calculate_file_checksum(path)
281
+ except Exception:
282
+ return False
283
+ else:
284
+ return current_checksum == cached_file.sql_file.checksum
112
285
 
113
286
  def _read_file_content(self, path: Union[str, Path]) -> str:
114
287
  """Read file content using storage backend.
@@ -120,8 +293,10 @@ class SQLFileLoader:
120
293
  File content as string.
121
294
 
122
295
  Raises:
123
- SQLFileParseError: If file cannot be read.
296
+ SQLFileNotFoundError: If file does not exist.
297
+ SQLFileParseError: If file cannot be read or parsed.
124
298
  """
299
+
125
300
  path_str = str(path)
126
301
 
127
302
  try:
@@ -129,6 +304,10 @@ class SQLFileLoader:
129
304
  return backend.read_text(path_str, encoding=self.encoding)
130
305
  except KeyError as e:
131
306
  raise SQLFileNotFoundError(path_str) from e
307
+ except StorageOperationFailedError as e:
308
+ if "not found" in str(e).lower() or "no such file" in str(e).lower():
309
+ raise SQLFileNotFoundError(path_str) from e
310
+ raise SQLFileParseError(path_str, path_str, e) from e
132
311
  except Exception as e:
133
312
  raise SQLFileParseError(path_str, path_str, e) from e
134
313
 
@@ -142,46 +321,91 @@ class SQLFileLoader:
142
321
  first_sql_line_index = i
143
322
  break
144
323
  if first_sql_line_index == -1:
145
- return "" # All comments or empty
324
+ return ""
146
325
  return "\n".join(lines[first_sql_line_index:]).strip()
147
326
 
148
327
  @staticmethod
149
- def _parse_sql_content(content: str, file_path: str) -> dict[str, str]:
150
- """Parse SQL content and extract named queries."""
151
- queries: dict[str, str] = {}
152
- matches = list(QUERY_NAME_PATTERN.finditer(content))
153
- if not matches:
328
+ def _parse_sql_content(content: str, file_path: str) -> "dict[str, NamedStatement]":
329
+ """Parse SQL content and extract named statements with dialect specifications.
330
+
331
+ Args:
332
+ content: Raw SQL file content to parse
333
+ file_path: File path for error reporting
334
+
335
+ Returns:
336
+ Dictionary mapping normalized statement names to NamedStatement objects
337
+
338
+ Raises:
339
+ SQLFileParseError: If no named statements found, duplicate names exist,
340
+ or invalid dialect names are specified
341
+ """
342
+ statements: dict[str, NamedStatement] = {}
343
+ content.splitlines()
344
+
345
+ name_matches = list(QUERY_NAME_PATTERN.finditer(content))
346
+ if not name_matches:
154
347
  raise SQLFileParseError(
155
- file_path, file_path, ValueError("No named SQL statements found (-- name: query_name)")
348
+ file_path, file_path, ValueError("No named SQL statements found (-- name: statement_name)")
156
349
  )
157
350
 
158
- for i, match in enumerate(matches):
159
- raw_query_name = match.group(1).strip()
351
+ for i, match in enumerate(name_matches):
352
+ raw_statement_name = match.group(1).strip()
353
+ statement_start_line = content[: match.start()].count("\n")
354
+
160
355
  start_pos = match.end()
161
- end_pos = matches[i + 1].start() if i + 1 < len(matches) else len(content)
356
+ end_pos = name_matches[i + 1].start() if i + 1 < len(name_matches) else len(content)
162
357
 
163
- sql_text = content[start_pos:end_pos].strip()
164
- if not raw_query_name or not sql_text:
358
+ statement_section = content[start_pos:end_pos].strip()
359
+ if not raw_statement_name or not statement_section:
165
360
  continue
166
361
 
167
- clean_sql = SQLFileLoader._strip_leading_comments(sql_text)
362
+ dialect = None
363
+ statement_sql = statement_section
364
+
365
+ section_lines = [line.strip() for line in statement_section.split("\n") if line.strip()]
366
+ if section_lines:
367
+ first_line = section_lines[0]
368
+ dialect_match = DIALECT_PATTERN.match(first_line)
369
+ if dialect_match:
370
+ declared_dialect = dialect_match.group("dialect").lower()
371
+
372
+ normalized_dialect = _normalize_dialect(declared_dialect)
373
+
374
+ if normalized_dialect not in SUPPORTED_DIALECTS:
375
+ suggestions = _get_dialect_suggestions(normalized_dialect)
376
+ warning_msg = f"Unknown dialect '{declared_dialect}' at line {statement_start_line + 1}"
377
+ if suggestions:
378
+ warning_msg += f". Did you mean: {', '.join(suggestions)}?"
379
+ warning_msg += (
380
+ f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
381
+ )
382
+ logger.warning(warning_msg)
383
+ dialect = declared_dialect.lower()
384
+ else:
385
+ dialect = normalized_dialect
386
+ remaining_lines = section_lines[1:]
387
+ statement_sql = "\n".join(remaining_lines)
388
+
389
+ clean_sql = SQLFileLoader._strip_leading_comments(statement_sql)
168
390
  if clean_sql:
169
- query_name = _normalize_query_name(raw_query_name)
170
- if query_name in queries:
171
- raise SQLFileParseError(file_path, file_path, ValueError(f"Duplicate query name: {raw_query_name}"))
172
- queries[query_name] = clean_sql
391
+ normalized_name = _normalize_query_name(raw_statement_name)
392
+ if normalized_name in statements:
393
+ raise SQLFileParseError(
394
+ file_path, file_path, ValueError(f"Duplicate statement name: {raw_statement_name}")
395
+ )
396
+
397
+ statements[normalized_name] = NamedStatement(
398
+ name=normalized_name, sql=clean_sql, dialect=dialect, start_line=statement_start_line
399
+ )
173
400
 
174
- if not queries:
175
- raise SQLFileParseError(file_path, file_path, ValueError("No valid SQL queries found after parsing"))
401
+ if not statements:
402
+ raise SQLFileParseError(file_path, file_path, ValueError("No valid SQL statements found after parsing"))
176
403
 
177
- return queries
404
+ return statements
178
405
 
179
406
  def load_sql(self, *paths: Union[str, Path]) -> None:
180
407
  """Load SQL files and parse named queries.
181
408
 
182
- Supports both individual files and directories. When loading directories,
183
- automatically namespaces queries based on subdirectory structure.
184
-
185
409
  Args:
186
410
  *paths: One or more file paths or directory paths to load.
187
411
  """
@@ -203,9 +427,11 @@ class SQLFileLoader:
203
427
  path_obj = Path(path)
204
428
  if path_obj.is_dir():
205
429
  loaded_count += self._load_directory(path_obj)
206
- else:
430
+ elif path_obj.exists():
207
431
  self._load_single_file(path_obj, None)
208
432
  loaded_count += 1
433
+ elif path_obj.suffix:
434
+ self._raise_file_not_found(str(path))
209
435
 
210
436
  duration = time.perf_counter() - start_time
211
437
  new_queries = len(self._queries) - query_count_before
@@ -250,23 +476,77 @@ class SQLFileLoader:
250
476
  return len(sql_files)
251
477
 
252
478
  def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
253
- """Load a single SQL file with optional namespace.
479
+ """Load a single SQL file with optional namespace and caching.
254
480
 
255
481
  Args:
256
- file_path: Path to the SQL file (can be string for URIs or Path for local files).
482
+ file_path: Path to the SQL file.
257
483
  namespace: Optional namespace prefix for queries.
258
484
  """
259
485
  path_str = str(file_path)
260
486
 
261
487
  if path_str in self._files:
262
- return # Already loaded
488
+ return
489
+
490
+ cache_config = get_cache_config()
491
+ if not cache_config.compiled_cache_enabled:
492
+ self._load_file_without_cache(file_path, namespace)
493
+ return
494
+
495
+ cache_key_str = self._generate_file_cache_key(file_path)
496
+ cache_key = CacheKey((cache_key_str,))
497
+ unified_cache = get_default_cache()
498
+ cached_file = unified_cache.get(cache_key)
499
+
500
+ if (
501
+ cached_file is not None
502
+ and isinstance(cached_file, CachedSQLFile)
503
+ and self._is_file_unchanged(file_path, cached_file)
504
+ ):
505
+ self._files[path_str] = cached_file.sql_file
506
+ for name, statement in cached_file.parsed_statements.items():
507
+ namespaced_name = f"{namespace}.{name}" if namespace else name
508
+ if namespaced_name in self._queries:
509
+ existing_file = self._query_to_file.get(namespaced_name, "unknown")
510
+ if existing_file != path_str:
511
+ raise SQLFileParseError(
512
+ path_str,
513
+ path_str,
514
+ ValueError(f"Query name '{namespaced_name}' already exists in file: {existing_file}"),
515
+ )
516
+ self._queries[namespaced_name] = statement
517
+ self._query_to_file[namespaced_name] = path_str
518
+ return
519
+
520
+ self._load_file_without_cache(file_path, namespace)
521
+
522
+ if path_str in self._files:
523
+ sql_file = self._files[path_str]
524
+ file_statements: dict[str, NamedStatement] = {}
525
+ for query_name, query_path in self._query_to_file.items():
526
+ if query_path == path_str:
527
+ stored_name = query_name
528
+ if namespace and query_name.startswith(f"{namespace}."):
529
+ stored_name = query_name[len(namespace) + 1 :]
530
+ file_statements[stored_name] = self._queries[query_name]
531
+
532
+ cached_file_data = CachedSQLFile(sql_file=sql_file, parsed_statements=file_statements)
533
+ unified_cache.put(cache_key, cached_file_data)
534
+
535
+ def _load_file_without_cache(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
536
+ """Load a single SQL file without caching.
537
+
538
+ Args:
539
+ file_path: Path to the SQL file.
540
+ namespace: Optional namespace prefix for queries.
541
+ """
542
+ path_str = str(file_path)
263
543
 
264
544
  content = self._read_file_content(file_path)
265
545
  sql_file = SQLFile(content=content, path=path_str)
266
546
  self._files[path_str] = sql_file
267
547
 
268
- queries = self._parse_sql_content(content, path_str)
269
- for name, sql in queries.items():
548
+ statements = self._parse_sql_content(content, path_str)
549
+ for name, statement in statements.items():
270
550
  namespaced_name = f"{namespace}.{name}" if namespace else name
271
551
  if namespaced_name in self._queries:
272
552
  existing_file = self._query_to_file.get(namespaced_name, "unknown")
@@ -276,15 +556,16 @@ class SQLFileLoader:
276
556
  path_str,
277
557
  ValueError(f"Query name '{namespaced_name}' already exists in file: {existing_file}"),
278
558
  )
279
- self._queries[namespaced_name] = sql
559
+ self._queries[namespaced_name] = statement
280
560
  self._query_to_file[namespaced_name] = path_str
281
561
 
282
- def add_named_sql(self, name: str, sql: str) -> None:
562
+ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) -> None:
283
563
  """Add a named SQL query directly without loading from a file.
284
564
 
285
565
  Args:
286
566
  name: Name for the SQL query.
287
567
  sql: Raw SQL content.
568
+ dialect: Optional dialect for the SQL statement.
288
569
 
289
570
  Raises:
290
571
  ValueError: If query name already exists.
@@ -294,74 +575,100 @@ class SQLFileLoader:
294
575
  msg = f"Query name '{name}' already exists (source: {existing_source})"
295
576
  raise ValueError(msg)
296
577
 
297
- self._queries[name] = sql.strip()
578
+ if dialect is not None:
579
+ normalized_dialect = _normalize_dialect(dialect)
580
+ if normalized_dialect not in SUPPORTED_DIALECTS:
581
+ suggestions = _get_dialect_suggestions(normalized_dialect)
582
+ warning_msg = f"Unknown dialect '{dialect}'"
583
+ if suggestions:
584
+ warning_msg += f". Did you mean: {', '.join(suggestions)}?"
585
+ warning_msg += f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
586
+ logger.warning(warning_msg)
587
+ dialect = dialect.lower()
588
+ else:
589
+ dialect = normalized_dialect
590
+
591
+ statement = NamedStatement(name=name, sql=sql.strip(), dialect=dialect, start_line=0)
592
+ self._queries[name] = statement
298
593
  self._query_to_file[name] = "<directly added>"
299
594
 
300
- def get_sql(self, name: str, parameters: "Optional[Any]" = None, **kwargs: "Any") -> "SQL":
301
- """Get a SQL object by query name.
595
+ def get_sql(
596
+ self, name: str, parameters: "Optional[Any]" = None, dialect: "Optional[str]" = None, **kwargs: "Any"
597
+ ) -> "SQL":
598
+ """Get a SQL object by statement name with dialect support.
302
599
 
303
600
  Args:
304
- name: Name of the query (from -- name: in SQL file).
305
- Hyphens in names are automatically converted to underscores.
306
- parameters: Parameters for the SQL query (aiosql-compatible).
601
+ name: Name of the statement (from -- name: in SQL file).
602
+ Hyphens in names are converted to underscores.
603
+ parameters: Parameters for the SQL statement.
604
+ dialect: Optional dialect override.
307
605
  **kwargs: Additional parameters to pass to the SQL object.
308
606
 
309
607
  Returns:
310
608
  SQL object ready for execution.
311
609
 
312
610
  Raises:
313
- SQLFileNotFoundError: If query name not found.
611
+ SQLFileNotFoundError: If statement name not found.
314
612
  """
315
613
  correlation_id = CorrelationContext.get()
316
614
 
317
- # Normalize query name for lookup
318
615
  safe_name = _normalize_query_name(name)
319
616
 
320
- logger.debug(
321
- "Retrieving SQL query: %s",
322
- name,
323
- extra={
324
- "query_name": name,
325
- "safe_name": safe_name,
326
- "has_parameters": parameters is not None,
327
- "correlation_id": correlation_id,
328
- },
329
- )
330
-
331
617
  if safe_name not in self._queries:
332
618
  available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
333
619
  logger.error(
334
- "Query not found: %s",
620
+ "Statement not found: %s",
335
621
  name,
336
622
  extra={
337
- "query_name": name,
623
+ "statement_name": name,
338
624
  "safe_name": safe_name,
339
- "available_queries": len(self._queries),
625
+ "available_statements": len(self._queries),
340
626
  "correlation_id": correlation_id,
341
627
  },
342
628
  )
343
- raise SQLFileNotFoundError(name, path=f"Query '{name}' not found. Available queries: {available}")
629
+ raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
630
+
631
+ parsed_statement = self._queries[safe_name]
632
+
633
+ effective_dialect = dialect or parsed_statement.dialect
634
+
635
+ if dialect is not None:
636
+ normalized_dialect = _normalize_dialect(dialect)
637
+ if normalized_dialect not in SUPPORTED_DIALECTS:
638
+ suggestions = _get_dialect_suggestions(normalized_dialect)
639
+ warning_msg = f"Unknown dialect '{dialect}'"
640
+ if suggestions:
641
+ warning_msg += f". Did you mean: {', '.join(suggestions)}?"
642
+ warning_msg += f". Supported dialects: {', '.join(sorted(SUPPORTED_DIALECTS))}. Using dialect as-is."
643
+ logger.warning(warning_msg)
644
+ effective_dialect = dialect.lower()
645
+ else:
646
+ effective_dialect = normalized_dialect
344
647
 
345
648
  sql_kwargs = dict(kwargs)
346
649
  if parameters is not None:
347
650
  sql_kwargs["parameters"] = parameters
348
651
 
349
- source_file = self._query_to_file.get(safe_name, "unknown")
350
-
351
- logger.debug(
352
- "Found query %s from %s",
353
- name,
354
- source_file,
355
- extra={
356
- "query_name": name,
357
- "safe_name": safe_name,
358
- "source_file": source_file,
359
- "sql_length": len(self._queries[safe_name]),
360
- "correlation_id": correlation_id,
361
- },
362
- )
652
+ sqlglot_dialect = None
653
+ if effective_dialect:
654
+ sqlglot_dialect = _normalize_dialect_for_sqlglot(effective_dialect)
655
+
656
+ if not effective_dialect and "statement_config" not in sql_kwargs:
657
+ validator = ParameterValidator()
658
+ param_info = validator.extract_parameters(parsed_statement.sql)
659
+ if param_info:
660
+ styles = {p.style for p in param_info}
661
+ if styles:
662
+ detected_style = next(iter(styles))
663
+ sql_kwargs["statement_config"] = StatementConfig(
664
+ parameter_config=ParameterStyleConfig(
665
+ default_parameter_style=detected_style,
666
+ supported_parameter_styles=styles,
667
+ preserve_parameter_format=True,
668
+ )
669
+ )
363
670
 
364
- return SQL(self._queries[safe_name], **sql_kwargs)
671
+ return SQL(parsed_statement.sql, dialect=sqlglot_dialect, **sql_kwargs)
365
672
 
366
673
  def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
367
674
  """Get a loaded SQLFile object by path.
@@ -375,7 +682,7 @@ class SQLFileLoader:
375
682
  return self._files.get(str(path))
376
683
 
377
684
  def get_file_for_query(self, name: str) -> "Optional[SQLFile]":
378
- """Get the SQLFile object that contains a query.
685
+ """Get the SQLFile object containing a query.
379
686
 
380
687
  Args:
381
688
  name: Query name (hyphens are converted to underscores).
@@ -409,7 +716,7 @@ class SQLFileLoader:
409
716
  """Check if a query exists.
410
717
 
411
718
  Args:
412
- name: Query name to check (hyphens are converted to underscores).
719
+ name: Query name to check.
413
720
 
414
721
  Returns:
415
722
  True if query exists.
@@ -423,11 +730,23 @@ class SQLFileLoader:
423
730
  self._queries.clear()
424
731
  self._query_to_file.clear()
425
732
 
733
+ cache_config = get_cache_config()
734
+ if cache_config.compiled_cache_enabled:
735
+ unified_cache = get_default_cache()
736
+ unified_cache.clear()
737
+
738
+ def clear_file_cache(self) -> None:
739
+ """Clear the file cache only, keeping loaded queries."""
740
+ cache_config = get_cache_config()
741
+ if cache_config.compiled_cache_enabled:
742
+ unified_cache = get_default_cache()
743
+ unified_cache.clear()
744
+
426
745
  def get_query_text(self, name: str) -> str:
427
746
  """Get raw SQL text for a query.
428
747
 
429
748
  Args:
430
- name: Query name (hyphens are converted to underscores).
749
+ name: Query name.
431
750
 
432
751
  Returns:
433
752
  Raw SQL text.
@@ -438,4 +757,4 @@ class SQLFileLoader:
438
757
  safe_name = _normalize_query_name(name)
439
758
  if safe_name not in self._queries:
440
759
  raise SQLFileNotFoundError(name)
441
- return self._queries[safe_name]
760
+ return self._queries[safe_name].sql