sqlspec 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_typing.py +24 -32
- sqlspec/adapters/adbc/config.py +1 -1
- sqlspec/adapters/adbc/driver.py +336 -165
- sqlspec/adapters/aiosqlite/driver.py +211 -126
- sqlspec/adapters/asyncmy/driver.py +164 -68
- sqlspec/adapters/asyncpg/config.py +3 -1
- sqlspec/adapters/asyncpg/driver.py +190 -231
- sqlspec/adapters/bigquery/driver.py +178 -169
- sqlspec/adapters/duckdb/driver.py +175 -84
- sqlspec/adapters/oracledb/driver.py +224 -90
- sqlspec/adapters/psqlpy/driver.py +267 -187
- sqlspec/adapters/psycopg/driver.py +138 -184
- sqlspec/adapters/sqlite/driver.py +153 -121
- sqlspec/base.py +57 -45
- sqlspec/extensions/litestar/__init__.py +3 -12
- sqlspec/extensions/litestar/config.py +22 -7
- sqlspec/extensions/litestar/handlers.py +142 -85
- sqlspec/extensions/litestar/plugin.py +9 -8
- sqlspec/extensions/litestar/providers.py +521 -0
- sqlspec/filters.py +214 -11
- sqlspec/mixins.py +152 -2
- sqlspec/statement.py +276 -271
- sqlspec/typing.py +18 -1
- sqlspec/utils/__init__.py +2 -2
- sqlspec/utils/singleton.py +35 -0
- sqlspec/utils/sync_tools.py +90 -151
- sqlspec/utils/text.py +68 -5
- {sqlspec-0.10.0.dist-info → sqlspec-0.11.0.dist-info}/METADATA +5 -1
- {sqlspec-0.10.0.dist-info → sqlspec-0.11.0.dist-info}/RECORD +32 -30
- {sqlspec-0.10.0.dist-info → sqlspec-0.11.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.10.0.dist-info → sqlspec-0.11.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.10.0.dist-info → sqlspec-0.11.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/statement.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
# ruff: noqa: RUF100, PLR6301, PLR0912, PLR0915, C901, PLR0911, PLR0914, N806
|
|
2
2
|
import logging
|
|
3
|
-
import
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
from functools import cached_property
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from dataclasses import dataclass, field
|
|
6
5
|
from typing import (
|
|
6
|
+
TYPE_CHECKING,
|
|
7
7
|
Any,
|
|
8
8
|
Optional,
|
|
9
9
|
Union,
|
|
@@ -15,24 +15,13 @@ from sqlglot import exp
|
|
|
15
15
|
from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError
|
|
16
16
|
from sqlspec.typing import StatementParameterType
|
|
17
17
|
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from sqlspec.filters import StatementFilter
|
|
20
|
+
|
|
18
21
|
__all__ = ("SQLStatement",)
|
|
19
22
|
|
|
20
23
|
logger = logging.getLogger("sqlspec")
|
|
21
24
|
|
|
22
|
-
# Regex to find :param style placeholders, skipping those inside quotes or SQL comments
|
|
23
|
-
# Adapted from previous version in psycopg adapter
|
|
24
|
-
PARAM_REGEX = re.compile(
|
|
25
|
-
r"""(?<![:\w]) # Negative lookbehind to avoid matching things like ::type or \:escaped
|
|
26
|
-
(?:
|
|
27
|
-
(?P<dquote>"(?:[^"]|"")*") | # Double-quoted strings (support SQL standard escaping "")
|
|
28
|
-
(?P<squote>'(?:[^']|'')*') | # Single-quoted strings (support SQL standard escaping '')
|
|
29
|
-
(?P<comment>--.*?\n|\/\*.*?\*\/) | # SQL comments (single line or multi-line)
|
|
30
|
-
: (?P<var_name>[a-zA-Z_][a-zA-Z0-9_]*) # :var_name identifier
|
|
31
|
-
)
|
|
32
|
-
""",
|
|
33
|
-
re.VERBOSE | re.DOTALL,
|
|
34
|
-
)
|
|
35
|
-
|
|
36
25
|
|
|
37
26
|
@dataclass()
|
|
38
27
|
class SQLStatement:
|
|
@@ -42,16 +31,18 @@ class SQLStatement:
|
|
|
42
31
|
a clean interface for parameter binding and SQL statement formatting.
|
|
43
32
|
"""
|
|
44
33
|
|
|
45
|
-
dialect: str
|
|
46
|
-
"""The SQL dialect to use for parsing (e.g., 'postgres', 'mysql'). Defaults to 'postgres' if None."""
|
|
47
34
|
sql: str
|
|
48
35
|
"""The raw SQL statement."""
|
|
49
36
|
parameters: Optional[StatementParameterType] = None
|
|
50
37
|
"""The parameters for the SQL statement."""
|
|
51
38
|
kwargs: Optional[dict[str, Any]] = None
|
|
52
39
|
"""Keyword arguments passed for parameter binding."""
|
|
40
|
+
dialect: Optional[str] = None
|
|
41
|
+
"""SQL dialect to use for parsing. If not provided, sqlglot will try to auto-detect."""
|
|
53
42
|
|
|
54
|
-
_merged_parameters: Optional[Union[StatementParameterType, dict[str, Any]]] = None
|
|
43
|
+
_merged_parameters: Optional[Union[StatementParameterType, dict[str, Any]]] = field(default=None, init=False)
|
|
44
|
+
_parsed_expression: Optional[exp.Expression] = field(default=None, init=False)
|
|
45
|
+
_param_counter: int = field(default=0, init=False)
|
|
55
46
|
|
|
56
47
|
def __post_init__(self) -> None:
|
|
57
48
|
"""Merge parameters and kwargs after initialization."""
|
|
@@ -70,48 +61,72 @@ class SQLStatement:
|
|
|
70
61
|
|
|
71
62
|
self._merged_parameters = merged_params
|
|
72
63
|
|
|
73
|
-
def process(
|
|
64
|
+
def process(
|
|
65
|
+
self,
|
|
66
|
+
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]], Optional[exp.Expression]]":
|
|
74
67
|
"""Process the SQL statement and merged parameters for execution.
|
|
75
68
|
|
|
69
|
+
This method validates the parameters against the SQL statement using sqlglot
|
|
70
|
+
parsing but returns the *original* SQL string, the merged parameters,
|
|
71
|
+
and the parsed sqlglot expression if successful.
|
|
72
|
+
The actual formatting of SQL placeholders and parameter structures for the
|
|
73
|
+
DBAPI driver is delegated to the specific adapter.
|
|
74
|
+
|
|
76
75
|
Returns:
|
|
77
|
-
A tuple containing the
|
|
78
|
-
|
|
76
|
+
A tuple containing the *original* SQL string, the merged/validated
|
|
77
|
+
parameters (dict, tuple, list, or None), and the parsed sqlglot expression
|
|
78
|
+
(or None if parsing failed).
|
|
79
79
|
|
|
80
80
|
Raises:
|
|
81
|
-
SQLParsingError: If the SQL statement contains parameter placeholders
|
|
82
|
-
|
|
83
|
-
Returns:
|
|
84
|
-
A tuple containing the processed SQL string and the processed parameters
|
|
85
|
-
ready for database driver execution.
|
|
81
|
+
SQLParsingError: If the SQL statement contains parameter placeholders
|
|
82
|
+
but no parameters were provided, or if parsing fails unexpectedly.
|
|
86
83
|
"""
|
|
84
|
+
# Parse the SQL to find expected parameters
|
|
85
|
+
try:
|
|
86
|
+
expression = self._parse_sql()
|
|
87
|
+
# Find all parameter expressions (:name, ?, @name, $1, etc.)
|
|
88
|
+
# These are nodes that sqlglot considers as bind parameters.
|
|
89
|
+
all_sqlglot_placeholders = list(expression.find_all(exp.Placeholder, exp.Parameter))
|
|
90
|
+
except SQLParsingError as e:
|
|
91
|
+
logger.debug(
|
|
92
|
+
"SQL parsing failed during validation: %s. Returning original SQL and parameters for adapter.", e
|
|
93
|
+
)
|
|
94
|
+
self._parsed_expression = None
|
|
95
|
+
return self.sql, self._merged_parameters, None
|
|
96
|
+
|
|
87
97
|
if self._merged_parameters is None:
|
|
88
|
-
#
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
98
|
+
# If no parameters were provided, but the parsed SQL expects them, raise an error.
|
|
99
|
+
if all_sqlglot_placeholders:
|
|
100
|
+
placeholder_types_desc = []
|
|
101
|
+
for p_node in all_sqlglot_placeholders:
|
|
102
|
+
if isinstance(p_node, exp.Parameter) and p_node.name:
|
|
103
|
+
placeholder_types_desc.append(f"named (e.g., :{p_node.name}, @{p_node.name})")
|
|
104
|
+
elif (
|
|
105
|
+
isinstance(p_node, exp.Placeholder)
|
|
106
|
+
and p_node.this
|
|
107
|
+
and not isinstance(p_node.this, (exp.Identifier, exp.Literal))
|
|
108
|
+
and not str(p_node.this).isdigit()
|
|
109
|
+
):
|
|
110
|
+
placeholder_types_desc.append(f"named (e.g., :{p_node.this})")
|
|
111
|
+
elif isinstance(p_node, exp.Parameter) and p_node.name and p_node.name.isdigit():
|
|
112
|
+
placeholder_types_desc.append("positional (e.g., $1, :1)")
|
|
113
|
+
elif isinstance(p_node, exp.Placeholder) and p_node.this is None:
|
|
114
|
+
placeholder_types_desc.append("positional (?)")
|
|
115
|
+
desc_str = ", ".join(sorted(set(placeholder_types_desc))) or "unknown"
|
|
116
|
+
msg = f"SQL statement contains {desc_str} parameter placeholders, but no parameters were provided. SQL: {self.sql}"
|
|
101
117
|
raise SQLParsingError(msg)
|
|
102
|
-
return self.sql, None
|
|
118
|
+
return self.sql, None, self._parsed_expression
|
|
103
119
|
|
|
120
|
+
# Validate provided parameters against parsed SQL parameters
|
|
104
121
|
if isinstance(self._merged_parameters, dict):
|
|
105
|
-
|
|
106
|
-
|
|
122
|
+
self._validate_dict_params(all_sqlglot_placeholders, self._merged_parameters)
|
|
123
|
+
elif isinstance(self._merged_parameters, (tuple, list)):
|
|
124
|
+
self._validate_sequence_params(all_sqlglot_placeholders, self._merged_parameters)
|
|
125
|
+
else: # Scalar parameter
|
|
126
|
+
self._validate_scalar_param(all_sqlglot_placeholders, self._merged_parameters)
|
|
107
127
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
return self._process_sequence_params(self._merged_parameters)
|
|
111
|
-
|
|
112
|
-
# Assume it's a single scalar value otherwise
|
|
113
|
-
# Pass only the value, parsing happens inside for validation
|
|
114
|
-
return self._process_scalar_param(self._merged_parameters)
|
|
128
|
+
# Return the original SQL and the merged parameters for the adapter to process
|
|
129
|
+
return self.sql, self._merged_parameters, self._parsed_expression
|
|
115
130
|
|
|
116
131
|
def _parse_sql(self) -> exp.Expression:
|
|
117
132
|
"""Parse the SQL using sqlglot.
|
|
@@ -122,252 +137,242 @@ class SQLStatement:
|
|
|
122
137
|
Returns:
|
|
123
138
|
The parsed SQL expression.
|
|
124
139
|
"""
|
|
125
|
-
parse_dialect = self.dialect or "postgres"
|
|
126
140
|
try:
|
|
127
|
-
|
|
128
|
-
|
|
141
|
+
if not self.sql.strip():
|
|
142
|
+
self._parsed_expression = exp.Select()
|
|
143
|
+
return self._parsed_expression
|
|
144
|
+
# Use the provided dialect if available, otherwise sqlglot will try to auto-detect
|
|
145
|
+
self._parsed_expression = sqlglot.parse_one(self.sql, dialect=self.dialect)
|
|
146
|
+
if self._parsed_expression is None:
|
|
147
|
+
self._parsed_expression = exp.Select() # type: ignore[unreachable]
|
|
129
148
|
except Exception as e:
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
msg = f"Failed to parse SQL with dialect '{parse_dialect or 'auto-detected'}': {error_detail}\nSQL: {self.sql}"
|
|
149
|
+
msg = f"Failed to parse SQL for validation: {e!s}\nSQL: {self.sql}"
|
|
150
|
+
self._parsed_expression = None
|
|
133
151
|
raise SQLParsingError(msg) from e
|
|
152
|
+
else:
|
|
153
|
+
return self._parsed_expression
|
|
154
|
+
|
|
155
|
+
def _validate_dict_params(
|
|
156
|
+
self, all_sqlglot_placeholders: Sequence[exp.Expression], parameter_dict: dict[str, Any]
|
|
157
|
+
) -> None:
|
|
158
|
+
sqlglot_named_params: dict[str, Union[exp.Parameter, exp.Placeholder]] = {}
|
|
159
|
+
has_positional_qmark = False
|
|
160
|
+
|
|
161
|
+
for p_node in all_sqlglot_placeholders:
|
|
162
|
+
if (
|
|
163
|
+
isinstance(p_node, exp.Parameter) and p_node.name and not p_node.name.isdigit()
|
|
164
|
+
): # @name, $name (non-numeric)
|
|
165
|
+
sqlglot_named_params[p_node.name] = p_node
|
|
166
|
+
elif (
|
|
167
|
+
isinstance(p_node, exp.Placeholder)
|
|
168
|
+
and p_node.this
|
|
169
|
+
and not isinstance(p_node.this, (exp.Identifier, exp.Literal))
|
|
170
|
+
and not str(p_node.this).isdigit()
|
|
171
|
+
): # :name
|
|
172
|
+
sqlglot_named_params[str(p_node.this)] = p_node
|
|
173
|
+
elif isinstance(p_node, exp.Placeholder) and p_node.this is None: # ?
|
|
174
|
+
has_positional_qmark = True
|
|
175
|
+
# Ignores numeric placeholders like $1, :1 for dict validation for now
|
|
176
|
+
|
|
177
|
+
if has_positional_qmark:
|
|
178
|
+
msg = f"Dictionary parameters provided, but found unnamed placeholders ('?') in SQL: {self.sql}"
|
|
179
|
+
raise ParameterStyleMismatchError(msg)
|
|
180
|
+
|
|
181
|
+
if not sqlglot_named_params and parameter_dict:
|
|
182
|
+
msg = f"Dictionary parameters provided, but no named placeholders (e.g., ':name', '$name', '@name') found by sqlglot in SQL: {self.sql}"
|
|
183
|
+
raise ParameterStyleMismatchError(msg)
|
|
134
184
|
|
|
135
|
-
|
|
185
|
+
missing_keys = set(sqlglot_named_params.keys()) - set(parameter_dict.keys())
|
|
186
|
+
if missing_keys:
|
|
187
|
+
msg = f"Named parameters found in SQL by sqlglot but not provided: {missing_keys}. SQL: {self.sql}"
|
|
188
|
+
raise SQLParsingError(msg)
|
|
189
|
+
|
|
190
|
+
def _validate_sequence_params(
|
|
136
191
|
self,
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
192
|
+
all_sqlglot_placeholders: Sequence[exp.Expression],
|
|
193
|
+
params: Union[tuple[Any, ...], list[Any]],
|
|
194
|
+
) -> None:
|
|
195
|
+
sqlglot_named_param_names = [] # For detecting named params
|
|
196
|
+
sqlglot_positional_count = 0 # For counting ?, $1, :1 etc.
|
|
197
|
+
|
|
198
|
+
for p_node in all_sqlglot_placeholders:
|
|
199
|
+
if isinstance(p_node, exp.Parameter) and p_node.name and not p_node.name.isdigit(): # @name, $name
|
|
200
|
+
sqlglot_named_param_names.append(p_node.name)
|
|
201
|
+
elif (
|
|
202
|
+
isinstance(p_node, exp.Placeholder)
|
|
203
|
+
and p_node.this
|
|
204
|
+
and not isinstance(p_node.this, (exp.Identifier, exp.Literal))
|
|
205
|
+
and not str(p_node.this).isdigit()
|
|
206
|
+
): # :name
|
|
207
|
+
sqlglot_named_param_names.append(str(p_node.this))
|
|
208
|
+
elif isinstance(p_node, exp.Placeholder) and p_node.this is None: # ?
|
|
209
|
+
sqlglot_positional_count += 1
|
|
210
|
+
elif isinstance(p_node, exp.Parameter) and ( # noqa: PLR0916
|
|
211
|
+
(p_node.name and p_node.name.isdigit())
|
|
212
|
+
or (
|
|
213
|
+
not p_node.name
|
|
214
|
+
and p_node.this
|
|
215
|
+
and isinstance(p_node.this, (str, exp.Identifier, exp.Literal))
|
|
216
|
+
and str(p_node.this).isdigit()
|
|
217
|
+
)
|
|
218
|
+
):
|
|
219
|
+
# $1, :1 style (parsed as Parameter with name="1" or this="1" or this=Identifier(this="1") or this=Literal(this=1))
|
|
220
|
+
sqlglot_positional_count += 1
|
|
221
|
+
elif (
|
|
222
|
+
isinstance(p_node, exp.Placeholder) and p_node.this and str(p_node.this).isdigit()
|
|
223
|
+
): # :1 style (Placeholder with this="1")
|
|
224
|
+
sqlglot_positional_count += 1
|
|
225
|
+
|
|
226
|
+
if sqlglot_named_param_names:
|
|
227
|
+
msg = f"Sequence parameters provided, but found named placeholders ({', '.join(sorted(set(sqlglot_named_param_names)))}) in SQL: {self.sql}"
|
|
228
|
+
raise ParameterStyleMismatchError(msg)
|
|
140
229
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
230
|
+
actual_count_provided = len(params)
|
|
231
|
+
|
|
232
|
+
if sqlglot_positional_count != actual_count_provided:
|
|
233
|
+
msg = (
|
|
234
|
+
f"Parameter count mismatch. SQL expects {sqlglot_positional_count} (sqlglot) positional "
|
|
235
|
+
f"parameters, but {actual_count_provided} were provided. SQL: {self.sql}"
|
|
236
|
+
)
|
|
237
|
+
raise SQLParsingError(msg)
|
|
238
|
+
|
|
239
|
+
def _validate_scalar_param(self, all_sqlglot_placeholders: Sequence[exp.Expression], param_value: Any) -> None:
|
|
240
|
+
"""Validates a single scalar parameter against parsed SQL parameters."""
|
|
241
|
+
self._validate_sequence_params(
|
|
242
|
+
all_sqlglot_placeholders, (param_value,)
|
|
243
|
+
) # Treat scalar as a single-element sequence
|
|
244
|
+
|
|
245
|
+
def get_expression(self) -> exp.Expression:
|
|
246
|
+
"""Get the parsed SQLglot expression, parsing if necessary.
|
|
144
247
|
|
|
145
248
|
Returns:
|
|
146
|
-
|
|
147
|
-
ready for database driver execution.
|
|
249
|
+
The SQLglot expression.
|
|
148
250
|
"""
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
if self.dialect == "sqlite": # Handles :name natively
|
|
155
|
-
return self.sql, parameter_dict
|
|
156
|
-
|
|
157
|
-
# Add bypass for postgres handled by specific adapters (e.g., asyncpg)
|
|
158
|
-
if self.dialect == "postgres":
|
|
159
|
-
# The adapter (e.g., asyncpg) will handle :name -> $n conversion.
|
|
160
|
-
# SQLStatement just validates parameters against the original SQL here.
|
|
161
|
-
# Perform validation using regex if sqlglot parsing fails, otherwise use sqlglot.
|
|
162
|
-
try:
|
|
163
|
-
expression = self._parse_sql()
|
|
164
|
-
sql_params = list(expression.find_all(exp.Parameter))
|
|
165
|
-
named_sql_params = [p for p in sql_params if p.name]
|
|
166
|
-
unnamed_sql_params = [p for p in sql_params if not p.name]
|
|
167
|
-
|
|
168
|
-
if unnamed_sql_params:
|
|
169
|
-
msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') found by sqlglot for postgres."
|
|
170
|
-
raise ParameterStyleMismatchError(msg)
|
|
171
|
-
|
|
172
|
-
# Validate keys using sqlglot results
|
|
173
|
-
required_keys = {p.name for p in named_sql_params}
|
|
174
|
-
provided_keys = set(parameter_dict.keys())
|
|
175
|
-
missing_keys = required_keys - provided_keys
|
|
176
|
-
if missing_keys:
|
|
177
|
-
msg = (
|
|
178
|
-
f"Named parameters found in SQL (via sqlglot) but not provided: {missing_keys}. SQL: {self.sql}"
|
|
179
|
-
)
|
|
180
|
-
raise SQLParsingError(msg) # noqa: TRY301
|
|
181
|
-
# Allow extra keys
|
|
182
|
-
|
|
183
|
-
except SQLParsingError as e:
|
|
184
|
-
logger.debug("SQLglot parsing failed for postgres dict params, attempting regex validation: %s", e)
|
|
185
|
-
# Regex validation fallback (without conversion)
|
|
186
|
-
postgres_found_params_regex: list[str] = []
|
|
187
|
-
for match in PARAM_REGEX.finditer(self.sql):
|
|
188
|
-
if match.group("dquote") or match.group("squote") or match.group("comment"):
|
|
189
|
-
continue
|
|
190
|
-
if match.group("var_name"):
|
|
191
|
-
var_name = match.group("var_name")
|
|
192
|
-
postgres_found_params_regex.append(var_name)
|
|
193
|
-
if var_name not in parameter_dict:
|
|
194
|
-
msg = f"Named parameter ':{var_name}' found in SQL (via regex) but not provided. SQL: {self.sql}"
|
|
195
|
-
raise SQLParsingError(msg) # noqa: B904
|
|
196
|
-
|
|
197
|
-
if not postgres_found_params_regex and parameter_dict:
|
|
198
|
-
msg = f"Dictionary parameters provided, but no named placeholders (:name) found via regex. SQL: {self.sql}"
|
|
199
|
-
raise ParameterStyleMismatchError(msg) # noqa: B904
|
|
200
|
-
# Allow extra keys with regex check too
|
|
201
|
-
|
|
202
|
-
# Return the *original* SQL and the processed dict for the adapter to handle
|
|
203
|
-
return self.sql, parameter_dict
|
|
204
|
-
|
|
205
|
-
if self.dialect == "duckdb": # Handles $name natively (and :name via driver? Check driver docs)
|
|
206
|
-
# Bypass sqlglot/regex checks. Trust user SQL ($name or ?) + dict for DuckDB driver.
|
|
207
|
-
# We lose :name -> $name conversion *if* sqlglot parsing fails, but avoid errors on valid $name SQL.
|
|
208
|
-
return self.sql, parameter_dict
|
|
209
|
-
# --- End Bypasses ---
|
|
251
|
+
if self._parsed_expression is None:
|
|
252
|
+
self._parse_sql()
|
|
253
|
+
if self._parsed_expression is None: # Still None after parsing attempt
|
|
254
|
+
return exp.Select() # Return an empty SELECT as fallback
|
|
255
|
+
return self._parsed_expression
|
|
210
256
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
sql_params = list(expression.find_all(exp.Parameter))
|
|
214
|
-
named_sql_params = [p for p in sql_params if p.name]
|
|
215
|
-
unnamed_sql_params = [p for p in sql_params if not p.name]
|
|
216
|
-
sqlglot_parsed_ok = True
|
|
217
|
-
logger.debug("SQLglot parsed dict params successfully for: %s", self.sql)
|
|
218
|
-
except SQLParsingError as e:
|
|
219
|
-
logger.debug("SQLglot parsing failed for dict params, attempting regex fallback: %s", e)
|
|
220
|
-
# Proceed using regex fallback below
|
|
257
|
+
def generate_param_name(self, base_name: str) -> str:
|
|
258
|
+
"""Generates a unique parameter name.
|
|
221
259
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') found by sqlglot."
|
|
225
|
-
raise ParameterStyleMismatchError(msg)
|
|
260
|
+
Args:
|
|
261
|
+
base_name: The base name for the parameter.
|
|
226
262
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
if
|
|
232
|
-
|
|
233
|
-
# ... (regex fallback code as implemented previously) ...
|
|
234
|
-
logger.debug("Using regex fallback for dict param processing: %s", self.sql)
|
|
235
|
-
# --- Regex Fallback Logic ---
|
|
236
|
-
regex_processed_sql_parts: list[str] = []
|
|
237
|
-
ordered_params = []
|
|
238
|
-
last_end = 0
|
|
239
|
-
regex_found_params: list[str] = []
|
|
240
|
-
|
|
241
|
-
for match in PARAM_REGEX.finditer(self.sql):
|
|
242
|
-
# Skip matches that are comments or quoted strings
|
|
243
|
-
if match.group("dquote") or match.group("squote") or match.group("comment"):
|
|
244
|
-
continue
|
|
245
|
-
|
|
246
|
-
if match.group("var_name"):
|
|
247
|
-
var_name = match.group("var_name")
|
|
248
|
-
regex_found_params.append(var_name)
|
|
249
|
-
# Get start and end from the match object for the :var_name part
|
|
250
|
-
# The var_name group itself doesn't include the leading :, so adjust start.
|
|
251
|
-
start = match.start("var_name") - 1
|
|
252
|
-
end = match.end("var_name")
|
|
253
|
-
|
|
254
|
-
if var_name not in parameter_dict:
|
|
255
|
-
msg = (
|
|
256
|
-
f"Named parameter ':{var_name}' found in SQL (via regex) but not provided. SQL: {self.sql}"
|
|
257
|
-
)
|
|
258
|
-
raise SQLParsingError(msg)
|
|
259
|
-
|
|
260
|
-
regex_processed_sql_parts.extend((self.sql[last_end:start], self.param_style)) # Use target style
|
|
261
|
-
ordered_params.append(parameter_dict[var_name])
|
|
262
|
-
last_end = end
|
|
263
|
-
|
|
264
|
-
regex_processed_sql_parts.append(self.sql[last_end:])
|
|
265
|
-
|
|
266
|
-
# Validation with regex results
|
|
267
|
-
if not regex_found_params and parameter_dict:
|
|
268
|
-
msg = f"Dictionary parameters provided, but no named placeholders (e.g., :name) found via regex in the SQL query for dialect '{self.dialect}'. SQL: {self.sql}"
|
|
269
|
-
raise ParameterStyleMismatchError(msg)
|
|
270
|
-
|
|
271
|
-
provided_keys = set(parameter_dict.keys())
|
|
272
|
-
missing_keys = set(regex_found_params) - provided_keys # Should be caught above, but double check
|
|
273
|
-
if missing_keys:
|
|
274
|
-
msg = f"Named parameters found in SQL (via regex) but not provided: {missing_keys}. SQL: {self.sql}"
|
|
275
|
-
raise SQLParsingError(msg)
|
|
276
|
-
|
|
277
|
-
extra_keys = provided_keys - set(regex_found_params)
|
|
278
|
-
if extra_keys:
|
|
279
|
-
# Allow extra keys
|
|
280
|
-
pass
|
|
263
|
+
Returns:
|
|
264
|
+
The generated parameter name.
|
|
265
|
+
"""
|
|
266
|
+
self._param_counter += 1
|
|
267
|
+
safe_base_name = "".join(c if c.isalnum() else "_" for c in base_name if c.isalnum() or c == "_")
|
|
268
|
+
return f"param_{safe_base_name}_{self._param_counter}"
|
|
281
269
|
|
|
282
|
-
|
|
270
|
+
def add_condition(self, condition: exp.Condition, params: Optional[dict[str, Any]] = None) -> None:
|
|
271
|
+
"""Adds a condition to the WHERE clause of the query.
|
|
283
272
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
273
|
+
Args:
|
|
274
|
+
condition: The condition to add to the WHERE clause.
|
|
275
|
+
params: The parameters to add to the statement parameters.
|
|
276
|
+
"""
|
|
277
|
+
expression = self.get_expression()
|
|
278
|
+
if not isinstance(expression, (exp.Select, exp.Update, exp.Delete)):
|
|
279
|
+
return # Cannot add WHERE to some expressions
|
|
280
|
+
|
|
281
|
+
# Update the expression
|
|
282
|
+
expression.where(condition, copy=False)
|
|
283
|
+
|
|
284
|
+
# Update the parameters
|
|
285
|
+
if params:
|
|
286
|
+
if self._merged_parameters is None:
|
|
287
|
+
self._merged_parameters = params
|
|
288
|
+
elif isinstance(self._merged_parameters, dict):
|
|
289
|
+
self._merged_parameters.update(params)
|
|
290
|
+
else:
|
|
291
|
+
# Convert to dict if not already
|
|
292
|
+
self._merged_parameters = params
|
|
287
293
|
|
|
288
|
-
#
|
|
289
|
-
|
|
294
|
+
# Update the SQL string
|
|
295
|
+
self.sql = expression.sql(dialect=self.dialect)
|
|
290
296
|
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
return self.sql, ()
|
|
297
|
+
def add_order_by(self, field_name: str, direction: str = "asc") -> None:
|
|
298
|
+
"""Adds an ORDER BY clause.
|
|
294
299
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
300
|
+
Args:
|
|
301
|
+
field_name: The name of the field to order by.
|
|
302
|
+
direction: The direction to order by ("asc" or "desc").
|
|
303
|
+
"""
|
|
304
|
+
expression = self.get_expression()
|
|
305
|
+
if not isinstance(expression, exp.Select):
|
|
306
|
+
return
|
|
298
307
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
msg = f"Named parameters found in SQL (via sqlglot) but not provided: {missing_keys}. SQL: {self.sql}"
|
|
302
|
-
raise SQLParsingError(msg)
|
|
308
|
+
expression.order_by(exp.Ordered(this=exp.column(field_name), desc=direction.lower() == "desc"), copy=False)
|
|
309
|
+
self.sql = expression.sql(dialect=self.dialect)
|
|
303
310
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
pass # Allow extra keys
|
|
307
|
-
|
|
308
|
-
# Note: DuckDB handled by bypass above if sqlglot fails.
|
|
309
|
-
# This block handles successful sqlglot parse for other dialects.
|
|
310
|
-
# We don't need the specific DuckDB $name conversion here anymore,
|
|
311
|
-
# as the bypass handles the native $name case.
|
|
312
|
-
# The general logic converts :name -> self.param_style for dialects like postgres.
|
|
313
|
-
# if self.dialect == "duckdb": ... (Removed specific block here)
|
|
314
|
-
|
|
315
|
-
# For other dialects requiring positional conversion (using sqlglot param info):
|
|
316
|
-
sqlglot_processed_parts: list[str] = []
|
|
317
|
-
ordered_params = []
|
|
318
|
-
last_end = 0
|
|
319
|
-
for param in active_named_params: # Use active_named_params
|
|
320
|
-
start = param.this.this.start
|
|
321
|
-
end = param.this.this.end
|
|
322
|
-
sqlglot_processed_parts.extend((self.sql[last_end:start], self.param_style))
|
|
323
|
-
ordered_params.append(parameter_dict[param.name])
|
|
324
|
-
last_end = end
|
|
325
|
-
sqlglot_processed_parts.append(self.sql[last_end:])
|
|
326
|
-
return "".join(sqlglot_processed_parts), tuple(ordered_params)
|
|
327
|
-
|
|
328
|
-
def _process_sequence_params(
|
|
329
|
-
self, params: Union[tuple[Any, ...], list[Any]]
|
|
330
|
-
) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
|
|
331
|
-
"""Processes a sequence of parameters.
|
|
311
|
+
def add_limit(self, limit_val: int, param_name: Optional[str] = None) -> None:
|
|
312
|
+
"""Adds a LIMIT clause.
|
|
332
313
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
314
|
+
Args:
|
|
315
|
+
limit_val: The value for the LIMIT clause.
|
|
316
|
+
param_name: Optional name for the parameter.
|
|
317
|
+
"""
|
|
318
|
+
expression = self.get_expression()
|
|
319
|
+
if not isinstance(expression, exp.Select):
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
if param_name:
|
|
323
|
+
expression.limit(exp.Placeholder(this=param_name), copy=False)
|
|
324
|
+
if self._merged_parameters is None:
|
|
325
|
+
self._merged_parameters = {param_name: limit_val}
|
|
326
|
+
elif isinstance(self._merged_parameters, dict):
|
|
327
|
+
self._merged_parameters[param_name] = limit_val
|
|
328
|
+
else:
|
|
329
|
+
expression.limit(exp.Literal.number(limit_val), copy=False)
|
|
330
|
+
|
|
331
|
+
self.sql = expression.sql(dialect=self.dialect)
|
|
332
|
+
|
|
333
|
+
def add_offset(self, offset_val: int, param_name: Optional[str] = None) -> None:
|
|
334
|
+
"""Adds an OFFSET clause.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
offset_val: The value for the OFFSET clause.
|
|
338
|
+
param_name: Optional name for the parameter.
|
|
336
339
|
"""
|
|
337
|
-
|
|
340
|
+
expression = self.get_expression()
|
|
341
|
+
if not isinstance(expression, exp.Select):
|
|
342
|
+
return
|
|
338
343
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
344
|
+
if param_name:
|
|
345
|
+
expression.offset(exp.Placeholder(this=param_name), copy=False)
|
|
346
|
+
if self._merged_parameters is None:
|
|
347
|
+
self._merged_parameters = {param_name: offset_val}
|
|
348
|
+
elif isinstance(self._merged_parameters, dict):
|
|
349
|
+
self._merged_parameters[param_name] = offset_val
|
|
350
|
+
else:
|
|
351
|
+
expression.offset(exp.Literal.number(offset_val), copy=False)
|
|
352
|
+
|
|
353
|
+
self.sql = expression.sql(dialect=self.dialect)
|
|
354
|
+
|
|
355
|
+
def apply_filter(self, filter_obj: "StatementFilter") -> "SQLStatement":
|
|
356
|
+
"""Apply a statement filter to this statement.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
filter_obj: The filter to apply.
|
|
343
360
|
|
|
344
361
|
Returns:
|
|
345
|
-
|
|
346
|
-
ready for database driver execution.
|
|
362
|
+
The modified statement.
|
|
347
363
|
"""
|
|
348
|
-
|
|
364
|
+
from sqlspec.filters import apply_filter
|
|
365
|
+
|
|
366
|
+
return apply_filter(self, filter_obj)
|
|
367
|
+
|
|
368
|
+
def to_sql(self, dialect: Optional[str] = None) -> str:
|
|
369
|
+
"""Generate SQL string using the specified dialect.
|
|
349
370
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
"""Get the parameter style based on the dialect.
|
|
371
|
+
Args:
|
|
372
|
+
dialect: SQL dialect to use for SQL generation. If None, uses the statement's dialect.
|
|
353
373
|
|
|
354
374
|
Returns:
|
|
355
|
-
|
|
375
|
+
SQL string in the specified dialect.
|
|
356
376
|
"""
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
# Map dialects to parameter styles for placeholder replacement
|
|
360
|
-
# Note: Used when converting named params (:name) for dialects needing positional.
|
|
361
|
-
# Dialects supporting named params natively (SQLite, DuckDB) are handled via bypasses.
|
|
362
|
-
dialect_to_param_style = {
|
|
363
|
-
"postgres": "%s",
|
|
364
|
-
"mysql": "%s",
|
|
365
|
-
"oracle": ":1",
|
|
366
|
-
"mssql": "?",
|
|
367
|
-
"bigquery": "?",
|
|
368
|
-
"snowflake": "?",
|
|
369
|
-
"cockroach": "%s",
|
|
370
|
-
"db2": "?",
|
|
371
|
-
}
|
|
372
|
-
# Default to '?' for unknown/unhandled dialects or when dialect=None is forced
|
|
373
|
-
return dialect_to_param_style.get(dialect, "?")
|
|
377
|
+
expression = self.get_expression()
|
|
378
|
+
return expression.sql(dialect=dialect or self.dialect)
|