sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +116 -141
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +231 -181
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +132 -124
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +34 -30
  19. sqlspec/adapters/psycopg/driver.py +342 -214
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +150 -104
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +149 -216
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +31 -118
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +70 -23
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +102 -65
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +885 -379
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +82 -35
  91. sqlspec/storage/backends/obstore.py +66 -49
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -170
  110. sqlspec-0.12.1.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -7,6 +7,8 @@ TypedParameter objects and perform appropriate type conversions.
7
7
  from decimal import Decimal
8
8
  from typing import TYPE_CHECKING, Any, Optional, Union
9
9
 
10
+ from sqlspec.utils.type_guards import has_parameter_value
11
+
10
12
  if TYPE_CHECKING:
11
13
  from sqlspec.typing import SQLParameterType
12
14
 
@@ -68,13 +70,10 @@ class TypeCoercionMixin:
68
70
  Returns:
69
71
  Coerced parameter value suitable for the database
70
72
  """
71
- # Check if it's a TypedParameter
72
- if hasattr(param, "__class__") and param.__class__.__name__ == "TypedParameter":
73
- # Extract value and type hint
73
+ if has_parameter_value(param):
74
74
  value = param.value
75
75
  type_hint = param.type_hint
76
76
 
77
- # Apply driver-specific coercion based on type hint
78
77
  return self._apply_type_coercion(value, type_hint)
79
78
  # Regular parameter - apply default coercion
80
79
  return self._apply_type_coercion(param, None)
@@ -0,0 +1,138 @@
1
+ """Consolidated parameter processing utilities for database drivers.
2
+
3
+ This module provides centralized parameter handling logic to avoid duplication
4
+ across sync and async driver implementations.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, Optional, Union
8
+
9
+ from sqlspec.statement.filters import StatementFilter
10
+ from sqlspec.utils.type_guards import is_sync_transaction_capable
11
+
12
+ if TYPE_CHECKING:
13
+ from sqlspec.typing import StatementParameters
14
+
15
+ __all__ = (
16
+ "convert_parameters_to_positional",
17
+ "normalize_parameter_sequence",
18
+ "process_execute_many_parameters",
19
+ "separate_filters_and_parameters",
20
+ "should_use_transaction",
21
+ )
22
+
23
+
24
+ def separate_filters_and_parameters(
25
+ parameters: "tuple[Union[StatementParameters, StatementFilter], ...]",
26
+ ) -> "tuple[list[StatementFilter], list[Any]]":
27
+ """Separate filters from parameters in a mixed parameter tuple.
28
+
29
+ Args:
30
+ parameters: Mixed tuple of parameters and filters
31
+
32
+ Returns:
33
+ Tuple of (filters, parameters) lists
34
+ """
35
+
36
+ filters: list[StatementFilter] = []
37
+ param_values: list[Any] = []
38
+
39
+ for param in parameters:
40
+ if isinstance(param, StatementFilter):
41
+ filters.append(param)
42
+ else:
43
+ param_values.append(param)
44
+
45
+ return filters, param_values
46
+
47
+
48
+ def process_execute_many_parameters(
49
+ parameters: "tuple[Union[StatementParameters, StatementFilter], ...]",
50
+ ) -> "tuple[list[StatementFilter], Optional[list[Any]]]":
51
+ """Process parameters for execute_many operations.
52
+
53
+ Args:
54
+ parameters: Mixed tuple of parameters and filters
55
+
56
+ Returns:
57
+ Tuple of (filters, parameter_sequence)
58
+ """
59
+ filters, param_values = separate_filters_and_parameters(parameters)
60
+
61
+ # Use first parameter as the sequence for execute_many
62
+ param_sequence = param_values[0] if param_values else None
63
+
64
+ # Normalize the parameter sequence
65
+ param_sequence = normalize_parameter_sequence(param_sequence)
66
+
67
+ return filters, param_sequence
68
+
69
+
70
+ def normalize_parameter_sequence(params: Any) -> Optional[list[Any]]:
71
+ """Normalize a parameter sequence to a list format.
72
+
73
+ Args:
74
+ params: Parameter sequence in various formats
75
+
76
+ Returns:
77
+ Normalized list of parameters or None
78
+ """
79
+ if params is None:
80
+ return None
81
+
82
+ if isinstance(params, list):
83
+ return params
84
+
85
+ if isinstance(params, tuple):
86
+ return list(params)
87
+
88
+ # Check if it's iterable (but not string or dict)
89
+ # Use duck typing to check for iterable protocol
90
+ try:
91
+ iter(params)
92
+ if not isinstance(params, (str, dict)):
93
+ return list(params)
94
+ except TypeError:
95
+ pass
96
+
97
+ # Single parameter, wrap in list
98
+ return [params]
99
+
100
+
101
+ def convert_parameters_to_positional(params: "dict[str, Any]", parameter_info: "list[Any]") -> list[Any]:
102
+ """Convert named parameters to positional based on SQL order.
103
+
104
+ Args:
105
+ params: Dictionary of named parameters
106
+ parameter_info: List of parameter info from SQL parsing
107
+
108
+ Returns:
109
+ List of positional parameters
110
+ """
111
+ if not params:
112
+ return []
113
+
114
+ # Handle param_0, param_1, etc. pattern
115
+ if all(key.startswith("param_") for key in params):
116
+ return [params[f"param_{i}"] for i in range(len(params))]
117
+
118
+ # Convert based on parameter info order
119
+ # Check for name attribute using getattr with default
120
+ result = []
121
+ for info in parameter_info:
122
+ param_name = getattr(info, "name", None)
123
+ if param_name is not None:
124
+ result.append(params.get(param_name, None))
125
+ return result
126
+
127
+
128
+ def should_use_transaction(connection: Any, auto_commit: bool = True) -> bool:
129
+ """Determine if a transaction should be used.
130
+
131
+ Args:
132
+ connection: Database connection object
133
+ auto_commit: Whether auto-commit is enabled
134
+
135
+ Returns:
136
+ True if transaction capabilities are available and should be used
137
+ """
138
+ return False if auto_commit else is_sync_transaction_capable(connection)
sqlspec/exceptions.py CHANGED
@@ -19,6 +19,7 @@ __all__ = (
19
19
  "RepositoryError",
20
20
  "RiskLevel",
21
21
  "SQLBuilderError",
22
+ "SQLCompilationError",
22
23
  "SQLConversionError",
23
24
  "SQLFileNotFoundError",
24
25
  "SQLFileParseError",
@@ -122,6 +123,15 @@ class SQLBuilderError(SQLSpecError):
122
123
  super().__init__(message)
123
124
 
124
125
 
126
+ class SQLCompilationError(SQLSpecError):
127
+ """Issues Compiling SQL statements."""
128
+
129
+ def __init__(self, message: Optional[str] = None) -> None:
130
+ if message is None:
131
+ message = "Issues compiling SQL statement."
132
+ super().__init__(message)
133
+
134
+
125
135
  class SQLConversionError(SQLSpecError):
126
136
  """Issues converting SQL statements."""
127
137
 
@@ -374,7 +384,6 @@ def wrap_exceptions(
374
384
  yield
375
385
 
376
386
  except Exception as exc:
377
- # Handle suppression first
378
387
  if suppress is not None and (
379
388
  (isinstance(suppress, type) and isinstance(exc, suppress))
380
389
  or (isinstance(suppress, tuple) and isinstance(exc, suppress))
@@ -385,7 +394,6 @@ def wrap_exceptions(
385
394
  if isinstance(exc, SQLSpecError):
386
395
  raise
387
396
 
388
- # Handle wrapping
389
397
  if wrap_exceptions is False:
390
398
  raise
391
399
  msg = "An error occurred during the operation."
@@ -40,11 +40,9 @@ def _normalize_dialect(dialect: "Union[str, Any, None]") -> str:
40
40
  Returns:
41
41
  Normalized dialect name
42
42
  """
43
- # Handle different dialect types
44
43
  if dialect is None:
45
44
  return "sql"
46
45
 
47
- # Extract string from dialect class or instance
48
46
  if hasattr(dialect, "__name__"): # It's a class
49
47
  dialect_str = str(dialect.__name__).lower() # pyright: ignore
50
48
  elif hasattr(dialect, "name"): # It's an instance with name attribute
@@ -134,7 +132,6 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
134
132
  "Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
135
133
  )
136
134
 
137
- # Create SQL object and apply filters
138
135
  sql_obj = self._create_sql_object(sql, parameters)
139
136
  # Execute using SQLSpec driver
140
137
  result = self.driver.execute(sql_obj, connection=conn)
@@ -192,12 +189,9 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
192
189
  return None
193
190
 
194
191
  if isinstance(row, dict):
195
- # Return first value from dict
196
192
  return next(iter(row.values())) if row else None
197
193
  if hasattr(row, "__getitem__"):
198
- # Handle tuple/list-like objects
199
194
  return row[0] if len(row) > 0 else None
200
- # Handle scalar or object with attributes
201
195
  return row
202
196
 
203
197
  @contextmanager
@@ -216,7 +210,6 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
216
210
  sql_obj = self._create_sql_object(sql, parameters)
217
211
  result = self.driver.execute(sql_obj, connection=conn)
218
212
 
219
- # Create a cursor-like object
220
213
  class CursorLike:
221
214
  def __init__(self, result: Any) -> None:
222
215
  self.result = result
@@ -386,12 +379,9 @@ class AiosqlAsyncAdapter(_AiosqlAdapterBase):
386
379
  return None
387
380
 
388
381
  if isinstance(row, dict):
389
- # Return first value from dict
390
382
  return next(iter(row.values())) if row else None
391
383
  if hasattr(row, "__getitem__"):
392
- # Handle tuple/list-like objects
393
384
  return row[0] if len(row) > 0 else None
394
- # Handle scalar or object with attributes
395
385
  return row
396
386
 
397
387
  @asynccontextmanager
@@ -246,7 +246,6 @@ def session_provider_maker(
246
246
 
247
247
  conn_type_annotation = config.connection_type
248
248
 
249
- # Import Dependency at function level to avoid circular imports
250
249
  from litestar.params import Dependency
251
250
 
252
251
  db_conn_param = inspect.Parameter(
@@ -69,7 +69,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
69
69
  [SQLSpec, ConnectionT, PoolT, DriverT, DatabaseConfig, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT]
70
70
  )
71
71
 
72
- # Create signature namespace for connection types
73
72
  signature_namespace = {}
74
73
 
75
74
  for c in self._plugin_configs:
@@ -78,7 +77,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
78
77
  app_config.signature_types.append(c.config.connection_type) # type: ignore[union-attr]
79
78
  app_config.signature_types.append(c.config.driver_type) # type: ignore[union-attr]
80
79
 
81
- # Get signature namespace from the config
82
80
  if hasattr(c.config, "get_signature_namespace"):
83
81
  config_namespace = c.config.get_signature_namespace() # type: ignore[attr-defined]
84
82
  signature_namespace.update(config_namespace)
@@ -93,7 +91,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
93
91
  }
94
92
  )
95
93
 
96
- # Update app config with signature namespace
97
94
  if signature_namespace:
98
95
  app_config.signature_namespace.update(signature_namespace)
99
96
 
@@ -173,7 +173,6 @@ def _make_hashable(value: Any) -> HashableType:
173
173
  A hashable version of the value.
174
174
  """
175
175
  if isinstance(value, dict):
176
- # Convert dict to tuple of tuples with sorted keys
177
176
  items = []
178
177
  for k in sorted(value.keys()): # pyright: ignore
179
178
  v = value[k] # pyright: ignore
@@ -261,7 +260,6 @@ def _create_statement_filters(
261
260
  required=False,
262
261
  ),
263
262
  ) -> SearchFilter:
264
- # Handle both string and set input types for search fields
265
263
  field_names = set(search_fields.split(",")) if isinstance(search_fields, str) else set(search_fields)
266
264
 
267
265
  return SearchFilter(
@@ -286,9 +284,7 @@ def _create_statement_filters(
286
284
 
287
285
  filters[dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY] = Provide(provide_order_by, sync_to_thread=False)
288
286
 
289
- # Add not_in filter providers
290
287
  if not_in_fields := config.get("not_in_fields"):
291
- # Get all field names, handling both strings and FieldNameType objects
292
288
  not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
293
289
 
294
290
  for field_def in not_in_fields:
@@ -313,9 +309,7 @@ def _create_statement_filters(
313
309
  provider = create_not_in_filter_provider(field_def) # pyright: ignore
314
310
  filters[f"{field_def.name}_not_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
315
311
 
316
- # Add in filter providers
317
312
  if in_fields := config.get("in_fields"):
318
- # Get all field names, handling both strings and FieldNameType objects
319
313
  in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
320
314
 
321
315
  for field_def in in_fields:
@@ -361,7 +355,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
361
355
  parameters: dict[str, inspect.Parameter] = {}
362
356
  annotations: dict[str, Any] = {}
363
357
 
364
- # Build parameters based on config
365
358
  if cls := config.get("id_filter"):
366
359
  parameters["id_filter"] = inspect.Parameter(
367
360
  name="id_filter",
@@ -416,7 +409,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
416
409
  )
417
410
  annotations["order_by_filter"] = OrderByFilter
418
411
 
419
- # Add parameters for not_in filters
420
412
  if not_in_fields := config.get("not_in_fields"):
421
413
  for field_def in not_in_fields:
422
414
  field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
@@ -428,7 +420,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
428
420
  )
429
421
  annotations[f"{field_def.name}_not_in_filter"] = NotInCollectionFilter[field_def.type_hint] # type: ignore
430
422
 
431
- # Add parameters for in filters
432
423
  if in_fields := config.get("in_fields"):
433
424
  for field_def in in_fields:
434
425
  field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
@@ -472,9 +463,7 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
472
463
  ):
473
464
  filters.append(order_by)
474
465
 
475
- # Add not_in filters
476
466
  if not_in_fields := config.get("not_in_fields"):
477
- # Get all field names, handling both strings and FieldNameType objects
478
467
  not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
479
468
  for field_def in not_in_fields:
480
469
  field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
@@ -482,9 +471,7 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
482
471
  if filter_ is not None:
483
472
  filters.append(filter_)
484
473
 
485
- # Add in filters
486
474
  if in_fields := config.get("in_fields"):
487
- # Get all field names, handling both strings and FieldNameType objects
488
475
  in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
489
476
  for field_def in in_fields:
490
477
  field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
@@ -493,7 +480,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
493
480
  filters.append(filter_)
494
481
  return filters
495
482
 
496
- # Set both signature and annotations
497
483
  provide_filters.__signature__ = inspect.Signature( # type: ignore
498
484
  parameters=list(parameters.values()), return_annotation=list[FilterTypes]
499
485
  )
sqlspec/loader.py CHANGED
@@ -26,7 +26,7 @@ logger = get_logger("loader")
26
26
  # Matches: -- name: query_name (supports hyphens and special suffixes)
27
27
  # We capture the name plus any trailing special characters
28
28
  QUERY_NAME_PATTERN = re.compile(r"^\s*--\s*name\s*:\s*([\w-]+[^\w\s]*)\s*$", re.MULTILINE | re.IGNORECASE)
29
-
29
+ TRIM_TRAILING_SPECIAL_CHARS = re.compile(r"[^\w-]+$")
30
30
  MIN_QUERY_PARTS = 3
31
31
 
32
32
 
@@ -42,10 +42,8 @@ def _normalize_query_name(name: str) -> str:
42
42
  Returns:
43
43
  Normalized query name suitable as Python identifier
44
44
  """
45
- # First strip any trailing special characters
46
- name = re.sub(r"[^\w-]+$", "", name)
47
- # Then replace hyphens with underscores
48
- return name.replace("-", "_")
45
+ # Strip trailing non-alphanumeric characters (excluding underscore) and replace hyphens
46
+ return TRIM_TRAILING_SPECIAL_CHARS.sub("", name).replace("-", "_")
49
47
 
50
48
 
51
49
  @dataclass
@@ -113,7 +111,7 @@ class SQLFileLoader:
113
111
  self._query_to_file: dict[str, str] = {} # Maps query name to file path
114
112
 
115
113
  def _read_file_content(self, path: Union[str, Path]) -> str:
116
- """Read file content using appropriate backend.
114
+ """Read file content using storage backend.
117
115
 
118
116
  Args:
119
117
  path: File path (can be local path or URI).
@@ -126,37 +124,13 @@ class SQLFileLoader:
126
124
  """
127
125
  path_str = str(path)
128
126
 
129
- # Use storage backend for URIs (anything with a scheme)
130
- if "://" in path_str:
131
- try:
132
- backend = self.storage_registry.get(path_str)
133
- return backend.read_text(path_str, encoding=self.encoding)
134
- except KeyError as e:
135
- raise SQLFileNotFoundError(path_str) from e
136
- except Exception as e:
137
- raise SQLFileParseError(path_str, path_str, e) from e
138
-
139
- # Handle local file paths
140
- local_path = Path(path_str)
141
- self._check_file_path(local_path)
142
- content_bytes = self._read_file_content_bytes(local_path)
143
- return content_bytes.decode(self.encoding)
144
-
145
- @staticmethod
146
- def _read_file_content_bytes(path: Path) -> bytes:
147
127
  try:
148
- return path.read_bytes()
128
+ backend = self.storage_registry.get(path)
129
+ return backend.read_text(path_str, encoding=self.encoding)
130
+ except KeyError as e:
131
+ raise SQLFileNotFoundError(path_str) from e
149
132
  except Exception as e:
150
- raise SQLFileParseError(str(path), str(path), e) from e
151
-
152
- @staticmethod
153
- def _check_file_path(path: Union[str, Path]) -> None:
154
- """Ensure the file exists and is a valid path."""
155
- path_obj = Path(path).resolve()
156
- if not path_obj.exists():
157
- raise SQLFileNotFoundError(str(path_obj))
158
- if not path_obj.is_file():
159
- raise SQLFileParseError(str(path_obj), str(path_obj), ValueError("Path is not a file"))
133
+ raise SQLFileParseError(path_str, path_str, e) from e
160
134
 
161
135
  @staticmethod
162
136
  def _strip_leading_comments(sql_text: str) -> str:
@@ -173,48 +147,27 @@ class SQLFileLoader:
173
147
 
174
148
  @staticmethod
175
149
  def _parse_sql_content(content: str, file_path: str) -> dict[str, str]:
176
- """Parse SQL content and extract named queries.
177
-
178
- Args:
179
- content: SQL file content.
180
- file_path: Path to the file (for error messages).
181
-
182
- Returns:
183
- Dictionary mapping query names to SQL text.
184
-
185
- Raises:
186
- SQLFileParseError: If no named queries found.
187
- """
150
+ """Parse SQL content and extract named queries."""
188
151
  queries: dict[str, str] = {}
189
-
190
- # Split content by query name patterns
191
- parts = QUERY_NAME_PATTERN.split(content)
192
-
193
- if len(parts) < MIN_QUERY_PARTS:
194
- # No named queries found
152
+ matches = list(QUERY_NAME_PATTERN.finditer(content))
153
+ if not matches:
195
154
  raise SQLFileParseError(
196
155
  file_path, file_path, ValueError("No named SQL statements found (-- name: query_name)")
197
156
  )
198
157
 
199
- # Process each named query
200
- for i in range(1, len(parts), 2):
201
- if i + 1 >= len(parts):
202
- break
203
-
204
- raw_query_name = parts[i].strip()
205
- sql_text = parts[i + 1].strip()
158
+ for i, match in enumerate(matches):
159
+ raw_query_name = match.group(1).strip()
160
+ start_pos = match.end()
161
+ end_pos = matches[i + 1].start() if i + 1 < len(matches) else len(content)
206
162
 
163
+ sql_text = content[start_pos:end_pos].strip()
207
164
  if not raw_query_name or not sql_text:
208
165
  continue
209
166
 
210
167
  clean_sql = SQLFileLoader._strip_leading_comments(sql_text)
211
-
212
168
  if clean_sql:
213
- # Normalize to Python-compatible identifier
214
169
  query_name = _normalize_query_name(raw_query_name)
215
-
216
170
  if query_name in queries:
217
- # Duplicate query name
218
171
  raise SQLFileParseError(file_path, file_path, ValueError(f"Duplicate query name: {raw_query_name}"))
219
172
  queries[query_name] = clean_sql
220
173
 
@@ -243,19 +196,13 @@ class SQLFileLoader:
243
196
  try:
244
197
  for path in paths:
245
198
  path_str = str(path)
246
-
247
- # Check if it's a URI
248
199
  if "://" in path_str:
249
- # URIs are always treated as files, not directories
250
200
  self._load_single_file(path, None)
251
201
  loaded_count += 1
252
202
  else:
253
- # Local path - check if it's a directory or file
254
203
  path_obj = Path(path)
255
204
  if path_obj.is_dir():
256
- file_count_before = len(self._files)
257
- self._load_directory(path_obj)
258
- loaded_count += len(self._files) - file_count_before
205
+ loaded_count += self._load_directory(path_obj)
259
206
  else:
260
207
  self._load_single_file(path_obj, None)
261
208
  loaded_count += 1
@@ -289,31 +236,18 @@ class SQLFileLoader:
289
236
  )
290
237
  raise
291
238
 
292
- def _load_directory(self, dir_path: Path) -> None:
293
- """Load all SQL files from a directory with namespacing.
294
-
295
- Args:
296
- dir_path: Directory path to scan for SQL files.
297
-
298
- Raises:
299
- SQLFileParseError: If directory contains no SQL files.
300
- """
239
+ def _load_directory(self, dir_path: Path) -> int:
240
+ """Load all SQL files from a directory with namespacing."""
301
241
  sql_files = list(dir_path.rglob("*.sql"))
302
-
303
242
  if not sql_files:
304
- raise SQLFileParseError(
305
- str(dir_path), str(dir_path), ValueError(f"No SQL files found in directory: {dir_path}")
306
- )
243
+ return 0
307
244
 
308
245
  for file_path in sql_files:
309
- # Calculate namespace based on relative path from base directory
310
246
  relative_path = file_path.relative_to(dir_path)
311
247
  namespace_parts = relative_path.parent.parts
312
-
313
- # Create namespace (empty for root-level files)
314
248
  namespace = ".".join(namespace_parts) if namespace_parts else None
315
-
316
249
  self._load_single_file(file_path, namespace)
250
+ return len(sql_files)
317
251
 
318
252
  def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
319
253
  """Load a single SQL file with optional namespace.
@@ -324,42 +258,24 @@ class SQLFileLoader:
324
258
  """
325
259
  path_str = str(file_path)
326
260
 
327
- # Check if already loaded
328
261
  if path_str in self._files:
329
- # File already loaded, just ensure queries are in the main dict
330
- file_obj = self._files[path_str]
331
- queries = self._parse_sql_content(file_obj.content, path_str)
332
- for name in queries:
333
- namespaced_name = f"{namespace}.{name}" if namespace else name
334
- if namespaced_name not in self._queries:
335
- self._queries[namespaced_name] = queries[name]
336
- self._query_to_file[namespaced_name] = path_str
337
- return
338
-
339
- # Read file content
340
- content = self._read_file_content(file_path)
262
+ return # Already loaded
341
263
 
342
- # Create SQLFile object
264
+ content = self._read_file_content(file_path)
343
265
  sql_file = SQLFile(content=content, path=path_str)
344
-
345
- # Cache the file
346
266
  self._files[path_str] = sql_file
347
267
 
348
- # Parse and cache queries
349
268
  queries = self._parse_sql_content(content, path_str)
350
-
351
- # Merge into main query dictionary with namespace
352
269
  for name, sql in queries.items():
353
270
  namespaced_name = f"{namespace}.{name}" if namespace else name
354
-
355
- if namespaced_name in self._queries and self._query_to_file.get(namespaced_name) != path_str:
356
- # Query name exists from a different file
271
+ if namespaced_name in self._queries:
357
272
  existing_file = self._query_to_file.get(namespaced_name, "unknown")
358
- raise SQLFileParseError(
359
- path_str,
360
- path_str,
361
- ValueError(f"Query name '{namespaced_name}' already exists in file: {existing_file}"),
362
- )
273
+ if existing_file != path_str:
274
+ raise SQLFileParseError(
275
+ path_str,
276
+ path_str,
277
+ ValueError(f"Query name '{namespaced_name}' already exists in file: {existing_file}"),
278
+ )
363
279
  self._queries[namespaced_name] = sql
364
280
  self._query_to_file[namespaced_name] = path_str
365
281
 
@@ -379,7 +295,6 @@ class SQLFileLoader:
379
295
  raise ValueError(msg)
380
296
 
381
297
  self._queries[name] = sql.strip()
382
- # Use special marker for directly added queries
383
298
  self._query_to_file[name] = "<directly added>"
384
299
 
385
300
  def get_sql(self, name: str, parameters: "Optional[Any]" = None, **kwargs: "Any") -> "SQL":
@@ -427,12 +342,10 @@ class SQLFileLoader:
427
342
  )
428
343
  raise SQLFileNotFoundError(name, path=f"Query '{name}' not found. Available queries: {available}")
429
344
 
430
- # Merge parameters and kwargs for SQL object creation
431
345
  sql_kwargs = dict(kwargs)
432
346
  if parameters is not None:
433
347
  sql_kwargs["parameters"] = parameters
434
348
 
435
- # Get source file for additional context
436
349
  source_file = self._query_to_file.get(safe_name, "unknown")
437
350
 
438
351
  logger.debug(