sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +100 -130
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +125 -167
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +114 -111
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +18 -31
  19. sqlspec/adapters/psycopg/driver.py +283 -236
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +103 -97
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +67 -181
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +25 -90
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +67 -22
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +91 -67
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +863 -391
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +53 -8
  91. sqlspec/storage/backends/obstore.py +15 -19
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -173
  110. sqlspec-0.12.2.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,138 @@
1
+ """Consolidated parameter processing utilities for database drivers.
2
+
3
+ This module provides centralized parameter handling logic to avoid duplication
4
+ across sync and async driver implementations.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, Optional, Union
8
+
9
+ from sqlspec.statement.filters import StatementFilter
10
+ from sqlspec.utils.type_guards import is_sync_transaction_capable
11
+
12
+ if TYPE_CHECKING:
13
+ from sqlspec.typing import StatementParameters
14
+
15
+ __all__ = (
16
+ "convert_parameters_to_positional",
17
+ "normalize_parameter_sequence",
18
+ "process_execute_many_parameters",
19
+ "separate_filters_and_parameters",
20
+ "should_use_transaction",
21
+ )
22
+
23
+
24
+ def separate_filters_and_parameters(
25
+ parameters: "tuple[Union[StatementParameters, StatementFilter], ...]",
26
+ ) -> "tuple[list[StatementFilter], list[Any]]":
27
+ """Separate filters from parameters in a mixed parameter tuple.
28
+
29
+ Args:
30
+ parameters: Mixed tuple of parameters and filters
31
+
32
+ Returns:
33
+ Tuple of (filters, parameters) lists
34
+ """
35
+
36
+ filters: list[StatementFilter] = []
37
+ param_values: list[Any] = []
38
+
39
+ for param in parameters:
40
+ if isinstance(param, StatementFilter):
41
+ filters.append(param)
42
+ else:
43
+ param_values.append(param)
44
+
45
+ return filters, param_values
46
+
47
+
48
+ def process_execute_many_parameters(
49
+ parameters: "tuple[Union[StatementParameters, StatementFilter], ...]",
50
+ ) -> "tuple[list[StatementFilter], Optional[list[Any]]]":
51
+ """Process parameters for execute_many operations.
52
+
53
+ Args:
54
+ parameters: Mixed tuple of parameters and filters
55
+
56
+ Returns:
57
+ Tuple of (filters, parameter_sequence)
58
+ """
59
+ filters, param_values = separate_filters_and_parameters(parameters)
60
+
61
+ # Use first parameter as the sequence for execute_many
62
+ param_sequence = param_values[0] if param_values else None
63
+
64
+ # Normalize the parameter sequence
65
+ param_sequence = normalize_parameter_sequence(param_sequence)
66
+
67
+ return filters, param_sequence
68
+
69
+
70
+ def normalize_parameter_sequence(params: Any) -> Optional[list[Any]]:
71
+ """Normalize a parameter sequence to a list format.
72
+
73
+ Args:
74
+ params: Parameter sequence in various formats
75
+
76
+ Returns:
77
+ Normalized list of parameters or None
78
+ """
79
+ if params is None:
80
+ return None
81
+
82
+ if isinstance(params, list):
83
+ return params
84
+
85
+ if isinstance(params, tuple):
86
+ return list(params)
87
+
88
+ # Check if it's iterable (but not string or dict)
89
+ # Use duck typing to check for iterable protocol
90
+ try:
91
+ iter(params)
92
+ if not isinstance(params, (str, dict)):
93
+ return list(params)
94
+ except TypeError:
95
+ pass
96
+
97
+ # Single parameter, wrap in list
98
+ return [params]
99
+
100
+
101
+ def convert_parameters_to_positional(params: "dict[str, Any]", parameter_info: "list[Any]") -> list[Any]:
102
+ """Convert named parameters to positional based on SQL order.
103
+
104
+ Args:
105
+ params: Dictionary of named parameters
106
+ parameter_info: List of parameter info from SQL parsing
107
+
108
+ Returns:
109
+ List of positional parameters
110
+ """
111
+ if not params:
112
+ return []
113
+
114
+ # Handle param_0, param_1, etc. pattern
115
+ if all(key.startswith("param_") for key in params):
116
+ return [params[f"param_{i}"] for i in range(len(params))]
117
+
118
+ # Convert based on parameter info order
119
+ # Check for name attribute using getattr with default
120
+ result = []
121
+ for info in parameter_info:
122
+ param_name = getattr(info, "name", None)
123
+ if param_name is not None:
124
+ result.append(params.get(param_name, None))
125
+ return result
126
+
127
+
128
+ def should_use_transaction(connection: Any, auto_commit: bool = True) -> bool:
129
+ """Determine if a transaction should be used.
130
+
131
+ Args:
132
+ connection: Database connection object
133
+ auto_commit: Whether auto-commit is enabled
134
+
135
+ Returns:
136
+ True if transaction capabilities are available and should be used
137
+ """
138
+ return False if auto_commit else is_sync_transaction_capable(connection)
sqlspec/exceptions.py CHANGED
@@ -19,6 +19,7 @@ __all__ = (
19
19
  "RepositoryError",
20
20
  "RiskLevel",
21
21
  "SQLBuilderError",
22
+ "SQLCompilationError",
22
23
  "SQLConversionError",
23
24
  "SQLFileNotFoundError",
24
25
  "SQLFileParseError",
@@ -122,6 +123,15 @@ class SQLBuilderError(SQLSpecError):
122
123
  super().__init__(message)
123
124
 
124
125
 
126
+ class SQLCompilationError(SQLSpecError):
127
+ """Issues Compiling SQL statements."""
128
+
129
+ def __init__(self, message: Optional[str] = None) -> None:
130
+ if message is None:
131
+ message = "Issues compiling SQL statement."
132
+ super().__init__(message)
133
+
134
+
125
135
  class SQLConversionError(SQLSpecError):
126
136
  """Issues converting SQL statements."""
127
137
 
@@ -374,7 +384,6 @@ def wrap_exceptions(
374
384
  yield
375
385
 
376
386
  except Exception as exc:
377
- # Handle suppression first
378
387
  if suppress is not None and (
379
388
  (isinstance(suppress, type) and isinstance(exc, suppress))
380
389
  or (isinstance(suppress, tuple) and isinstance(exc, suppress))
@@ -385,7 +394,6 @@ def wrap_exceptions(
385
394
  if isinstance(exc, SQLSpecError):
386
395
  raise
387
396
 
388
- # Handle wrapping
389
397
  if wrap_exceptions is False:
390
398
  raise
391
399
  msg = "An error occurred during the operation."
@@ -40,11 +40,9 @@ def _normalize_dialect(dialect: "Union[str, Any, None]") -> str:
40
40
  Returns:
41
41
  Normalized dialect name
42
42
  """
43
- # Handle different dialect types
44
43
  if dialect is None:
45
44
  return "sql"
46
45
 
47
- # Extract string from dialect class or instance
48
46
  if hasattr(dialect, "__name__"): # It's a class
49
47
  dialect_str = str(dialect.__name__).lower() # pyright: ignore
50
48
  elif hasattr(dialect, "name"): # It's an instance with name attribute
@@ -134,7 +132,6 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
134
132
  "Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
135
133
  )
136
134
 
137
- # Create SQL object and apply filters
138
135
  sql_obj = self._create_sql_object(sql, parameters)
139
136
  # Execute using SQLSpec driver
140
137
  result = self.driver.execute(sql_obj, connection=conn)
@@ -192,12 +189,9 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
192
189
  return None
193
190
 
194
191
  if isinstance(row, dict):
195
- # Return first value from dict
196
192
  return next(iter(row.values())) if row else None
197
193
  if hasattr(row, "__getitem__"):
198
- # Handle tuple/list-like objects
199
194
  return row[0] if len(row) > 0 else None
200
- # Handle scalar or object with attributes
201
195
  return row
202
196
 
203
197
  @contextmanager
@@ -216,7 +210,6 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
216
210
  sql_obj = self._create_sql_object(sql, parameters)
217
211
  result = self.driver.execute(sql_obj, connection=conn)
218
212
 
219
- # Create a cursor-like object
220
213
  class CursorLike:
221
214
  def __init__(self, result: Any) -> None:
222
215
  self.result = result
@@ -386,12 +379,9 @@ class AiosqlAsyncAdapter(_AiosqlAdapterBase):
386
379
  return None
387
380
 
388
381
  if isinstance(row, dict):
389
- # Return first value from dict
390
382
  return next(iter(row.values())) if row else None
391
383
  if hasattr(row, "__getitem__"):
392
- # Handle tuple/list-like objects
393
384
  return row[0] if len(row) > 0 else None
394
- # Handle scalar or object with attributes
395
385
  return row
396
386
 
397
387
  @asynccontextmanager
@@ -246,7 +246,6 @@ def session_provider_maker(
246
246
 
247
247
  conn_type_annotation = config.connection_type
248
248
 
249
- # Import Dependency at function level to avoid circular imports
250
249
  from litestar.params import Dependency
251
250
 
252
251
  db_conn_param = inspect.Parameter(
@@ -69,7 +69,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
69
69
  [SQLSpec, ConnectionT, PoolT, DriverT, DatabaseConfig, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT]
70
70
  )
71
71
 
72
- # Create signature namespace for connection types
73
72
  signature_namespace = {}
74
73
 
75
74
  for c in self._plugin_configs:
@@ -78,7 +77,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
78
77
  app_config.signature_types.append(c.config.connection_type) # type: ignore[union-attr]
79
78
  app_config.signature_types.append(c.config.driver_type) # type: ignore[union-attr]
80
79
 
81
- # Get signature namespace from the config
82
80
  if hasattr(c.config, "get_signature_namespace"):
83
81
  config_namespace = c.config.get_signature_namespace() # type: ignore[attr-defined]
84
82
  signature_namespace.update(config_namespace)
@@ -93,7 +91,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
93
91
  }
94
92
  )
95
93
 
96
- # Update app config with signature namespace
97
94
  if signature_namespace:
98
95
  app_config.signature_namespace.update(signature_namespace)
99
96
 
@@ -173,7 +173,6 @@ def _make_hashable(value: Any) -> HashableType:
173
173
  A hashable version of the value.
174
174
  """
175
175
  if isinstance(value, dict):
176
- # Convert dict to tuple of tuples with sorted keys
177
176
  items = []
178
177
  for k in sorted(value.keys()): # pyright: ignore
179
178
  v = value[k] # pyright: ignore
@@ -261,7 +260,6 @@ def _create_statement_filters(
261
260
  required=False,
262
261
  ),
263
262
  ) -> SearchFilter:
264
- # Handle both string and set input types for search fields
265
263
  field_names = set(search_fields.split(",")) if isinstance(search_fields, str) else set(search_fields)
266
264
 
267
265
  return SearchFilter(
@@ -286,9 +284,7 @@ def _create_statement_filters(
286
284
 
287
285
  filters[dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY] = Provide(provide_order_by, sync_to_thread=False)
288
286
 
289
- # Add not_in filter providers
290
287
  if not_in_fields := config.get("not_in_fields"):
291
- # Get all field names, handling both strings and FieldNameType objects
292
288
  not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
293
289
 
294
290
  for field_def in not_in_fields:
@@ -313,9 +309,7 @@ def _create_statement_filters(
313
309
  provider = create_not_in_filter_provider(field_def) # pyright: ignore
314
310
  filters[f"{field_def.name}_not_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
315
311
 
316
- # Add in filter providers
317
312
  if in_fields := config.get("in_fields"):
318
- # Get all field names, handling both strings and FieldNameType objects
319
313
  in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
320
314
 
321
315
  for field_def in in_fields:
@@ -361,7 +355,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
361
355
  parameters: dict[str, inspect.Parameter] = {}
362
356
  annotations: dict[str, Any] = {}
363
357
 
364
- # Build parameters based on config
365
358
  if cls := config.get("id_filter"):
366
359
  parameters["id_filter"] = inspect.Parameter(
367
360
  name="id_filter",
@@ -416,7 +409,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
416
409
  )
417
410
  annotations["order_by_filter"] = OrderByFilter
418
411
 
419
- # Add parameters for not_in filters
420
412
  if not_in_fields := config.get("not_in_fields"):
421
413
  for field_def in not_in_fields:
422
414
  field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
@@ -428,7 +420,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
428
420
  )
429
421
  annotations[f"{field_def.name}_not_in_filter"] = NotInCollectionFilter[field_def.type_hint] # type: ignore
430
422
 
431
- # Add parameters for in filters
432
423
  if in_fields := config.get("in_fields"):
433
424
  for field_def in in_fields:
434
425
  field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
@@ -472,9 +463,7 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
472
463
  ):
473
464
  filters.append(order_by)
474
465
 
475
- # Add not_in filters
476
466
  if not_in_fields := config.get("not_in_fields"):
477
- # Get all field names, handling both strings and FieldNameType objects
478
467
  not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
479
468
  for field_def in not_in_fields:
480
469
  field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
@@ -482,9 +471,7 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
482
471
  if filter_ is not None:
483
472
  filters.append(filter_)
484
473
 
485
- # Add in filters
486
474
  if in_fields := config.get("in_fields"):
487
- # Get all field names, handling both strings and FieldNameType objects
488
475
  in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
489
476
  for field_def in in_fields:
490
477
  field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
@@ -493,7 +480,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
493
480
  filters.append(filter_)
494
481
  return filters
495
482
 
496
- # Set both signature and annotations
497
483
  provide_filters.__signature__ = inspect.Signature( # type: ignore
498
484
  parameters=list(parameters.values()), return_annotation=list[FilterTypes]
499
485
  )
sqlspec/loader.py CHANGED
@@ -26,7 +26,7 @@ logger = get_logger("loader")
26
26
  # Matches: -- name: query_name (supports hyphens and special suffixes)
27
27
  # We capture the name plus any trailing special characters
28
28
  QUERY_NAME_PATTERN = re.compile(r"^\s*--\s*name\s*:\s*([\w-]+[^\w\s]*)\s*$", re.MULTILINE | re.IGNORECASE)
29
-
29
+ TRIM_TRAILING_SPECIAL_CHARS = re.compile(r"[^\w-]+$")
30
30
  MIN_QUERY_PARTS = 3
31
31
 
32
32
 
@@ -42,10 +42,8 @@ def _normalize_query_name(name: str) -> str:
42
42
  Returns:
43
43
  Normalized query name suitable as Python identifier
44
44
  """
45
- # First strip any trailing special characters
46
- name = re.sub(r"[^\w-]+$", "", name)
47
- # Then replace hyphens with underscores
48
- return name.replace("-", "_")
45
+ # Strip trailing non-alphanumeric characters (excluding underscore) and replace hyphens
46
+ return TRIM_TRAILING_SPECIAL_CHARS.sub("", name).replace("-", "_")
49
47
 
50
48
 
51
49
  @dataclass
@@ -127,8 +125,6 @@ class SQLFileLoader:
127
125
  path_str = str(path)
128
126
 
129
127
  try:
130
- # Always use storage backend for consistent behavior
131
- # Pass the original path object to allow storage registry to handle Path -> file:// conversion
132
128
  backend = self.storage_registry.get(path)
133
129
  return backend.read_text(path_str, encoding=self.encoding)
134
130
  except KeyError as e:
@@ -151,48 +147,27 @@ class SQLFileLoader:
151
147
 
152
148
  @staticmethod
153
149
  def _parse_sql_content(content: str, file_path: str) -> dict[str, str]:
154
- """Parse SQL content and extract named queries.
155
-
156
- Args:
157
- content: SQL file content.
158
- file_path: Path to the file (for error messages).
159
-
160
- Returns:
161
- Dictionary mapping query names to SQL text.
162
-
163
- Raises:
164
- SQLFileParseError: If no named queries found.
165
- """
150
+ """Parse SQL content and extract named queries."""
166
151
  queries: dict[str, str] = {}
167
-
168
- # Split content by query name patterns
169
- parts = QUERY_NAME_PATTERN.split(content)
170
-
171
- if len(parts) < MIN_QUERY_PARTS:
172
- # No named queries found
152
+ matches = list(QUERY_NAME_PATTERN.finditer(content))
153
+ if not matches:
173
154
  raise SQLFileParseError(
174
155
  file_path, file_path, ValueError("No named SQL statements found (-- name: query_name)")
175
156
  )
176
157
 
177
- # Process each named query
178
- for i in range(1, len(parts), 2):
179
- if i + 1 >= len(parts):
180
- break
181
-
182
- raw_query_name = parts[i].strip()
183
- sql_text = parts[i + 1].strip()
158
+ for i, match in enumerate(matches):
159
+ raw_query_name = match.group(1).strip()
160
+ start_pos = match.end()
161
+ end_pos = matches[i + 1].start() if i + 1 < len(matches) else len(content)
184
162
 
163
+ sql_text = content[start_pos:end_pos].strip()
185
164
  if not raw_query_name or not sql_text:
186
165
  continue
187
166
 
188
167
  clean_sql = SQLFileLoader._strip_leading_comments(sql_text)
189
-
190
168
  if clean_sql:
191
- # Normalize to Python-compatible identifier
192
169
  query_name = _normalize_query_name(raw_query_name)
193
-
194
170
  if query_name in queries:
195
- # Duplicate query name
196
171
  raise SQLFileParseError(file_path, file_path, ValueError(f"Duplicate query name: {raw_query_name}"))
197
172
  queries[query_name] = clean_sql
198
173
 
@@ -221,19 +196,13 @@ class SQLFileLoader:
221
196
  try:
222
197
  for path in paths:
223
198
  path_str = str(path)
224
-
225
- # Check if it's a URI
226
199
  if "://" in path_str:
227
- # URIs are always treated as files, not directories
228
200
  self._load_single_file(path, None)
229
201
  loaded_count += 1
230
202
  else:
231
- # Local path - check if it's a directory or file
232
203
  path_obj = Path(path)
233
204
  if path_obj.is_dir():
234
- file_count_before = len(self._files)
235
- self._load_directory(path_obj)
236
- loaded_count += len(self._files) - file_count_before
205
+ loaded_count += self._load_directory(path_obj)
237
206
  else:
238
207
  self._load_single_file(path_obj, None)
239
208
  loaded_count += 1
@@ -267,31 +236,18 @@ class SQLFileLoader:
267
236
  )
268
237
  raise
269
238
 
270
- def _load_directory(self, dir_path: Path) -> None:
271
- """Load all SQL files from a directory with namespacing.
272
-
273
- Args:
274
- dir_path: Directory path to scan for SQL files.
275
-
276
- Raises:
277
- SQLFileParseError: If directory contains no SQL files.
278
- """
239
+ def _load_directory(self, dir_path: Path) -> int:
240
+ """Load all SQL files from a directory with namespacing."""
279
241
  sql_files = list(dir_path.rglob("*.sql"))
280
-
281
242
  if not sql_files:
282
- raise SQLFileParseError(
283
- str(dir_path), str(dir_path), ValueError(f"No SQL files found in directory: {dir_path}")
284
- )
243
+ return 0
285
244
 
286
245
  for file_path in sql_files:
287
- # Calculate namespace based on relative path from base directory
288
246
  relative_path = file_path.relative_to(dir_path)
289
247
  namespace_parts = relative_path.parent.parts
290
-
291
- # Create namespace (empty for root-level files)
292
248
  namespace = ".".join(namespace_parts) if namespace_parts else None
293
-
294
249
  self._load_single_file(file_path, namespace)
250
+ return len(sql_files)
295
251
 
296
252
  def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
297
253
  """Load a single SQL file with optional namespace.
@@ -302,42 +258,24 @@ class SQLFileLoader:
302
258
  """
303
259
  path_str = str(file_path)
304
260
 
305
- # Check if already loaded
306
261
  if path_str in self._files:
307
- # File already loaded, just ensure queries are in the main dict
308
- file_obj = self._files[path_str]
309
- queries = self._parse_sql_content(file_obj.content, path_str)
310
- for name in queries:
311
- namespaced_name = f"{namespace}.{name}" if namespace else name
312
- if namespaced_name not in self._queries:
313
- self._queries[namespaced_name] = queries[name]
314
- self._query_to_file[namespaced_name] = path_str
315
- return
316
-
317
- # Read file content
318
- content = self._read_file_content(file_path)
262
+ return # Already loaded
319
263
 
320
- # Create SQLFile object
264
+ content = self._read_file_content(file_path)
321
265
  sql_file = SQLFile(content=content, path=path_str)
322
-
323
- # Cache the file
324
266
  self._files[path_str] = sql_file
325
267
 
326
- # Parse and cache queries
327
268
  queries = self._parse_sql_content(content, path_str)
328
-
329
- # Merge into main query dictionary with namespace
330
269
  for name, sql in queries.items():
331
270
  namespaced_name = f"{namespace}.{name}" if namespace else name
332
-
333
- if namespaced_name in self._queries and self._query_to_file.get(namespaced_name) != path_str:
334
- # Query name exists from a different file
271
+ if namespaced_name in self._queries:
335
272
  existing_file = self._query_to_file.get(namespaced_name, "unknown")
336
- raise SQLFileParseError(
337
- path_str,
338
- path_str,
339
- ValueError(f"Query name '{namespaced_name}' already exists in file: {existing_file}"),
340
- )
273
+ if existing_file != path_str:
274
+ raise SQLFileParseError(
275
+ path_str,
276
+ path_str,
277
+ ValueError(f"Query name '{namespaced_name}' already exists in file: {existing_file}"),
278
+ )
341
279
  self._queries[namespaced_name] = sql
342
280
  self._query_to_file[namespaced_name] = path_str
343
281
 
@@ -357,7 +295,6 @@ class SQLFileLoader:
357
295
  raise ValueError(msg)
358
296
 
359
297
  self._queries[name] = sql.strip()
360
- # Use special marker for directly added queries
361
298
  self._query_to_file[name] = "<directly added>"
362
299
 
363
300
  def get_sql(self, name: str, parameters: "Optional[Any]" = None, **kwargs: "Any") -> "SQL":
@@ -405,12 +342,10 @@ class SQLFileLoader:
405
342
  )
406
343
  raise SQLFileNotFoundError(name, path=f"Query '{name}' not found. Available queries: {available}")
407
344
 
408
- # Merge parameters and kwargs for SQL object creation
409
345
  sql_kwargs = dict(kwargs)
410
346
  if parameters is not None:
411
347
  sql_kwargs["parameters"] = parameters
412
348
 
413
- # Get source file for additional context
414
349
  source_file = self._query_to_file.get(safe_name, "unknown")
415
350
 
416
351
  logger.debug(