sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +18 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Consolidated parameter processing utilities for database drivers.
|
|
2
|
+
|
|
3
|
+
This module provides centralized parameter handling logic to avoid duplication
|
|
4
|
+
across sync and async driver implementations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
8
|
+
|
|
9
|
+
from sqlspec.statement.filters import StatementFilter
|
|
10
|
+
from sqlspec.utils.type_guards import is_sync_transaction_capable
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from sqlspec.typing import StatementParameters
|
|
14
|
+
|
|
15
|
+
__all__ = (
|
|
16
|
+
"convert_parameters_to_positional",
|
|
17
|
+
"normalize_parameter_sequence",
|
|
18
|
+
"process_execute_many_parameters",
|
|
19
|
+
"separate_filters_and_parameters",
|
|
20
|
+
"should_use_transaction",
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def separate_filters_and_parameters(
|
|
25
|
+
parameters: "tuple[Union[StatementParameters, StatementFilter], ...]",
|
|
26
|
+
) -> "tuple[list[StatementFilter], list[Any]]":
|
|
27
|
+
"""Separate filters from parameters in a mixed parameter tuple.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
parameters: Mixed tuple of parameters and filters
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Tuple of (filters, parameters) lists
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
filters: list[StatementFilter] = []
|
|
37
|
+
param_values: list[Any] = []
|
|
38
|
+
|
|
39
|
+
for param in parameters:
|
|
40
|
+
if isinstance(param, StatementFilter):
|
|
41
|
+
filters.append(param)
|
|
42
|
+
else:
|
|
43
|
+
param_values.append(param)
|
|
44
|
+
|
|
45
|
+
return filters, param_values
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def process_execute_many_parameters(
|
|
49
|
+
parameters: "tuple[Union[StatementParameters, StatementFilter], ...]",
|
|
50
|
+
) -> "tuple[list[StatementFilter], Optional[list[Any]]]":
|
|
51
|
+
"""Process parameters for execute_many operations.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
parameters: Mixed tuple of parameters and filters
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Tuple of (filters, parameter_sequence)
|
|
58
|
+
"""
|
|
59
|
+
filters, param_values = separate_filters_and_parameters(parameters)
|
|
60
|
+
|
|
61
|
+
# Use first parameter as the sequence for execute_many
|
|
62
|
+
param_sequence = param_values[0] if param_values else None
|
|
63
|
+
|
|
64
|
+
# Normalize the parameter sequence
|
|
65
|
+
param_sequence = normalize_parameter_sequence(param_sequence)
|
|
66
|
+
|
|
67
|
+
return filters, param_sequence
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def normalize_parameter_sequence(params: Any) -> Optional[list[Any]]:
|
|
71
|
+
"""Normalize a parameter sequence to a list format.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
params: Parameter sequence in various formats
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Normalized list of parameters or None
|
|
78
|
+
"""
|
|
79
|
+
if params is None:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
if isinstance(params, list):
|
|
83
|
+
return params
|
|
84
|
+
|
|
85
|
+
if isinstance(params, tuple):
|
|
86
|
+
return list(params)
|
|
87
|
+
|
|
88
|
+
# Check if it's iterable (but not string or dict)
|
|
89
|
+
# Use duck typing to check for iterable protocol
|
|
90
|
+
try:
|
|
91
|
+
iter(params)
|
|
92
|
+
if not isinstance(params, (str, dict)):
|
|
93
|
+
return list(params)
|
|
94
|
+
except TypeError:
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
# Single parameter, wrap in list
|
|
98
|
+
return [params]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def convert_parameters_to_positional(params: "dict[str, Any]", parameter_info: "list[Any]") -> list[Any]:
|
|
102
|
+
"""Convert named parameters to positional based on SQL order.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
params: Dictionary of named parameters
|
|
106
|
+
parameter_info: List of parameter info from SQL parsing
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
List of positional parameters
|
|
110
|
+
"""
|
|
111
|
+
if not params:
|
|
112
|
+
return []
|
|
113
|
+
|
|
114
|
+
# Handle param_0, param_1, etc. pattern
|
|
115
|
+
if all(key.startswith("param_") for key in params):
|
|
116
|
+
return [params[f"param_{i}"] for i in range(len(params))]
|
|
117
|
+
|
|
118
|
+
# Convert based on parameter info order
|
|
119
|
+
# Check for name attribute using getattr with default
|
|
120
|
+
result = []
|
|
121
|
+
for info in parameter_info:
|
|
122
|
+
param_name = getattr(info, "name", None)
|
|
123
|
+
if param_name is not None:
|
|
124
|
+
result.append(params.get(param_name, None))
|
|
125
|
+
return result
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def should_use_transaction(connection: Any, auto_commit: bool = True) -> bool:
|
|
129
|
+
"""Determine if a transaction should be used.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
connection: Database connection object
|
|
133
|
+
auto_commit: Whether auto-commit is enabled
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
True if transaction capabilities are available and should be used
|
|
137
|
+
"""
|
|
138
|
+
return False if auto_commit else is_sync_transaction_capable(connection)
|
sqlspec/exceptions.py
CHANGED
|
@@ -19,6 +19,7 @@ __all__ = (
|
|
|
19
19
|
"RepositoryError",
|
|
20
20
|
"RiskLevel",
|
|
21
21
|
"SQLBuilderError",
|
|
22
|
+
"SQLCompilationError",
|
|
22
23
|
"SQLConversionError",
|
|
23
24
|
"SQLFileNotFoundError",
|
|
24
25
|
"SQLFileParseError",
|
|
@@ -122,6 +123,15 @@ class SQLBuilderError(SQLSpecError):
|
|
|
122
123
|
super().__init__(message)
|
|
123
124
|
|
|
124
125
|
|
|
126
|
+
class SQLCompilationError(SQLSpecError):
|
|
127
|
+
"""Issues Compiling SQL statements."""
|
|
128
|
+
|
|
129
|
+
def __init__(self, message: Optional[str] = None) -> None:
|
|
130
|
+
if message is None:
|
|
131
|
+
message = "Issues compiling SQL statement."
|
|
132
|
+
super().__init__(message)
|
|
133
|
+
|
|
134
|
+
|
|
125
135
|
class SQLConversionError(SQLSpecError):
|
|
126
136
|
"""Issues converting SQL statements."""
|
|
127
137
|
|
|
@@ -374,7 +384,6 @@ def wrap_exceptions(
|
|
|
374
384
|
yield
|
|
375
385
|
|
|
376
386
|
except Exception as exc:
|
|
377
|
-
# Handle suppression first
|
|
378
387
|
if suppress is not None and (
|
|
379
388
|
(isinstance(suppress, type) and isinstance(exc, suppress))
|
|
380
389
|
or (isinstance(suppress, tuple) and isinstance(exc, suppress))
|
|
@@ -385,7 +394,6 @@ def wrap_exceptions(
|
|
|
385
394
|
if isinstance(exc, SQLSpecError):
|
|
386
395
|
raise
|
|
387
396
|
|
|
388
|
-
# Handle wrapping
|
|
389
397
|
if wrap_exceptions is False:
|
|
390
398
|
raise
|
|
391
399
|
msg = "An error occurred during the operation."
|
|
@@ -40,11 +40,9 @@ def _normalize_dialect(dialect: "Union[str, Any, None]") -> str:
|
|
|
40
40
|
Returns:
|
|
41
41
|
Normalized dialect name
|
|
42
42
|
"""
|
|
43
|
-
# Handle different dialect types
|
|
44
43
|
if dialect is None:
|
|
45
44
|
return "sql"
|
|
46
45
|
|
|
47
|
-
# Extract string from dialect class or instance
|
|
48
46
|
if hasattr(dialect, "__name__"): # It's a class
|
|
49
47
|
dialect_str = str(dialect.__name__).lower() # pyright: ignore
|
|
50
48
|
elif hasattr(dialect, "name"): # It's an instance with name attribute
|
|
@@ -134,7 +132,6 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
|
|
|
134
132
|
"Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
|
|
135
133
|
)
|
|
136
134
|
|
|
137
|
-
# Create SQL object and apply filters
|
|
138
135
|
sql_obj = self._create_sql_object(sql, parameters)
|
|
139
136
|
# Execute using SQLSpec driver
|
|
140
137
|
result = self.driver.execute(sql_obj, connection=conn)
|
|
@@ -192,12 +189,9 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
|
|
|
192
189
|
return None
|
|
193
190
|
|
|
194
191
|
if isinstance(row, dict):
|
|
195
|
-
# Return first value from dict
|
|
196
192
|
return next(iter(row.values())) if row else None
|
|
197
193
|
if hasattr(row, "__getitem__"):
|
|
198
|
-
# Handle tuple/list-like objects
|
|
199
194
|
return row[0] if len(row) > 0 else None
|
|
200
|
-
# Handle scalar or object with attributes
|
|
201
195
|
return row
|
|
202
196
|
|
|
203
197
|
@contextmanager
|
|
@@ -216,7 +210,6 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
|
|
|
216
210
|
sql_obj = self._create_sql_object(sql, parameters)
|
|
217
211
|
result = self.driver.execute(sql_obj, connection=conn)
|
|
218
212
|
|
|
219
|
-
# Create a cursor-like object
|
|
220
213
|
class CursorLike:
|
|
221
214
|
def __init__(self, result: Any) -> None:
|
|
222
215
|
self.result = result
|
|
@@ -386,12 +379,9 @@ class AiosqlAsyncAdapter(_AiosqlAdapterBase):
|
|
|
386
379
|
return None
|
|
387
380
|
|
|
388
381
|
if isinstance(row, dict):
|
|
389
|
-
# Return first value from dict
|
|
390
382
|
return next(iter(row.values())) if row else None
|
|
391
383
|
if hasattr(row, "__getitem__"):
|
|
392
|
-
# Handle tuple/list-like objects
|
|
393
384
|
return row[0] if len(row) > 0 else None
|
|
394
|
-
# Handle scalar or object with attributes
|
|
395
385
|
return row
|
|
396
386
|
|
|
397
387
|
@asynccontextmanager
|
|
@@ -69,7 +69,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
|
|
|
69
69
|
[SQLSpec, ConnectionT, PoolT, DriverT, DatabaseConfig, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT]
|
|
70
70
|
)
|
|
71
71
|
|
|
72
|
-
# Create signature namespace for connection types
|
|
73
72
|
signature_namespace = {}
|
|
74
73
|
|
|
75
74
|
for c in self._plugin_configs:
|
|
@@ -78,7 +77,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
|
|
|
78
77
|
app_config.signature_types.append(c.config.connection_type) # type: ignore[union-attr]
|
|
79
78
|
app_config.signature_types.append(c.config.driver_type) # type: ignore[union-attr]
|
|
80
79
|
|
|
81
|
-
# Get signature namespace from the config
|
|
82
80
|
if hasattr(c.config, "get_signature_namespace"):
|
|
83
81
|
config_namespace = c.config.get_signature_namespace() # type: ignore[attr-defined]
|
|
84
82
|
signature_namespace.update(config_namespace)
|
|
@@ -93,7 +91,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
|
|
|
93
91
|
}
|
|
94
92
|
)
|
|
95
93
|
|
|
96
|
-
# Update app config with signature namespace
|
|
97
94
|
if signature_namespace:
|
|
98
95
|
app_config.signature_namespace.update(signature_namespace)
|
|
99
96
|
|
|
@@ -173,7 +173,6 @@ def _make_hashable(value: Any) -> HashableType:
|
|
|
173
173
|
A hashable version of the value.
|
|
174
174
|
"""
|
|
175
175
|
if isinstance(value, dict):
|
|
176
|
-
# Convert dict to tuple of tuples with sorted keys
|
|
177
176
|
items = []
|
|
178
177
|
for k in sorted(value.keys()): # pyright: ignore
|
|
179
178
|
v = value[k] # pyright: ignore
|
|
@@ -261,7 +260,6 @@ def _create_statement_filters(
|
|
|
261
260
|
required=False,
|
|
262
261
|
),
|
|
263
262
|
) -> SearchFilter:
|
|
264
|
-
# Handle both string and set input types for search fields
|
|
265
263
|
field_names = set(search_fields.split(",")) if isinstance(search_fields, str) else set(search_fields)
|
|
266
264
|
|
|
267
265
|
return SearchFilter(
|
|
@@ -286,9 +284,7 @@ def _create_statement_filters(
|
|
|
286
284
|
|
|
287
285
|
filters[dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY] = Provide(provide_order_by, sync_to_thread=False)
|
|
288
286
|
|
|
289
|
-
# Add not_in filter providers
|
|
290
287
|
if not_in_fields := config.get("not_in_fields"):
|
|
291
|
-
# Get all field names, handling both strings and FieldNameType objects
|
|
292
288
|
not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
|
|
293
289
|
|
|
294
290
|
for field_def in not_in_fields:
|
|
@@ -313,9 +309,7 @@ def _create_statement_filters(
|
|
|
313
309
|
provider = create_not_in_filter_provider(field_def) # pyright: ignore
|
|
314
310
|
filters[f"{field_def.name}_not_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
|
|
315
311
|
|
|
316
|
-
# Add in filter providers
|
|
317
312
|
if in_fields := config.get("in_fields"):
|
|
318
|
-
# Get all field names, handling both strings and FieldNameType objects
|
|
319
313
|
in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
|
|
320
314
|
|
|
321
315
|
for field_def in in_fields:
|
|
@@ -361,7 +355,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
361
355
|
parameters: dict[str, inspect.Parameter] = {}
|
|
362
356
|
annotations: dict[str, Any] = {}
|
|
363
357
|
|
|
364
|
-
# Build parameters based on config
|
|
365
358
|
if cls := config.get("id_filter"):
|
|
366
359
|
parameters["id_filter"] = inspect.Parameter(
|
|
367
360
|
name="id_filter",
|
|
@@ -416,7 +409,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
416
409
|
)
|
|
417
410
|
annotations["order_by_filter"] = OrderByFilter
|
|
418
411
|
|
|
419
|
-
# Add parameters for not_in filters
|
|
420
412
|
if not_in_fields := config.get("not_in_fields"):
|
|
421
413
|
for field_def in not_in_fields:
|
|
422
414
|
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
|
|
@@ -428,7 +420,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
428
420
|
)
|
|
429
421
|
annotations[f"{field_def.name}_not_in_filter"] = NotInCollectionFilter[field_def.type_hint] # type: ignore
|
|
430
422
|
|
|
431
|
-
# Add parameters for in filters
|
|
432
423
|
if in_fields := config.get("in_fields"):
|
|
433
424
|
for field_def in in_fields:
|
|
434
425
|
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
|
|
@@ -472,9 +463,7 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
472
463
|
):
|
|
473
464
|
filters.append(order_by)
|
|
474
465
|
|
|
475
|
-
# Add not_in filters
|
|
476
466
|
if not_in_fields := config.get("not_in_fields"):
|
|
477
|
-
# Get all field names, handling both strings and FieldNameType objects
|
|
478
467
|
not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
|
|
479
468
|
for field_def in not_in_fields:
|
|
480
469
|
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
|
|
@@ -482,9 +471,7 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
482
471
|
if filter_ is not None:
|
|
483
472
|
filters.append(filter_)
|
|
484
473
|
|
|
485
|
-
# Add in filters
|
|
486
474
|
if in_fields := config.get("in_fields"):
|
|
487
|
-
# Get all field names, handling both strings and FieldNameType objects
|
|
488
475
|
in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
|
|
489
476
|
for field_def in in_fields:
|
|
490
477
|
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
|
|
@@ -493,7 +480,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
493
480
|
filters.append(filter_)
|
|
494
481
|
return filters
|
|
495
482
|
|
|
496
|
-
# Set both signature and annotations
|
|
497
483
|
provide_filters.__signature__ = inspect.Signature( # type: ignore
|
|
498
484
|
parameters=list(parameters.values()), return_annotation=list[FilterTypes]
|
|
499
485
|
)
|
sqlspec/loader.py
CHANGED
|
@@ -26,7 +26,7 @@ logger = get_logger("loader")
|
|
|
26
26
|
# Matches: -- name: query_name (supports hyphens and special suffixes)
|
|
27
27
|
# We capture the name plus any trailing special characters
|
|
28
28
|
QUERY_NAME_PATTERN = re.compile(r"^\s*--\s*name\s*:\s*([\w-]+[^\w\s]*)\s*$", re.MULTILINE | re.IGNORECASE)
|
|
29
|
-
|
|
29
|
+
TRIM_TRAILING_SPECIAL_CHARS = re.compile(r"[^\w-]+$")
|
|
30
30
|
MIN_QUERY_PARTS = 3
|
|
31
31
|
|
|
32
32
|
|
|
@@ -42,10 +42,8 @@ def _normalize_query_name(name: str) -> str:
|
|
|
42
42
|
Returns:
|
|
43
43
|
Normalized query name suitable as Python identifier
|
|
44
44
|
"""
|
|
45
|
-
#
|
|
46
|
-
|
|
47
|
-
# Then replace hyphens with underscores
|
|
48
|
-
return name.replace("-", "_")
|
|
45
|
+
# Strip trailing non-alphanumeric characters (excluding underscore) and replace hyphens
|
|
46
|
+
return TRIM_TRAILING_SPECIAL_CHARS.sub("", name).replace("-", "_")
|
|
49
47
|
|
|
50
48
|
|
|
51
49
|
@dataclass
|
|
@@ -127,8 +125,6 @@ class SQLFileLoader:
|
|
|
127
125
|
path_str = str(path)
|
|
128
126
|
|
|
129
127
|
try:
|
|
130
|
-
# Always use storage backend for consistent behavior
|
|
131
|
-
# Pass the original path object to allow storage registry to handle Path -> file:// conversion
|
|
132
128
|
backend = self.storage_registry.get(path)
|
|
133
129
|
return backend.read_text(path_str, encoding=self.encoding)
|
|
134
130
|
except KeyError as e:
|
|
@@ -151,48 +147,27 @@ class SQLFileLoader:
|
|
|
151
147
|
|
|
152
148
|
@staticmethod
|
|
153
149
|
def _parse_sql_content(content: str, file_path: str) -> dict[str, str]:
|
|
154
|
-
"""Parse SQL content and extract named queries.
|
|
155
|
-
|
|
156
|
-
Args:
|
|
157
|
-
content: SQL file content.
|
|
158
|
-
file_path: Path to the file (for error messages).
|
|
159
|
-
|
|
160
|
-
Returns:
|
|
161
|
-
Dictionary mapping query names to SQL text.
|
|
162
|
-
|
|
163
|
-
Raises:
|
|
164
|
-
SQLFileParseError: If no named queries found.
|
|
165
|
-
"""
|
|
150
|
+
"""Parse SQL content and extract named queries."""
|
|
166
151
|
queries: dict[str, str] = {}
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
parts = QUERY_NAME_PATTERN.split(content)
|
|
170
|
-
|
|
171
|
-
if len(parts) < MIN_QUERY_PARTS:
|
|
172
|
-
# No named queries found
|
|
152
|
+
matches = list(QUERY_NAME_PATTERN.finditer(content))
|
|
153
|
+
if not matches:
|
|
173
154
|
raise SQLFileParseError(
|
|
174
155
|
file_path, file_path, ValueError("No named SQL statements found (-- name: query_name)")
|
|
175
156
|
)
|
|
176
157
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
raw_query_name = parts[i].strip()
|
|
183
|
-
sql_text = parts[i + 1].strip()
|
|
158
|
+
for i, match in enumerate(matches):
|
|
159
|
+
raw_query_name = match.group(1).strip()
|
|
160
|
+
start_pos = match.end()
|
|
161
|
+
end_pos = matches[i + 1].start() if i + 1 < len(matches) else len(content)
|
|
184
162
|
|
|
163
|
+
sql_text = content[start_pos:end_pos].strip()
|
|
185
164
|
if not raw_query_name or not sql_text:
|
|
186
165
|
continue
|
|
187
166
|
|
|
188
167
|
clean_sql = SQLFileLoader._strip_leading_comments(sql_text)
|
|
189
|
-
|
|
190
168
|
if clean_sql:
|
|
191
|
-
# Normalize to Python-compatible identifier
|
|
192
169
|
query_name = _normalize_query_name(raw_query_name)
|
|
193
|
-
|
|
194
170
|
if query_name in queries:
|
|
195
|
-
# Duplicate query name
|
|
196
171
|
raise SQLFileParseError(file_path, file_path, ValueError(f"Duplicate query name: {raw_query_name}"))
|
|
197
172
|
queries[query_name] = clean_sql
|
|
198
173
|
|
|
@@ -221,19 +196,13 @@ class SQLFileLoader:
|
|
|
221
196
|
try:
|
|
222
197
|
for path in paths:
|
|
223
198
|
path_str = str(path)
|
|
224
|
-
|
|
225
|
-
# Check if it's a URI
|
|
226
199
|
if "://" in path_str:
|
|
227
|
-
# URIs are always treated as files, not directories
|
|
228
200
|
self._load_single_file(path, None)
|
|
229
201
|
loaded_count += 1
|
|
230
202
|
else:
|
|
231
|
-
# Local path - check if it's a directory or file
|
|
232
203
|
path_obj = Path(path)
|
|
233
204
|
if path_obj.is_dir():
|
|
234
|
-
|
|
235
|
-
self._load_directory(path_obj)
|
|
236
|
-
loaded_count += len(self._files) - file_count_before
|
|
205
|
+
loaded_count += self._load_directory(path_obj)
|
|
237
206
|
else:
|
|
238
207
|
self._load_single_file(path_obj, None)
|
|
239
208
|
loaded_count += 1
|
|
@@ -267,31 +236,18 @@ class SQLFileLoader:
|
|
|
267
236
|
)
|
|
268
237
|
raise
|
|
269
238
|
|
|
270
|
-
def _load_directory(self, dir_path: Path) ->
|
|
271
|
-
"""Load all SQL files from a directory with namespacing.
|
|
272
|
-
|
|
273
|
-
Args:
|
|
274
|
-
dir_path: Directory path to scan for SQL files.
|
|
275
|
-
|
|
276
|
-
Raises:
|
|
277
|
-
SQLFileParseError: If directory contains no SQL files.
|
|
278
|
-
"""
|
|
239
|
+
def _load_directory(self, dir_path: Path) -> int:
|
|
240
|
+
"""Load all SQL files from a directory with namespacing."""
|
|
279
241
|
sql_files = list(dir_path.rglob("*.sql"))
|
|
280
|
-
|
|
281
242
|
if not sql_files:
|
|
282
|
-
|
|
283
|
-
str(dir_path), str(dir_path), ValueError(f"No SQL files found in directory: {dir_path}")
|
|
284
|
-
)
|
|
243
|
+
return 0
|
|
285
244
|
|
|
286
245
|
for file_path in sql_files:
|
|
287
|
-
# Calculate namespace based on relative path from base directory
|
|
288
246
|
relative_path = file_path.relative_to(dir_path)
|
|
289
247
|
namespace_parts = relative_path.parent.parts
|
|
290
|
-
|
|
291
|
-
# Create namespace (empty for root-level files)
|
|
292
248
|
namespace = ".".join(namespace_parts) if namespace_parts else None
|
|
293
|
-
|
|
294
249
|
self._load_single_file(file_path, namespace)
|
|
250
|
+
return len(sql_files)
|
|
295
251
|
|
|
296
252
|
def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
|
|
297
253
|
"""Load a single SQL file with optional namespace.
|
|
@@ -302,42 +258,24 @@ class SQLFileLoader:
|
|
|
302
258
|
"""
|
|
303
259
|
path_str = str(file_path)
|
|
304
260
|
|
|
305
|
-
# Check if already loaded
|
|
306
261
|
if path_str in self._files:
|
|
307
|
-
#
|
|
308
|
-
file_obj = self._files[path_str]
|
|
309
|
-
queries = self._parse_sql_content(file_obj.content, path_str)
|
|
310
|
-
for name in queries:
|
|
311
|
-
namespaced_name = f"{namespace}.{name}" if namespace else name
|
|
312
|
-
if namespaced_name not in self._queries:
|
|
313
|
-
self._queries[namespaced_name] = queries[name]
|
|
314
|
-
self._query_to_file[namespaced_name] = path_str
|
|
315
|
-
return
|
|
316
|
-
|
|
317
|
-
# Read file content
|
|
318
|
-
content = self._read_file_content(file_path)
|
|
262
|
+
return # Already loaded
|
|
319
263
|
|
|
320
|
-
|
|
264
|
+
content = self._read_file_content(file_path)
|
|
321
265
|
sql_file = SQLFile(content=content, path=path_str)
|
|
322
|
-
|
|
323
|
-
# Cache the file
|
|
324
266
|
self._files[path_str] = sql_file
|
|
325
267
|
|
|
326
|
-
# Parse and cache queries
|
|
327
268
|
queries = self._parse_sql_content(content, path_str)
|
|
328
|
-
|
|
329
|
-
# Merge into main query dictionary with namespace
|
|
330
269
|
for name, sql in queries.items():
|
|
331
270
|
namespaced_name = f"{namespace}.{name}" if namespace else name
|
|
332
|
-
|
|
333
|
-
if namespaced_name in self._queries and self._query_to_file.get(namespaced_name) != path_str:
|
|
334
|
-
# Query name exists from a different file
|
|
271
|
+
if namespaced_name in self._queries:
|
|
335
272
|
existing_file = self._query_to_file.get(namespaced_name, "unknown")
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
273
|
+
if existing_file != path_str:
|
|
274
|
+
raise SQLFileParseError(
|
|
275
|
+
path_str,
|
|
276
|
+
path_str,
|
|
277
|
+
ValueError(f"Query name '{namespaced_name}' already exists in file: {existing_file}"),
|
|
278
|
+
)
|
|
341
279
|
self._queries[namespaced_name] = sql
|
|
342
280
|
self._query_to_file[namespaced_name] = path_str
|
|
343
281
|
|
|
@@ -357,7 +295,6 @@ class SQLFileLoader:
|
|
|
357
295
|
raise ValueError(msg)
|
|
358
296
|
|
|
359
297
|
self._queries[name] = sql.strip()
|
|
360
|
-
# Use special marker for directly added queries
|
|
361
298
|
self._query_to_file[name] = "<directly added>"
|
|
362
299
|
|
|
363
300
|
def get_sql(self, name: str, parameters: "Optional[Any]" = None, **kwargs: "Any") -> "SQL":
|
|
@@ -405,12 +342,10 @@ class SQLFileLoader:
|
|
|
405
342
|
)
|
|
406
343
|
raise SQLFileNotFoundError(name, path=f"Query '{name}' not found. Available queries: {available}")
|
|
407
344
|
|
|
408
|
-
# Merge parameters and kwargs for SQL object creation
|
|
409
345
|
sql_kwargs = dict(kwargs)
|
|
410
346
|
if parameters is not None:
|
|
411
347
|
sql_kwargs["parameters"] = parameters
|
|
412
348
|
|
|
413
|
-
# Get source file for additional context
|
|
414
349
|
source_file = self._query_to_file.get(safe_name, "unknown")
|
|
415
350
|
|
|
416
351
|
logger.debug(
|