sqlspec 0.10.1__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

sqlspec/statement.py CHANGED
@@ -1,9 +1,9 @@
1
1
  # ruff: noqa: RUF100, PLR6301, PLR0912, PLR0915, C901, PLR0911, PLR0914, N806
2
2
  import logging
3
- import re
4
- from dataclasses import dataclass
5
- from functools import cached_property
3
+ from collections.abc import Sequence
4
+ from dataclasses import dataclass, field
6
5
  from typing import (
6
+ TYPE_CHECKING,
7
7
  Any,
8
8
  Optional,
9
9
  Union,
@@ -15,24 +15,13 @@ from sqlglot import exp
15
15
  from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError
16
16
  from sqlspec.typing import StatementParameterType
17
17
 
18
+ if TYPE_CHECKING:
19
+ from sqlspec.filters import StatementFilter
20
+
18
21
  __all__ = ("SQLStatement",)
19
22
 
20
23
  logger = logging.getLogger("sqlspec")
21
24
 
22
- # Regex to find :param style placeholders, skipping those inside quotes or SQL comments
23
- # Adapted from previous version in psycopg adapter
24
- PARAM_REGEX = re.compile(
25
- r"""(?<![:\w]) # Negative lookbehind to avoid matching things like ::type or \:escaped
26
- (?:
27
- (?P<dquote>"(?:[^"]|"")*") | # Double-quoted strings (support SQL standard escaping "")
28
- (?P<squote>'(?:[^']|'')*') | # Single-quoted strings (support SQL standard escaping '')
29
- (?P<comment>--.*?\n|\/\*.*?\*\/) | # SQL comments (single line or multi-line)
30
- : (?P<var_name>[a-zA-Z_][a-zA-Z0-9_]*) # :var_name identifier
31
- )
32
- """,
33
- re.VERBOSE | re.DOTALL,
34
- )
35
-
36
25
 
37
26
  @dataclass()
38
27
  class SQLStatement:
@@ -42,16 +31,18 @@ class SQLStatement:
42
31
  a clean interface for parameter binding and SQL statement formatting.
43
32
  """
44
33
 
45
- dialect: str
46
- """The SQL dialect to use for parsing (e.g., 'postgres', 'mysql'). Defaults to 'postgres' if None."""
47
34
  sql: str
48
35
  """The raw SQL statement."""
49
36
  parameters: Optional[StatementParameterType] = None
50
37
  """The parameters for the SQL statement."""
51
38
  kwargs: Optional[dict[str, Any]] = None
52
39
  """Keyword arguments passed for parameter binding."""
40
+ dialect: Optional[str] = None
41
+ """SQL dialect to use for parsing. If not provided, sqlglot will try to auto-detect."""
53
42
 
54
- _merged_parameters: Optional[Union[StatementParameterType, dict[str, Any]]] = None
43
+ _merged_parameters: Optional[Union[StatementParameterType, dict[str, Any]]] = field(default=None, init=False)
44
+ _parsed_expression: Optional[exp.Expression] = field(default=None, init=False)
45
+ _param_counter: int = field(default=0, init=False)
55
46
 
56
47
  def __post_init__(self) -> None:
57
48
  """Merge parameters and kwargs after initialization."""
@@ -70,48 +61,72 @@ class SQLStatement:
70
61
 
71
62
  self._merged_parameters = merged_params
72
63
 
73
- def process(self) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
64
+ def process(
65
+ self,
66
+ ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]], Optional[exp.Expression]]":
74
67
  """Process the SQL statement and merged parameters for execution.
75
68
 
69
+ This method validates the parameters against the SQL statement using sqlglot
70
+ parsing but returns the *original* SQL string, the merged parameters,
71
+ and the parsed sqlglot expression if successful.
72
+ The actual formatting of SQL placeholders and parameter structures for the
73
+ DBAPI driver is delegated to the specific adapter.
74
+
76
75
  Returns:
77
- A tuple containing the processed SQL string and the processed parameters
78
- ready for database driver execution.
76
+ A tuple containing the *original* SQL string, the merged/validated
77
+ parameters (dict, tuple, list, or None), and the parsed sqlglot expression
78
+ (or None if parsing failed).
79
79
 
80
80
  Raises:
81
- SQLParsingError: If the SQL statement contains parameter placeholders, but no parameters were provided.
82
-
83
- Returns:
84
- A tuple containing the processed SQL string and the processed parameters
85
- ready for database driver execution.
81
+ SQLParsingError: If the SQL statement contains parameter placeholders
82
+ but no parameters were provided, or if parsing fails unexpectedly.
86
83
  """
84
+ # Parse the SQL to find expected parameters
85
+ try:
86
+ expression = self._parse_sql()
87
+ # Find all parameter expressions (:name, ?, @name, $1, etc.)
88
+ # These are nodes that sqlglot considers as bind parameters.
89
+ all_sqlglot_placeholders = list(expression.find_all(exp.Placeholder, exp.Parameter))
90
+ except SQLParsingError as e:
91
+ logger.debug(
92
+ "SQL parsing failed during validation: %s. Returning original SQL and parameters for adapter.", e
93
+ )
94
+ self._parsed_expression = None
95
+ return self.sql, self._merged_parameters, None
96
+
87
97
  if self._merged_parameters is None:
88
- # Validate that the SQL doesn't expect parameters if none were provided
89
- # Parse ONLY if we need to validate
90
- try: # Add try/except in case parsing fails even here
91
- expression = self._parse_sql()
92
- except SQLParsingError:
93
- # If parsing fails, we can't validate, but maybe that's okay if no params were passed?
94
- # Log a warning? For now, let the original error propagate if needed.
95
- # Or, maybe assume it's okay if _merged_parameters is None?
96
- # Let's re-raise for now, as unparsable SQL is usually bad.
97
- logger.warning("SQL statement is unparsable: %s", self.sql)
98
- return self.sql, None
99
- if list(expression.find_all(exp.Parameter)):
100
- msg = "SQL statement contains parameter placeholders, but no parameters were provided."
98
+ # If no parameters were provided, but the parsed SQL expects them, raise an error.
99
+ if all_sqlglot_placeholders:
100
+ placeholder_types_desc = []
101
+ for p_node in all_sqlglot_placeholders:
102
+ if isinstance(p_node, exp.Parameter) and p_node.name:
103
+ placeholder_types_desc.append(f"named (e.g., :{p_node.name}, @{p_node.name})")
104
+ elif (
105
+ isinstance(p_node, exp.Placeholder)
106
+ and p_node.this
107
+ and not isinstance(p_node.this, (exp.Identifier, exp.Literal))
108
+ and not str(p_node.this).isdigit()
109
+ ):
110
+ placeholder_types_desc.append(f"named (e.g., :{p_node.this})")
111
+ elif isinstance(p_node, exp.Parameter) and p_node.name and p_node.name.isdigit():
112
+ placeholder_types_desc.append("positional (e.g., $1, :1)")
113
+ elif isinstance(p_node, exp.Placeholder) and p_node.this is None:
114
+ placeholder_types_desc.append("positional (?)")
115
+ desc_str = ", ".join(sorted(set(placeholder_types_desc))) or "unknown"
116
+ msg = f"SQL statement contains {desc_str} parameter placeholders, but no parameters were provided. SQL: {self.sql}"
101
117
  raise SQLParsingError(msg)
102
- return self.sql, None
118
+ return self.sql, None, self._parsed_expression
103
119
 
120
+ # Validate provided parameters against parsed SQL parameters
104
121
  if isinstance(self._merged_parameters, dict):
105
- # Pass only the dict, parsing happens inside
106
- return self._process_dict_params(self._merged_parameters)
122
+ self._validate_dict_params(all_sqlglot_placeholders, self._merged_parameters)
123
+ elif isinstance(self._merged_parameters, (tuple, list)):
124
+ self._validate_sequence_params(all_sqlglot_placeholders, self._merged_parameters)
125
+ else: # Scalar parameter
126
+ self._validate_scalar_param(all_sqlglot_placeholders, self._merged_parameters)
107
127
 
108
- if isinstance(self._merged_parameters, (tuple, list)):
109
- # Pass only the sequence, parsing happens inside if needed for validation
110
- return self._process_sequence_params(self._merged_parameters)
111
-
112
- # Assume it's a single scalar value otherwise
113
- # Pass only the value, parsing happens inside for validation
114
- return self._process_scalar_param(self._merged_parameters)
128
+ # Return the original SQL and the merged parameters for the adapter to process
129
+ return self.sql, self._merged_parameters, self._parsed_expression
115
130
 
116
131
  def _parse_sql(self) -> exp.Expression:
117
132
  """Parse the SQL using sqlglot.
@@ -122,252 +137,242 @@ class SQLStatement:
122
137
  Returns:
123
138
  The parsed SQL expression.
124
139
  """
125
- parse_dialect = self.dialect or "postgres"
126
140
  try:
127
- read_dialect = parse_dialect or None
128
- return sqlglot.parse_one(self.sql, read=read_dialect)
141
+ if not self.sql.strip():
142
+ self._parsed_expression = exp.Select()
143
+ return self._parsed_expression
144
+ # Use the provided dialect if available, otherwise sqlglot will try to auto-detect
145
+ self._parsed_expression = sqlglot.parse_one(self.sql, dialect=self.dialect)
146
+ if self._parsed_expression is None:
147
+ self._parsed_expression = exp.Select() # type: ignore[unreachable]
129
148
  except Exception as e:
130
- # Ensure the original sqlglot error message is included
131
- error_detail = str(e)
132
- msg = f"Failed to parse SQL with dialect '{parse_dialect or 'auto-detected'}': {error_detail}\nSQL: {self.sql}"
149
+ msg = f"Failed to parse SQL for validation: {e!s}\nSQL: {self.sql}"
150
+ self._parsed_expression = None
133
151
  raise SQLParsingError(msg) from e
152
+ else:
153
+ return self._parsed_expression
154
+
155
+ def _validate_dict_params(
156
+ self, all_sqlglot_placeholders: Sequence[exp.Expression], parameter_dict: dict[str, Any]
157
+ ) -> None:
158
+ sqlglot_named_params: dict[str, Union[exp.Parameter, exp.Placeholder]] = {}
159
+ has_positional_qmark = False
160
+
161
+ for p_node in all_sqlglot_placeholders:
162
+ if (
163
+ isinstance(p_node, exp.Parameter) and p_node.name and not p_node.name.isdigit()
164
+ ): # @name, $name (non-numeric)
165
+ sqlglot_named_params[p_node.name] = p_node
166
+ elif (
167
+ isinstance(p_node, exp.Placeholder)
168
+ and p_node.this
169
+ and not isinstance(p_node.this, (exp.Identifier, exp.Literal))
170
+ and not str(p_node.this).isdigit()
171
+ ): # :name
172
+ sqlglot_named_params[str(p_node.this)] = p_node
173
+ elif isinstance(p_node, exp.Placeholder) and p_node.this is None: # ?
174
+ has_positional_qmark = True
175
+ # Ignores numeric placeholders like $1, :1 for dict validation for now
176
+
177
+ if has_positional_qmark:
178
+ msg = f"Dictionary parameters provided, but found unnamed placeholders ('?') in SQL: {self.sql}"
179
+ raise ParameterStyleMismatchError(msg)
180
+
181
+ if not sqlglot_named_params and parameter_dict:
182
+ msg = f"Dictionary parameters provided, but no named placeholders (e.g., ':name', '$name', '@name') found by sqlglot in SQL: {self.sql}"
183
+ raise ParameterStyleMismatchError(msg)
134
184
 
135
- def _process_dict_params(
185
+ missing_keys = set(sqlglot_named_params.keys()) - set(parameter_dict.keys())
186
+ if missing_keys:
187
+ msg = f"Named parameters found in SQL by sqlglot but not provided: {missing_keys}. SQL: {self.sql}"
188
+ raise SQLParsingError(msg)
189
+
190
+ def _validate_sequence_params(
136
191
  self,
137
- parameter_dict: dict[str, Any],
138
- ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
139
- """Processes dictionary parameters based on dialect capabilities.
192
+ all_sqlglot_placeholders: Sequence[exp.Expression],
193
+ params: Union[tuple[Any, ...], list[Any]],
194
+ ) -> None:
195
+ sqlglot_named_param_names = [] # For detecting named params
196
+ sqlglot_positional_count = 0 # For counting ?, $1, :1 etc.
197
+
198
+ for p_node in all_sqlglot_placeholders:
199
+ if isinstance(p_node, exp.Parameter) and p_node.name and not p_node.name.isdigit(): # @name, $name
200
+ sqlglot_named_param_names.append(p_node.name)
201
+ elif (
202
+ isinstance(p_node, exp.Placeholder)
203
+ and p_node.this
204
+ and not isinstance(p_node.this, (exp.Identifier, exp.Literal))
205
+ and not str(p_node.this).isdigit()
206
+ ): # :name
207
+ sqlglot_named_param_names.append(str(p_node.this))
208
+ elif isinstance(p_node, exp.Placeholder) and p_node.this is None: # ?
209
+ sqlglot_positional_count += 1
210
+ elif isinstance(p_node, exp.Parameter) and ( # noqa: PLR0916
211
+ (p_node.name and p_node.name.isdigit())
212
+ or (
213
+ not p_node.name
214
+ and p_node.this
215
+ and isinstance(p_node.this, (str, exp.Identifier, exp.Literal))
216
+ and str(p_node.this).isdigit()
217
+ )
218
+ ):
219
+ # $1, :1 style (parsed as Parameter with name="1" or this="1" or this=Identifier(this="1") or this=Literal(this=1))
220
+ sqlglot_positional_count += 1
221
+ elif (
222
+ isinstance(p_node, exp.Placeholder) and p_node.this and str(p_node.this).isdigit()
223
+ ): # :1 style (Placeholder with this="1")
224
+ sqlglot_positional_count += 1
225
+
226
+ if sqlglot_named_param_names:
227
+ msg = f"Sequence parameters provided, but found named placeholders ({', '.join(sorted(set(sqlglot_named_param_names)))}) in SQL: {self.sql}"
228
+ raise ParameterStyleMismatchError(msg)
140
229
 
141
- Raises:
142
- ParameterStyleMismatchError: If the SQL statement contains unnamed placeholders (e.g., '?') in the SQL query.
143
- SQLParsingError: If the SQL statement contains named parameters, but no parameters were provided.
230
+ actual_count_provided = len(params)
231
+
232
+ if sqlglot_positional_count != actual_count_provided:
233
+ msg = (
234
+ f"Parameter count mismatch. SQL expects {sqlglot_positional_count} (sqlglot) positional "
235
+ f"parameters, but {actual_count_provided} were provided. SQL: {self.sql}"
236
+ )
237
+ raise SQLParsingError(msg)
238
+
239
+ def _validate_scalar_param(self, all_sqlglot_placeholders: Sequence[exp.Expression], param_value: Any) -> None:
240
+ """Validates a single scalar parameter against parsed SQL parameters."""
241
+ self._validate_sequence_params(
242
+ all_sqlglot_placeholders, (param_value,)
243
+ ) # Treat scalar as a single-element sequence
244
+
245
+ def get_expression(self) -> exp.Expression:
246
+ """Get the parsed SQLglot expression, parsing if necessary.
144
247
 
145
248
  Returns:
146
- A tuple containing the processed SQL string and the processed parameters
147
- ready for database driver execution.
249
+ The SQLglot expression.
148
250
  """
149
- # Attempt to parse with sqlglot first (for other dialects like postgres, mysql)
150
- named_sql_params: Optional[list[exp.Parameter]] = None
151
- unnamed_sql_params: Optional[list[exp.Parameter]] = None
152
- sqlglot_parsed_ok = False
153
- # --- Dialect-Specific Bypasses for Native Handling ---
154
- if self.dialect == "sqlite": # Handles :name natively
155
- return self.sql, parameter_dict
156
-
157
- # Add bypass for postgres handled by specific adapters (e.g., asyncpg)
158
- if self.dialect == "postgres":
159
- # The adapter (e.g., asyncpg) will handle :name -> $n conversion.
160
- # SQLStatement just validates parameters against the original SQL here.
161
- # Perform validation using regex if sqlglot parsing fails, otherwise use sqlglot.
162
- try:
163
- expression = self._parse_sql()
164
- sql_params = list(expression.find_all(exp.Parameter))
165
- named_sql_params = [p for p in sql_params if p.name]
166
- unnamed_sql_params = [p for p in sql_params if not p.name]
167
-
168
- if unnamed_sql_params:
169
- msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') found by sqlglot for postgres."
170
- raise ParameterStyleMismatchError(msg)
171
-
172
- # Validate keys using sqlglot results
173
- required_keys = {p.name for p in named_sql_params}
174
- provided_keys = set(parameter_dict.keys())
175
- missing_keys = required_keys - provided_keys
176
- if missing_keys:
177
- msg = (
178
- f"Named parameters found in SQL (via sqlglot) but not provided: {missing_keys}. SQL: {self.sql}"
179
- )
180
- raise SQLParsingError(msg) # noqa: TRY301
181
- # Allow extra keys
182
-
183
- except SQLParsingError as e:
184
- logger.debug("SQLglot parsing failed for postgres dict params, attempting regex validation: %s", e)
185
- # Regex validation fallback (without conversion)
186
- postgres_found_params_regex: list[str] = []
187
- for match in PARAM_REGEX.finditer(self.sql):
188
- if match.group("dquote") or match.group("squote") or match.group("comment"):
189
- continue
190
- if match.group("var_name"):
191
- var_name = match.group("var_name")
192
- postgres_found_params_regex.append(var_name)
193
- if var_name not in parameter_dict:
194
- msg = f"Named parameter ':{var_name}' found in SQL (via regex) but not provided. SQL: {self.sql}"
195
- raise SQLParsingError(msg) # noqa: B904
196
-
197
- if not postgres_found_params_regex and parameter_dict:
198
- msg = f"Dictionary parameters provided, but no named placeholders (:name) found via regex. SQL: {self.sql}"
199
- raise ParameterStyleMismatchError(msg) # noqa: B904
200
- # Allow extra keys with regex check too
201
-
202
- # Return the *original* SQL and the processed dict for the adapter to handle
203
- return self.sql, parameter_dict
204
-
205
- if self.dialect == "duckdb": # Handles $name natively (and :name via driver? Check driver docs)
206
- # Bypass sqlglot/regex checks. Trust user SQL ($name or ?) + dict for DuckDB driver.
207
- # We lose :name -> $name conversion *if* sqlglot parsing fails, but avoid errors on valid $name SQL.
208
- return self.sql, parameter_dict
209
- # --- End Bypasses ---
251
+ if self._parsed_expression is None:
252
+ self._parse_sql()
253
+ if self._parsed_expression is None: # Still None after parsing attempt
254
+ return exp.Select() # Return an empty SELECT as fallback
255
+ return self._parsed_expression
210
256
 
211
- try:
212
- expression = self._parse_sql()
213
- sql_params = list(expression.find_all(exp.Parameter))
214
- named_sql_params = [p for p in sql_params if p.name]
215
- unnamed_sql_params = [p for p in sql_params if not p.name]
216
- sqlglot_parsed_ok = True
217
- logger.debug("SQLglot parsed dict params successfully for: %s", self.sql)
218
- except SQLParsingError as e:
219
- logger.debug("SQLglot parsing failed for dict params, attempting regex fallback: %s", e)
220
- # Proceed using regex fallback below
257
+ def generate_param_name(self, base_name: str) -> str:
258
+ """Generates a unique parameter name.
221
259
 
222
- # Check for unnamed placeholders if parsing worked
223
- if sqlglot_parsed_ok and unnamed_sql_params:
224
- msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') found by sqlglot."
225
- raise ParameterStyleMismatchError(msg)
260
+ Args:
261
+ base_name: The base name for the parameter.
226
262
 
227
- # Determine if we need to use regex fallback
228
- # Use fallback if: parsing failed OR (parsing worked BUT found no named params when a dict was provided)
229
- use_regex_fallback = not sqlglot_parsed_ok or (not named_sql_params and parameter_dict)
230
-
231
- if use_regex_fallback:
232
- # Regex fallback logic for :name -> self.param_style conversion
233
- # ... (regex fallback code as implemented previously) ...
234
- logger.debug("Using regex fallback for dict param processing: %s", self.sql)
235
- # --- Regex Fallback Logic ---
236
- regex_processed_sql_parts: list[str] = []
237
- ordered_params = []
238
- last_end = 0
239
- regex_found_params: list[str] = []
240
-
241
- for match in PARAM_REGEX.finditer(self.sql):
242
- # Skip matches that are comments or quoted strings
243
- if match.group("dquote") or match.group("squote") or match.group("comment"):
244
- continue
245
-
246
- if match.group("var_name"):
247
- var_name = match.group("var_name")
248
- regex_found_params.append(var_name)
249
- # Get start and end from the match object for the :var_name part
250
- # The var_name group itself doesn't include the leading :, so adjust start.
251
- start = match.start("var_name") - 1
252
- end = match.end("var_name")
253
-
254
- if var_name not in parameter_dict:
255
- msg = (
256
- f"Named parameter ':{var_name}' found in SQL (via regex) but not provided. SQL: {self.sql}"
257
- )
258
- raise SQLParsingError(msg)
259
-
260
- regex_processed_sql_parts.extend((self.sql[last_end:start], self.param_style)) # Use target style
261
- ordered_params.append(parameter_dict[var_name])
262
- last_end = end
263
-
264
- regex_processed_sql_parts.append(self.sql[last_end:])
265
-
266
- # Validation with regex results
267
- if not regex_found_params and parameter_dict:
268
- msg = f"Dictionary parameters provided, but no named placeholders (e.g., :name) found via regex in the SQL query for dialect '{self.dialect}'. SQL: {self.sql}"
269
- raise ParameterStyleMismatchError(msg)
270
-
271
- provided_keys = set(parameter_dict.keys())
272
- missing_keys = set(regex_found_params) - provided_keys # Should be caught above, but double check
273
- if missing_keys:
274
- msg = f"Named parameters found in SQL (via regex) but not provided: {missing_keys}. SQL: {self.sql}"
275
- raise SQLParsingError(msg)
276
-
277
- extra_keys = provided_keys - set(regex_found_params)
278
- if extra_keys:
279
- # Allow extra keys
280
- pass
263
+ Returns:
264
+ The generated parameter name.
265
+ """
266
+ self._param_counter += 1
267
+ safe_base_name = "".join(c if c.isalnum() else "_" for c in base_name if c.isalnum() or c == "_")
268
+ return f"param_{safe_base_name}_{self._param_counter}"
281
269
 
282
- return "".join(regex_processed_sql_parts), tuple(ordered_params)
270
+ def add_condition(self, condition: exp.Condition, params: Optional[dict[str, Any]] = None) -> None:
271
+ """Adds a condition to the WHERE clause of the query.
283
272
 
284
- # Sqlglot Logic (if parsing worked and found params)
285
- # ... (sqlglot logic as implemented previously, including :name -> %s conversion) ...
286
- logger.debug("Using sqlglot results for dict param processing: %s", self.sql)
273
+ Args:
274
+ condition: The condition to add to the WHERE clause.
275
+ params: The parameters to add to the statement parameters.
276
+ """
277
+ expression = self.get_expression()
278
+ if not isinstance(expression, (exp.Select, exp.Update, exp.Delete)):
279
+ return # Cannot add WHERE to some expressions
280
+
281
+ # Update the expression
282
+ expression.where(condition, copy=False)
283
+
284
+ # Update the parameters
285
+ if params:
286
+ if self._merged_parameters is None:
287
+ self._merged_parameters = params
288
+ elif isinstance(self._merged_parameters, dict):
289
+ self._merged_parameters.update(params)
290
+ else:
291
+ # Convert to dict if not already
292
+ self._merged_parameters = params
287
293
 
288
- # Ensure named_sql_params is iterable, default to empty list if None (shouldn't happen ideally)
289
- active_named_params = named_sql_params or []
294
+ # Update the SQL string
295
+ self.sql = expression.sql(dialect=self.dialect)
290
296
 
291
- if not active_named_params and not parameter_dict:
292
- # No SQL params found by sqlglot, no provided params dict -> OK
293
- return self.sql, ()
297
+ def add_order_by(self, field_name: str, direction: str = "asc") -> None:
298
+ """Adds an ORDER BY clause.
294
299
 
295
- # Validation with sqlglot results
296
- required_keys = {p.name for p in active_named_params} # Use active_named_params
297
- provided_keys = set(parameter_dict.keys())
300
+ Args:
301
+ field_name: The name of the field to order by.
302
+ direction: The direction to order by ("asc" or "desc").
303
+ """
304
+ expression = self.get_expression()
305
+ if not isinstance(expression, exp.Select):
306
+ return
298
307
 
299
- missing_keys = required_keys - provided_keys
300
- if missing_keys:
301
- msg = f"Named parameters found in SQL (via sqlglot) but not provided: {missing_keys}. SQL: {self.sql}"
302
- raise SQLParsingError(msg)
308
+ expression.order_by(exp.Ordered(this=exp.column(field_name), desc=direction.lower() == "desc"), copy=False)
309
+ self.sql = expression.sql(dialect=self.dialect)
303
310
 
304
- extra_keys = provided_keys - required_keys
305
- if extra_keys:
306
- pass # Allow extra keys
307
-
308
- # Note: DuckDB handled by bypass above if sqlglot fails.
309
- # This block handles successful sqlglot parse for other dialects.
310
- # We don't need the specific DuckDB $name conversion here anymore,
311
- # as the bypass handles the native $name case.
312
- # The general logic converts :name -> self.param_style for dialects like postgres.
313
- # if self.dialect == "duckdb": ... (Removed specific block here)
314
-
315
- # For other dialects requiring positional conversion (using sqlglot param info):
316
- sqlglot_processed_parts: list[str] = []
317
- ordered_params = []
318
- last_end = 0
319
- for param in active_named_params: # Use active_named_params
320
- start = param.this.this.start
321
- end = param.this.this.end
322
- sqlglot_processed_parts.extend((self.sql[last_end:start], self.param_style))
323
- ordered_params.append(parameter_dict[param.name])
324
- last_end = end
325
- sqlglot_processed_parts.append(self.sql[last_end:])
326
- return "".join(sqlglot_processed_parts), tuple(ordered_params)
327
-
328
- def _process_sequence_params(
329
- self, params: Union[tuple[Any, ...], list[Any]]
330
- ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
331
- """Processes a sequence of parameters.
311
+ def add_limit(self, limit_val: int, param_name: Optional[str] = None) -> None:
312
+ """Adds a LIMIT clause.
332
313
 
333
- Returns:
334
- A tuple containing the processed SQL string and the processed parameters
335
- ready for database driver execution.
314
+ Args:
315
+ limit_val: The value for the LIMIT clause.
316
+ param_name: Optional name for the parameter.
317
+ """
318
+ expression = self.get_expression()
319
+ if not isinstance(expression, exp.Select):
320
+ return
321
+
322
+ if param_name:
323
+ expression.limit(exp.Placeholder(this=param_name), copy=False)
324
+ if self._merged_parameters is None:
325
+ self._merged_parameters = {param_name: limit_val}
326
+ elif isinstance(self._merged_parameters, dict):
327
+ self._merged_parameters[param_name] = limit_val
328
+ else:
329
+ expression.limit(exp.Literal.number(limit_val), copy=False)
330
+
331
+ self.sql = expression.sql(dialect=self.dialect)
332
+
333
+ def add_offset(self, offset_val: int, param_name: Optional[str] = None) -> None:
334
+ """Adds an OFFSET clause.
335
+
336
+ Args:
337
+ offset_val: The value for the OFFSET clause.
338
+ param_name: Optional name for the parameter.
336
339
  """
337
- return self.sql, params
340
+ expression = self.get_expression()
341
+ if not isinstance(expression, exp.Select):
342
+ return
338
343
 
339
- def _process_scalar_param(
340
- self, param_value: Any
341
- ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]":
342
- """Processes a single scalar parameter value.
344
+ if param_name:
345
+ expression.offset(exp.Placeholder(this=param_name), copy=False)
346
+ if self._merged_parameters is None:
347
+ self._merged_parameters = {param_name: offset_val}
348
+ elif isinstance(self._merged_parameters, dict):
349
+ self._merged_parameters[param_name] = offset_val
350
+ else:
351
+ expression.offset(exp.Literal.number(offset_val), copy=False)
352
+
353
+ self.sql = expression.sql(dialect=self.dialect)
354
+
355
+ def apply_filter(self, filter_obj: "StatementFilter") -> "SQLStatement":
356
+ """Apply a statement filter to this statement.
357
+
358
+ Args:
359
+ filter_obj: The filter to apply.
343
360
 
344
361
  Returns:
345
- A tuple containing the processed SQL string and the processed parameters
346
- ready for database driver execution.
362
+ The modified statement.
347
363
  """
348
- return self.sql, (param_value,)
364
+ from sqlspec.filters import apply_filter
365
+
366
+ return apply_filter(self, filter_obj)
367
+
368
+ def to_sql(self, dialect: Optional[str] = None) -> str:
369
+ """Generate SQL string using the specified dialect.
349
370
 
350
- @cached_property
351
- def param_style(self) -> str:
352
- """Get the parameter style based on the dialect.
371
+ Args:
372
+ dialect: SQL dialect to use for SQL generation. If None, uses the statement's dialect.
353
373
 
354
374
  Returns:
355
- The parameter style placeholder for the dialect.
375
+ SQL string in the specified dialect.
356
376
  """
357
- dialect = self.dialect
358
-
359
- # Map dialects to parameter styles for placeholder replacement
360
- # Note: Used when converting named params (:name) for dialects needing positional.
361
- # Dialects supporting named params natively (SQLite, DuckDB) are handled via bypasses.
362
- dialect_to_param_style = {
363
- "postgres": "%s",
364
- "mysql": "%s",
365
- "oracle": ":1",
366
- "mssql": "?",
367
- "bigquery": "?",
368
- "snowflake": "?",
369
- "cockroach": "%s",
370
- "db2": "?",
371
- }
372
- # Default to '?' for unknown/unhandled dialects or when dialect=None is forced
373
- return dialect_to_param_style.get(dialect, "?")
377
+ expression = self.get_expression()
378
+ return expression.sql(dialect=dialect or self.dialect)