sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -644
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -462
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +217 -451
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +418 -498
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +592 -634
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +393 -436
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +549 -942
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -550
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +732 -733
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +243 -426
- sqlspec/base.py +220 -825
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -330
- sqlspec/mixins.py +0 -306
- sqlspec/statement.py +0 -378
- sqlspec-0.11.0.dist-info/RECORD +0 -69
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/statement/sql.py
ADDED
|
@@ -0,0 +1,1198 @@
|
|
|
1
|
+
"""SQL statement handling with centralized parameter management."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field, replace
|
|
4
|
+
from typing import Any, Optional, Union
|
|
5
|
+
|
|
6
|
+
import sqlglot
|
|
7
|
+
import sqlglot.expressions as exp
|
|
8
|
+
from sqlglot.dialects.dialect import DialectType
|
|
9
|
+
from sqlglot.errors import ParseError
|
|
10
|
+
|
|
11
|
+
from sqlspec.exceptions import RiskLevel, SQLValidationError
|
|
12
|
+
from sqlspec.statement.filters import StatementFilter
|
|
13
|
+
from sqlspec.statement.parameters import ParameterConverter, ParameterStyle, ParameterValidator
|
|
14
|
+
from sqlspec.statement.pipelines.base import StatementPipeline
|
|
15
|
+
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
16
|
+
from sqlspec.statement.pipelines.transformers import CommentRemover, ParameterizeLiterals
|
|
17
|
+
from sqlspec.statement.pipelines.validators import DMLSafetyValidator, ParameterStyleValidator
|
|
18
|
+
from sqlspec.typing import is_dict
|
|
19
|
+
from sqlspec.utils.logging import get_logger
|
|
20
|
+
|
|
21
|
+
__all__ = ("SQL", "SQLConfig", "Statement")
|
|
22
|
+
|
|
23
|
+
logger = get_logger("sqlspec.statement")
|
|
24
|
+
|
|
25
|
+
Statement = Union[str, exp.Expression, "SQL"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class _ProcessedState:
|
|
30
|
+
"""Cached state from pipeline processing."""
|
|
31
|
+
|
|
32
|
+
processed_expression: exp.Expression
|
|
33
|
+
processed_sql: str
|
|
34
|
+
merged_parameters: Any
|
|
35
|
+
validation_errors: list[Any] = field(default_factory=list)
|
|
36
|
+
analysis_results: dict[str, Any] = field(default_factory=dict)
|
|
37
|
+
transformation_results: dict[str, Any] = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class SQLConfig:
|
|
42
|
+
"""Configuration for SQL statement behavior."""
|
|
43
|
+
|
|
44
|
+
# Behavior flags
|
|
45
|
+
enable_parsing: bool = True
|
|
46
|
+
enable_validation: bool = True
|
|
47
|
+
enable_transformations: bool = True
|
|
48
|
+
enable_analysis: bool = False
|
|
49
|
+
enable_normalization: bool = True
|
|
50
|
+
strict_mode: bool = False
|
|
51
|
+
cache_parsed_expression: bool = True
|
|
52
|
+
|
|
53
|
+
# Component lists for explicit staging
|
|
54
|
+
transformers: Optional[list[Any]] = None
|
|
55
|
+
validators: Optional[list[Any]] = None
|
|
56
|
+
analyzers: Optional[list[Any]] = None
|
|
57
|
+
|
|
58
|
+
# Other configs
|
|
59
|
+
parameter_converter: ParameterConverter = field(default_factory=ParameterConverter)
|
|
60
|
+
parameter_validator: ParameterValidator = field(default_factory=ParameterValidator)
|
|
61
|
+
analysis_cache_size: int = 1000
|
|
62
|
+
input_sql_had_placeholders: bool = False # Populated by SQL.__init__
|
|
63
|
+
|
|
64
|
+
# Parameter style configuration
|
|
65
|
+
allowed_parameter_styles: Optional[tuple[str, ...]] = None
|
|
66
|
+
"""Allowed parameter styles for this SQL configuration (e.g., ('qmark', 'named_colon'))."""
|
|
67
|
+
|
|
68
|
+
target_parameter_style: Optional[str] = None
|
|
69
|
+
"""Target parameter style for SQL generation."""
|
|
70
|
+
|
|
71
|
+
allow_mixed_parameter_styles: bool = False
|
|
72
|
+
"""Whether to allow mixing named and positional parameters in same query."""
|
|
73
|
+
|
|
74
|
+
def validate_parameter_style(self, style: Union[ParameterStyle, str]) -> bool:
|
|
75
|
+
"""Check if a parameter style is allowed.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
style: Parameter style to validate (can be ParameterStyle enum or string)
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
True if the style is allowed, False otherwise
|
|
82
|
+
"""
|
|
83
|
+
if self.allowed_parameter_styles is None:
|
|
84
|
+
return True # No restrictions
|
|
85
|
+
style_str = str(style)
|
|
86
|
+
return style_str in self.allowed_parameter_styles
|
|
87
|
+
|
|
88
|
+
def get_statement_pipeline(self) -> StatementPipeline:
|
|
89
|
+
"""Get the configured statement pipeline.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
StatementPipeline configured with transformers, validators, and analyzers
|
|
93
|
+
"""
|
|
94
|
+
# Import here to avoid circular dependencies
|
|
95
|
+
|
|
96
|
+
# Create transformers based on config
|
|
97
|
+
transformers = []
|
|
98
|
+
if self.transformers is not None:
|
|
99
|
+
# Use explicit transformers if provided
|
|
100
|
+
transformers = list(self.transformers)
|
|
101
|
+
# Use default transformers
|
|
102
|
+
elif self.enable_transformations:
|
|
103
|
+
# Use target_parameter_style if available, otherwise default to "?"
|
|
104
|
+
placeholder_style = self.target_parameter_style or "?"
|
|
105
|
+
transformers = [CommentRemover(), ParameterizeLiterals(placeholder_style=placeholder_style)]
|
|
106
|
+
|
|
107
|
+
# Create validators based on config
|
|
108
|
+
validators = []
|
|
109
|
+
if self.validators is not None:
|
|
110
|
+
# Use explicit validators if provided
|
|
111
|
+
validators = list(self.validators)
|
|
112
|
+
# Use default validators
|
|
113
|
+
elif self.enable_validation:
|
|
114
|
+
validators = [ParameterStyleValidator(fail_on_violation=self.strict_mode), DMLSafetyValidator()]
|
|
115
|
+
|
|
116
|
+
# Create analyzers based on config
|
|
117
|
+
analyzers = []
|
|
118
|
+
if self.analyzers is not None:
|
|
119
|
+
# Use explicit analyzers if provided
|
|
120
|
+
analyzers = list(self.analyzers)
|
|
121
|
+
# Use default analyzers
|
|
122
|
+
elif self.enable_analysis:
|
|
123
|
+
# Currently no default analyzers
|
|
124
|
+
analyzers = []
|
|
125
|
+
|
|
126
|
+
return StatementPipeline(transformers=transformers, validators=validators, analyzers=analyzers)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class SQL:
|
|
130
|
+
"""Immutable SQL statement with centralized parameter management.
|
|
131
|
+
|
|
132
|
+
The SQL class is the single source of truth for:
|
|
133
|
+
- SQL expression/statement
|
|
134
|
+
- Positional parameters
|
|
135
|
+
- Named parameters
|
|
136
|
+
- Applied filters
|
|
137
|
+
|
|
138
|
+
All methods that modify state return new SQL instances.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
__slots__ = (
|
|
142
|
+
"_builder_result_type", # Optional[type] - for query builders
|
|
143
|
+
"_config", # SQLConfig - configuration
|
|
144
|
+
"_dialect", # DialectType - SQL dialect
|
|
145
|
+
"_filters", # list[StatementFilter] - filters to apply
|
|
146
|
+
"_is_many", # bool - for executemany operations
|
|
147
|
+
"_is_script", # bool - for script execution
|
|
148
|
+
"_named_params", # dict[str, Any] - named parameters
|
|
149
|
+
"_original_parameters", # Any - original parameters as passed in
|
|
150
|
+
"_original_sql", # str - original SQL before normalization
|
|
151
|
+
"_placeholder_mapping", # dict[str, Union[str, int]] - placeholder normalization mapping
|
|
152
|
+
"_positional_params", # list[Any] - positional parameters
|
|
153
|
+
"_processed_state", # Cached processed state
|
|
154
|
+
"_processing_context", # SQLProcessingContext - context from pipeline processing
|
|
155
|
+
"_raw_sql", # str - original SQL string for compatibility
|
|
156
|
+
"_statement", # exp.Expression - the SQL expression
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
statement: Union[str, exp.Expression, "SQL"],
|
|
162
|
+
*parameters: Union[Any, StatementFilter, list[Union[Any, StatementFilter]]],
|
|
163
|
+
_dialect: DialectType = None,
|
|
164
|
+
_config: Optional[SQLConfig] = None,
|
|
165
|
+
_builder_result_type: Optional[type] = None,
|
|
166
|
+
_existing_state: Optional[dict[str, Any]] = None,
|
|
167
|
+
**kwargs: Any,
|
|
168
|
+
) -> None:
|
|
169
|
+
"""Initialize SQL with centralized parameter management."""
|
|
170
|
+
self._config = _config or SQLConfig()
|
|
171
|
+
self._dialect = _dialect
|
|
172
|
+
self._builder_result_type = _builder_result_type
|
|
173
|
+
self._processed_state: Optional[_ProcessedState] = None
|
|
174
|
+
self._processing_context: Optional[SQLProcessingContext] = None
|
|
175
|
+
self._positional_params: list[Any] = []
|
|
176
|
+
self._named_params: dict[str, Any] = {}
|
|
177
|
+
self._filters: list[StatementFilter] = []
|
|
178
|
+
self._statement: exp.Expression
|
|
179
|
+
self._raw_sql: str = ""
|
|
180
|
+
self._original_parameters: Any = None
|
|
181
|
+
self._original_sql: str = ""
|
|
182
|
+
self._placeholder_mapping: dict[str, Union[str, int]] = {}
|
|
183
|
+
self._is_many: bool = False
|
|
184
|
+
self._is_script: bool = False
|
|
185
|
+
|
|
186
|
+
if isinstance(statement, SQL):
|
|
187
|
+
self._init_from_sql_object(statement, _dialect, _config, _builder_result_type)
|
|
188
|
+
else:
|
|
189
|
+
self._init_from_str_or_expression(statement)
|
|
190
|
+
|
|
191
|
+
if _existing_state:
|
|
192
|
+
self._load_from_existing_state(_existing_state)
|
|
193
|
+
|
|
194
|
+
if not isinstance(statement, SQL):
|
|
195
|
+
self._set_original_parameters(*parameters)
|
|
196
|
+
|
|
197
|
+
self._process_parameters(*parameters, **kwargs)
|
|
198
|
+
|
|
199
|
+
def _init_from_sql_object(
|
|
200
|
+
self, statement: "SQL", dialect: DialectType, config: Optional[SQLConfig], builder_result_type: Optional[type]
|
|
201
|
+
) -> None:
|
|
202
|
+
"""Initialize attributes from an existing SQL object."""
|
|
203
|
+
self._statement = statement._statement
|
|
204
|
+
self._dialect = dialect or statement._dialect
|
|
205
|
+
self._config = config or statement._config
|
|
206
|
+
self._builder_result_type = builder_result_type or statement._builder_result_type
|
|
207
|
+
self._is_many = statement._is_many
|
|
208
|
+
self._is_script = statement._is_script
|
|
209
|
+
self._raw_sql = statement._raw_sql
|
|
210
|
+
self._original_parameters = statement._original_parameters
|
|
211
|
+
self._original_sql = statement._original_sql
|
|
212
|
+
self._placeholder_mapping = statement._placeholder_mapping.copy()
|
|
213
|
+
self._positional_params.extend(statement._positional_params)
|
|
214
|
+
self._named_params.update(statement._named_params)
|
|
215
|
+
self._filters.extend(statement._filters)
|
|
216
|
+
|
|
217
|
+
def _init_from_str_or_expression(self, statement: Union[str, exp.Expression]) -> None:
|
|
218
|
+
"""Initialize attributes from a SQL string or expression."""
|
|
219
|
+
if isinstance(statement, str):
|
|
220
|
+
self._raw_sql = statement
|
|
221
|
+
if self._raw_sql and not self._config.input_sql_had_placeholders:
|
|
222
|
+
param_info = self._config.parameter_validator.extract_parameters(self._raw_sql)
|
|
223
|
+
if param_info:
|
|
224
|
+
self._config = replace(self._config, input_sql_had_placeholders=True)
|
|
225
|
+
self._statement = self._to_expression(statement)
|
|
226
|
+
else:
|
|
227
|
+
self._raw_sql = statement.sql(dialect=self._dialect) # pyright: ignore
|
|
228
|
+
self._statement = statement
|
|
229
|
+
|
|
230
|
+
def _load_from_existing_state(self, existing_state: dict[str, Any]) -> None:
|
|
231
|
+
"""Load state from a dictionary (used by copy)."""
|
|
232
|
+
self._positional_params = list(existing_state.get("positional_params", self._positional_params))
|
|
233
|
+
self._named_params = dict(existing_state.get("named_params", self._named_params))
|
|
234
|
+
self._filters = list(existing_state.get("filters", self._filters))
|
|
235
|
+
self._is_many = existing_state.get("is_many", self._is_many)
|
|
236
|
+
self._is_script = existing_state.get("is_script", self._is_script)
|
|
237
|
+
self._raw_sql = existing_state.get("raw_sql", self._raw_sql)
|
|
238
|
+
|
|
239
|
+
def _set_original_parameters(self, *parameters: Any) -> None:
|
|
240
|
+
"""Store the original parameters for compatibility."""
|
|
241
|
+
if len(parameters) == 1 and not isinstance(parameters[0], StatementFilter):
|
|
242
|
+
self._original_parameters = parameters[0]
|
|
243
|
+
elif len(parameters) > 1:
|
|
244
|
+
self._original_parameters = parameters
|
|
245
|
+
else:
|
|
246
|
+
self._original_parameters = None
|
|
247
|
+
|
|
248
|
+
def _process_parameters(self, *parameters: Any, **kwargs: Any) -> None:
|
|
249
|
+
"""Process positional and keyword arguments for parameters and filters."""
|
|
250
|
+
for param in parameters:
|
|
251
|
+
self._process_parameter_item(param)
|
|
252
|
+
|
|
253
|
+
if "parameters" in kwargs:
|
|
254
|
+
param_value = kwargs.pop("parameters")
|
|
255
|
+
if isinstance(param_value, (list, tuple)):
|
|
256
|
+
self._positional_params.extend(param_value)
|
|
257
|
+
elif isinstance(param_value, dict):
|
|
258
|
+
self._named_params.update(param_value)
|
|
259
|
+
else:
|
|
260
|
+
self._positional_params.append(param_value)
|
|
261
|
+
|
|
262
|
+
for key, value in kwargs.items():
|
|
263
|
+
if not key.startswith("_"):
|
|
264
|
+
self._named_params[key] = value
|
|
265
|
+
|
|
266
|
+
def _process_parameter_item(self, item: Any) -> None:
|
|
267
|
+
"""Process a single item from the parameters list."""
|
|
268
|
+
if isinstance(item, StatementFilter):
|
|
269
|
+
self._filters.append(item)
|
|
270
|
+
pos_params, named_params = self._extract_filter_parameters(item)
|
|
271
|
+
self._positional_params.extend(pos_params)
|
|
272
|
+
self._named_params.update(named_params)
|
|
273
|
+
elif isinstance(item, list):
|
|
274
|
+
for sub_item in item:
|
|
275
|
+
self._process_parameter_item(sub_item)
|
|
276
|
+
elif isinstance(item, dict):
|
|
277
|
+
self._named_params.update(item)
|
|
278
|
+
elif isinstance(item, tuple):
|
|
279
|
+
self._positional_params.extend(item)
|
|
280
|
+
else:
|
|
281
|
+
self._positional_params.append(item)
|
|
282
|
+
|
|
283
|
+
def _ensure_processed(self) -> None:
|
|
284
|
+
"""Ensure the SQL has been processed through the pipeline (lazy initialization).
|
|
285
|
+
|
|
286
|
+
This method implements the facade pattern with lazy processing.
|
|
287
|
+
It's called by public methods that need processed state.
|
|
288
|
+
"""
|
|
289
|
+
if self._processed_state is not None:
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
# Get the final expression and parameters after filters
|
|
293
|
+
final_expr, final_params = self._build_final_state()
|
|
294
|
+
|
|
295
|
+
# Check if the raw SQL has placeholders
|
|
296
|
+
if self._raw_sql:
|
|
297
|
+
validator = self._config.parameter_validator
|
|
298
|
+
raw_param_info = validator.extract_parameters(self._raw_sql)
|
|
299
|
+
has_placeholders = bool(raw_param_info)
|
|
300
|
+
else:
|
|
301
|
+
has_placeholders = self._config.input_sql_had_placeholders
|
|
302
|
+
|
|
303
|
+
# Update config if we detected placeholders
|
|
304
|
+
if has_placeholders and not self._config.input_sql_had_placeholders:
|
|
305
|
+
self._config = replace(self._config, input_sql_had_placeholders=True)
|
|
306
|
+
|
|
307
|
+
# Create processing context
|
|
308
|
+
context = SQLProcessingContext(
|
|
309
|
+
initial_sql_string=self._raw_sql or final_expr.sql(dialect=self._dialect),
|
|
310
|
+
dialect=self._dialect,
|
|
311
|
+
config=self._config,
|
|
312
|
+
current_expression=final_expr,
|
|
313
|
+
initial_expression=final_expr,
|
|
314
|
+
merged_parameters=final_params,
|
|
315
|
+
input_sql_had_placeholders=has_placeholders,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Extract parameter info from the SQL
|
|
319
|
+
validator = self._config.parameter_validator
|
|
320
|
+
context.parameter_info = validator.extract_parameters(context.initial_sql_string)
|
|
321
|
+
|
|
322
|
+
# Run the pipeline
|
|
323
|
+
pipeline = self._config.get_statement_pipeline()
|
|
324
|
+
result = pipeline.execute_pipeline(context)
|
|
325
|
+
|
|
326
|
+
# Store the processing context for later use
|
|
327
|
+
self._processing_context = result.context
|
|
328
|
+
|
|
329
|
+
# Extract processed state
|
|
330
|
+
processed_expr = result.expression
|
|
331
|
+
if isinstance(processed_expr, exp.Anonymous):
|
|
332
|
+
processed_sql = self._raw_sql or context.initial_sql_string
|
|
333
|
+
else:
|
|
334
|
+
processed_sql = processed_expr.sql(dialect=self._dialect, comments=False)
|
|
335
|
+
logger.debug("Processed expression SQL: '%s'", processed_sql)
|
|
336
|
+
|
|
337
|
+
# Check if we need to denormalize pyformat placeholders
|
|
338
|
+
if self._placeholder_mapping and self._original_sql:
|
|
339
|
+
# We normalized pyformat placeholders before parsing, need to denormalize
|
|
340
|
+
original_sql = self._original_sql
|
|
341
|
+
# Extract parameter info from the original SQL to get the original styles
|
|
342
|
+
param_info = self._config.parameter_validator.extract_parameters(original_sql)
|
|
343
|
+
|
|
344
|
+
# Find the target style (should be pyformat)
|
|
345
|
+
from sqlspec.statement.parameters import ParameterStyle
|
|
346
|
+
|
|
347
|
+
target_styles = {p.style for p in param_info}
|
|
348
|
+
logger.debug(
|
|
349
|
+
"Denormalizing SQL: before='%s', original='%s', styles=%s",
|
|
350
|
+
processed_sql,
|
|
351
|
+
original_sql,
|
|
352
|
+
target_styles,
|
|
353
|
+
)
|
|
354
|
+
if ParameterStyle.POSITIONAL_PYFORMAT in target_styles:
|
|
355
|
+
# Denormalize back to %s
|
|
356
|
+
processed_sql = self._config.parameter_converter._denormalize_sql(
|
|
357
|
+
processed_sql, param_info, ParameterStyle.POSITIONAL_PYFORMAT
|
|
358
|
+
)
|
|
359
|
+
logger.debug("Denormalized SQL to: '%s'", processed_sql)
|
|
360
|
+
elif ParameterStyle.NAMED_PYFORMAT in target_styles:
|
|
361
|
+
# Denormalize back to %(name)s
|
|
362
|
+
processed_sql = self._config.parameter_converter._denormalize_sql(
|
|
363
|
+
processed_sql, param_info, ParameterStyle.NAMED_PYFORMAT
|
|
364
|
+
)
|
|
365
|
+
logger.debug("Denormalized SQL to: '%s'", processed_sql)
|
|
366
|
+
else:
|
|
367
|
+
logger.debug(
|
|
368
|
+
"No denormalization needed: mapping=%s, original=%s",
|
|
369
|
+
bool(self._placeholder_mapping),
|
|
370
|
+
bool(self._original_sql),
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Merge parameters from pipeline
|
|
374
|
+
merged_params = final_params
|
|
375
|
+
# Only merge extracted parameters if the original SQL didn't have placeholders
|
|
376
|
+
# If it already had placeholders, the parameters should already be provided
|
|
377
|
+
if result.context.extracted_parameters_from_pipeline and not context.input_sql_had_placeholders:
|
|
378
|
+
if isinstance(merged_params, dict):
|
|
379
|
+
for i, param in enumerate(result.context.extracted_parameters_from_pipeline):
|
|
380
|
+
param_name = f"param_{i}"
|
|
381
|
+
merged_params[param_name] = param
|
|
382
|
+
elif isinstance(merged_params, list):
|
|
383
|
+
merged_params.extend(result.context.extracted_parameters_from_pipeline)
|
|
384
|
+
elif merged_params is None:
|
|
385
|
+
merged_params = result.context.extracted_parameters_from_pipeline
|
|
386
|
+
else:
|
|
387
|
+
# Single value, convert to list
|
|
388
|
+
merged_params = [merged_params, *list(result.context.extracted_parameters_from_pipeline)]
|
|
389
|
+
|
|
390
|
+
# Cache the processed state
|
|
391
|
+
self._processed_state = _ProcessedState(
|
|
392
|
+
processed_expression=processed_expr,
|
|
393
|
+
processed_sql=processed_sql,
|
|
394
|
+
merged_parameters=merged_params,
|
|
395
|
+
validation_errors=list(result.context.validation_errors),
|
|
396
|
+
analysis_results={}, # Can be populated from analysis_findings if needed
|
|
397
|
+
transformation_results={}, # Can be populated from transformations if needed
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Check strict mode
|
|
401
|
+
if self._config.strict_mode and self._processed_state.validation_errors:
|
|
402
|
+
# Find the highest risk error
|
|
403
|
+
highest_risk_error = max(
|
|
404
|
+
self._processed_state.validation_errors,
|
|
405
|
+
key=lambda e: e.risk_level.value if hasattr(e, "risk_level") else 0,
|
|
406
|
+
)
|
|
407
|
+
raise SQLValidationError(
|
|
408
|
+
message=highest_risk_error.message,
|
|
409
|
+
sql=self._raw_sql or processed_sql,
|
|
410
|
+
risk_level=getattr(highest_risk_error, "risk_level", RiskLevel.HIGH),
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
def _to_expression(self, statement: Union[str, exp.Expression]) -> exp.Expression:
|
|
414
|
+
"""Convert string to sqlglot expression."""
|
|
415
|
+
if isinstance(statement, exp.Expression):
|
|
416
|
+
return statement
|
|
417
|
+
|
|
418
|
+
# Handle empty string
|
|
419
|
+
if not statement or not statement.strip():
|
|
420
|
+
# Return an empty select instead of Anonymous for empty strings
|
|
421
|
+
return exp.Select()
|
|
422
|
+
|
|
423
|
+
# Check if parsing is disabled
|
|
424
|
+
if not self._config.enable_parsing:
|
|
425
|
+
# Return an anonymous expression that preserves the raw SQL
|
|
426
|
+
return exp.Anonymous(this=statement)
|
|
427
|
+
|
|
428
|
+
# Check if SQL contains pyformat placeholders that need normalization
|
|
429
|
+
from sqlspec.statement.parameters import ParameterStyle
|
|
430
|
+
|
|
431
|
+
validator = self._config.parameter_validator
|
|
432
|
+
param_info = validator.extract_parameters(statement)
|
|
433
|
+
|
|
434
|
+
# Check if we have pyformat placeholders
|
|
435
|
+
has_pyformat = any(
|
|
436
|
+
p.style in {ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.NAMED_PYFORMAT} for p in param_info
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
normalized_sql = statement
|
|
440
|
+
placeholder_mapping: dict[str, Any] = {}
|
|
441
|
+
|
|
442
|
+
if has_pyformat:
|
|
443
|
+
# Normalize pyformat placeholders to named placeholders for SQLGlot
|
|
444
|
+
converter = self._config.parameter_converter
|
|
445
|
+
normalized_sql, placeholder_mapping = converter._transform_sql_for_parsing(statement, param_info)
|
|
446
|
+
# Store the original SQL before normalization
|
|
447
|
+
self._original_sql = statement
|
|
448
|
+
self._placeholder_mapping = placeholder_mapping
|
|
449
|
+
|
|
450
|
+
try:
|
|
451
|
+
# Parse with sqlglot
|
|
452
|
+
expressions = sqlglot.parse(normalized_sql, dialect=self._dialect) # pyright: ignore
|
|
453
|
+
if not expressions:
|
|
454
|
+
# Empty statement
|
|
455
|
+
return exp.Anonymous(this=statement)
|
|
456
|
+
first_expr = expressions[0]
|
|
457
|
+
if first_expr is None:
|
|
458
|
+
# Could not parse
|
|
459
|
+
return exp.Anonymous(this=statement)
|
|
460
|
+
|
|
461
|
+
except ParseError as e:
|
|
462
|
+
# If parsing fails, wrap in a RawString expression
|
|
463
|
+
logger.debug("Failed to parse SQL: %s", e)
|
|
464
|
+
return exp.Anonymous(this=statement)
|
|
465
|
+
return first_expr
|
|
466
|
+
|
|
467
|
+
@staticmethod
|
|
468
|
+
def _extract_filter_parameters(filter_obj: StatementFilter) -> tuple[list[Any], dict[str, Any]]:
|
|
469
|
+
"""Extract parameters from a filter object."""
|
|
470
|
+
if hasattr(filter_obj, "extract_parameters"):
|
|
471
|
+
return filter_obj.extract_parameters()
|
|
472
|
+
# Fallback for filters that don't implement the new method yet
|
|
473
|
+
return [], {}
|
|
474
|
+
|
|
475
|
+
def copy(
|
|
476
|
+
self,
|
|
477
|
+
statement: Optional[Union[str, exp.Expression]] = None,
|
|
478
|
+
parameters: Optional[Any] = None,
|
|
479
|
+
dialect: DialectType = None,
|
|
480
|
+
config: Optional[SQLConfig] = None,
|
|
481
|
+
**kwargs: Any,
|
|
482
|
+
) -> "SQL":
|
|
483
|
+
"""Create a copy with optional modifications.
|
|
484
|
+
|
|
485
|
+
This is the primary method for creating modified SQL objects.
|
|
486
|
+
"""
|
|
487
|
+
# Prepare existing state
|
|
488
|
+
existing_state = {
|
|
489
|
+
"positional_params": list(self._positional_params),
|
|
490
|
+
"named_params": dict(self._named_params),
|
|
491
|
+
"filters": list(self._filters),
|
|
492
|
+
"is_many": self._is_many,
|
|
493
|
+
"is_script": self._is_script,
|
|
494
|
+
"raw_sql": self._raw_sql,
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
# Create new instance
|
|
498
|
+
new_statement = statement if statement is not None else self._statement
|
|
499
|
+
new_dialect = dialect if dialect is not None else self._dialect
|
|
500
|
+
new_config = config if config is not None else self._config
|
|
501
|
+
|
|
502
|
+
# If parameters are explicitly provided, they replace existing ones
|
|
503
|
+
if parameters is not None:
|
|
504
|
+
# Clear existing state so only new parameters are used
|
|
505
|
+
existing_state["positional_params"] = []
|
|
506
|
+
existing_state["named_params"] = {}
|
|
507
|
+
# Pass parameters through normal processing
|
|
508
|
+
return SQL(
|
|
509
|
+
new_statement,
|
|
510
|
+
parameters,
|
|
511
|
+
_dialect=new_dialect,
|
|
512
|
+
_config=new_config,
|
|
513
|
+
_builder_result_type=self._builder_result_type,
|
|
514
|
+
_existing_state=None, # Don't use existing state
|
|
515
|
+
**kwargs,
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
return SQL(
|
|
519
|
+
new_statement,
|
|
520
|
+
_dialect=new_dialect,
|
|
521
|
+
_config=new_config,
|
|
522
|
+
_builder_result_type=self._builder_result_type,
|
|
523
|
+
_existing_state=existing_state,
|
|
524
|
+
**kwargs,
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
def add_named_parameter(self, name: str, value: Any) -> "SQL":
|
|
528
|
+
"""Add a named parameter and return a new SQL instance."""
|
|
529
|
+
new_obj = self.copy()
|
|
530
|
+
new_obj._named_params[name] = value
|
|
531
|
+
return new_obj
|
|
532
|
+
|
|
533
|
+
def get_unique_parameter_name(
|
|
534
|
+
self, base_name: str, namespace: Optional[str] = None, preserve_original: bool = False
|
|
535
|
+
) -> str:
|
|
536
|
+
"""Generate a unique parameter name.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
base_name: The base parameter name
|
|
540
|
+
namespace: Optional namespace prefix (e.g., 'cte', 'subquery')
|
|
541
|
+
preserve_original: If True, try to preserve the original name
|
|
542
|
+
|
|
543
|
+
Returns:
|
|
544
|
+
A unique parameter name
|
|
545
|
+
"""
|
|
546
|
+
# Check both positional and named params
|
|
547
|
+
all_param_names = set(self._named_params.keys())
|
|
548
|
+
|
|
549
|
+
# Build the candidate name
|
|
550
|
+
candidate = f"{namespace}_{base_name}" if namespace else base_name
|
|
551
|
+
|
|
552
|
+
# If preserve_original and the name is unique, use it
|
|
553
|
+
if preserve_original and candidate not in all_param_names:
|
|
554
|
+
return candidate
|
|
555
|
+
|
|
556
|
+
# If not preserving or name exists, generate unique name
|
|
557
|
+
if candidate not in all_param_names:
|
|
558
|
+
return candidate
|
|
559
|
+
|
|
560
|
+
# Generate unique name with counter
|
|
561
|
+
counter = 1
|
|
562
|
+
while True:
|
|
563
|
+
new_candidate = f"{candidate}_{counter}"
|
|
564
|
+
if new_candidate not in all_param_names:
|
|
565
|
+
return new_candidate
|
|
566
|
+
counter += 1
|
|
567
|
+
|
|
568
|
+
def where(self, condition: "Union[str, exp.Expression, exp.Condition]") -> "SQL":
|
|
569
|
+
"""Apply WHERE clause and return new SQL instance."""
|
|
570
|
+
# Convert condition to expression
|
|
571
|
+
condition_expr = self._to_expression(condition) if isinstance(condition, str) else condition
|
|
572
|
+
|
|
573
|
+
# Apply WHERE to statement
|
|
574
|
+
if hasattr(self._statement, "where"):
|
|
575
|
+
new_statement = self._statement.where(condition_expr) # pyright: ignore
|
|
576
|
+
else:
|
|
577
|
+
# Wrap in SELECT if needed
|
|
578
|
+
new_statement = exp.Select().from_(self._statement).where(condition_expr) # pyright: ignore
|
|
579
|
+
|
|
580
|
+
return self.copy(statement=new_statement)
|
|
581
|
+
|
|
582
|
+
def filter(self, filter_obj: StatementFilter) -> "SQL":
|
|
583
|
+
"""Apply a filter and return a new SQL instance."""
|
|
584
|
+
# Create a new SQL object with the filter added
|
|
585
|
+
new_obj = self.copy()
|
|
586
|
+
new_obj._filters.append(filter_obj)
|
|
587
|
+
# Extract filter parameters
|
|
588
|
+
pos_params, named_params = self._extract_filter_parameters(filter_obj)
|
|
589
|
+
new_obj._positional_params.extend(pos_params)
|
|
590
|
+
new_obj._named_params.update(named_params)
|
|
591
|
+
return new_obj
|
|
592
|
+
|
|
593
|
+
def as_many(self, parameters: "Optional[list[Any]]" = None) -> "SQL":
|
|
594
|
+
"""Mark for executemany with optional parameters."""
|
|
595
|
+
new_obj = self.copy()
|
|
596
|
+
new_obj._is_many = True
|
|
597
|
+
if parameters is not None:
|
|
598
|
+
# Replace parameters for executemany
|
|
599
|
+
new_obj._positional_params = []
|
|
600
|
+
new_obj._named_params = {}
|
|
601
|
+
new_obj._positional_params = parameters
|
|
602
|
+
return new_obj
|
|
603
|
+
|
|
604
|
+
def as_script(self) -> "SQL":
|
|
605
|
+
"""Mark as script for execution."""
|
|
606
|
+
new_obj = self.copy()
|
|
607
|
+
new_obj._is_script = True
|
|
608
|
+
return new_obj
|
|
609
|
+
|
|
610
|
+
def _build_final_state(self) -> tuple[exp.Expression, Any]:
|
|
611
|
+
"""Build final expression and parameters after applying filters."""
|
|
612
|
+
# Start with current statement
|
|
613
|
+
final_expr = self._statement
|
|
614
|
+
|
|
615
|
+
# Apply all filters to the expression
|
|
616
|
+
for filter_obj in self._filters:
|
|
617
|
+
if hasattr(filter_obj, "append_to_statement"):
|
|
618
|
+
temp_sql = SQL(final_expr, config=self._config, dialect=self._dialect)
|
|
619
|
+
temp_sql._positional_params = list(self._positional_params)
|
|
620
|
+
temp_sql._named_params = dict(self._named_params)
|
|
621
|
+
result = filter_obj.append_to_statement(temp_sql)
|
|
622
|
+
final_expr = result._statement if isinstance(result, SQL) else result
|
|
623
|
+
|
|
624
|
+
# Determine final parameters format
|
|
625
|
+
final_params: Any
|
|
626
|
+
if self._named_params and not self._positional_params:
|
|
627
|
+
# Only named params
|
|
628
|
+
final_params = dict(self._named_params)
|
|
629
|
+
elif self._positional_params and not self._named_params:
|
|
630
|
+
# Always return a list for positional params to maintain sequence type
|
|
631
|
+
final_params = list(self._positional_params)
|
|
632
|
+
elif self._positional_params and self._named_params:
|
|
633
|
+
# Mixed - merge into dict
|
|
634
|
+
final_params = dict(self._named_params)
|
|
635
|
+
# Add positional params with generated names
|
|
636
|
+
for i, param in enumerate(self._positional_params):
|
|
637
|
+
param_name = f"arg_{i}"
|
|
638
|
+
while param_name in final_params:
|
|
639
|
+
param_name = f"arg_{i}_{id(param)}"
|
|
640
|
+
final_params[param_name] = param
|
|
641
|
+
else:
|
|
642
|
+
# No parameters
|
|
643
|
+
final_params = None
|
|
644
|
+
|
|
645
|
+
return final_expr, final_params
|
|
646
|
+
|
|
647
|
+
# Properties for compatibility
|
|
648
|
+
@property
|
|
649
|
+
def sql(self) -> str:
|
|
650
|
+
"""Get SQL string."""
|
|
651
|
+
# Handle empty string case
|
|
652
|
+
if not self._raw_sql or (self._raw_sql and not self._raw_sql.strip()):
|
|
653
|
+
return ""
|
|
654
|
+
|
|
655
|
+
# For scripts, always return the raw SQL to preserve multi-statement scripts
|
|
656
|
+
if self._is_script and self._raw_sql:
|
|
657
|
+
return self._raw_sql
|
|
658
|
+
# If parsing is disabled, return the raw SQL
|
|
659
|
+
if not self._config.enable_parsing and self._raw_sql:
|
|
660
|
+
return self._raw_sql
|
|
661
|
+
|
|
662
|
+
# Ensure processed
|
|
663
|
+
self._ensure_processed()
|
|
664
|
+
assert self._processed_state is not None
|
|
665
|
+
return self._processed_state.processed_sql
|
|
666
|
+
|
|
667
|
+
@property
|
|
668
|
+
def expression(self) -> Optional[exp.Expression]:
|
|
669
|
+
"""Get the final expression."""
|
|
670
|
+
# Return None if parsing is disabled
|
|
671
|
+
if not self._config.enable_parsing:
|
|
672
|
+
return None
|
|
673
|
+
self._ensure_processed()
|
|
674
|
+
assert self._processed_state is not None
|
|
675
|
+
return self._processed_state.processed_expression
|
|
676
|
+
|
|
677
|
+
@property
|
|
678
|
+
def parameters(self) -> Any:
|
|
679
|
+
"""Get merged parameters."""
|
|
680
|
+
self._ensure_processed()
|
|
681
|
+
assert self._processed_state is not None
|
|
682
|
+
return self._processed_state.merged_parameters
|
|
683
|
+
|
|
684
|
+
@property
|
|
685
|
+
def is_many(self) -> bool:
|
|
686
|
+
"""Check if this is for executemany."""
|
|
687
|
+
return self._is_many
|
|
688
|
+
|
|
689
|
+
@property
|
|
690
|
+
def is_script(self) -> bool:
|
|
691
|
+
"""Check if this is a script."""
|
|
692
|
+
return self._is_script
|
|
693
|
+
|
|
694
|
+
def to_sql(self, placeholder_style: Optional[str] = None) -> str:
|
|
695
|
+
"""Convert to SQL string with given placeholder style."""
|
|
696
|
+
if self._is_script:
|
|
697
|
+
return self.sql
|
|
698
|
+
sql, _ = self.compile(placeholder_style=placeholder_style)
|
|
699
|
+
return sql
|
|
700
|
+
|
|
701
|
+
def get_parameters(self, style: Optional[str] = None) -> Any:
|
|
702
|
+
"""Get parameters in the requested style."""
|
|
703
|
+
# Get compiled parameters with style
|
|
704
|
+
_, params = self.compile(placeholder_style=style)
|
|
705
|
+
return params
|
|
706
|
+
|
|
707
|
+
def compile(self, placeholder_style: Optional[str] = None) -> tuple[str, Any]:
|
|
708
|
+
"""Compile to SQL and parameters."""
|
|
709
|
+
# For scripts, return raw SQL directly without processing
|
|
710
|
+
if self._is_script:
|
|
711
|
+
return self.sql, None
|
|
712
|
+
|
|
713
|
+
# If parsing is disabled, return raw SQL without transformation
|
|
714
|
+
if not self._config.enable_parsing and self._raw_sql:
|
|
715
|
+
return self._raw_sql, self._raw_parameters
|
|
716
|
+
|
|
717
|
+
# Ensure processed
|
|
718
|
+
self._ensure_processed()
|
|
719
|
+
|
|
720
|
+
# Get processed SQL and parameters
|
|
721
|
+
assert self._processed_state is not None
|
|
722
|
+
sql = self._processed_state.processed_sql
|
|
723
|
+
params = self._processed_state.merged_parameters
|
|
724
|
+
|
|
725
|
+
# Check if parameters were reordered during processing
|
|
726
|
+
if params is not None and hasattr(self, "_processing_context") and self._processing_context:
|
|
727
|
+
parameter_mapping = self._processing_context.metadata.get("parameter_position_mapping")
|
|
728
|
+
if parameter_mapping:
|
|
729
|
+
# Apply parameter reordering based on the mapping
|
|
730
|
+
params = self._reorder_parameters(params, parameter_mapping)
|
|
731
|
+
|
|
732
|
+
# If no placeholder style requested, return as-is
|
|
733
|
+
if placeholder_style is None:
|
|
734
|
+
return sql, params
|
|
735
|
+
|
|
736
|
+
# Convert to requested placeholder style
|
|
737
|
+
if placeholder_style:
|
|
738
|
+
sql, params = self._convert_placeholder_style(sql, params, placeholder_style)
|
|
739
|
+
|
|
740
|
+
# Debug log the final SQL
|
|
741
|
+
logger.debug("Final compiled SQL: '%s'", sql)
|
|
742
|
+
return sql, params
|
|
743
|
+
|
|
744
|
+
@staticmethod
|
|
745
|
+
def _reorder_parameters(params: Any, mapping: dict[int, int]) -> Any:
|
|
746
|
+
"""Reorder parameters based on the position mapping.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
params: Original parameters (list, tuple, or dict)
|
|
750
|
+
mapping: Dict mapping new positions to original positions
|
|
751
|
+
|
|
752
|
+
Returns:
|
|
753
|
+
Reordered parameters in the same format as input
|
|
754
|
+
"""
|
|
755
|
+
if isinstance(params, (list, tuple)):
|
|
756
|
+
# Create a new list with reordered parameters
|
|
757
|
+
reordered_list = [None] * len(params) # pyright: ignore
|
|
758
|
+
for new_pos, old_pos in mapping.items():
|
|
759
|
+
if old_pos < len(params):
|
|
760
|
+
reordered_list[new_pos] = params[old_pos] # pyright: ignore
|
|
761
|
+
|
|
762
|
+
# Handle any unmapped positions
|
|
763
|
+
for i, val in enumerate(reordered_list):
|
|
764
|
+
if val is None and i < len(params) and i not in mapping:
|
|
765
|
+
# If position wasn't mapped, try to use original
|
|
766
|
+
reordered_list[i] = params[i] # pyright: ignore
|
|
767
|
+
|
|
768
|
+
# Return in same format as input
|
|
769
|
+
return tuple(reordered_list) if isinstance(params, tuple) else reordered_list
|
|
770
|
+
|
|
771
|
+
if isinstance(params, dict):
|
|
772
|
+
# For dict parameters, we need to handle differently
|
|
773
|
+
# If keys are like param_0, param_1, we can reorder them
|
|
774
|
+
if all(key.startswith("param_") and key[6:].isdigit() for key in params):
|
|
775
|
+
reordered_dict: dict[str, Any] = {}
|
|
776
|
+
for new_pos, old_pos in mapping.items():
|
|
777
|
+
old_key = f"param_{old_pos}"
|
|
778
|
+
new_key = f"param_{new_pos}"
|
|
779
|
+
if old_key in params:
|
|
780
|
+
reordered_dict[new_key] = params[old_key]
|
|
781
|
+
|
|
782
|
+
# Add any unmapped parameters
|
|
783
|
+
for key, value in params.items():
|
|
784
|
+
if key not in reordered_dict and key.startswith("param_"):
|
|
785
|
+
idx = int(key[6:])
|
|
786
|
+
if idx not in mapping:
|
|
787
|
+
reordered_dict[key] = value
|
|
788
|
+
|
|
789
|
+
return reordered_dict
|
|
790
|
+
# Can't reorder named parameters, return as-is
|
|
791
|
+
return params
|
|
792
|
+
# Single value or unknown format, return as-is
|
|
793
|
+
return params
|
|
794
|
+
|
|
795
|
+
def _convert_placeholder_style(self, sql: str, params: Any, placeholder_style: str) -> tuple[str, Any]:
|
|
796
|
+
"""Convert SQL and parameters to the requested placeholder style.
|
|
797
|
+
|
|
798
|
+
Args:
|
|
799
|
+
sql: The SQL string to convert
|
|
800
|
+
params: The parameters to convert
|
|
801
|
+
placeholder_style: Target placeholder style
|
|
802
|
+
|
|
803
|
+
Returns:
|
|
804
|
+
Tuple of (converted_sql, converted_params)
|
|
805
|
+
"""
|
|
806
|
+
# Extract parameter info from current SQL
|
|
807
|
+
converter = self._config.parameter_converter
|
|
808
|
+
param_info = converter.validator.extract_parameters(sql)
|
|
809
|
+
|
|
810
|
+
if not param_info:
|
|
811
|
+
return sql, params
|
|
812
|
+
|
|
813
|
+
# Use the internal denormalize method to convert to target style
|
|
814
|
+
from sqlspec.statement.parameters import ParameterStyle
|
|
815
|
+
|
|
816
|
+
target_style = ParameterStyle(placeholder_style) if isinstance(placeholder_style, str) else placeholder_style
|
|
817
|
+
|
|
818
|
+
# Replace placeholders in SQL
|
|
819
|
+
sql = self._replace_placeholders_in_sql(sql, param_info, target_style)
|
|
820
|
+
|
|
821
|
+
# Convert parameters to appropriate format
|
|
822
|
+
params = self._convert_parameters_format(params, param_info, target_style)
|
|
823
|
+
|
|
824
|
+
return sql, params
|
|
825
|
+
|
|
826
|
+
def _replace_placeholders_in_sql(self, sql: str, param_info: list[Any], target_style: "ParameterStyle") -> str:
|
|
827
|
+
"""Replace placeholders in SQL string with target style placeholders.
|
|
828
|
+
|
|
829
|
+
Args:
|
|
830
|
+
sql: The SQL string
|
|
831
|
+
param_info: List of parameter information
|
|
832
|
+
target_style: Target parameter style
|
|
833
|
+
|
|
834
|
+
Returns:
|
|
835
|
+
SQL string with replaced placeholders
|
|
836
|
+
"""
|
|
837
|
+
# Sort by position in reverse to avoid position shifts
|
|
838
|
+
sorted_params = sorted(param_info, key=lambda p: p.position, reverse=True)
|
|
839
|
+
|
|
840
|
+
for p in sorted_params:
|
|
841
|
+
new_placeholder = self._generate_placeholder(p, target_style)
|
|
842
|
+
# Replace the placeholder in SQL
|
|
843
|
+
start = p.position
|
|
844
|
+
end = start + len(p.placeholder_text)
|
|
845
|
+
sql = sql[:start] + new_placeholder + sql[end:]
|
|
846
|
+
|
|
847
|
+
return sql
|
|
848
|
+
|
|
849
|
+
@staticmethod
|
|
850
|
+
def _generate_placeholder(param: Any, target_style: "ParameterStyle") -> str:
|
|
851
|
+
"""Generate a placeholder string for the given parameter style.
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
param: Parameter information object
|
|
855
|
+
target_style: Target parameter style
|
|
856
|
+
|
|
857
|
+
Returns:
|
|
858
|
+
Placeholder string
|
|
859
|
+
"""
|
|
860
|
+
if target_style == ParameterStyle.QMARK:
|
|
861
|
+
return "?"
|
|
862
|
+
if target_style == ParameterStyle.NUMERIC:
|
|
863
|
+
# Use 1-based numbering for numeric style
|
|
864
|
+
return f"${param.ordinal + 1}"
|
|
865
|
+
if target_style == ParameterStyle.NAMED_COLON:
|
|
866
|
+
# Use original name if available, otherwise generate one
|
|
867
|
+
# Oracle doesn't like underscores at the start of parameter names
|
|
868
|
+
if param.name and not param.name.isdigit():
|
|
869
|
+
# Use the name if it's not just a number
|
|
870
|
+
return f":{param.name}"
|
|
871
|
+
# Generate a new name for numeric placeholders or missing names
|
|
872
|
+
return f":arg_{param.ordinal}"
|
|
873
|
+
if target_style == ParameterStyle.NAMED_AT:
|
|
874
|
+
# Use @ prefix for BigQuery style
|
|
875
|
+
# BigQuery requires parameter names to start with a letter, not underscore
|
|
876
|
+
return f"@{param.name or f'param_{param.ordinal}'}"
|
|
877
|
+
if target_style == ParameterStyle.POSITIONAL_COLON:
|
|
878
|
+
# Use :1, :2, etc. for Oracle positional style
|
|
879
|
+
return f":{param.ordinal + 1}"
|
|
880
|
+
if target_style == ParameterStyle.POSITIONAL_PYFORMAT:
|
|
881
|
+
# Use %s for positional pyformat
|
|
882
|
+
return "%s"
|
|
883
|
+
if target_style == ParameterStyle.NAMED_PYFORMAT:
|
|
884
|
+
# Use %(name)s for named pyformat
|
|
885
|
+
return f"%({param.name or f'_arg_{param.ordinal}'})s"
|
|
886
|
+
# Keep original for unknown styles
|
|
887
|
+
return str(param.placeholder_text)
|
|
888
|
+
|
|
889
|
+
def _convert_parameters_format(self, params: Any, param_info: list[Any], target_style: "ParameterStyle") -> Any:
|
|
890
|
+
"""Convert parameters to the appropriate format for the target style.
|
|
891
|
+
|
|
892
|
+
Args:
|
|
893
|
+
params: Original parameters
|
|
894
|
+
param_info: List of parameter information
|
|
895
|
+
target_style: Target parameter style
|
|
896
|
+
|
|
897
|
+
Returns:
|
|
898
|
+
Converted parameters
|
|
899
|
+
"""
|
|
900
|
+
if target_style == ParameterStyle.POSITIONAL_COLON:
|
|
901
|
+
return self._convert_to_positional_colon_format(params, param_info)
|
|
902
|
+
if target_style in {ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_PYFORMAT}:
|
|
903
|
+
return self._convert_to_positional_format(params, param_info)
|
|
904
|
+
if target_style == ParameterStyle.NAMED_COLON:
|
|
905
|
+
return self._convert_to_named_colon_format(params, param_info)
|
|
906
|
+
if target_style == ParameterStyle.NAMED_PYFORMAT:
|
|
907
|
+
return self._convert_to_named_pyformat_format(params, param_info)
|
|
908
|
+
return params
|
|
909
|
+
|
|
910
|
+
def _convert_to_positional_colon_format(self, params: Any, param_info: list[Any]) -> Any:
|
|
911
|
+
"""Convert to dict format for Oracle positional colon style.
|
|
912
|
+
|
|
913
|
+
Oracle's positional colon style uses :1, :2, etc. placeholders and expects
|
|
914
|
+
parameters as a dict with string keys "1", "2", etc.
|
|
915
|
+
|
|
916
|
+
For execute_many operations, returns a list of parameter sets.
|
|
917
|
+
|
|
918
|
+
Args:
|
|
919
|
+
params: Original parameters
|
|
920
|
+
param_info: List of parameter information
|
|
921
|
+
|
|
922
|
+
Returns:
|
|
923
|
+
Dict of parameters with string keys "1", "2", etc., or list for execute_many
|
|
924
|
+
"""
|
|
925
|
+
# Special handling for execute_many
|
|
926
|
+
if self._is_many and isinstance(params, list) and params and isinstance(params[0], (list, tuple)):
|
|
927
|
+
# This is execute_many - keep as list but process each item
|
|
928
|
+
return params
|
|
929
|
+
|
|
930
|
+
result_dict: dict[str, Any] = {}
|
|
931
|
+
|
|
932
|
+
if isinstance(params, (list, tuple)):
|
|
933
|
+
# Convert list/tuple to dict with string keys based on param_info
|
|
934
|
+
if param_info:
|
|
935
|
+
# Check if all param names are numeric (positional colon style)
|
|
936
|
+
all_numeric = all(p.name and p.name.isdigit() for p in param_info)
|
|
937
|
+
if all_numeric:
|
|
938
|
+
# Sort param_info by numeric name to match list order
|
|
939
|
+
sorted_params = sorted(param_info, key=lambda p: int(p.name))
|
|
940
|
+
for i, value in enumerate(params):
|
|
941
|
+
if i < len(sorted_params):
|
|
942
|
+
# Map based on numeric order, not SQL appearance order
|
|
943
|
+
param_name = sorted_params[i].name
|
|
944
|
+
result_dict[param_name] = value
|
|
945
|
+
else:
|
|
946
|
+
# Extra parameters
|
|
947
|
+
result_dict[str(i + 1)] = value
|
|
948
|
+
else:
|
|
949
|
+
# Non-numeric names, map by ordinal
|
|
950
|
+
for i, value in enumerate(params):
|
|
951
|
+
if i < len(param_info):
|
|
952
|
+
param_name = param_info[i].name or str(i + 1)
|
|
953
|
+
result_dict[param_name] = value
|
|
954
|
+
else:
|
|
955
|
+
result_dict[str(i + 1)] = value
|
|
956
|
+
else:
|
|
957
|
+
# No param_info, default to 1-based indexing
|
|
958
|
+
for i, value in enumerate(params):
|
|
959
|
+
result_dict[str(i + 1)] = value
|
|
960
|
+
return result_dict
|
|
961
|
+
|
|
962
|
+
if not is_dict(params) and param_info:
|
|
963
|
+
# Single value parameter
|
|
964
|
+
if param_info and param_info[0].name and param_info[0].name.isdigit():
|
|
965
|
+
# Use the actual parameter name from SQL (e.g., "0")
|
|
966
|
+
result_dict[param_info[0].name] = params
|
|
967
|
+
else:
|
|
968
|
+
# Default to "1"
|
|
969
|
+
result_dict["1"] = params
|
|
970
|
+
return result_dict
|
|
971
|
+
|
|
972
|
+
if isinstance(params, dict):
|
|
973
|
+
# Check if already in correct format (keys are "1", "2", etc.)
|
|
974
|
+
if all(key.isdigit() for key in params):
|
|
975
|
+
return params
|
|
976
|
+
|
|
977
|
+
# Convert from other dict formats
|
|
978
|
+
for p in sorted(param_info, key=lambda x: x.ordinal):
|
|
979
|
+
# Oracle uses 1-based indexing
|
|
980
|
+
oracle_key = str(p.ordinal + 1)
|
|
981
|
+
if p.name and p.name in params:
|
|
982
|
+
result_dict[oracle_key] = params[p.name]
|
|
983
|
+
elif f"arg_{p.ordinal}" in params:
|
|
984
|
+
result_dict[oracle_key] = params[f"arg_{p.ordinal}"]
|
|
985
|
+
elif f"param_{p.ordinal}" in params:
|
|
986
|
+
result_dict[oracle_key] = params[f"param_{p.ordinal}"]
|
|
987
|
+
return result_dict
|
|
988
|
+
|
|
989
|
+
return params
|
|
990
|
+
|
|
991
|
+
@staticmethod
|
|
992
|
+
def _convert_to_positional_format(params: Any, param_info: list[Any]) -> Any:
|
|
993
|
+
"""Convert to list format for positional parameter styles.
|
|
994
|
+
|
|
995
|
+
Args:
|
|
996
|
+
params: Original parameters
|
|
997
|
+
param_info: List of parameter information
|
|
998
|
+
|
|
999
|
+
Returns:
|
|
1000
|
+
List of parameters
|
|
1001
|
+
"""
|
|
1002
|
+
result_list: list[Any] = []
|
|
1003
|
+
if is_dict(params):
|
|
1004
|
+
for p in param_info:
|
|
1005
|
+
if p.name and p.name in params:
|
|
1006
|
+
# Named parameter - get from dict and extract value from TypedParameter if needed
|
|
1007
|
+
val = params[p.name]
|
|
1008
|
+
if hasattr(val, "value"):
|
|
1009
|
+
result_list.append(val.value)
|
|
1010
|
+
else:
|
|
1011
|
+
result_list.append(val)
|
|
1012
|
+
elif p.name is None:
|
|
1013
|
+
# Unnamed parameter (qmark style) - look for arg_N
|
|
1014
|
+
arg_key = f"arg_{p.ordinal}"
|
|
1015
|
+
if arg_key in params:
|
|
1016
|
+
# Extract value from TypedParameter if needed
|
|
1017
|
+
val = params[arg_key]
|
|
1018
|
+
if hasattr(val, "value"):
|
|
1019
|
+
result_list.append(val.value)
|
|
1020
|
+
else:
|
|
1021
|
+
result_list.append(val)
|
|
1022
|
+
else:
|
|
1023
|
+
result_list.append(None)
|
|
1024
|
+
else:
|
|
1025
|
+
# Named parameter not in dict
|
|
1026
|
+
result_list.append(None)
|
|
1027
|
+
return result_list
|
|
1028
|
+
if isinstance(params, (list, tuple)):
|
|
1029
|
+
for param in params:
|
|
1030
|
+
if hasattr(param, "value"):
|
|
1031
|
+
result_list.append(param.value)
|
|
1032
|
+
else:
|
|
1033
|
+
result_list.append(param)
|
|
1034
|
+
return result_list
|
|
1035
|
+
return params
|
|
1036
|
+
|
|
1037
|
+
@staticmethod
|
|
1038
|
+
def _convert_to_named_colon_format(params: Any, param_info: list[Any]) -> Any:
|
|
1039
|
+
"""Convert to dict format for named colon style.
|
|
1040
|
+
|
|
1041
|
+
Args:
|
|
1042
|
+
params: Original parameters
|
|
1043
|
+
param_info: List of parameter information
|
|
1044
|
+
|
|
1045
|
+
Returns:
|
|
1046
|
+
Dict of parameters with generated names
|
|
1047
|
+
"""
|
|
1048
|
+
result_dict: dict[str, Any] = {}
|
|
1049
|
+
if is_dict(params):
|
|
1050
|
+
# For dict params with matching parameter names, return as-is
|
|
1051
|
+
# Otherwise, remap to match the expected names
|
|
1052
|
+
if all(p.name in params for p in param_info if p.name):
|
|
1053
|
+
return params
|
|
1054
|
+
for p in param_info:
|
|
1055
|
+
if p.name and p.name in params:
|
|
1056
|
+
result_dict[p.name] = params[p.name]
|
|
1057
|
+
elif f"param_{p.ordinal}" in params:
|
|
1058
|
+
# Handle param_N style names
|
|
1059
|
+
# Oracle doesn't like underscores at the start of parameter names
|
|
1060
|
+
result_dict[p.name or f"arg_{p.ordinal}"] = params[f"param_{p.ordinal}"]
|
|
1061
|
+
return result_dict
|
|
1062
|
+
if isinstance(params, (list, tuple)):
|
|
1063
|
+
# Convert list/tuple to dict with parameter names from param_info
|
|
1064
|
+
|
|
1065
|
+
for i, value in enumerate(params):
|
|
1066
|
+
if i < len(param_info):
|
|
1067
|
+
p = param_info[i]
|
|
1068
|
+
# Use the actual parameter name if available
|
|
1069
|
+
# Oracle doesn't like underscores at the start of parameter names
|
|
1070
|
+
param_name = p.name or f"arg_{i}"
|
|
1071
|
+
result_dict[param_name] = value
|
|
1072
|
+
return result_dict
|
|
1073
|
+
return params
|
|
1074
|
+
|
|
1075
|
+
@staticmethod
|
|
1076
|
+
def _convert_to_named_pyformat_format(params: Any, param_info: list[Any]) -> Any:
|
|
1077
|
+
"""Convert to dict format for named pyformat style.
|
|
1078
|
+
|
|
1079
|
+
Args:
|
|
1080
|
+
params: Original parameters
|
|
1081
|
+
param_info: List of parameter information
|
|
1082
|
+
|
|
1083
|
+
Returns:
|
|
1084
|
+
Dict of parameters with names
|
|
1085
|
+
"""
|
|
1086
|
+
if isinstance(params, (list, tuple)):
|
|
1087
|
+
# Convert list to dict with generated names
|
|
1088
|
+
result_dict: dict[str, Any] = {}
|
|
1089
|
+
for i, p in enumerate(param_info):
|
|
1090
|
+
if i < len(params):
|
|
1091
|
+
param_name = p.name or f"param_{i}"
|
|
1092
|
+
result_dict[param_name] = params[i]
|
|
1093
|
+
return result_dict
|
|
1094
|
+
return params
|
|
1095
|
+
|
|
1096
|
+
# Validation properties for compatibility
|
|
1097
|
+
@property
|
|
1098
|
+
def validation_errors(self) -> list[Any]:
|
|
1099
|
+
"""Get validation errors."""
|
|
1100
|
+
if not self._config.enable_validation:
|
|
1101
|
+
return []
|
|
1102
|
+
self._ensure_processed()
|
|
1103
|
+
assert self._processed_state
|
|
1104
|
+
return self._processed_state.validation_errors
|
|
1105
|
+
|
|
1106
|
+
@property
|
|
1107
|
+
def has_errors(self) -> bool:
|
|
1108
|
+
"""Check if there are validation errors."""
|
|
1109
|
+
return bool(self.validation_errors)
|
|
1110
|
+
|
|
1111
|
+
@property
|
|
1112
|
+
def is_safe(self) -> bool:
|
|
1113
|
+
"""Check if statement is safe."""
|
|
1114
|
+
return not self.has_errors
|
|
1115
|
+
|
|
1116
|
+
# Additional compatibility methods
|
|
1117
|
+
def validate(self) -> list[Any]:
|
|
1118
|
+
"""Validate the SQL statement and return validation errors."""
|
|
1119
|
+
return self.validation_errors
|
|
1120
|
+
|
|
1121
|
+
@property
|
|
1122
|
+
def parameter_info(self) -> list[Any]:
|
|
1123
|
+
"""Get parameter information from the SQL statement."""
|
|
1124
|
+
validator = self._config.parameter_validator
|
|
1125
|
+
if self._config.enable_parsing and self._processed_state:
|
|
1126
|
+
sql_for_validation = self.expression.sql(dialect=self._dialect) if self.expression else self.sql # pyright: ignore
|
|
1127
|
+
else:
|
|
1128
|
+
sql_for_validation = self.sql
|
|
1129
|
+
return validator.extract_parameters(sql_for_validation)
|
|
1130
|
+
|
|
1131
|
+
@property
|
|
1132
|
+
def _raw_parameters(self) -> Any:
|
|
1133
|
+
"""Get raw parameters for compatibility."""
|
|
1134
|
+
# Return the original parameters as passed in
|
|
1135
|
+
return self._original_parameters
|
|
1136
|
+
|
|
1137
|
+
@property
|
|
1138
|
+
def _sql(self) -> str:
|
|
1139
|
+
"""Get SQL string for compatibility."""
|
|
1140
|
+
return self.sql
|
|
1141
|
+
|
|
1142
|
+
@property
|
|
1143
|
+
def _expression(self) -> Optional[exp.Expression]:
|
|
1144
|
+
"""Get expression for compatibility."""
|
|
1145
|
+
return self.expression
|
|
1146
|
+
|
|
1147
|
+
@property
|
|
1148
|
+
def statement(self) -> exp.Expression:
|
|
1149
|
+
"""Get statement for compatibility."""
|
|
1150
|
+
return self._statement
|
|
1151
|
+
|
|
1152
|
+
def limit(self, count: int, use_parameter: bool = False) -> "SQL":
|
|
1153
|
+
"""Add LIMIT clause."""
|
|
1154
|
+
if use_parameter:
|
|
1155
|
+
# Create a unique parameter name
|
|
1156
|
+
param_name = self.get_unique_parameter_name("limit")
|
|
1157
|
+
# Add parameter to the SQL object
|
|
1158
|
+
result = self
|
|
1159
|
+
result = result.add_named_parameter(param_name, count)
|
|
1160
|
+
# Use placeholder in the expression
|
|
1161
|
+
if hasattr(result._statement, "limit"):
|
|
1162
|
+
new_statement = result._statement.limit(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1163
|
+
else:
|
|
1164
|
+
new_statement = exp.Select().from_(result._statement).limit(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1165
|
+
return result.copy(statement=new_statement)
|
|
1166
|
+
if hasattr(self._statement, "limit"):
|
|
1167
|
+
new_statement = self._statement.limit(count) # pyright: ignore
|
|
1168
|
+
else:
|
|
1169
|
+
new_statement = exp.Select().from_(self._statement).limit(count) # pyright: ignore
|
|
1170
|
+
return self.copy(statement=new_statement)
|
|
1171
|
+
|
|
1172
|
+
def offset(self, count: int, use_parameter: bool = False) -> "SQL":
|
|
1173
|
+
"""Add OFFSET clause."""
|
|
1174
|
+
if use_parameter:
|
|
1175
|
+
# Create a unique parameter name
|
|
1176
|
+
param_name = self.get_unique_parameter_name("offset")
|
|
1177
|
+
# Add parameter to the SQL object
|
|
1178
|
+
result = self
|
|
1179
|
+
result = result.add_named_parameter(param_name, count)
|
|
1180
|
+
# Use placeholder in the expression
|
|
1181
|
+
if hasattr(result._statement, "offset"):
|
|
1182
|
+
new_statement = result._statement.offset(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1183
|
+
else:
|
|
1184
|
+
new_statement = exp.Select().from_(result._statement).offset(exp.Placeholder(this=param_name)) # pyright: ignore
|
|
1185
|
+
return result.copy(statement=new_statement)
|
|
1186
|
+
if hasattr(self._statement, "offset"):
|
|
1187
|
+
new_statement = self._statement.offset(count) # pyright: ignore
|
|
1188
|
+
else:
|
|
1189
|
+
new_statement = exp.Select().from_(self._statement).offset(count) # pyright: ignore
|
|
1190
|
+
return self.copy(statement=new_statement)
|
|
1191
|
+
|
|
1192
|
+
def order_by(self, expression: exp.Expression) -> "SQL":
|
|
1193
|
+
"""Add ORDER BY clause."""
|
|
1194
|
+
if hasattr(self._statement, "order_by"):
|
|
1195
|
+
new_statement = self._statement.order_by(expression) # pyright: ignore
|
|
1196
|
+
else:
|
|
1197
|
+
new_statement = exp.Select().from_(self._statement).order_by(expression) # pyright: ignore
|
|
1198
|
+
return self.copy(statement=new_statement)
|