sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +116 -141
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +231 -181
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +132 -124
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +34 -30
- sqlspec/adapters/psycopg/driver.py +342 -214
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +150 -104
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +149 -216
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +31 -118
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +70 -23
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +102 -65
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +885 -379
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +82 -35
- sqlspec/storage/backends/obstore.py +66 -49
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -170
- sqlspec-0.12.1.dist-info/RECORD +0 -145
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -7,6 +7,8 @@ TypedParameter objects and perform appropriate type conversions.
|
|
|
7
7
|
from decimal import Decimal
|
|
8
8
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
9
9
|
|
|
10
|
+
from sqlspec.utils.type_guards import has_parameter_value
|
|
11
|
+
|
|
10
12
|
if TYPE_CHECKING:
|
|
11
13
|
from sqlspec.typing import SQLParameterType
|
|
12
14
|
|
|
@@ -68,13 +70,10 @@ class TypeCoercionMixin:
|
|
|
68
70
|
Returns:
|
|
69
71
|
Coerced parameter value suitable for the database
|
|
70
72
|
"""
|
|
71
|
-
|
|
72
|
-
if hasattr(param, "__class__") and param.__class__.__name__ == "TypedParameter":
|
|
73
|
-
# Extract value and type hint
|
|
73
|
+
if has_parameter_value(param):
|
|
74
74
|
value = param.value
|
|
75
75
|
type_hint = param.type_hint
|
|
76
76
|
|
|
77
|
-
# Apply driver-specific coercion based on type hint
|
|
78
77
|
return self._apply_type_coercion(value, type_hint)
|
|
79
78
|
# Regular parameter - apply default coercion
|
|
80
79
|
return self._apply_type_coercion(param, None)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Consolidated parameter processing utilities for database drivers.
|
|
2
|
+
|
|
3
|
+
This module provides centralized parameter handling logic to avoid duplication
|
|
4
|
+
across sync and async driver implementations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
8
|
+
|
|
9
|
+
from sqlspec.statement.filters import StatementFilter
|
|
10
|
+
from sqlspec.utils.type_guards import is_sync_transaction_capable
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from sqlspec.typing import StatementParameters
|
|
14
|
+
|
|
15
|
+
__all__ = (
|
|
16
|
+
"convert_parameters_to_positional",
|
|
17
|
+
"normalize_parameter_sequence",
|
|
18
|
+
"process_execute_many_parameters",
|
|
19
|
+
"separate_filters_and_parameters",
|
|
20
|
+
"should_use_transaction",
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def separate_filters_and_parameters(
|
|
25
|
+
parameters: "tuple[Union[StatementParameters, StatementFilter], ...]",
|
|
26
|
+
) -> "tuple[list[StatementFilter], list[Any]]":
|
|
27
|
+
"""Separate filters from parameters in a mixed parameter tuple.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
parameters: Mixed tuple of parameters and filters
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Tuple of (filters, parameters) lists
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
filters: list[StatementFilter] = []
|
|
37
|
+
param_values: list[Any] = []
|
|
38
|
+
|
|
39
|
+
for param in parameters:
|
|
40
|
+
if isinstance(param, StatementFilter):
|
|
41
|
+
filters.append(param)
|
|
42
|
+
else:
|
|
43
|
+
param_values.append(param)
|
|
44
|
+
|
|
45
|
+
return filters, param_values
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def process_execute_many_parameters(
|
|
49
|
+
parameters: "tuple[Union[StatementParameters, StatementFilter], ...]",
|
|
50
|
+
) -> "tuple[list[StatementFilter], Optional[list[Any]]]":
|
|
51
|
+
"""Process parameters for execute_many operations.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
parameters: Mixed tuple of parameters and filters
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Tuple of (filters, parameter_sequence)
|
|
58
|
+
"""
|
|
59
|
+
filters, param_values = separate_filters_and_parameters(parameters)
|
|
60
|
+
|
|
61
|
+
# Use first parameter as the sequence for execute_many
|
|
62
|
+
param_sequence = param_values[0] if param_values else None
|
|
63
|
+
|
|
64
|
+
# Normalize the parameter sequence
|
|
65
|
+
param_sequence = normalize_parameter_sequence(param_sequence)
|
|
66
|
+
|
|
67
|
+
return filters, param_sequence
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def normalize_parameter_sequence(params: Any) -> Optional[list[Any]]:
|
|
71
|
+
"""Normalize a parameter sequence to a list format.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
params: Parameter sequence in various formats
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Normalized list of parameters or None
|
|
78
|
+
"""
|
|
79
|
+
if params is None:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
if isinstance(params, list):
|
|
83
|
+
return params
|
|
84
|
+
|
|
85
|
+
if isinstance(params, tuple):
|
|
86
|
+
return list(params)
|
|
87
|
+
|
|
88
|
+
# Check if it's iterable (but not string or dict)
|
|
89
|
+
# Use duck typing to check for iterable protocol
|
|
90
|
+
try:
|
|
91
|
+
iter(params)
|
|
92
|
+
if not isinstance(params, (str, dict)):
|
|
93
|
+
return list(params)
|
|
94
|
+
except TypeError:
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
# Single parameter, wrap in list
|
|
98
|
+
return [params]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def convert_parameters_to_positional(params: "dict[str, Any]", parameter_info: "list[Any]") -> list[Any]:
|
|
102
|
+
"""Convert named parameters to positional based on SQL order.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
params: Dictionary of named parameters
|
|
106
|
+
parameter_info: List of parameter info from SQL parsing
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
List of positional parameters
|
|
110
|
+
"""
|
|
111
|
+
if not params:
|
|
112
|
+
return []
|
|
113
|
+
|
|
114
|
+
# Handle param_0, param_1, etc. pattern
|
|
115
|
+
if all(key.startswith("param_") for key in params):
|
|
116
|
+
return [params[f"param_{i}"] for i in range(len(params))]
|
|
117
|
+
|
|
118
|
+
# Convert based on parameter info order
|
|
119
|
+
# Check for name attribute using getattr with default
|
|
120
|
+
result = []
|
|
121
|
+
for info in parameter_info:
|
|
122
|
+
param_name = getattr(info, "name", None)
|
|
123
|
+
if param_name is not None:
|
|
124
|
+
result.append(params.get(param_name, None))
|
|
125
|
+
return result
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def should_use_transaction(connection: Any, auto_commit: bool = True) -> bool:
|
|
129
|
+
"""Determine if a transaction should be used.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
connection: Database connection object
|
|
133
|
+
auto_commit: Whether auto-commit is enabled
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
True if transaction capabilities are available and should be used
|
|
137
|
+
"""
|
|
138
|
+
return False if auto_commit else is_sync_transaction_capable(connection)
|
sqlspec/exceptions.py
CHANGED
|
@@ -19,6 +19,7 @@ __all__ = (
|
|
|
19
19
|
"RepositoryError",
|
|
20
20
|
"RiskLevel",
|
|
21
21
|
"SQLBuilderError",
|
|
22
|
+
"SQLCompilationError",
|
|
22
23
|
"SQLConversionError",
|
|
23
24
|
"SQLFileNotFoundError",
|
|
24
25
|
"SQLFileParseError",
|
|
@@ -122,6 +123,15 @@ class SQLBuilderError(SQLSpecError):
|
|
|
122
123
|
super().__init__(message)
|
|
123
124
|
|
|
124
125
|
|
|
126
|
+
class SQLCompilationError(SQLSpecError):
|
|
127
|
+
"""Issues Compiling SQL statements."""
|
|
128
|
+
|
|
129
|
+
def __init__(self, message: Optional[str] = None) -> None:
|
|
130
|
+
if message is None:
|
|
131
|
+
message = "Issues compiling SQL statement."
|
|
132
|
+
super().__init__(message)
|
|
133
|
+
|
|
134
|
+
|
|
125
135
|
class SQLConversionError(SQLSpecError):
|
|
126
136
|
"""Issues converting SQL statements."""
|
|
127
137
|
|
|
@@ -374,7 +384,6 @@ def wrap_exceptions(
|
|
|
374
384
|
yield
|
|
375
385
|
|
|
376
386
|
except Exception as exc:
|
|
377
|
-
# Handle suppression first
|
|
378
387
|
if suppress is not None and (
|
|
379
388
|
(isinstance(suppress, type) and isinstance(exc, suppress))
|
|
380
389
|
or (isinstance(suppress, tuple) and isinstance(exc, suppress))
|
|
@@ -385,7 +394,6 @@ def wrap_exceptions(
|
|
|
385
394
|
if isinstance(exc, SQLSpecError):
|
|
386
395
|
raise
|
|
387
396
|
|
|
388
|
-
# Handle wrapping
|
|
389
397
|
if wrap_exceptions is False:
|
|
390
398
|
raise
|
|
391
399
|
msg = "An error occurred during the operation."
|
|
@@ -40,11 +40,9 @@ def _normalize_dialect(dialect: "Union[str, Any, None]") -> str:
|
|
|
40
40
|
Returns:
|
|
41
41
|
Normalized dialect name
|
|
42
42
|
"""
|
|
43
|
-
# Handle different dialect types
|
|
44
43
|
if dialect is None:
|
|
45
44
|
return "sql"
|
|
46
45
|
|
|
47
|
-
# Extract string from dialect class or instance
|
|
48
46
|
if hasattr(dialect, "__name__"): # It's a class
|
|
49
47
|
dialect_str = str(dialect.__name__).lower() # pyright: ignore
|
|
50
48
|
elif hasattr(dialect, "name"): # It's an instance with name attribute
|
|
@@ -134,7 +132,6 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
|
|
|
134
132
|
"Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
|
|
135
133
|
)
|
|
136
134
|
|
|
137
|
-
# Create SQL object and apply filters
|
|
138
135
|
sql_obj = self._create_sql_object(sql, parameters)
|
|
139
136
|
# Execute using SQLSpec driver
|
|
140
137
|
result = self.driver.execute(sql_obj, connection=conn)
|
|
@@ -192,12 +189,9 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
|
|
|
192
189
|
return None
|
|
193
190
|
|
|
194
191
|
if isinstance(row, dict):
|
|
195
|
-
# Return first value from dict
|
|
196
192
|
return next(iter(row.values())) if row else None
|
|
197
193
|
if hasattr(row, "__getitem__"):
|
|
198
|
-
# Handle tuple/list-like objects
|
|
199
194
|
return row[0] if len(row) > 0 else None
|
|
200
|
-
# Handle scalar or object with attributes
|
|
201
195
|
return row
|
|
202
196
|
|
|
203
197
|
@contextmanager
|
|
@@ -216,7 +210,6 @@ class AiosqlSyncAdapter(_AiosqlAdapterBase):
|
|
|
216
210
|
sql_obj = self._create_sql_object(sql, parameters)
|
|
217
211
|
result = self.driver.execute(sql_obj, connection=conn)
|
|
218
212
|
|
|
219
|
-
# Create a cursor-like object
|
|
220
213
|
class CursorLike:
|
|
221
214
|
def __init__(self, result: Any) -> None:
|
|
222
215
|
self.result = result
|
|
@@ -386,12 +379,9 @@ class AiosqlAsyncAdapter(_AiosqlAdapterBase):
|
|
|
386
379
|
return None
|
|
387
380
|
|
|
388
381
|
if isinstance(row, dict):
|
|
389
|
-
# Return first value from dict
|
|
390
382
|
return next(iter(row.values())) if row else None
|
|
391
383
|
if hasattr(row, "__getitem__"):
|
|
392
|
-
# Handle tuple/list-like objects
|
|
393
384
|
return row[0] if len(row) > 0 else None
|
|
394
|
-
# Handle scalar or object with attributes
|
|
395
385
|
return row
|
|
396
386
|
|
|
397
387
|
@asynccontextmanager
|
|
@@ -69,7 +69,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
|
|
|
69
69
|
[SQLSpec, ConnectionT, PoolT, DriverT, DatabaseConfig, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT]
|
|
70
70
|
)
|
|
71
71
|
|
|
72
|
-
# Create signature namespace for connection types
|
|
73
72
|
signature_namespace = {}
|
|
74
73
|
|
|
75
74
|
for c in self._plugin_configs:
|
|
@@ -78,7 +77,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
|
|
|
78
77
|
app_config.signature_types.append(c.config.connection_type) # type: ignore[union-attr]
|
|
79
78
|
app_config.signature_types.append(c.config.driver_type) # type: ignore[union-attr]
|
|
80
79
|
|
|
81
|
-
# Get signature namespace from the config
|
|
82
80
|
if hasattr(c.config, "get_signature_namespace"):
|
|
83
81
|
config_namespace = c.config.get_signature_namespace() # type: ignore[attr-defined]
|
|
84
82
|
signature_namespace.update(config_namespace)
|
|
@@ -93,7 +91,6 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
|
|
|
93
91
|
}
|
|
94
92
|
)
|
|
95
93
|
|
|
96
|
-
# Update app config with signature namespace
|
|
97
94
|
if signature_namespace:
|
|
98
95
|
app_config.signature_namespace.update(signature_namespace)
|
|
99
96
|
|
|
@@ -173,7 +173,6 @@ def _make_hashable(value: Any) -> HashableType:
|
|
|
173
173
|
A hashable version of the value.
|
|
174
174
|
"""
|
|
175
175
|
if isinstance(value, dict):
|
|
176
|
-
# Convert dict to tuple of tuples with sorted keys
|
|
177
176
|
items = []
|
|
178
177
|
for k in sorted(value.keys()): # pyright: ignore
|
|
179
178
|
v = value[k] # pyright: ignore
|
|
@@ -261,7 +260,6 @@ def _create_statement_filters(
|
|
|
261
260
|
required=False,
|
|
262
261
|
),
|
|
263
262
|
) -> SearchFilter:
|
|
264
|
-
# Handle both string and set input types for search fields
|
|
265
263
|
field_names = set(search_fields.split(",")) if isinstance(search_fields, str) else set(search_fields)
|
|
266
264
|
|
|
267
265
|
return SearchFilter(
|
|
@@ -286,9 +284,7 @@ def _create_statement_filters(
|
|
|
286
284
|
|
|
287
285
|
filters[dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY] = Provide(provide_order_by, sync_to_thread=False)
|
|
288
286
|
|
|
289
|
-
# Add not_in filter providers
|
|
290
287
|
if not_in_fields := config.get("not_in_fields"):
|
|
291
|
-
# Get all field names, handling both strings and FieldNameType objects
|
|
292
288
|
not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
|
|
293
289
|
|
|
294
290
|
for field_def in not_in_fields:
|
|
@@ -313,9 +309,7 @@ def _create_statement_filters(
|
|
|
313
309
|
provider = create_not_in_filter_provider(field_def) # pyright: ignore
|
|
314
310
|
filters[f"{field_def.name}_not_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
|
|
315
311
|
|
|
316
|
-
# Add in filter providers
|
|
317
312
|
if in_fields := config.get("in_fields"):
|
|
318
|
-
# Get all field names, handling both strings and FieldNameType objects
|
|
319
313
|
in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
|
|
320
314
|
|
|
321
315
|
for field_def in in_fields:
|
|
@@ -361,7 +355,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
361
355
|
parameters: dict[str, inspect.Parameter] = {}
|
|
362
356
|
annotations: dict[str, Any] = {}
|
|
363
357
|
|
|
364
|
-
# Build parameters based on config
|
|
365
358
|
if cls := config.get("id_filter"):
|
|
366
359
|
parameters["id_filter"] = inspect.Parameter(
|
|
367
360
|
name="id_filter",
|
|
@@ -416,7 +409,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
416
409
|
)
|
|
417
410
|
annotations["order_by_filter"] = OrderByFilter
|
|
418
411
|
|
|
419
|
-
# Add parameters for not_in filters
|
|
420
412
|
if not_in_fields := config.get("not_in_fields"):
|
|
421
413
|
for field_def in not_in_fields:
|
|
422
414
|
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
|
|
@@ -428,7 +420,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
428
420
|
)
|
|
429
421
|
annotations[f"{field_def.name}_not_in_filter"] = NotInCollectionFilter[field_def.type_hint] # type: ignore
|
|
430
422
|
|
|
431
|
-
# Add parameters for in filters
|
|
432
423
|
if in_fields := config.get("in_fields"):
|
|
433
424
|
for field_def in in_fields:
|
|
434
425
|
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
|
|
@@ -472,9 +463,7 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
472
463
|
):
|
|
473
464
|
filters.append(order_by)
|
|
474
465
|
|
|
475
|
-
# Add not_in filters
|
|
476
466
|
if not_in_fields := config.get("not_in_fields"):
|
|
477
|
-
# Get all field names, handling both strings and FieldNameType objects
|
|
478
467
|
not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
|
|
479
468
|
for field_def in not_in_fields:
|
|
480
469
|
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
|
|
@@ -482,9 +471,7 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
482
471
|
if filter_ is not None:
|
|
483
472
|
filters.append(filter_)
|
|
484
473
|
|
|
485
|
-
# Add in filters
|
|
486
474
|
if in_fields := config.get("in_fields"):
|
|
487
|
-
# Get all field names, handling both strings and FieldNameType objects
|
|
488
475
|
in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
|
|
489
476
|
for field_def in in_fields:
|
|
490
477
|
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
|
|
@@ -493,7 +480,6 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
|
|
|
493
480
|
filters.append(filter_)
|
|
494
481
|
return filters
|
|
495
482
|
|
|
496
|
-
# Set both signature and annotations
|
|
497
483
|
provide_filters.__signature__ = inspect.Signature( # type: ignore
|
|
498
484
|
parameters=list(parameters.values()), return_annotation=list[FilterTypes]
|
|
499
485
|
)
|
sqlspec/loader.py
CHANGED
|
@@ -26,7 +26,7 @@ logger = get_logger("loader")
|
|
|
26
26
|
# Matches: -- name: query_name (supports hyphens and special suffixes)
|
|
27
27
|
# We capture the name plus any trailing special characters
|
|
28
28
|
QUERY_NAME_PATTERN = re.compile(r"^\s*--\s*name\s*:\s*([\w-]+[^\w\s]*)\s*$", re.MULTILINE | re.IGNORECASE)
|
|
29
|
-
|
|
29
|
+
TRIM_TRAILING_SPECIAL_CHARS = re.compile(r"[^\w-]+$")
|
|
30
30
|
MIN_QUERY_PARTS = 3
|
|
31
31
|
|
|
32
32
|
|
|
@@ -42,10 +42,8 @@ def _normalize_query_name(name: str) -> str:
|
|
|
42
42
|
Returns:
|
|
43
43
|
Normalized query name suitable as Python identifier
|
|
44
44
|
"""
|
|
45
|
-
#
|
|
46
|
-
|
|
47
|
-
# Then replace hyphens with underscores
|
|
48
|
-
return name.replace("-", "_")
|
|
45
|
+
# Strip trailing non-alphanumeric characters (excluding underscore) and replace hyphens
|
|
46
|
+
return TRIM_TRAILING_SPECIAL_CHARS.sub("", name).replace("-", "_")
|
|
49
47
|
|
|
50
48
|
|
|
51
49
|
@dataclass
|
|
@@ -113,7 +111,7 @@ class SQLFileLoader:
|
|
|
113
111
|
self._query_to_file: dict[str, str] = {} # Maps query name to file path
|
|
114
112
|
|
|
115
113
|
def _read_file_content(self, path: Union[str, Path]) -> str:
|
|
116
|
-
"""Read file content using
|
|
114
|
+
"""Read file content using storage backend.
|
|
117
115
|
|
|
118
116
|
Args:
|
|
119
117
|
path: File path (can be local path or URI).
|
|
@@ -126,37 +124,13 @@ class SQLFileLoader:
|
|
|
126
124
|
"""
|
|
127
125
|
path_str = str(path)
|
|
128
126
|
|
|
129
|
-
# Use storage backend for URIs (anything with a scheme)
|
|
130
|
-
if "://" in path_str:
|
|
131
|
-
try:
|
|
132
|
-
backend = self.storage_registry.get(path_str)
|
|
133
|
-
return backend.read_text(path_str, encoding=self.encoding)
|
|
134
|
-
except KeyError as e:
|
|
135
|
-
raise SQLFileNotFoundError(path_str) from e
|
|
136
|
-
except Exception as e:
|
|
137
|
-
raise SQLFileParseError(path_str, path_str, e) from e
|
|
138
|
-
|
|
139
|
-
# Handle local file paths
|
|
140
|
-
local_path = Path(path_str)
|
|
141
|
-
self._check_file_path(local_path)
|
|
142
|
-
content_bytes = self._read_file_content_bytes(local_path)
|
|
143
|
-
return content_bytes.decode(self.encoding)
|
|
144
|
-
|
|
145
|
-
@staticmethod
|
|
146
|
-
def _read_file_content_bytes(path: Path) -> bytes:
|
|
147
127
|
try:
|
|
148
|
-
|
|
128
|
+
backend = self.storage_registry.get(path)
|
|
129
|
+
return backend.read_text(path_str, encoding=self.encoding)
|
|
130
|
+
except KeyError as e:
|
|
131
|
+
raise SQLFileNotFoundError(path_str) from e
|
|
149
132
|
except Exception as e:
|
|
150
|
-
raise SQLFileParseError(
|
|
151
|
-
|
|
152
|
-
@staticmethod
|
|
153
|
-
def _check_file_path(path: Union[str, Path]) -> None:
|
|
154
|
-
"""Ensure the file exists and is a valid path."""
|
|
155
|
-
path_obj = Path(path).resolve()
|
|
156
|
-
if not path_obj.exists():
|
|
157
|
-
raise SQLFileNotFoundError(str(path_obj))
|
|
158
|
-
if not path_obj.is_file():
|
|
159
|
-
raise SQLFileParseError(str(path_obj), str(path_obj), ValueError("Path is not a file"))
|
|
133
|
+
raise SQLFileParseError(path_str, path_str, e) from e
|
|
160
134
|
|
|
161
135
|
@staticmethod
|
|
162
136
|
def _strip_leading_comments(sql_text: str) -> str:
|
|
@@ -173,48 +147,27 @@ class SQLFileLoader:
|
|
|
173
147
|
|
|
174
148
|
@staticmethod
|
|
175
149
|
def _parse_sql_content(content: str, file_path: str) -> dict[str, str]:
|
|
176
|
-
"""Parse SQL content and extract named queries.
|
|
177
|
-
|
|
178
|
-
Args:
|
|
179
|
-
content: SQL file content.
|
|
180
|
-
file_path: Path to the file (for error messages).
|
|
181
|
-
|
|
182
|
-
Returns:
|
|
183
|
-
Dictionary mapping query names to SQL text.
|
|
184
|
-
|
|
185
|
-
Raises:
|
|
186
|
-
SQLFileParseError: If no named queries found.
|
|
187
|
-
"""
|
|
150
|
+
"""Parse SQL content and extract named queries."""
|
|
188
151
|
queries: dict[str, str] = {}
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
parts = QUERY_NAME_PATTERN.split(content)
|
|
192
|
-
|
|
193
|
-
if len(parts) < MIN_QUERY_PARTS:
|
|
194
|
-
# No named queries found
|
|
152
|
+
matches = list(QUERY_NAME_PATTERN.finditer(content))
|
|
153
|
+
if not matches:
|
|
195
154
|
raise SQLFileParseError(
|
|
196
155
|
file_path, file_path, ValueError("No named SQL statements found (-- name: query_name)")
|
|
197
156
|
)
|
|
198
157
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
raw_query_name = parts[i].strip()
|
|
205
|
-
sql_text = parts[i + 1].strip()
|
|
158
|
+
for i, match in enumerate(matches):
|
|
159
|
+
raw_query_name = match.group(1).strip()
|
|
160
|
+
start_pos = match.end()
|
|
161
|
+
end_pos = matches[i + 1].start() if i + 1 < len(matches) else len(content)
|
|
206
162
|
|
|
163
|
+
sql_text = content[start_pos:end_pos].strip()
|
|
207
164
|
if not raw_query_name or not sql_text:
|
|
208
165
|
continue
|
|
209
166
|
|
|
210
167
|
clean_sql = SQLFileLoader._strip_leading_comments(sql_text)
|
|
211
|
-
|
|
212
168
|
if clean_sql:
|
|
213
|
-
# Normalize to Python-compatible identifier
|
|
214
169
|
query_name = _normalize_query_name(raw_query_name)
|
|
215
|
-
|
|
216
170
|
if query_name in queries:
|
|
217
|
-
# Duplicate query name
|
|
218
171
|
raise SQLFileParseError(file_path, file_path, ValueError(f"Duplicate query name: {raw_query_name}"))
|
|
219
172
|
queries[query_name] = clean_sql
|
|
220
173
|
|
|
@@ -243,19 +196,13 @@ class SQLFileLoader:
|
|
|
243
196
|
try:
|
|
244
197
|
for path in paths:
|
|
245
198
|
path_str = str(path)
|
|
246
|
-
|
|
247
|
-
# Check if it's a URI
|
|
248
199
|
if "://" in path_str:
|
|
249
|
-
# URIs are always treated as files, not directories
|
|
250
200
|
self._load_single_file(path, None)
|
|
251
201
|
loaded_count += 1
|
|
252
202
|
else:
|
|
253
|
-
# Local path - check if it's a directory or file
|
|
254
203
|
path_obj = Path(path)
|
|
255
204
|
if path_obj.is_dir():
|
|
256
|
-
|
|
257
|
-
self._load_directory(path_obj)
|
|
258
|
-
loaded_count += len(self._files) - file_count_before
|
|
205
|
+
loaded_count += self._load_directory(path_obj)
|
|
259
206
|
else:
|
|
260
207
|
self._load_single_file(path_obj, None)
|
|
261
208
|
loaded_count += 1
|
|
@@ -289,31 +236,18 @@ class SQLFileLoader:
|
|
|
289
236
|
)
|
|
290
237
|
raise
|
|
291
238
|
|
|
292
|
-
def _load_directory(self, dir_path: Path) ->
|
|
293
|
-
"""Load all SQL files from a directory with namespacing.
|
|
294
|
-
|
|
295
|
-
Args:
|
|
296
|
-
dir_path: Directory path to scan for SQL files.
|
|
297
|
-
|
|
298
|
-
Raises:
|
|
299
|
-
SQLFileParseError: If directory contains no SQL files.
|
|
300
|
-
"""
|
|
239
|
+
def _load_directory(self, dir_path: Path) -> int:
|
|
240
|
+
"""Load all SQL files from a directory with namespacing."""
|
|
301
241
|
sql_files = list(dir_path.rglob("*.sql"))
|
|
302
|
-
|
|
303
242
|
if not sql_files:
|
|
304
|
-
|
|
305
|
-
str(dir_path), str(dir_path), ValueError(f"No SQL files found in directory: {dir_path}")
|
|
306
|
-
)
|
|
243
|
+
return 0
|
|
307
244
|
|
|
308
245
|
for file_path in sql_files:
|
|
309
|
-
# Calculate namespace based on relative path from base directory
|
|
310
246
|
relative_path = file_path.relative_to(dir_path)
|
|
311
247
|
namespace_parts = relative_path.parent.parts
|
|
312
|
-
|
|
313
|
-
# Create namespace (empty for root-level files)
|
|
314
248
|
namespace = ".".join(namespace_parts) if namespace_parts else None
|
|
315
|
-
|
|
316
249
|
self._load_single_file(file_path, namespace)
|
|
250
|
+
return len(sql_files)
|
|
317
251
|
|
|
318
252
|
def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
|
|
319
253
|
"""Load a single SQL file with optional namespace.
|
|
@@ -324,42 +258,24 @@ class SQLFileLoader:
|
|
|
324
258
|
"""
|
|
325
259
|
path_str = str(file_path)
|
|
326
260
|
|
|
327
|
-
# Check if already loaded
|
|
328
261
|
if path_str in self._files:
|
|
329
|
-
#
|
|
330
|
-
file_obj = self._files[path_str]
|
|
331
|
-
queries = self._parse_sql_content(file_obj.content, path_str)
|
|
332
|
-
for name in queries:
|
|
333
|
-
namespaced_name = f"{namespace}.{name}" if namespace else name
|
|
334
|
-
if namespaced_name not in self._queries:
|
|
335
|
-
self._queries[namespaced_name] = queries[name]
|
|
336
|
-
self._query_to_file[namespaced_name] = path_str
|
|
337
|
-
return
|
|
338
|
-
|
|
339
|
-
# Read file content
|
|
340
|
-
content = self._read_file_content(file_path)
|
|
262
|
+
return # Already loaded
|
|
341
263
|
|
|
342
|
-
|
|
264
|
+
content = self._read_file_content(file_path)
|
|
343
265
|
sql_file = SQLFile(content=content, path=path_str)
|
|
344
|
-
|
|
345
|
-
# Cache the file
|
|
346
266
|
self._files[path_str] = sql_file
|
|
347
267
|
|
|
348
|
-
# Parse and cache queries
|
|
349
268
|
queries = self._parse_sql_content(content, path_str)
|
|
350
|
-
|
|
351
|
-
# Merge into main query dictionary with namespace
|
|
352
269
|
for name, sql in queries.items():
|
|
353
270
|
namespaced_name = f"{namespace}.{name}" if namespace else name
|
|
354
|
-
|
|
355
|
-
if namespaced_name in self._queries and self._query_to_file.get(namespaced_name) != path_str:
|
|
356
|
-
# Query name exists from a different file
|
|
271
|
+
if namespaced_name in self._queries:
|
|
357
272
|
existing_file = self._query_to_file.get(namespaced_name, "unknown")
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
273
|
+
if existing_file != path_str:
|
|
274
|
+
raise SQLFileParseError(
|
|
275
|
+
path_str,
|
|
276
|
+
path_str,
|
|
277
|
+
ValueError(f"Query name '{namespaced_name}' already exists in file: {existing_file}"),
|
|
278
|
+
)
|
|
363
279
|
self._queries[namespaced_name] = sql
|
|
364
280
|
self._query_to_file[namespaced_name] = path_str
|
|
365
281
|
|
|
@@ -379,7 +295,6 @@ class SQLFileLoader:
|
|
|
379
295
|
raise ValueError(msg)
|
|
380
296
|
|
|
381
297
|
self._queries[name] = sql.strip()
|
|
382
|
-
# Use special marker for directly added queries
|
|
383
298
|
self._query_to_file[name] = "<directly added>"
|
|
384
299
|
|
|
385
300
|
def get_sql(self, name: str, parameters: "Optional[Any]" = None, **kwargs: "Any") -> "SQL":
|
|
@@ -427,12 +342,10 @@ class SQLFileLoader:
|
|
|
427
342
|
)
|
|
428
343
|
raise SQLFileNotFoundError(name, path=f"Query '{name}' not found. Available queries: {available}")
|
|
429
344
|
|
|
430
|
-
# Merge parameters and kwargs for SQL object creation
|
|
431
345
|
sql_kwargs = dict(kwargs)
|
|
432
346
|
if parameters is not None:
|
|
433
347
|
sql_kwargs["parameters"] = parameters
|
|
434
348
|
|
|
435
|
-
# Get source file for additional context
|
|
436
349
|
source_file = self._query_to_file.get(safe_name, "unknown")
|
|
437
350
|
|
|
438
351
|
logger.debug(
|