sqlspec 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +15 -0
- sqlspec/_serialization.py +16 -2
- sqlspec/_typing.py +40 -7
- sqlspec/adapters/adbc/__init__.py +7 -0
- sqlspec/adapters/adbc/config.py +183 -17
- sqlspec/adapters/adbc/driver.py +392 -0
- sqlspec/adapters/aiosqlite/__init__.py +5 -1
- sqlspec/adapters/aiosqlite/config.py +24 -6
- sqlspec/adapters/aiosqlite/driver.py +264 -0
- sqlspec/adapters/asyncmy/__init__.py +7 -2
- sqlspec/adapters/asyncmy/config.py +71 -11
- sqlspec/adapters/asyncmy/driver.py +246 -0
- sqlspec/adapters/asyncpg/__init__.py +9 -0
- sqlspec/adapters/asyncpg/config.py +102 -25
- sqlspec/adapters/asyncpg/driver.py +444 -0
- sqlspec/adapters/duckdb/__init__.py +5 -1
- sqlspec/adapters/duckdb/config.py +194 -12
- sqlspec/adapters/duckdb/driver.py +225 -0
- sqlspec/adapters/oracledb/__init__.py +7 -4
- sqlspec/adapters/oracledb/config/__init__.py +4 -4
- sqlspec/adapters/oracledb/config/_asyncio.py +96 -12
- sqlspec/adapters/oracledb/config/_common.py +1 -1
- sqlspec/adapters/oracledb/config/_sync.py +96 -12
- sqlspec/adapters/oracledb/driver.py +571 -0
- sqlspec/adapters/psqlpy/__init__.py +0 -0
- sqlspec/adapters/psqlpy/config.py +258 -0
- sqlspec/adapters/psqlpy/driver.py +335 -0
- sqlspec/adapters/psycopg/__init__.py +16 -0
- sqlspec/adapters/psycopg/config/__init__.py +6 -6
- sqlspec/adapters/psycopg/config/_async.py +107 -15
- sqlspec/adapters/psycopg/config/_common.py +2 -2
- sqlspec/adapters/psycopg/config/_sync.py +107 -15
- sqlspec/adapters/psycopg/driver.py +578 -0
- sqlspec/adapters/sqlite/__init__.py +7 -0
- sqlspec/adapters/sqlite/config.py +24 -6
- sqlspec/adapters/sqlite/driver.py +305 -0
- sqlspec/base.py +565 -63
- sqlspec/exceptions.py +30 -0
- sqlspec/extensions/litestar/__init__.py +19 -0
- sqlspec/extensions/litestar/_utils.py +56 -0
- sqlspec/extensions/litestar/config.py +87 -0
- sqlspec/extensions/litestar/handlers.py +213 -0
- sqlspec/extensions/litestar/plugin.py +105 -11
- sqlspec/statement.py +373 -0
- sqlspec/typing.py +81 -17
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/fixtures.py +4 -5
- sqlspec/utils/sync_tools.py +335 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/METADATA +4 -1
- sqlspec-0.9.0.dist-info/RECORD +61 -0
- sqlspec-0.7.1.dist-info/RECORD +0 -46
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/statement.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
# ruff: noqa: RUF100, PLR6301, PLR0912, PLR0915, C901, PLR0911, PLR0914, N806
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from functools import cached_property
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
Optional,
|
|
9
|
+
Union,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
import sqlglot
|
|
13
|
+
from sqlglot import exp
|
|
14
|
+
|
|
15
|
+
from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError
|
|
16
|
+
from sqlspec.typing import StatementParameterType
|
|
17
|
+
|
|
18
|
+
__all__ = ("SQLStatement",)
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger("sqlspec")
|
|
21
|
+
|
|
22
|
+
# Regex to find :param style placeholders, skipping those inside quotes or SQL comments
|
|
23
|
+
# Adapted from previous version in psycopg adapter
|
|
24
|
+
PARAM_REGEX = re.compile(
|
|
25
|
+
r"""(?<![:\w]) # Negative lookbehind to avoid matching things like ::type or \:escaped
|
|
26
|
+
(?:
|
|
27
|
+
(?P<dquote>"(?:[^"]|"")*") | # Double-quoted strings (support SQL standard escaping "")
|
|
28
|
+
(?P<squote>'(?:[^']|'')*') | # Single-quoted strings (support SQL standard escaping '')
|
|
29
|
+
(?P<comment>--.*?\n|\/\*.*?\*\/) | # SQL comments (single line or multi-line)
|
|
30
|
+
: (?P<var_name>[a-zA-Z_][a-zA-Z0-9_]*) # :var_name identifier
|
|
31
|
+
)
|
|
32
|
+
""",
|
|
33
|
+
re.VERBOSE | re.DOTALL,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass()
|
|
38
|
+
class SQLStatement:
|
|
39
|
+
"""An immutable representation of a SQL statement with its parameters.
|
|
40
|
+
|
|
41
|
+
This class encapsulates the SQL statement and its parameters, providing
|
|
42
|
+
a clean interface for parameter binding and SQL statement formatting.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
dialect: str
|
|
46
|
+
"""The SQL dialect to use for parsing (e.g., 'postgres', 'mysql'). Defaults to 'postgres' if None."""
|
|
47
|
+
sql: str
|
|
48
|
+
"""The raw SQL statement."""
|
|
49
|
+
parameters: Optional[StatementParameterType] = None
|
|
50
|
+
"""The parameters for the SQL statement."""
|
|
51
|
+
kwargs: Optional[dict[str, Any]] = None
|
|
52
|
+
"""Keyword arguments passed for parameter binding."""
|
|
53
|
+
|
|
54
|
+
_merged_parameters: Optional[Union[StatementParameterType, dict[str, Any]]] = None
|
|
55
|
+
|
|
56
|
+
def __post_init__(self) -> None:
|
|
57
|
+
"""Merge parameters and kwargs after initialization."""
|
|
58
|
+
merged_params = self.parameters
|
|
59
|
+
|
|
60
|
+
if self.kwargs:
|
|
61
|
+
if merged_params is None:
|
|
62
|
+
merged_params = self.kwargs
|
|
63
|
+
elif isinstance(merged_params, dict):
|
|
64
|
+
# Merge kwargs into parameters dict, kwargs take precedence
|
|
65
|
+
merged_params = {**merged_params, **self.kwargs}
|
|
66
|
+
else:
|
|
67
|
+
# If parameters is sequence or scalar, kwargs replace it
|
|
68
|
+
# Consider adding a warning here if this behavior is surprising
|
|
69
|
+
merged_params = self.kwargs
|
|
70
|
+
|
|
71
|
+
self._merged_parameters = merged_params
|
|
72
|
+
|
|
73
|
+
def process(self) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
|
|
74
|
+
"""Process the SQL statement and merged parameters for execution.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
A tuple containing the processed SQL string and the processed parameters
|
|
78
|
+
ready for database driver execution.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
SQLParsingError: If the SQL statement contains parameter placeholders, but no parameters were provided.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
A tuple containing the processed SQL string and the processed parameters
|
|
85
|
+
ready for database driver execution.
|
|
86
|
+
"""
|
|
87
|
+
if self._merged_parameters is None:
|
|
88
|
+
# Validate that the SQL doesn't expect parameters if none were provided
|
|
89
|
+
# Parse ONLY if we need to validate
|
|
90
|
+
try: # Add try/except in case parsing fails even here
|
|
91
|
+
expression = self._parse_sql()
|
|
92
|
+
except SQLParsingError:
|
|
93
|
+
# If parsing fails, we can't validate, but maybe that's okay if no params were passed?
|
|
94
|
+
# Log a warning? For now, let the original error propagate if needed.
|
|
95
|
+
# Or, maybe assume it's okay if _merged_parameters is None?
|
|
96
|
+
# Let's re-raise for now, as unparsable SQL is usually bad.
|
|
97
|
+
logger.warning("SQL statement is unparsable: %s", self.sql)
|
|
98
|
+
return self.sql, None
|
|
99
|
+
if list(expression.find_all(exp.Parameter)):
|
|
100
|
+
msg = "SQL statement contains parameter placeholders, but no parameters were provided."
|
|
101
|
+
raise SQLParsingError(msg)
|
|
102
|
+
return self.sql, None
|
|
103
|
+
|
|
104
|
+
if isinstance(self._merged_parameters, dict):
|
|
105
|
+
# Pass only the dict, parsing happens inside
|
|
106
|
+
return self._process_dict_params(self._merged_parameters)
|
|
107
|
+
|
|
108
|
+
if isinstance(self._merged_parameters, (tuple, list)):
|
|
109
|
+
# Pass only the sequence, parsing happens inside if needed for validation
|
|
110
|
+
return self._process_sequence_params(self._merged_parameters)
|
|
111
|
+
|
|
112
|
+
# Assume it's a single scalar value otherwise
|
|
113
|
+
# Pass only the value, parsing happens inside for validation
|
|
114
|
+
return self._process_scalar_param(self._merged_parameters)
|
|
115
|
+
|
|
116
|
+
def _parse_sql(self) -> exp.Expression:
|
|
117
|
+
"""Parse the SQL using sqlglot.
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
SQLParsingError: If the SQL statement cannot be parsed.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
The parsed SQL expression.
|
|
124
|
+
"""
|
|
125
|
+
parse_dialect = self.dialect or "postgres"
|
|
126
|
+
try:
|
|
127
|
+
read_dialect = parse_dialect or None
|
|
128
|
+
return sqlglot.parse_one(self.sql, read=read_dialect)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
# Ensure the original sqlglot error message is included
|
|
131
|
+
error_detail = str(e)
|
|
132
|
+
msg = f"Failed to parse SQL with dialect '{parse_dialect or 'auto-detected'}': {error_detail}\nSQL: {self.sql}"
|
|
133
|
+
raise SQLParsingError(msg) from e
|
|
134
|
+
|
|
135
|
+
def _process_dict_params(
|
|
136
|
+
self,
|
|
137
|
+
parameter_dict: dict[str, Any],
|
|
138
|
+
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
|
|
139
|
+
"""Processes dictionary parameters based on dialect capabilities.
|
|
140
|
+
|
|
141
|
+
Raises:
|
|
142
|
+
ParameterStyleMismatchError: If the SQL statement contains unnamed placeholders (e.g., '?') in the SQL query.
|
|
143
|
+
SQLParsingError: If the SQL statement contains named parameters, but no parameters were provided.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
A tuple containing the processed SQL string and the processed parameters
|
|
147
|
+
ready for database driver execution.
|
|
148
|
+
"""
|
|
149
|
+
# Attempt to parse with sqlglot first (for other dialects like postgres, mysql)
|
|
150
|
+
named_sql_params: Optional[list[exp.Parameter]] = None
|
|
151
|
+
unnamed_sql_params: Optional[list[exp.Parameter]] = None
|
|
152
|
+
sqlglot_parsed_ok = False
|
|
153
|
+
# --- Dialect-Specific Bypasses for Native Handling ---
|
|
154
|
+
if self.dialect == "sqlite": # Handles :name natively
|
|
155
|
+
return self.sql, parameter_dict
|
|
156
|
+
|
|
157
|
+
# Add bypass for postgres handled by specific adapters (e.g., asyncpg)
|
|
158
|
+
if self.dialect == "postgres":
|
|
159
|
+
# The adapter (e.g., asyncpg) will handle :name -> $n conversion.
|
|
160
|
+
# SQLStatement just validates parameters against the original SQL here.
|
|
161
|
+
# Perform validation using regex if sqlglot parsing fails, otherwise use sqlglot.
|
|
162
|
+
try:
|
|
163
|
+
expression = self._parse_sql()
|
|
164
|
+
sql_params = list(expression.find_all(exp.Parameter))
|
|
165
|
+
named_sql_params = [p for p in sql_params if p.name]
|
|
166
|
+
unnamed_sql_params = [p for p in sql_params if not p.name]
|
|
167
|
+
|
|
168
|
+
if unnamed_sql_params:
|
|
169
|
+
msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') found by sqlglot for postgres."
|
|
170
|
+
raise ParameterStyleMismatchError(msg)
|
|
171
|
+
|
|
172
|
+
# Validate keys using sqlglot results
|
|
173
|
+
required_keys = {p.name for p in named_sql_params}
|
|
174
|
+
provided_keys = set(parameter_dict.keys())
|
|
175
|
+
missing_keys = required_keys - provided_keys
|
|
176
|
+
if missing_keys:
|
|
177
|
+
msg = (
|
|
178
|
+
f"Named parameters found in SQL (via sqlglot) but not provided: {missing_keys}. SQL: {self.sql}"
|
|
179
|
+
)
|
|
180
|
+
raise SQLParsingError(msg) # noqa: TRY301
|
|
181
|
+
# Allow extra keys
|
|
182
|
+
|
|
183
|
+
except SQLParsingError as e:
|
|
184
|
+
logger.debug("SQLglot parsing failed for postgres dict params, attempting regex validation: %s", e)
|
|
185
|
+
# Regex validation fallback (without conversion)
|
|
186
|
+
postgres_found_params_regex: list[str] = []
|
|
187
|
+
for match in PARAM_REGEX.finditer(self.sql):
|
|
188
|
+
if match.group("dquote") or match.group("squote") or match.group("comment"):
|
|
189
|
+
continue
|
|
190
|
+
if match.group("var_name"):
|
|
191
|
+
var_name = match.group("var_name")
|
|
192
|
+
postgres_found_params_regex.append(var_name)
|
|
193
|
+
if var_name not in parameter_dict:
|
|
194
|
+
msg = f"Named parameter ':{var_name}' found in SQL (via regex) but not provided. SQL: {self.sql}"
|
|
195
|
+
raise SQLParsingError(msg) # noqa: B904
|
|
196
|
+
|
|
197
|
+
if not postgres_found_params_regex and parameter_dict:
|
|
198
|
+
msg = f"Dictionary parameters provided, but no named placeholders (:name) found via regex. SQL: {self.sql}"
|
|
199
|
+
raise ParameterStyleMismatchError(msg) # noqa: B904
|
|
200
|
+
# Allow extra keys with regex check too
|
|
201
|
+
|
|
202
|
+
# Return the *original* SQL and the processed dict for the adapter to handle
|
|
203
|
+
return self.sql, parameter_dict
|
|
204
|
+
|
|
205
|
+
if self.dialect == "duckdb": # Handles $name natively (and :name via driver? Check driver docs)
|
|
206
|
+
# Bypass sqlglot/regex checks. Trust user SQL ($name or ?) + dict for DuckDB driver.
|
|
207
|
+
# We lose :name -> $name conversion *if* sqlglot parsing fails, but avoid errors on valid $name SQL.
|
|
208
|
+
return self.sql, parameter_dict
|
|
209
|
+
# --- End Bypasses ---
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
expression = self._parse_sql()
|
|
213
|
+
sql_params = list(expression.find_all(exp.Parameter))
|
|
214
|
+
named_sql_params = [p for p in sql_params if p.name]
|
|
215
|
+
unnamed_sql_params = [p for p in sql_params if not p.name]
|
|
216
|
+
sqlglot_parsed_ok = True
|
|
217
|
+
logger.debug("SQLglot parsed dict params successfully for: %s", self.sql)
|
|
218
|
+
except SQLParsingError as e:
|
|
219
|
+
logger.debug("SQLglot parsing failed for dict params, attempting regex fallback: %s", e)
|
|
220
|
+
# Proceed using regex fallback below
|
|
221
|
+
|
|
222
|
+
# Check for unnamed placeholders if parsing worked
|
|
223
|
+
if sqlglot_parsed_ok and unnamed_sql_params:
|
|
224
|
+
msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') found by sqlglot."
|
|
225
|
+
raise ParameterStyleMismatchError(msg)
|
|
226
|
+
|
|
227
|
+
# Determine if we need to use regex fallback
|
|
228
|
+
# Use fallback if: parsing failed OR (parsing worked BUT found no named params when a dict was provided)
|
|
229
|
+
use_regex_fallback = not sqlglot_parsed_ok or (not named_sql_params and parameter_dict)
|
|
230
|
+
|
|
231
|
+
if use_regex_fallback:
|
|
232
|
+
# Regex fallback logic for :name -> self.param_style conversion
|
|
233
|
+
# ... (regex fallback code as implemented previously) ...
|
|
234
|
+
logger.debug("Using regex fallback for dict param processing: %s", self.sql)
|
|
235
|
+
# --- Regex Fallback Logic ---
|
|
236
|
+
regex_processed_sql_parts: list[str] = []
|
|
237
|
+
ordered_params = []
|
|
238
|
+
last_end = 0
|
|
239
|
+
regex_found_params: list[str] = []
|
|
240
|
+
|
|
241
|
+
for match in PARAM_REGEX.finditer(self.sql):
|
|
242
|
+
# Skip matches that are comments or quoted strings
|
|
243
|
+
if match.group("dquote") or match.group("squote") or match.group("comment"):
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
if match.group("var_name"):
|
|
247
|
+
var_name = match.group("var_name")
|
|
248
|
+
regex_found_params.append(var_name)
|
|
249
|
+
# Get start and end from the match object for the :var_name part
|
|
250
|
+
# The var_name group itself doesn't include the leading :, so adjust start.
|
|
251
|
+
start = match.start("var_name") - 1
|
|
252
|
+
end = match.end("var_name")
|
|
253
|
+
|
|
254
|
+
if var_name not in parameter_dict:
|
|
255
|
+
msg = (
|
|
256
|
+
f"Named parameter ':{var_name}' found in SQL (via regex) but not provided. SQL: {self.sql}"
|
|
257
|
+
)
|
|
258
|
+
raise SQLParsingError(msg)
|
|
259
|
+
|
|
260
|
+
regex_processed_sql_parts.extend((self.sql[last_end:start], self.param_style)) # Use target style
|
|
261
|
+
ordered_params.append(parameter_dict[var_name])
|
|
262
|
+
last_end = end
|
|
263
|
+
|
|
264
|
+
regex_processed_sql_parts.append(self.sql[last_end:])
|
|
265
|
+
|
|
266
|
+
# Validation with regex results
|
|
267
|
+
if not regex_found_params and parameter_dict:
|
|
268
|
+
msg = f"Dictionary parameters provided, but no named placeholders (e.g., :name) found via regex in the SQL query for dialect '{self.dialect}'. SQL: {self.sql}"
|
|
269
|
+
raise ParameterStyleMismatchError(msg)
|
|
270
|
+
|
|
271
|
+
provided_keys = set(parameter_dict.keys())
|
|
272
|
+
missing_keys = set(regex_found_params) - provided_keys # Should be caught above, but double check
|
|
273
|
+
if missing_keys:
|
|
274
|
+
msg = f"Named parameters found in SQL (via regex) but not provided: {missing_keys}. SQL: {self.sql}"
|
|
275
|
+
raise SQLParsingError(msg)
|
|
276
|
+
|
|
277
|
+
extra_keys = provided_keys - set(regex_found_params)
|
|
278
|
+
if extra_keys:
|
|
279
|
+
# Allow extra keys
|
|
280
|
+
pass
|
|
281
|
+
|
|
282
|
+
return "".join(regex_processed_sql_parts), tuple(ordered_params)
|
|
283
|
+
|
|
284
|
+
# Sqlglot Logic (if parsing worked and found params)
|
|
285
|
+
# ... (sqlglot logic as implemented previously, including :name -> %s conversion) ...
|
|
286
|
+
logger.debug("Using sqlglot results for dict param processing: %s", self.sql)
|
|
287
|
+
|
|
288
|
+
# Ensure named_sql_params is iterable, default to empty list if None (shouldn't happen ideally)
|
|
289
|
+
active_named_params = named_sql_params or []
|
|
290
|
+
|
|
291
|
+
if not active_named_params and not parameter_dict:
|
|
292
|
+
# No SQL params found by sqlglot, no provided params dict -> OK
|
|
293
|
+
return self.sql, ()
|
|
294
|
+
|
|
295
|
+
# Validation with sqlglot results
|
|
296
|
+
required_keys = {p.name for p in active_named_params} # Use active_named_params
|
|
297
|
+
provided_keys = set(parameter_dict.keys())
|
|
298
|
+
|
|
299
|
+
missing_keys = required_keys - provided_keys
|
|
300
|
+
if missing_keys:
|
|
301
|
+
msg = f"Named parameters found in SQL (via sqlglot) but not provided: {missing_keys}. SQL: {self.sql}"
|
|
302
|
+
raise SQLParsingError(msg)
|
|
303
|
+
|
|
304
|
+
extra_keys = provided_keys - required_keys
|
|
305
|
+
if extra_keys:
|
|
306
|
+
pass # Allow extra keys
|
|
307
|
+
|
|
308
|
+
# Note: DuckDB handled by bypass above if sqlglot fails.
|
|
309
|
+
# This block handles successful sqlglot parse for other dialects.
|
|
310
|
+
# We don't need the specific DuckDB $name conversion here anymore,
|
|
311
|
+
# as the bypass handles the native $name case.
|
|
312
|
+
# The general logic converts :name -> self.param_style for dialects like postgres.
|
|
313
|
+
# if self.dialect == "duckdb": ... (Removed specific block here)
|
|
314
|
+
|
|
315
|
+
# For other dialects requiring positional conversion (using sqlglot param info):
|
|
316
|
+
sqlglot_processed_parts: list[str] = []
|
|
317
|
+
ordered_params = []
|
|
318
|
+
last_end = 0
|
|
319
|
+
for param in active_named_params: # Use active_named_params
|
|
320
|
+
start = param.this.this.start
|
|
321
|
+
end = param.this.this.end
|
|
322
|
+
sqlglot_processed_parts.extend((self.sql[last_end:start], self.param_style))
|
|
323
|
+
ordered_params.append(parameter_dict[param.name])
|
|
324
|
+
last_end = end
|
|
325
|
+
sqlglot_processed_parts.append(self.sql[last_end:])
|
|
326
|
+
return "".join(sqlglot_processed_parts), tuple(ordered_params)
|
|
327
|
+
|
|
328
|
+
def _process_sequence_params(
|
|
329
|
+
self, params: Union[tuple[Any, ...], list[Any]]
|
|
330
|
+
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
|
|
331
|
+
"""Processes a sequence of parameters.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
A tuple containing the processed SQL string and the processed parameters
|
|
335
|
+
ready for database driver execution.
|
|
336
|
+
"""
|
|
337
|
+
return self.sql, params
|
|
338
|
+
|
|
339
|
+
def _process_scalar_param(
|
|
340
|
+
self, param_value: Any
|
|
341
|
+
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
|
|
342
|
+
"""Processes a single scalar parameter value.
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
A tuple containing the processed SQL string and the processed parameters
|
|
346
|
+
ready for database driver execution.
|
|
347
|
+
"""
|
|
348
|
+
return self.sql, (param_value,)
|
|
349
|
+
|
|
350
|
+
@cached_property
|
|
351
|
+
def param_style(self) -> str:
|
|
352
|
+
"""Get the parameter style based on the dialect.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
The parameter style placeholder for the dialect.
|
|
356
|
+
"""
|
|
357
|
+
dialect = self.dialect
|
|
358
|
+
|
|
359
|
+
# Map dialects to parameter styles for placeholder replacement
|
|
360
|
+
# Note: Used when converting named params (:name) for dialects needing positional.
|
|
361
|
+
# Dialects supporting named params natively (SQLite, DuckDB) are handled via bypasses.
|
|
362
|
+
dialect_to_param_style = {
|
|
363
|
+
"postgres": "%s",
|
|
364
|
+
"mysql": "%s",
|
|
365
|
+
"oracle": ":1",
|
|
366
|
+
"mssql": "?",
|
|
367
|
+
"bigquery": "?",
|
|
368
|
+
"snowflake": "?",
|
|
369
|
+
"cockroach": "%s",
|
|
370
|
+
"db2": "?",
|
|
371
|
+
}
|
|
372
|
+
# Default to '?' for unknown/unhandled dialects or when dialect=None is forced
|
|
373
|
+
return dialect_to_param_style.get(dialect, "?")
|
sqlspec/typing.py
CHANGED
|
@@ -7,10 +7,13 @@ from typing_extensions import TypeAlias, TypeGuard
|
|
|
7
7
|
from sqlspec._typing import (
|
|
8
8
|
LITESTAR_INSTALLED,
|
|
9
9
|
MSGSPEC_INSTALLED,
|
|
10
|
+
PYARROW_INSTALLED,
|
|
10
11
|
PYDANTIC_INSTALLED,
|
|
11
12
|
UNSET,
|
|
13
|
+
ArrowTable,
|
|
12
14
|
BaseModel,
|
|
13
15
|
DataclassProtocol,
|
|
16
|
+
DTOData,
|
|
14
17
|
Empty,
|
|
15
18
|
EmptyType,
|
|
16
19
|
Struct,
|
|
@@ -38,26 +41,53 @@ FilterTypeT = TypeVar("FilterTypeT", bound="StatementFilter")
|
|
|
38
41
|
|
|
39
42
|
:class:`~advanced_alchemy.filters.StatementFilter`
|
|
40
43
|
"""
|
|
44
|
+
SupportedSchemaModel: TypeAlias = "Union[Struct, BaseModel, DataclassProtocol]"
|
|
45
|
+
"""Type alias for pydantic or msgspec models.
|
|
41
46
|
|
|
47
|
+
:class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`DataclassProtocol`
|
|
48
|
+
"""
|
|
49
|
+
ModelDTOT = TypeVar("ModelDTOT", bound="SupportedSchemaModel")
|
|
50
|
+
"""Type variable for model DTOs.
|
|
42
51
|
|
|
43
|
-
|
|
52
|
+
:class:`msgspec.Struct`|:class:`pydantic.BaseModel`
|
|
53
|
+
"""
|
|
54
|
+
PydanticOrMsgspecT = SupportedSchemaModel
|
|
44
55
|
"""Type alias for pydantic or msgspec models.
|
|
45
56
|
|
|
46
57
|
:class:`msgspec.Struct` or :class:`pydantic.BaseModel`
|
|
47
58
|
"""
|
|
48
|
-
|
|
59
|
+
ModelDict: TypeAlias = "Union[dict[str, Any], SupportedSchemaModel, DTOData[SupportedSchemaModel]]"
|
|
49
60
|
"""Type alias for model dictionaries.
|
|
50
61
|
|
|
51
62
|
Represents:
|
|
52
63
|
- :type:`dict[str, Any]` | :class:`DataclassProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel`
|
|
53
64
|
"""
|
|
54
|
-
|
|
65
|
+
ModelDictList: TypeAlias = "Sequence[Union[dict[str, Any], SupportedSchemaModel]]"
|
|
55
66
|
"""Type alias for model dictionary lists.
|
|
56
67
|
|
|
57
68
|
A list or sequence of any of the following:
|
|
58
69
|
- :type:`Sequence`[:type:`dict[str, Any]` | :class:`DataclassProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel`]
|
|
59
70
|
|
|
60
71
|
"""
|
|
72
|
+
BulkModelDict: TypeAlias = (
|
|
73
|
+
"Union[Sequence[Union[dict[str, Any], SupportedSchemaModel]], DTOData[list[SupportedSchemaModel]]]"
|
|
74
|
+
)
|
|
75
|
+
"""Type alias for bulk model dictionaries.
|
|
76
|
+
|
|
77
|
+
Represents:
|
|
78
|
+
- :type:`Sequence`[:type:`dict[str, Any]` | :class:`DataclassProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel`]
|
|
79
|
+
- :class:`DTOData`[:type:`list[ModelT]`]
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
StatementParameterType: TypeAlias = "Union[Any, dict[str, Any], list[Any], tuple[Any, ...], None]"
|
|
83
|
+
"""Type alias for parameter types.
|
|
84
|
+
|
|
85
|
+
Represents:
|
|
86
|
+
- :type:`dict[str, Any]`
|
|
87
|
+
- :type:`list[Any]`
|
|
88
|
+
- :type:`tuple[Any, ...]`
|
|
89
|
+
- :type:`None`
|
|
90
|
+
"""
|
|
61
91
|
|
|
62
92
|
|
|
63
93
|
def is_dataclass_instance(obj: Any) -> "TypeGuard[DataclassProtocol]":
|
|
@@ -286,7 +316,14 @@ def is_schema_or_dict_without_field(
|
|
|
286
316
|
|
|
287
317
|
|
|
288
318
|
def is_dataclass(obj: "Any") -> "TypeGuard[DataclassProtocol]":
|
|
289
|
-
"""Check if an object is a dataclass.
|
|
319
|
+
"""Check if an object is a dataclass.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
obj: Value to check.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
bool
|
|
326
|
+
"""
|
|
290
327
|
return is_dataclass_instance(obj)
|
|
291
328
|
|
|
292
329
|
|
|
@@ -294,17 +331,33 @@ def is_dataclass_with_field(
|
|
|
294
331
|
obj: "Any",
|
|
295
332
|
field_name: str,
|
|
296
333
|
) -> "TypeGuard[object]": # Can't specify dataclass type directly
|
|
297
|
-
"""Check if an object is a dataclass and has a specific field.
|
|
334
|
+
"""Check if an object is a dataclass and has a specific field.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
obj: Value to check.
|
|
338
|
+
field_name: Field name to check for.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
bool
|
|
342
|
+
"""
|
|
298
343
|
return is_dataclass(obj) and hasattr(obj, field_name)
|
|
299
344
|
|
|
300
345
|
|
|
301
346
|
def is_dataclass_without_field(obj: "Any", field_name: str) -> "TypeGuard[object]":
|
|
302
|
-
"""Check if an object is a dataclass and does not have a specific field.
|
|
347
|
+
"""Check if an object is a dataclass and does not have a specific field.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
obj: Value to check.
|
|
351
|
+
field_name: Field name to check for.
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
bool
|
|
355
|
+
"""
|
|
303
356
|
return is_dataclass(obj) and not hasattr(obj, field_name)
|
|
304
357
|
|
|
305
358
|
|
|
306
359
|
def extract_dataclass_fields(
|
|
307
|
-
|
|
360
|
+
obj: "DataclassProtocol",
|
|
308
361
|
exclude_none: bool = False,
|
|
309
362
|
exclude_empty: bool = False,
|
|
310
363
|
include: "Optional[AbstractSet[str]]" = None,
|
|
@@ -313,12 +366,14 @@ def extract_dataclass_fields(
|
|
|
313
366
|
"""Extract dataclass fields.
|
|
314
367
|
|
|
315
368
|
Args:
|
|
316
|
-
|
|
369
|
+
obj: A dataclass instance.
|
|
317
370
|
exclude_none: Whether to exclude None values.
|
|
318
371
|
exclude_empty: Whether to exclude Empty values.
|
|
319
372
|
include: An iterable of fields to include.
|
|
320
373
|
exclude: An iterable of fields to exclude.
|
|
321
374
|
|
|
375
|
+
Raises:
|
|
376
|
+
ValueError: If there are fields that are both included and excluded.
|
|
322
377
|
|
|
323
378
|
Returns:
|
|
324
379
|
A tuple of dataclass fields.
|
|
@@ -330,11 +385,11 @@ def extract_dataclass_fields(
|
|
|
330
385
|
msg = f"Fields {common} are both included and excluded."
|
|
331
386
|
raise ValueError(msg)
|
|
332
387
|
|
|
333
|
-
dataclass_fields: Iterable[Field[Any]] = fields(
|
|
388
|
+
dataclass_fields: Iterable[Field[Any]] = fields(obj)
|
|
334
389
|
if exclude_none:
|
|
335
|
-
dataclass_fields = (field for field in dataclass_fields if getattr(
|
|
390
|
+
dataclass_fields = (field for field in dataclass_fields if getattr(obj, field.name) is not None)
|
|
336
391
|
if exclude_empty:
|
|
337
|
-
dataclass_fields = (field for field in dataclass_fields if getattr(
|
|
392
|
+
dataclass_fields = (field for field in dataclass_fields if getattr(obj, field.name) is not Empty)
|
|
338
393
|
if include:
|
|
339
394
|
dataclass_fields = (field for field in dataclass_fields if field.name in include)
|
|
340
395
|
if exclude:
|
|
@@ -344,7 +399,7 @@ def extract_dataclass_fields(
|
|
|
344
399
|
|
|
345
400
|
|
|
346
401
|
def extract_dataclass_items(
|
|
347
|
-
|
|
402
|
+
obj: "DataclassProtocol",
|
|
348
403
|
exclude_none: bool = False,
|
|
349
404
|
exclude_empty: bool = False,
|
|
350
405
|
include: "Optional[AbstractSet[str]]" = None,
|
|
@@ -355,7 +410,7 @@ def extract_dataclass_items(
|
|
|
355
410
|
Unlike the 'asdict' method exports by the stdlib, this function does not pickle values.
|
|
356
411
|
|
|
357
412
|
Args:
|
|
358
|
-
|
|
413
|
+
obj: A dataclass instance.
|
|
359
414
|
exclude_none: Whether to exclude None values.
|
|
360
415
|
exclude_empty: Whether to exclude Empty values.
|
|
361
416
|
include: An iterable of fields to include.
|
|
@@ -364,8 +419,8 @@ def extract_dataclass_items(
|
|
|
364
419
|
Returns:
|
|
365
420
|
A tuple of key/value pairs.
|
|
366
421
|
"""
|
|
367
|
-
dataclass_fields = extract_dataclass_fields(
|
|
368
|
-
return tuple((field.name, getattr(
|
|
422
|
+
dataclass_fields = extract_dataclass_fields(obj, exclude_none, exclude_empty, include, exclude)
|
|
423
|
+
return tuple((field.name, getattr(obj, field.name)) for field in dataclass_fields)
|
|
369
424
|
|
|
370
425
|
|
|
371
426
|
def dataclass_to_dict(
|
|
@@ -433,18 +488,22 @@ def schema_dump( # noqa: PLR0911
|
|
|
433
488
|
__all__ = (
|
|
434
489
|
"LITESTAR_INSTALLED",
|
|
435
490
|
"MSGSPEC_INSTALLED",
|
|
491
|
+
"PYARROW_INSTALLED",
|
|
436
492
|
"PYDANTIC_INSTALLED",
|
|
437
493
|
"PYDANTIC_USE_FAILFAST",
|
|
438
494
|
"UNSET",
|
|
495
|
+
"ArrowTable",
|
|
439
496
|
"BaseModel",
|
|
440
497
|
"DataclassProtocol",
|
|
441
498
|
"Empty",
|
|
442
499
|
"EmptyType",
|
|
443
500
|
"FailFast",
|
|
444
501
|
"FilterTypeT",
|
|
445
|
-
"
|
|
446
|
-
"
|
|
502
|
+
"ModelDict",
|
|
503
|
+
"ModelDictList",
|
|
504
|
+
"StatementParameterType",
|
|
447
505
|
"Struct",
|
|
506
|
+
"SupportedSchemaModel",
|
|
448
507
|
"TypeAdapter",
|
|
449
508
|
"UnsetType",
|
|
450
509
|
"convert",
|
|
@@ -484,3 +543,8 @@ if TYPE_CHECKING:
|
|
|
484
543
|
from sqlspec._typing import UNSET, Struct, UnsetType, convert
|
|
485
544
|
else:
|
|
486
545
|
from msgspec import UNSET, Struct, UnsetType, convert # noqa: TC004
|
|
546
|
+
|
|
547
|
+
if not PYARROW_INSTALLED:
|
|
548
|
+
from sqlspec._typing import ArrowTable
|
|
549
|
+
else:
|
|
550
|
+
from pyarrow import Table as ArrowTable # noqa: TC004
|
sqlspec/utils/__init__.py
CHANGED
sqlspec/utils/fixtures.py
CHANGED
|
@@ -19,7 +19,7 @@ def open_fixture(fixtures_path: "Union[Path, AsyncPath]", fixture_name: str) ->
|
|
|
19
19
|
fixture_name (str): The fixture name to load.
|
|
20
20
|
|
|
21
21
|
Raises:
|
|
22
|
-
:
|
|
22
|
+
FileNotFoundError: Fixtures not found.
|
|
23
23
|
|
|
24
24
|
Returns:
|
|
25
25
|
Any: The parsed JSON data
|
|
@@ -43,8 +43,8 @@ async def open_fixture_async(fixtures_path: "Union[Path, AsyncPath]", fixture_na
|
|
|
43
43
|
fixture_name (str): The fixture name to load.
|
|
44
44
|
|
|
45
45
|
Raises:
|
|
46
|
-
:
|
|
47
|
-
:
|
|
46
|
+
FileNotFoundError: Fixtures not found.
|
|
47
|
+
MissingDependencyError: The `anyio` library is required to use this function.
|
|
48
48
|
|
|
49
49
|
Returns:
|
|
50
50
|
Any: The parsed JSON data
|
|
@@ -52,8 +52,7 @@ async def open_fixture_async(fixtures_path: "Union[Path, AsyncPath]", fixture_na
|
|
|
52
52
|
try:
|
|
53
53
|
from anyio import Path as AsyncPath
|
|
54
54
|
except ImportError as exc:
|
|
55
|
-
|
|
56
|
-
raise MissingDependencyError(msg) from exc
|
|
55
|
+
raise MissingDependencyError(package="anyio") from exc
|
|
57
56
|
|
|
58
57
|
fixture = AsyncPath(fixtures_path / f"{fixture_name}.json")
|
|
59
58
|
if await fixture.exists():
|